dfir_rs/util/
mod.rs

1#![warn(missing_docs)]
2//! Helper utilities for the DFIR syntax.
3
4pub mod clear;
5#[cfg(feature = "dfir_macro")]
6#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
7pub mod demux_enum;
8pub mod monotonic_map;
9pub mod multiset;
10pub mod priority_stack;
11pub mod slot_vec;
12pub mod sparse_vec;
13pub mod unsync;
14
15pub mod simulation;
16
17mod monotonic;
18pub use monotonic::*;
19
20mod udp;
21#[cfg(not(target_arch = "wasm32"))]
22pub use udp::*;
23
24mod tcp;
25#[cfg(not(target_arch = "wasm32"))]
26pub use tcp::*;
27
28#[cfg(unix)]
29mod socket;
30#[cfg(unix)]
31pub use socket::*;
32
33#[cfg(feature = "deploy_integration")]
34#[cfg_attr(docsrs, doc(cfg(feature = "deploy_integration")))]
35pub mod deploy;
36
37use std::io::Read;
38use std::net::SocketAddr;
39use std::num::NonZeroUsize;
40use std::process::{Child, ChildStdin, ChildStdout, Stdio};
41use std::task::{Context, Poll};
42
43use futures::Stream;
44use serde::de::DeserializeOwned;
45use serde::ser::Serialize;
46
47/// Persit or delete tuples
48pub enum Persistence<T> {
49    /// Persist T values
50    Persist(T),
51    /// Delete all values that exactly match
52    Delete(T),
53}
54
55/// Persit or delete key-value pairs
56pub enum PersistenceKeyed<K, V> {
57    /// Persist key-value pairs
58    Persist(K, V),
59    /// Delete all tuples that have the key K
60    Delete(K),
61}
62
63/// Returns a channel as a (1) unbounded sender and (2) unbounded receiver `Stream` for use in DFIR.
64pub fn unbounded_channel<T>() -> (
65    tokio::sync::mpsc::UnboundedSender<T>,
66    tokio_stream::wrappers::UnboundedReceiverStream<T>,
67) {
68    let (send, recv) = tokio::sync::mpsc::unbounded_channel();
69    let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
70    (send, recv)
71}
72
73/// Returns an unsync channel as a (1) sender and (2) receiver `Stream` for use in DFIR.
74pub fn unsync_channel<T>(
75    capacity: Option<NonZeroUsize>,
76) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
77    unsync::mpsc::channel(capacity)
78}
79
80/// Returns an [`Iterator`] of any immediately available items from the [`Stream`].
81pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
82where
83    S: Stream,
84{
85    let mut stream = Box::pin(stream);
86    std::iter::from_fn(move || {
87        match stream
88            .as_mut()
89            .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
90        {
91            Poll::Ready(opt) => opt,
92            Poll::Pending => None,
93        }
94    })
95}
96
97/// Collects the immediately available items from the `Stream` into a `FromIterator` collection.
98///
99/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
100/// to retain ownership of your stream.
101pub fn collect_ready<C, S>(stream: S) -> C
102where
103    C: FromIterator<S::Item>,
104    S: Stream,
105{
106    assert!(
107        tokio::runtime::Handle::try_current().is_err(),
108        "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
109    );
110    ready_iter(stream).collect()
111}
112
113/// Collects the immediately available items from the `Stream` into a collection (`Default` + `Extend`).
114///
115/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
116/// to retain ownership of your stream.
117pub async fn collect_ready_async<C, S>(stream: S) -> C
118where
119    C: Default + Extend<S::Item>,
120    S: Stream,
121{
122    use std::sync::atomic::Ordering;
123
124    // Yield to let any background async tasks send to the stream.
125    tokio::task::yield_now().await;
126
127    let got_any_items = std::sync::atomic::AtomicBool::new(true);
128    let mut unfused_iter =
129        ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
130    let mut out = C::default();
131    while got_any_items.swap(false, Ordering::Relaxed) {
132        out.extend(unfused_iter.by_ref());
133        // Tokio unbounded channel returns items in lenght-128 chunks, so we have to be careful
134        // that everything gets returned. That is why we yield here and loop.
135        tokio::task::yield_now().await;
136    }
137    out
138}
139
140/// Serialize a message to bytes using bincode.
141pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
142where
143    T: Serialize,
144{
145    bytes::Bytes::from(bincode::serialize(&msg).unwrap())
146}
147
148/// Serialize a message from bytes using bincode.
149pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
150where
151    T: DeserializeOwned,
152{
153    bincode::deserialize(msg.as_ref())
154}
155
156/// Resolve the `ipv4` [`SocketAddr`] from an IP or hostname string.
157pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
158    use std::net::ToSocketAddrs;
159    let mut addrs = addr.to_socket_addrs()?;
160    let result = addrs.find(|addr| addr.is_ipv4());
161    match result {
162        Some(addr) => Ok(addr),
163        None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
164    }
165}
166
167/// Returns a length-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
168/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
169#[cfg(not(target_arch = "wasm32"))]
170pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
171    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
172    udp_bytes(socket)
173}
174
175/// Returns a newline-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
176/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
177#[cfg(not(target_arch = "wasm32"))]
178pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
179    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
180    udp_lines(socket)
181}
182
183/// Returns a newline-delimited bytes `Sender`, `Receiver`, and `SocketAddr` bound to the given address.
184///
185/// The input `addr` may have a port of `0`, the returned `SocketAddr` will be the address of the newly bound endpoint.
186/// The inbound connections can be used in full duplex mode. When a `(T, SocketAddr)` pair is fed to the `Sender`
187/// returned by this function, the `SocketAddr` will be looked up against the currently existing connections.
188/// If a match is found then the data will be sent on that connection. If no match is found then the data is silently dropped.
189#[cfg(not(target_arch = "wasm32"))]
190pub async fn bind_tcp_bytes(
191    addr: SocketAddr,
192) -> (
193    unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
194    unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
195    SocketAddr,
196) {
197    bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
198        .await
199        .unwrap()
200}
201
202/// This is the same thing as `bind_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
203#[cfg(not(target_arch = "wasm32"))]
204pub async fn bind_tcp_lines(
205    addr: SocketAddr,
206) -> (
207    unsync::mpsc::Sender<(String, SocketAddr)>,
208    unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
209    SocketAddr,
210) {
211    bind_tcp(addr, tokio_util::codec::LinesCodec::new())
212        .await
213        .unwrap()
214}
215
216/// The inverse of [`bind_tcp_bytes`].
217///
218/// `(Bytes, SocketAddr)` pairs fed to the returned `Sender` will initiate new tcp connections to the specified `SocketAddr`.
219/// These connections will be cached and reused, so that there will only be one connection per destination endpoint. When the endpoint sends data back it will be available via the returned `Receiver`
220#[cfg(not(target_arch = "wasm32"))]
221pub fn connect_tcp_bytes() -> (
222    TcpFramedSink<bytes::Bytes>,
223    TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
224) {
225    connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
226}
227
228/// This is the same thing as `connect_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
229#[cfg(not(target_arch = "wasm32"))]
230pub fn connect_tcp_lines() -> (
231    TcpFramedSink<String>,
232    TcpFramedStream<tokio_util::codec::LinesCodec>,
233) {
234    connect_tcp(tokio_util::codec::LinesCodec::new())
235}
236
237/// Sort a slice using a key fn which returns references.
238///
239/// From addendum in
240/// <https://stackoverflow.com/questions/56105305/how-to-sort-a-vec-of-structs-by-a-string-field>
241pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
242where
243    F: for<'a> Fn(&'a T) -> &'a K,
244    K: Ord,
245{
246    slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
247}
248
249/// Waits for a specific process output before returning.
250///
251/// When a child process is spawned often you want to wait until the child process is ready before
252/// moving on. One way to do that synchronization is by waiting for the child process to output
253/// something and match regex against that output. For example, you could wait until the child
254/// process outputs "Client live!" which would indicate that it is ready to receive input now on
255/// stdin.
256pub fn wait_for_process_output(
257    output_so_far: &mut String,
258    output: &mut ChildStdout,
259    wait_for: &str,
260) {
261    let re = regex::Regex::new(wait_for).unwrap();
262
263    while !re.is_match(output_so_far) {
264        println!("waiting: {}", output_so_far);
265        let mut buffer = [0u8; 1024];
266        let bytes_read = output.read(&mut buffer).unwrap();
267
268        if bytes_read == 0 {
269            panic!();
270        }
271
272        output_so_far.push_str(&String::from_utf8_lossy(&buffer[0..bytes_read]));
273
274        println!("XXX {}", output_so_far);
275    }
276}
277
278/// Terminates the inner [`Child`] process when dropped.
279///
280/// When a `Child` is dropped normally nothing happens but in unit tests you usually want to
281/// terminate the child and wait for it to terminate. `DroppableChild` does that for us.
282pub struct DroppableChild(Child);
283
284impl Drop for DroppableChild {
285    fn drop(&mut self) {
286        #[cfg(target_family = "windows")]
287        let _ = self.0.kill(); // Windows throws `PermissionDenied` if the process has already exited.
288        #[cfg(not(target_family = "windows"))]
289        self.0.kill().unwrap();
290
291        self.0.wait().unwrap();
292    }
293}
294
295/// Run a rust example as a test.
296///
297/// Rust examples are meant to be run by people and have a natural interface for that. This makes
298/// unit testing them cumbersome. This function wraps calling cargo run and piping the stdin/stdout
299/// of the example to easy to handle returned objects. The function also returns a `DroppableChild`
300/// which will ensure that the child processes will be cleaned up appropriately.
301pub fn run_cargo_example(test_name: &str, args: &str) -> (DroppableChild, ChildStdin, ChildStdout) {
302    let mut server = if args.is_empty() {
303        std::process::Command::new("cargo")
304            .args(["run", "-p", "dfir_rs", "--example"])
305            .arg(test_name)
306            .stdin(Stdio::piped())
307            .stdout(Stdio::piped())
308            .spawn()
309            .unwrap()
310    } else {
311        std::process::Command::new("cargo")
312            .args(["run", "-p", "dfir_rs", "--example"])
313            .arg(test_name)
314            .arg("--")
315            .args(args.split(' '))
316            .stdin(Stdio::piped())
317            .stdout(Stdio::piped())
318            .spawn()
319            .unwrap()
320    };
321
322    let stdin = server.stdin.take().unwrap();
323    let stdout = server.stdout.take().unwrap();
324
325    (DroppableChild(server), stdin, stdout)
326}
327
328/// Converts an iterator into a stream that emits `n` items at a time, yielding between each batch.
329///
330/// This is useful for breaking up a large iterator across several ticks: `source_iter(...)` always
331/// releases all items in the first tick. However using `iter_batches_stream` with `source_stream(...)`
332/// will cause `n` items to be released each tick. (Although more than that may be emitted if there
333/// are loops in the stratum).
334pub fn iter_batches_stream<I>(
335    iter: I,
336    n: usize,
337) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
338where
339    I: IntoIterator + Unpin,
340{
341    let mut count = 0;
342    let mut iter = iter.into_iter();
343    futures::stream::poll_fn(move |ctx| {
344        count += 1;
345        if n < count {
346            count = 0;
347            ctx.waker().wake_by_ref();
348            Poll::Pending
349        } else {
350            Poll::Ready(iter.next())
351        }
352    })
353}
354
355#[cfg(test)]
356mod test {
357    use super::*;
358
359    #[test]
360    pub fn test_collect_ready() {
361        let (send, mut recv) = unbounded_channel::<usize>();
362        for x in 0..1000 {
363            send.send(x).unwrap();
364        }
365        assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
366    }
367
368    #[crate::test]
369    pub async fn test_collect_ready_async() {
370        // Tokio unbounded channel returns items in 128 item long chunks, so we have to be careful that everything gets returned.
371        let (send, mut recv) = unbounded_channel::<usize>();
372        for x in 0..1000 {
373            send.send(x).unwrap();
374        }
375        assert_eq!(
376            1000,
377            collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
378        );
379    }
380}