Skip to main content

dfir_rs/util/
mod.rs

1//! Helper utilities for the DFIR syntax.
2#![warn(missing_docs)]
3
4#[cfg(feature = "dfir_macro")]
5#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
6pub mod demux_enum;
7pub mod multiset;
8pub mod priority_stack;
9pub mod sparse_vec;
10pub mod unsync;
11
12mod monotonic;
13pub use monotonic::*;
14
15mod udp;
16#[cfg(not(target_arch = "wasm32"))]
17pub use udp::*;
18
19mod tcp;
20#[cfg(not(target_arch = "wasm32"))]
21pub use tcp::*;
22
23#[cfg(unix)]
24mod socket;
25use std::net::SocketAddr;
26use std::num::NonZeroUsize;
27use std::task::{Context, Poll};
28
29use futures::Stream;
30use serde::de::DeserializeOwned;
31use serde::ser::Serialize;
32#[cfg(unix)]
33pub use socket::*;
34
35/// Persit or delete tuples
36pub enum Persistence<T> {
37    /// Persist T values
38    Persist(T),
39    /// Delete all values that exactly match
40    Delete(T),
41}
42
43/// Persit or delete key-value pairs
44pub enum PersistenceKeyed<K, V> {
45    /// Persist key-value pairs
46    Persist(K, V),
47    /// Delete all tuples that have the key K
48    Delete(K),
49}
50
51/// Returns a channel as a (1) unbounded sender and (2) unbounded receiver `Stream` for use in DFIR.
52pub fn unbounded_channel<T>() -> (
53    tokio::sync::mpsc::UnboundedSender<T>,
54    tokio_stream::wrappers::UnboundedReceiverStream<T>,
55) {
56    let (send, recv) = tokio::sync::mpsc::unbounded_channel();
57    let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
58    (send, recv)
59}
60
61/// Returns an unsync channel as a (1) sender and (2) receiver `Stream` for use in DFIR.
62pub fn unsync_channel<T>(
63    capacity: Option<NonZeroUsize>,
64) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
65    unsync::mpsc::channel(capacity)
66}
67
68/// Returns an [`Iterator`] of any immediately available items from the [`Stream`].
69pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
70where
71    S: Stream,
72{
73    let mut stream = Box::pin(stream);
74    std::iter::from_fn(move || {
75        match stream
76            .as_mut()
77            .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
78        {
79            Poll::Ready(opt) => opt,
80            Poll::Pending => None,
81        }
82    })
83}
84
85/// Collects the immediately available items from the `Stream` into a `FromIterator` collection.
86///
87/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
88/// to retain ownership of your stream.
89pub fn collect_ready<C, S>(stream: S) -> C
90where
91    C: FromIterator<S::Item>,
92    S: Stream,
93{
94    assert!(
95        tokio::runtime::Handle::try_current().is_err(),
96        "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
97    );
98    ready_iter(stream).collect()
99}
100
101/// Collects the immediately available items from the `Stream` into a collection (`Default` + `Extend`).
102///
103/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
104/// to retain ownership of your stream.
105pub async fn collect_ready_async<C, S>(stream: S) -> C
106where
107    C: Default + Extend<S::Item>,
108    S: Stream,
109{
110    use std::sync::atomic::Ordering;
111
112    // Yield to let any background async tasks send to the stream.
113    tokio::task::yield_now().await;
114
115    let got_any_items = std::sync::atomic::AtomicBool::new(true);
116    let mut unfused_iter =
117        ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
118    let mut out = C::default();
119    while got_any_items.swap(false, Ordering::Relaxed) {
120        out.extend(unfused_iter.by_ref());
121        // Tokio unbounded channel returns items in lenght-128 chunks, so we have to be careful
122        // that everything gets returned. That is why we yield here and loop.
123        tokio::task::yield_now().await;
124    }
125    out
126}
127
128/// Serialize a message to bytes using bincode.
129pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
130where
131    T: Serialize,
132{
133    bytes::Bytes::from(bincode::serialize(&msg).unwrap())
134}
135
136/// Serialize a message from bytes using bincode.
137pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
138where
139    T: DeserializeOwned,
140{
141    bincode::deserialize(msg.as_ref())
142}
143
144/// Resolve the `ipv4` [`SocketAddr`] from an IP or hostname string.
145pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
146    use std::net::ToSocketAddrs;
147    let mut addrs = addr.to_socket_addrs()?;
148    let result = addrs.find(|addr| addr.is_ipv4());
149    match result {
150        Some(addr) => Ok(addr),
151        None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
152    }
153}
154
155/// Returns a length-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
156/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
157#[cfg(not(target_arch = "wasm32"))]
158pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
159    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
160    udp_bytes(socket)
161}
162
163/// Returns a newline-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
164/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
165#[cfg(not(target_arch = "wasm32"))]
166pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
167    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
168    udp_lines(socket)
169}
170
171/// Returns a newline-delimited bytes `Sender`, `Receiver`, and `SocketAddr` bound to the given address.
172///
173/// The input `addr` may have a port of `0`, the returned `SocketAddr` will be the address of the newly bound endpoint.
174/// The inbound connections can be used in full duplex mode. When a `(T, SocketAddr)` pair is fed to the `Sender`
175/// returned by this function, the `SocketAddr` will be looked up against the currently existing connections.
176/// If a match is found then the data will be sent on that connection. If no match is found then the data is silently dropped.
177#[cfg(not(target_arch = "wasm32"))]
178pub async fn bind_tcp_bytes(
179    addr: SocketAddr,
180) -> (
181    unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
182    unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
183    SocketAddr,
184) {
185    bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
186        .await
187        .unwrap()
188}
189
190/// This is the same thing as `bind_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
191#[cfg(not(target_arch = "wasm32"))]
192pub async fn bind_tcp_lines(
193    addr: SocketAddr,
194) -> (
195    unsync::mpsc::Sender<(String, SocketAddr)>,
196    unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
197    SocketAddr,
198) {
199    bind_tcp(addr, tokio_util::codec::LinesCodec::new())
200        .await
201        .unwrap()
202}
203
204/// The inverse of [`bind_tcp_bytes`].
205///
206/// `(Bytes, SocketAddr)` pairs fed to the returned `Sender` will initiate new tcp connections to the specified `SocketAddr`.
207/// These connections will be cached and reused, so that there will only be one connection per destination endpoint. When the endpoint sends data back it will be available via the returned `Receiver`
208#[cfg(not(target_arch = "wasm32"))]
209pub fn connect_tcp_bytes() -> (
210    TcpFramedSink<bytes::Bytes>,
211    TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
212) {
213    connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
214}
215
216/// This is the same thing as `connect_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
217#[cfg(not(target_arch = "wasm32"))]
218pub fn connect_tcp_lines() -> (
219    TcpFramedSink<String>,
220    TcpFramedStream<tokio_util::codec::LinesCodec>,
221) {
222    connect_tcp(tokio_util::codec::LinesCodec::new())
223}
224
225/// Sort a slice using a key fn which returns references.
226///
227/// From addendum in
228/// <https://stackoverflow.com/questions/56105305/how-to-sort-a-vec-of-structs-by-a-string-field>
229pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
230where
231    F: for<'a> Fn(&'a T) -> &'a K,
232    K: Ord,
233{
234    slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
235}
236
237/// Converts an iterator into a stream that emits `n` items at a time, yielding between each batch.
238///
239/// This is useful for breaking up a large iterator across several ticks: `source_iter(...)` always
240/// releases all items in the first tick. However using `iter_batches_stream` with `source_stream(...)`
241/// will cause `n` items to be released each tick. (Although more than that may be emitted if there
242/// are loops in the stratum).
243pub fn iter_batches_stream<I>(
244    iter: I,
245    n: usize,
246) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
247where
248    I: IntoIterator + Unpin,
249{
250    let mut count = 0;
251    let mut iter = iter.into_iter();
252    futures::stream::poll_fn(move |ctx| {
253        count += 1;
254        if n < count {
255            count = 0;
256            ctx.waker().wake_by_ref();
257            Poll::Pending
258        } else {
259            Poll::Ready(iter.next())
260        }
261    })
262}
263
264#[cfg(test)]
265mod test {
266    use super::*;
267
268    #[test]
269    pub fn test_collect_ready() {
270        let (send, mut recv) = unbounded_channel::<usize>();
271        for x in 0..1000 {
272            send.send(x).unwrap();
273        }
274        assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
275    }
276
277    #[crate::test]
278    pub async fn test_collect_ready_async() {
279        // Tokio unbounded channel returns items in 128 item long chunks, so we have to be careful that everything gets returned.
280        let (send, mut recv) = unbounded_channel::<usize>();
281        for x in 0..1000 {
282            send.send(x).unwrap();
283        }
284        assert_eq!(
285            1000,
286            collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
287        );
288    }
289}