Skip to main content

dfir_rs/util/
mod.rs

1//! Helper utilities for the DFIR syntax.
2#![warn(missing_docs)]
3
4#[cfg(feature = "dfir_macro")]
5#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
6pub mod demux_enum;
7pub mod multiset;
8pub mod priority_stack;
9pub mod slot_vec;
10pub mod sparse_vec;
11pub mod unsync;
12
13pub mod simulation;
14
15mod monotonic;
16pub use monotonic::*;
17
18mod udp;
19#[cfg(not(target_arch = "wasm32"))]
20pub use udp::*;
21
22mod tcp;
23#[cfg(not(target_arch = "wasm32"))]
24pub use tcp::*;
25
26#[cfg(unix)]
27mod socket;
28use std::net::SocketAddr;
29use std::num::NonZeroUsize;
30use std::task::{Context, Poll};
31
32use futures::Stream;
33use serde::de::DeserializeOwned;
34use serde::ser::Serialize;
35#[cfg(unix)]
36pub use socket::*;
37
38/// Persit or delete tuples
39pub enum Persistence<T> {
40    /// Persist T values
41    Persist(T),
42    /// Delete all values that exactly match
43    Delete(T),
44}
45
46/// Persit or delete key-value pairs
47pub enum PersistenceKeyed<K, V> {
48    /// Persist key-value pairs
49    Persist(K, V),
50    /// Delete all tuples that have the key K
51    Delete(K),
52}
53
54/// Returns a channel as a (1) unbounded sender and (2) unbounded receiver `Stream` for use in DFIR.
55pub fn unbounded_channel<T>() -> (
56    tokio::sync::mpsc::UnboundedSender<T>,
57    tokio_stream::wrappers::UnboundedReceiverStream<T>,
58) {
59    let (send, recv) = tokio::sync::mpsc::unbounded_channel();
60    let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
61    (send, recv)
62}
63
64/// Returns an unsync channel as a (1) sender and (2) receiver `Stream` for use in DFIR.
65pub fn unsync_channel<T>(
66    capacity: Option<NonZeroUsize>,
67) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
68    unsync::mpsc::channel(capacity)
69}
70
71/// Returns an [`Iterator`] of any immediately available items from the [`Stream`].
72pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
73where
74    S: Stream,
75{
76    let mut stream = Box::pin(stream);
77    std::iter::from_fn(move || {
78        match stream
79            .as_mut()
80            .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
81        {
82            Poll::Ready(opt) => opt,
83            Poll::Pending => None,
84        }
85    })
86}
87
88/// Collects the immediately available items from the `Stream` into a `FromIterator` collection.
89///
90/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
91/// to retain ownership of your stream.
92pub fn collect_ready<C, S>(stream: S) -> C
93where
94    C: FromIterator<S::Item>,
95    S: Stream,
96{
97    assert!(
98        tokio::runtime::Handle::try_current().is_err(),
99        "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
100    );
101    ready_iter(stream).collect()
102}
103
104/// Collects the immediately available items from the `Stream` into a collection (`Default` + `Extend`).
105///
106/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
107/// to retain ownership of your stream.
108pub async fn collect_ready_async<C, S>(stream: S) -> C
109where
110    C: Default + Extend<S::Item>,
111    S: Stream,
112{
113    use std::sync::atomic::Ordering;
114
115    // Yield to let any background async tasks send to the stream.
116    tokio::task::yield_now().await;
117
118    let got_any_items = std::sync::atomic::AtomicBool::new(true);
119    let mut unfused_iter =
120        ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
121    let mut out = C::default();
122    while got_any_items.swap(false, Ordering::Relaxed) {
123        out.extend(unfused_iter.by_ref());
124        // Tokio unbounded channel returns items in lenght-128 chunks, so we have to be careful
125        // that everything gets returned. That is why we yield here and loop.
126        tokio::task::yield_now().await;
127    }
128    out
129}
130
131/// Serialize a message to bytes using bincode.
132pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
133where
134    T: Serialize,
135{
136    bytes::Bytes::from(bincode::serialize(&msg).unwrap())
137}
138
139/// Serialize a message from bytes using bincode.
140pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
141where
142    T: DeserializeOwned,
143{
144    bincode::deserialize(msg.as_ref())
145}
146
147/// Resolve the `ipv4` [`SocketAddr`] from an IP or hostname string.
148pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
149    use std::net::ToSocketAddrs;
150    let mut addrs = addr.to_socket_addrs()?;
151    let result = addrs.find(|addr| addr.is_ipv4());
152    match result {
153        Some(addr) => Ok(addr),
154        None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
155    }
156}
157
158/// Returns a length-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
159/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
160#[cfg(not(target_arch = "wasm32"))]
161pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
162    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
163    udp_bytes(socket)
164}
165
166/// Returns a newline-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
167/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
168#[cfg(not(target_arch = "wasm32"))]
169pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
170    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
171    udp_lines(socket)
172}
173
174/// Returns a newline-delimited bytes `Sender`, `Receiver`, and `SocketAddr` bound to the given address.
175///
176/// The input `addr` may have a port of `0`, the returned `SocketAddr` will be the address of the newly bound endpoint.
177/// The inbound connections can be used in full duplex mode. When a `(T, SocketAddr)` pair is fed to the `Sender`
178/// returned by this function, the `SocketAddr` will be looked up against the currently existing connections.
179/// If a match is found then the data will be sent on that connection. If no match is found then the data is silently dropped.
180#[cfg(not(target_arch = "wasm32"))]
181pub async fn bind_tcp_bytes(
182    addr: SocketAddr,
183) -> (
184    unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
185    unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
186    SocketAddr,
187) {
188    bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
189        .await
190        .unwrap()
191}
192
193/// This is the same thing as `bind_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
194#[cfg(not(target_arch = "wasm32"))]
195pub async fn bind_tcp_lines(
196    addr: SocketAddr,
197) -> (
198    unsync::mpsc::Sender<(String, SocketAddr)>,
199    unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
200    SocketAddr,
201) {
202    bind_tcp(addr, tokio_util::codec::LinesCodec::new())
203        .await
204        .unwrap()
205}
206
207/// The inverse of [`bind_tcp_bytes`].
208///
209/// `(Bytes, SocketAddr)` pairs fed to the returned `Sender` will initiate new tcp connections to the specified `SocketAddr`.
210/// These connections will be cached and reused, so that there will only be one connection per destination endpoint. When the endpoint sends data back it will be available via the returned `Receiver`
211#[cfg(not(target_arch = "wasm32"))]
212pub fn connect_tcp_bytes() -> (
213    TcpFramedSink<bytes::Bytes>,
214    TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
215) {
216    connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
217}
218
219/// This is the same thing as `connect_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
220#[cfg(not(target_arch = "wasm32"))]
221pub fn connect_tcp_lines() -> (
222    TcpFramedSink<String>,
223    TcpFramedStream<tokio_util::codec::LinesCodec>,
224) {
225    connect_tcp(tokio_util::codec::LinesCodec::new())
226}
227
228/// Sort a slice using a key fn which returns references.
229///
230/// From addendum in
231/// <https://stackoverflow.com/questions/56105305/how-to-sort-a-vec-of-structs-by-a-string-field>
232pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
233where
234    F: for<'a> Fn(&'a T) -> &'a K,
235    K: Ord,
236{
237    slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
238}
239
240/// Converts an iterator into a stream that emits `n` items at a time, yielding between each batch.
241///
242/// This is useful for breaking up a large iterator across several ticks: `source_iter(...)` always
243/// releases all items in the first tick. However using `iter_batches_stream` with `source_stream(...)`
244/// will cause `n` items to be released each tick. (Although more than that may be emitted if there
245/// are loops in the stratum).
246pub fn iter_batches_stream<I>(
247    iter: I,
248    n: usize,
249) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
250where
251    I: IntoIterator + Unpin,
252{
253    let mut count = 0;
254    let mut iter = iter.into_iter();
255    futures::stream::poll_fn(move |ctx| {
256        count += 1;
257        if n < count {
258            count = 0;
259            ctx.waker().wake_by_ref();
260            Poll::Pending
261        } else {
262            Poll::Ready(iter.next())
263        }
264    })
265}
266
267#[cfg(test)]
268mod test {
269    use super::*;
270
271    #[test]
272    pub fn test_collect_ready() {
273        let (send, mut recv) = unbounded_channel::<usize>();
274        for x in 0..1000 {
275            send.send(x).unwrap();
276        }
277        assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
278    }
279
280    #[crate::test]
281    pub async fn test_collect_ready_async() {
282        // Tokio unbounded channel returns items in 128 item long chunks, so we have to be careful that everything gets returned.
283        let (send, mut recv) = unbounded_channel::<usize>();
284        for x in 0..1000 {
285            send.send(x).unwrap();
286        }
287        assert_eq!(
288            1000,
289            collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
290        );
291    }
292}