dfir_rs/util/
mod.rs

1//! Helper utilities for the DFIR syntax.
2#![warn(missing_docs)]
3
4pub mod accumulator;
5pub mod clear;
6#[cfg(feature = "dfir_macro")]
7#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
8pub mod demux_enum;
9pub mod multiset;
10pub mod priority_stack;
11pub mod slot_vec;
12pub mod sparse_vec;
13pub mod unsync;
14
15pub mod simulation;
16
17mod monotonic;
18pub use monotonic::*;
19
20mod udp;
21#[cfg(not(target_arch = "wasm32"))]
22pub use udp::*;
23
24mod tcp;
25#[cfg(not(target_arch = "wasm32"))]
26pub use tcp::*;
27
28#[cfg(unix)]
29mod socket;
30use std::net::SocketAddr;
31use std::num::NonZeroUsize;
32use std::task::{Context, Poll};
33
34use futures::Stream;
35use serde::de::DeserializeOwned;
36use serde::ser::Serialize;
37#[cfg(unix)]
38pub use socket::*;
39
40/// Persit or delete tuples
41pub enum Persistence<T> {
42    /// Persist T values
43    Persist(T),
44    /// Delete all values that exactly match
45    Delete(T),
46}
47
48/// Persit or delete key-value pairs
49pub enum PersistenceKeyed<K, V> {
50    /// Persist key-value pairs
51    Persist(K, V),
52    /// Delete all tuples that have the key K
53    Delete(K),
54}
55
56/// Returns a channel as a (1) unbounded sender and (2) unbounded receiver `Stream` for use in DFIR.
57pub fn unbounded_channel<T>() -> (
58    tokio::sync::mpsc::UnboundedSender<T>,
59    tokio_stream::wrappers::UnboundedReceiverStream<T>,
60) {
61    let (send, recv) = tokio::sync::mpsc::unbounded_channel();
62    let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
63    (send, recv)
64}
65
66/// Returns an unsync channel as a (1) sender and (2) receiver `Stream` for use in DFIR.
67pub fn unsync_channel<T>(
68    capacity: Option<NonZeroUsize>,
69) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
70    unsync::mpsc::channel(capacity)
71}
72
73/// Returns an [`Iterator`] of any immediately available items from the [`Stream`].
74pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
75where
76    S: Stream,
77{
78    let mut stream = Box::pin(stream);
79    std::iter::from_fn(move || {
80        match stream
81            .as_mut()
82            .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
83        {
84            Poll::Ready(opt) => opt,
85            Poll::Pending => None,
86        }
87    })
88}
89
90/// Collects the immediately available items from the `Stream` into a `FromIterator` collection.
91///
92/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
93/// to retain ownership of your stream.
94pub fn collect_ready<C, S>(stream: S) -> C
95where
96    C: FromIterator<S::Item>,
97    S: Stream,
98{
99    assert!(
100        tokio::runtime::Handle::try_current().is_err(),
101        "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
102    );
103    ready_iter(stream).collect()
104}
105
106/// Collects the immediately available items from the `Stream` into a collection (`Default` + `Extend`).
107///
108/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
109/// to retain ownership of your stream.
110pub async fn collect_ready_async<C, S>(stream: S) -> C
111where
112    C: Default + Extend<S::Item>,
113    S: Stream,
114{
115    use std::sync::atomic::Ordering;
116
117    // Yield to let any background async tasks send to the stream.
118    tokio::task::yield_now().await;
119
120    let got_any_items = std::sync::atomic::AtomicBool::new(true);
121    let mut unfused_iter =
122        ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
123    let mut out = C::default();
124    while got_any_items.swap(false, Ordering::Relaxed) {
125        out.extend(unfused_iter.by_ref());
126        // Tokio unbounded channel returns items in lenght-128 chunks, so we have to be careful
127        // that everything gets returned. That is why we yield here and loop.
128        tokio::task::yield_now().await;
129    }
130    out
131}
132
133/// Serialize a message to bytes using bincode.
134pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
135where
136    T: Serialize,
137{
138    bytes::Bytes::from(bincode::serialize(&msg).unwrap())
139}
140
141/// Serialize a message from bytes using bincode.
142pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
143where
144    T: DeserializeOwned,
145{
146    bincode::deserialize(msg.as_ref())
147}
148
149/// Resolve the `ipv4` [`SocketAddr`] from an IP or hostname string.
150pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
151    use std::net::ToSocketAddrs;
152    let mut addrs = addr.to_socket_addrs()?;
153    let result = addrs.find(|addr| addr.is_ipv4());
154    match result {
155        Some(addr) => Ok(addr),
156        None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
157    }
158}
159
160/// Returns a length-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
161/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
162#[cfg(not(target_arch = "wasm32"))]
163pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
164    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
165    udp_bytes(socket)
166}
167
168/// Returns a newline-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
169/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
170#[cfg(not(target_arch = "wasm32"))]
171pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
172    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
173    udp_lines(socket)
174}
175
176/// Returns a newline-delimited bytes `Sender`, `Receiver`, and `SocketAddr` bound to the given address.
177///
178/// The input `addr` may have a port of `0`, the returned `SocketAddr` will be the address of the newly bound endpoint.
179/// The inbound connections can be used in full duplex mode. When a `(T, SocketAddr)` pair is fed to the `Sender`
180/// returned by this function, the `SocketAddr` will be looked up against the currently existing connections.
181/// If a match is found then the data will be sent on that connection. If no match is found then the data is silently dropped.
182#[cfg(not(target_arch = "wasm32"))]
183pub async fn bind_tcp_bytes(
184    addr: SocketAddr,
185) -> (
186    unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
187    unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
188    SocketAddr,
189) {
190    bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
191        .await
192        .unwrap()
193}
194
195/// This is the same thing as `bind_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
196#[cfg(not(target_arch = "wasm32"))]
197pub async fn bind_tcp_lines(
198    addr: SocketAddr,
199) -> (
200    unsync::mpsc::Sender<(String, SocketAddr)>,
201    unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
202    SocketAddr,
203) {
204    bind_tcp(addr, tokio_util::codec::LinesCodec::new())
205        .await
206        .unwrap()
207}
208
209/// The inverse of [`bind_tcp_bytes`].
210///
211/// `(Bytes, SocketAddr)` pairs fed to the returned `Sender` will initiate new tcp connections to the specified `SocketAddr`.
212/// These connections will be cached and reused, so that there will only be one connection per destination endpoint. When the endpoint sends data back it will be available via the returned `Receiver`
213#[cfg(not(target_arch = "wasm32"))]
214pub fn connect_tcp_bytes() -> (
215    TcpFramedSink<bytes::Bytes>,
216    TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
217) {
218    connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
219}
220
221/// This is the same thing as `connect_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
222#[cfg(not(target_arch = "wasm32"))]
223pub fn connect_tcp_lines() -> (
224    TcpFramedSink<String>,
225    TcpFramedStream<tokio_util::codec::LinesCodec>,
226) {
227    connect_tcp(tokio_util::codec::LinesCodec::new())
228}
229
230/// Sort a slice using a key fn which returns references.
231///
232/// From addendum in
233/// <https://stackoverflow.com/questions/56105305/how-to-sort-a-vec-of-structs-by-a-string-field>
234pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
235where
236    F: for<'a> Fn(&'a T) -> &'a K,
237    K: Ord,
238{
239    slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
240}
241
242/// Converts an iterator into a stream that emits `n` items at a time, yielding between each batch.
243///
244/// This is useful for breaking up a large iterator across several ticks: `source_iter(...)` always
245/// releases all items in the first tick. However using `iter_batches_stream` with `source_stream(...)`
246/// will cause `n` items to be released each tick. (Although more than that may be emitted if there
247/// are loops in the stratum).
248pub fn iter_batches_stream<I>(
249    iter: I,
250    n: usize,
251) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
252where
253    I: IntoIterator + Unpin,
254{
255    let mut count = 0;
256    let mut iter = iter.into_iter();
257    futures::stream::poll_fn(move |ctx| {
258        count += 1;
259        if n < count {
260            count = 0;
261            ctx.waker().wake_by_ref();
262            Poll::Pending
263        } else {
264            Poll::Ready(iter.next())
265        }
266    })
267}
268
269#[cfg(test)]
270mod test {
271    use super::*;
272
273    #[test]
274    pub fn test_collect_ready() {
275        let (send, mut recv) = unbounded_channel::<usize>();
276        for x in 0..1000 {
277            send.send(x).unwrap();
278        }
279        assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
280    }
281
282    #[crate::test]
283    pub async fn test_collect_ready_async() {
284        // Tokio unbounded channel returns items in 128 item long chunks, so we have to be careful that everything gets returned.
285        let (send, mut recv) = unbounded_channel::<usize>();
286        for x in 0..1000 {
287            send.send(x).unwrap();
288        }
289        assert_eq!(
290            1000,
291            collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
292        );
293    }
294}