dfir_rs/util/
mod.rs

1//! Helper utilities for the DFIR syntax.
2#![warn(missing_docs)]
3
4pub mod accumulator;
5pub mod clear;
6#[cfg(feature = "dfir_macro")]
7#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
8pub mod demux_enum;
9pub mod multiset;
10pub mod priority_stack;
11pub mod slot_vec;
12pub mod sparse_vec;
13pub mod unsync;
14
15pub mod simulation;
16
17mod monotonic;
18pub use monotonic::*;
19
20mod udp;
21#[cfg(not(target_arch = "wasm32"))]
22pub use udp::*;
23
24mod tcp;
25#[cfg(not(target_arch = "wasm32"))]
26pub use tcp::*;
27
28#[cfg(unix)]
29mod socket;
30#[cfg(unix)]
31pub use socket::*;
32
33#[cfg(feature = "deploy_integration")]
34#[cfg_attr(docsrs, doc(cfg(feature = "deploy_integration")))]
35pub mod deploy;
36
37use std::net::SocketAddr;
38use std::num::NonZeroUsize;
39use std::task::{Context, Poll};
40
41use futures::Stream;
42use serde::de::DeserializeOwned;
43use serde::ser::Serialize;
44
45/// Persit or delete tuples
46pub enum Persistence<T> {
47    /// Persist T values
48    Persist(T),
49    /// Delete all values that exactly match
50    Delete(T),
51}
52
53/// Persit or delete key-value pairs
54pub enum PersistenceKeyed<K, V> {
55    /// Persist key-value pairs
56    Persist(K, V),
57    /// Delete all tuples that have the key K
58    Delete(K),
59}
60
61/// Returns a channel as a (1) unbounded sender and (2) unbounded receiver `Stream` for use in DFIR.
62pub fn unbounded_channel<T>() -> (
63    tokio::sync::mpsc::UnboundedSender<T>,
64    tokio_stream::wrappers::UnboundedReceiverStream<T>,
65) {
66    let (send, recv) = tokio::sync::mpsc::unbounded_channel();
67    let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
68    (send, recv)
69}
70
71/// Returns an unsync channel as a (1) sender and (2) receiver `Stream` for use in DFIR.
72pub fn unsync_channel<T>(
73    capacity: Option<NonZeroUsize>,
74) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
75    unsync::mpsc::channel(capacity)
76}
77
78/// Returns an [`Iterator`] of any immediately available items from the [`Stream`].
79pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
80where
81    S: Stream,
82{
83    let mut stream = Box::pin(stream);
84    std::iter::from_fn(move || {
85        match stream
86            .as_mut()
87            .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
88        {
89            Poll::Ready(opt) => opt,
90            Poll::Pending => None,
91        }
92    })
93}
94
95/// Collects the immediately available items from the `Stream` into a `FromIterator` collection.
96///
97/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
98/// to retain ownership of your stream.
99pub fn collect_ready<C, S>(stream: S) -> C
100where
101    C: FromIterator<S::Item>,
102    S: Stream,
103{
104    assert!(
105        tokio::runtime::Handle::try_current().is_err(),
106        "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
107    );
108    ready_iter(stream).collect()
109}
110
111/// Collects the immediately available items from the `Stream` into a collection (`Default` + `Extend`).
112///
113/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
114/// to retain ownership of your stream.
115pub async fn collect_ready_async<C, S>(stream: S) -> C
116where
117    C: Default + Extend<S::Item>,
118    S: Stream,
119{
120    use std::sync::atomic::Ordering;
121
122    // Yield to let any background async tasks send to the stream.
123    tokio::task::yield_now().await;
124
125    let got_any_items = std::sync::atomic::AtomicBool::new(true);
126    let mut unfused_iter =
127        ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
128    let mut out = C::default();
129    while got_any_items.swap(false, Ordering::Relaxed) {
130        out.extend(unfused_iter.by_ref());
131        // Tokio unbounded channel returns items in lenght-128 chunks, so we have to be careful
132        // that everything gets returned. That is why we yield here and loop.
133        tokio::task::yield_now().await;
134    }
135    out
136}
137
138/// Serialize a message to bytes using bincode.
139pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
140where
141    T: Serialize,
142{
143    bytes::Bytes::from(bincode::serialize(&msg).unwrap())
144}
145
146/// Serialize a message from bytes using bincode.
147pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
148where
149    T: DeserializeOwned,
150{
151    bincode::deserialize(msg.as_ref())
152}
153
154/// Resolve the `ipv4` [`SocketAddr`] from an IP or hostname string.
155pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
156    use std::net::ToSocketAddrs;
157    let mut addrs = addr.to_socket_addrs()?;
158    let result = addrs.find(|addr| addr.is_ipv4());
159    match result {
160        Some(addr) => Ok(addr),
161        None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
162    }
163}
164
165/// Returns a length-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
166/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
167#[cfg(not(target_arch = "wasm32"))]
168pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
169    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
170    udp_bytes(socket)
171}
172
173/// Returns a newline-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
174/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
175#[cfg(not(target_arch = "wasm32"))]
176pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
177    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
178    udp_lines(socket)
179}
180
181/// Returns a newline-delimited bytes `Sender`, `Receiver`, and `SocketAddr` bound to the given address.
182///
183/// The input `addr` may have a port of `0`, the returned `SocketAddr` will be the address of the newly bound endpoint.
184/// The inbound connections can be used in full duplex mode. When a `(T, SocketAddr)` pair is fed to the `Sender`
185/// returned by this function, the `SocketAddr` will be looked up against the currently existing connections.
186/// If a match is found then the data will be sent on that connection. If no match is found then the data is silently dropped.
187#[cfg(not(target_arch = "wasm32"))]
188pub async fn bind_tcp_bytes(
189    addr: SocketAddr,
190) -> (
191    unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
192    unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
193    SocketAddr,
194) {
195    bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
196        .await
197        .unwrap()
198}
199
200/// This is the same thing as `bind_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
201#[cfg(not(target_arch = "wasm32"))]
202pub async fn bind_tcp_lines(
203    addr: SocketAddr,
204) -> (
205    unsync::mpsc::Sender<(String, SocketAddr)>,
206    unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
207    SocketAddr,
208) {
209    bind_tcp(addr, tokio_util::codec::LinesCodec::new())
210        .await
211        .unwrap()
212}
213
214/// The inverse of [`bind_tcp_bytes`].
215///
216/// `(Bytes, SocketAddr)` pairs fed to the returned `Sender` will initiate new tcp connections to the specified `SocketAddr`.
217/// These connections will be cached and reused, so that there will only be one connection per destination endpoint. When the endpoint sends data back it will be available via the returned `Receiver`
218#[cfg(not(target_arch = "wasm32"))]
219pub fn connect_tcp_bytes() -> (
220    TcpFramedSink<bytes::Bytes>,
221    TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
222) {
223    connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
224}
225
226/// This is the same thing as `connect_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
227#[cfg(not(target_arch = "wasm32"))]
228pub fn connect_tcp_lines() -> (
229    TcpFramedSink<String>,
230    TcpFramedStream<tokio_util::codec::LinesCodec>,
231) {
232    connect_tcp(tokio_util::codec::LinesCodec::new())
233}
234
235/// Sort a slice using a key fn which returns references.
236///
237/// From addendum in
238/// <https://stackoverflow.com/questions/56105305/how-to-sort-a-vec-of-structs-by-a-string-field>
239pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
240where
241    F: for<'a> Fn(&'a T) -> &'a K,
242    K: Ord,
243{
244    slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
245}
246
247/// Converts an iterator into a stream that emits `n` items at a time, yielding between each batch.
248///
249/// This is useful for breaking up a large iterator across several ticks: `source_iter(...)` always
250/// releases all items in the first tick. However using `iter_batches_stream` with `source_stream(...)`
251/// will cause `n` items to be released each tick. (Although more than that may be emitted if there
252/// are loops in the stratum).
253pub fn iter_batches_stream<I>(
254    iter: I,
255    n: usize,
256) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
257where
258    I: IntoIterator + Unpin,
259{
260    let mut count = 0;
261    let mut iter = iter.into_iter();
262    futures::stream::poll_fn(move |ctx| {
263        count += 1;
264        if n < count {
265            count = 0;
266            ctx.waker().wake_by_ref();
267            Poll::Pending
268        } else {
269            Poll::Ready(iter.next())
270        }
271    })
272}
273
274#[cfg(test)]
275mod test {
276    use super::*;
277
278    #[test]
279    pub fn test_collect_ready() {
280        let (send, mut recv) = unbounded_channel::<usize>();
281        for x in 0..1000 {
282            send.send(x).unwrap();
283        }
284        assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
285    }
286
287    #[crate::test]
288    pub async fn test_collect_ready_async() {
289        // Tokio unbounded channel returns items in 128 item long chunks, so we have to be careful that everything gets returned.
290        let (send, mut recv) = unbounded_channel::<usize>();
291        for x in 0..1000 {
292            send.send(x).unwrap();
293        }
294        assert_eq!(
295            1000,
296            collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
297        );
298    }
299}