Skip to main content

dfir_rs/util/
mod.rs

1//! Helper utilities for the DFIR syntax.
2#![warn(missing_docs)]
3
4#[cfg(feature = "dfir_macro")]
5#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
6pub mod demux_enum;
7pub mod multiset;
8pub mod sparse_vec;
9pub mod unsync;
10
11mod monotonic;
12pub use monotonic::*;
13
14mod udp;
15#[cfg(not(target_arch = "wasm32"))]
16pub use udp::*;
17
18mod tcp;
19#[cfg(not(target_arch = "wasm32"))]
20pub use tcp::*;
21
22#[cfg(unix)]
23mod socket;
24use std::net::SocketAddr;
25use std::num::NonZeroUsize;
26use std::task::{Context, Poll};
27
28use futures::Stream;
29use serde::de::DeserializeOwned;
30use serde::ser::Serialize;
31#[cfg(unix)]
32pub use socket::*;
33
34/// Persit or delete tuples
35pub enum Persistence<T> {
36    /// Persist T values
37    Persist(T),
38    /// Delete all values that exactly match
39    Delete(T),
40}
41
42/// Persit or delete key-value pairs
43pub enum PersistenceKeyed<K, V> {
44    /// Persist key-value pairs
45    Persist(K, V),
46    /// Delete all tuples that have the key K
47    Delete(K),
48}
49
50/// Returns a channel as a (1) unbounded sender and (2) unbounded receiver `Stream` for use in DFIR.
51pub fn unbounded_channel<T>() -> (
52    tokio::sync::mpsc::UnboundedSender<T>,
53    tokio_stream::wrappers::UnboundedReceiverStream<T>,
54) {
55    let (send, recv) = tokio::sync::mpsc::unbounded_channel();
56    let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
57    (send, recv)
58}
59
60/// Returns an unsync channel as a (1) sender and (2) receiver `Stream` for use in DFIR.
61pub fn unsync_channel<T>(
62    capacity: Option<NonZeroUsize>,
63) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
64    unsync::mpsc::channel(capacity)
65}
66
67/// Returns an [`Iterator`] of any immediately available items from the [`Stream`].
68pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
69where
70    S: Stream,
71{
72    let mut stream = Box::pin(stream);
73    std::iter::from_fn(move || {
74        match stream
75            .as_mut()
76            .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
77        {
78            Poll::Ready(opt) => opt,
79            Poll::Pending => None,
80        }
81    })
82}
83
84/// Collects the immediately available items from the `Stream` into a `FromIterator` collection.
85///
86/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
87/// to retain ownership of your stream.
88pub fn collect_ready<C, S>(stream: S) -> C
89where
90    C: FromIterator<S::Item>,
91    S: Stream,
92{
93    assert!(
94        tokio::runtime::Handle::try_current().is_err(),
95        "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
96    );
97    ready_iter(stream).collect()
98}
99
100/// Collects the immediately available items from the `Stream` into a collection (`Default` + `Extend`).
101///
102/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
103/// to retain ownership of your stream.
104pub async fn collect_ready_async<C, S>(stream: S) -> C
105where
106    C: Default + Extend<S::Item>,
107    S: Stream,
108{
109    use std::sync::atomic::Ordering;
110
111    // Yield to let any background async tasks send to the stream.
112    tokio::task::yield_now().await;
113
114    let got_any_items = std::sync::atomic::AtomicBool::new(true);
115    let mut unfused_iter =
116        ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
117    let mut out = C::default();
118    while got_any_items.swap(false, Ordering::Relaxed) {
119        out.extend(unfused_iter.by_ref());
120        // Tokio unbounded channel returns items in lenght-128 chunks, so we have to be careful
121        // that everything gets returned. That is why we yield here and loop.
122        tokio::task::yield_now().await;
123    }
124    out
125}
126
127/// Serialize a message to bytes using bincode.
128pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
129where
130    T: Serialize,
131{
132    bytes::Bytes::from(bincode::serialize(&msg).unwrap())
133}
134
135/// Serialize a message from bytes using bincode.
136pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
137where
138    T: DeserializeOwned,
139{
140    bincode::deserialize(msg.as_ref())
141}
142
143/// Resolve the `ipv4` [`SocketAddr`] from an IP or hostname string.
144pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
145    use std::net::ToSocketAddrs;
146    let mut addrs = addr.to_socket_addrs()?;
147    let result = addrs.find(|addr| addr.is_ipv4());
148    match result {
149        Some(addr) => Ok(addr),
150        None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
151    }
152}
153
154/// Returns a length-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
155/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
156#[cfg(not(target_arch = "wasm32"))]
157pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
158    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
159    udp_bytes(socket)
160}
161
162/// Returns a newline-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
163/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
164#[cfg(not(target_arch = "wasm32"))]
165pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
166    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
167    udp_lines(socket)
168}
169
170/// Returns a newline-delimited bytes `Sender`, `Receiver`, and `SocketAddr` bound to the given address.
171///
172/// The input `addr` may have a port of `0`, the returned `SocketAddr` will be the address of the newly bound endpoint.
173/// The inbound connections can be used in full duplex mode. When a `(T, SocketAddr)` pair is fed to the `Sender`
174/// returned by this function, the `SocketAddr` will be looked up against the currently existing connections.
175/// If a match is found then the data will be sent on that connection. If no match is found then the data is silently dropped.
176#[cfg(not(target_arch = "wasm32"))]
177pub async fn bind_tcp_bytes(
178    addr: SocketAddr,
179) -> (
180    unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
181    unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
182    SocketAddr,
183) {
184    bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
185        .await
186        .unwrap()
187}
188
189/// This is the same thing as `bind_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
190#[cfg(not(target_arch = "wasm32"))]
191pub async fn bind_tcp_lines(
192    addr: SocketAddr,
193) -> (
194    unsync::mpsc::Sender<(String, SocketAddr)>,
195    unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
196    SocketAddr,
197) {
198    bind_tcp(addr, tokio_util::codec::LinesCodec::new())
199        .await
200        .unwrap()
201}
202
203/// The inverse of [`bind_tcp_bytes`].
204///
205/// `(Bytes, SocketAddr)` pairs fed to the returned `Sender` will initiate new tcp connections to the specified `SocketAddr`.
206/// These connections will be cached and reused, so that there will only be one connection per destination endpoint. When the endpoint sends data back it will be available via the returned `Receiver`
207#[cfg(not(target_arch = "wasm32"))]
208pub fn connect_tcp_bytes() -> (
209    TcpFramedSink<bytes::Bytes>,
210    TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
211) {
212    connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
213}
214
215/// This is the same thing as `connect_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
216#[cfg(not(target_arch = "wasm32"))]
217pub fn connect_tcp_lines() -> (
218    TcpFramedSink<String>,
219    TcpFramedStream<tokio_util::codec::LinesCodec>,
220) {
221    connect_tcp(tokio_util::codec::LinesCodec::new())
222}
223
224/// Sort a slice using a key fn which returns references.
225///
226/// From addendum in
227/// <https://stackoverflow.com/questions/56105305/how-to-sort-a-vec-of-structs-by-a-string-field>
228pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
229where
230    F: for<'a> Fn(&'a T) -> &'a K,
231    K: Ord,
232{
233    slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
234}
235
236/// Converts an iterator into a stream that emits `n` items at a time, yielding between each batch.
237///
238/// This is useful for breaking up a large iterator across several ticks: `source_iter(...)` always
239/// releases all items in the first tick. However using `iter_batches_stream` with `source_stream(...)`
240/// will cause `n` items to be released each tick. (Although more than that may be emitted if there
241/// are loops in the stratum).
242pub fn iter_batches_stream<I>(
243    iter: I,
244    n: usize,
245) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
246where
247    I: IntoIterator + Unpin,
248{
249    let mut count = 0;
250    let mut iter = iter.into_iter();
251    futures::stream::poll_fn(move |ctx| {
252        count += 1;
253        if n < count {
254            count = 0;
255            ctx.waker().wake_by_ref();
256            Poll::Pending
257        } else {
258            Poll::Ready(iter.next())
259        }
260    })
261}
262
263#[cfg(test)]
264mod test {
265    use super::*;
266
267    #[test]
268    pub fn test_collect_ready() {
269        let (send, mut recv) = unbounded_channel::<usize>();
270        for x in 0..1000 {
271            send.send(x).unwrap();
272        }
273        assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
274    }
275
276    #[crate::test]
277    pub async fn test_collect_ready_async() {
278        // Tokio unbounded channel returns items in 128 item long chunks, so we have to be careful that everything gets returned.
279        let (send, mut recv) = unbounded_channel::<usize>();
280        for x in 0..1000 {
281            send.send(x).unwrap();
282        }
283        assert_eq!(
284            1000,
285            collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
286        );
287    }
288}