Skip to main content

dfir_rs/util/
mod.rs

1//! Helper utilities for the DFIR syntax.
2#![warn(missing_docs)]
3
4#[cfg(feature = "dfir_macro")]
5#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
6pub mod demux_enum;
7pub mod multiset;
8pub mod priority_stack;
9pub mod slot_vec;
10pub mod sparse_vec;
11pub mod unsync;
12
13mod monotonic;
14pub use monotonic::*;
15
16mod udp;
17#[cfg(not(target_arch = "wasm32"))]
18pub use udp::*;
19
20mod tcp;
21#[cfg(not(target_arch = "wasm32"))]
22pub use tcp::*;
23
24#[cfg(unix)]
25mod socket;
26use std::net::SocketAddr;
27use std::num::NonZeroUsize;
28use std::task::{Context, Poll};
29
30use futures::Stream;
31use serde::de::DeserializeOwned;
32use serde::ser::Serialize;
33#[cfg(unix)]
34pub use socket::*;
35
36/// Persit or delete tuples
37pub enum Persistence<T> {
38    /// Persist T values
39    Persist(T),
40    /// Delete all values that exactly match
41    Delete(T),
42}
43
44/// Persit or delete key-value pairs
45pub enum PersistenceKeyed<K, V> {
46    /// Persist key-value pairs
47    Persist(K, V),
48    /// Delete all tuples that have the key K
49    Delete(K),
50}
51
52/// Returns a channel as a (1) unbounded sender and (2) unbounded receiver `Stream` for use in DFIR.
53pub fn unbounded_channel<T>() -> (
54    tokio::sync::mpsc::UnboundedSender<T>,
55    tokio_stream::wrappers::UnboundedReceiverStream<T>,
56) {
57    let (send, recv) = tokio::sync::mpsc::unbounded_channel();
58    let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
59    (send, recv)
60}
61
62/// Returns an unsync channel as a (1) sender and (2) receiver `Stream` for use in DFIR.
63pub fn unsync_channel<T>(
64    capacity: Option<NonZeroUsize>,
65) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
66    unsync::mpsc::channel(capacity)
67}
68
69/// Returns an [`Iterator`] of any immediately available items from the [`Stream`].
70pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
71where
72    S: Stream,
73{
74    let mut stream = Box::pin(stream);
75    std::iter::from_fn(move || {
76        match stream
77            .as_mut()
78            .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
79        {
80            Poll::Ready(opt) => opt,
81            Poll::Pending => None,
82        }
83    })
84}
85
86/// Collects the immediately available items from the `Stream` into a `FromIterator` collection.
87///
88/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
89/// to retain ownership of your stream.
90pub fn collect_ready<C, S>(stream: S) -> C
91where
92    C: FromIterator<S::Item>,
93    S: Stream,
94{
95    assert!(
96        tokio::runtime::Handle::try_current().is_err(),
97        "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
98    );
99    ready_iter(stream).collect()
100}
101
102/// Collects the immediately available items from the `Stream` into a collection (`Default` + `Extend`).
103///
104/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
105/// to retain ownership of your stream.
106pub async fn collect_ready_async<C, S>(stream: S) -> C
107where
108    C: Default + Extend<S::Item>,
109    S: Stream,
110{
111    use std::sync::atomic::Ordering;
112
113    // Yield to let any background async tasks send to the stream.
114    tokio::task::yield_now().await;
115
116    let got_any_items = std::sync::atomic::AtomicBool::new(true);
117    let mut unfused_iter =
118        ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
119    let mut out = C::default();
120    while got_any_items.swap(false, Ordering::Relaxed) {
121        out.extend(unfused_iter.by_ref());
122        // Tokio unbounded channel returns items in lenght-128 chunks, so we have to be careful
123        // that everything gets returned. That is why we yield here and loop.
124        tokio::task::yield_now().await;
125    }
126    out
127}
128
129/// Serialize a message to bytes using bincode.
130pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
131where
132    T: Serialize,
133{
134    bytes::Bytes::from(bincode::serialize(&msg).unwrap())
135}
136
137/// Serialize a message from bytes using bincode.
138pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
139where
140    T: DeserializeOwned,
141{
142    bincode::deserialize(msg.as_ref())
143}
144
145/// Resolve the `ipv4` [`SocketAddr`] from an IP or hostname string.
146pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
147    use std::net::ToSocketAddrs;
148    let mut addrs = addr.to_socket_addrs()?;
149    let result = addrs.find(|addr| addr.is_ipv4());
150    match result {
151        Some(addr) => Ok(addr),
152        None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
153    }
154}
155
156/// Returns a length-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
157/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
158#[cfg(not(target_arch = "wasm32"))]
159pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
160    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
161    udp_bytes(socket)
162}
163
164/// Returns a newline-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
165/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
166#[cfg(not(target_arch = "wasm32"))]
167pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
168    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
169    udp_lines(socket)
170}
171
172/// Returns a newline-delimited bytes `Sender`, `Receiver`, and `SocketAddr` bound to the given address.
173///
174/// The input `addr` may have a port of `0`, the returned `SocketAddr` will be the address of the newly bound endpoint.
175/// The inbound connections can be used in full duplex mode. When a `(T, SocketAddr)` pair is fed to the `Sender`
176/// returned by this function, the `SocketAddr` will be looked up against the currently existing connections.
177/// If a match is found then the data will be sent on that connection. If no match is found then the data is silently dropped.
178#[cfg(not(target_arch = "wasm32"))]
179pub async fn bind_tcp_bytes(
180    addr: SocketAddr,
181) -> (
182    unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
183    unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
184    SocketAddr,
185) {
186    bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
187        .await
188        .unwrap()
189}
190
191/// This is the same thing as `bind_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
192#[cfg(not(target_arch = "wasm32"))]
193pub async fn bind_tcp_lines(
194    addr: SocketAddr,
195) -> (
196    unsync::mpsc::Sender<(String, SocketAddr)>,
197    unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
198    SocketAddr,
199) {
200    bind_tcp(addr, tokio_util::codec::LinesCodec::new())
201        .await
202        .unwrap()
203}
204
205/// The inverse of [`bind_tcp_bytes`].
206///
207/// `(Bytes, SocketAddr)` pairs fed to the returned `Sender` will initiate new tcp connections to the specified `SocketAddr`.
208/// These connections will be cached and reused, so that there will only be one connection per destination endpoint. When the endpoint sends data back it will be available via the returned `Receiver`
209#[cfg(not(target_arch = "wasm32"))]
210pub fn connect_tcp_bytes() -> (
211    TcpFramedSink<bytes::Bytes>,
212    TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
213) {
214    connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
215}
216
217/// This is the same thing as `connect_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
218#[cfg(not(target_arch = "wasm32"))]
219pub fn connect_tcp_lines() -> (
220    TcpFramedSink<String>,
221    TcpFramedStream<tokio_util::codec::LinesCodec>,
222) {
223    connect_tcp(tokio_util::codec::LinesCodec::new())
224}
225
226/// Sort a slice using a key fn which returns references.
227///
228/// From addendum in
229/// <https://stackoverflow.com/questions/56105305/how-to-sort-a-vec-of-structs-by-a-string-field>
230pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
231where
232    F: for<'a> Fn(&'a T) -> &'a K,
233    K: Ord,
234{
235    slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
236}
237
238/// Converts an iterator into a stream that emits `n` items at a time, yielding between each batch.
239///
240/// This is useful for breaking up a large iterator across several ticks: `source_iter(...)` always
241/// releases all items in the first tick. However using `iter_batches_stream` with `source_stream(...)`
242/// will cause `n` items to be released each tick. (Although more than that may be emitted if there
243/// are loops in the stratum).
244pub fn iter_batches_stream<I>(
245    iter: I,
246    n: usize,
247) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
248where
249    I: IntoIterator + Unpin,
250{
251    let mut count = 0;
252    let mut iter = iter.into_iter();
253    futures::stream::poll_fn(move |ctx| {
254        count += 1;
255        if n < count {
256            count = 0;
257            ctx.waker().wake_by_ref();
258            Poll::Pending
259        } else {
260            Poll::Ready(iter.next())
261        }
262    })
263}
264
265#[cfg(test)]
266mod test {
267    use super::*;
268
269    #[test]
270    pub fn test_collect_ready() {
271        let (send, mut recv) = unbounded_channel::<usize>();
272        for x in 0..1000 {
273            send.send(x).unwrap();
274        }
275        assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
276    }
277
278    #[crate::test]
279    pub async fn test_collect_ready_async() {
280        // Tokio unbounded channel returns items in 128 item long chunks, so we have to be careful that everything gets returned.
281        let (send, mut recv) = unbounded_channel::<usize>();
282        for x in 0..1000 {
283            send.send(x).unwrap();
284        }
285        assert_eq!(
286            1000,
287            collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
288        );
289    }
290}