dfir_rs/util/
mod.rs

1#![warn(missing_docs)]
2//! Helper utilities for the DFIR syntax.
3
4pub mod clear;
5#[cfg(feature = "dfir_macro")]
6#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
7pub mod demux_enum;
8pub mod multiset;
9pub mod priority_stack;
10pub mod slot_vec;
11pub mod sparse_vec;
12pub mod unsync;
13
14pub mod simulation;
15
16mod monotonic;
17pub use monotonic::*;
18
19mod udp;
20#[cfg(not(target_arch = "wasm32"))]
21pub use udp::*;
22
23mod tcp;
24#[cfg(not(target_arch = "wasm32"))]
25pub use tcp::*;
26
27#[cfg(unix)]
28mod socket;
29#[cfg(unix)]
30pub use socket::*;
31
32#[cfg(feature = "deploy_integration")]
33#[cfg_attr(docsrs, doc(cfg(feature = "deploy_integration")))]
34pub mod deploy;
35
36use std::io::Read;
37use std::net::SocketAddr;
38use std::num::NonZeroUsize;
39use std::process::{Child, ChildStdin, ChildStdout, Stdio};
40use std::task::{Context, Poll};
41
42use futures::Stream;
43use serde::de::DeserializeOwned;
44use serde::ser::Serialize;
45
46/// Persit or delete tuples
47pub enum Persistence<T> {
48    /// Persist T values
49    Persist(T),
50    /// Delete all values that exactly match
51    Delete(T),
52}
53
54/// Persit or delete key-value pairs
55pub enum PersistenceKeyed<K, V> {
56    /// Persist key-value pairs
57    Persist(K, V),
58    /// Delete all tuples that have the key K
59    Delete(K),
60}
61
62/// Returns a channel as a (1) unbounded sender and (2) unbounded receiver `Stream` for use in DFIR.
63pub fn unbounded_channel<T>() -> (
64    tokio::sync::mpsc::UnboundedSender<T>,
65    tokio_stream::wrappers::UnboundedReceiverStream<T>,
66) {
67    let (send, recv) = tokio::sync::mpsc::unbounded_channel();
68    let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
69    (send, recv)
70}
71
72/// Returns an unsync channel as a (1) sender and (2) receiver `Stream` for use in DFIR.
73pub fn unsync_channel<T>(
74    capacity: Option<NonZeroUsize>,
75) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
76    unsync::mpsc::channel(capacity)
77}
78
79/// Returns an [`Iterator`] of any immediately available items from the [`Stream`].
80pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
81where
82    S: Stream,
83{
84    let mut stream = Box::pin(stream);
85    std::iter::from_fn(move || {
86        match stream
87            .as_mut()
88            .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
89        {
90            Poll::Ready(opt) => opt,
91            Poll::Pending => None,
92        }
93    })
94}
95
96/// Collects the immediately available items from the `Stream` into a `FromIterator` collection.
97///
98/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
99/// to retain ownership of your stream.
100pub fn collect_ready<C, S>(stream: S) -> C
101where
102    C: FromIterator<S::Item>,
103    S: Stream,
104{
105    assert!(
106        tokio::runtime::Handle::try_current().is_err(),
107        "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
108    );
109    ready_iter(stream).collect()
110}
111
112/// Collects the immediately available items from the `Stream` into a collection (`Default` + `Extend`).
113///
114/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
115/// to retain ownership of your stream.
116pub async fn collect_ready_async<C, S>(stream: S) -> C
117where
118    C: Default + Extend<S::Item>,
119    S: Stream,
120{
121    use std::sync::atomic::Ordering;
122
123    // Yield to let any background async tasks send to the stream.
124    tokio::task::yield_now().await;
125
126    let got_any_items = std::sync::atomic::AtomicBool::new(true);
127    let mut unfused_iter =
128        ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
129    let mut out = C::default();
130    while got_any_items.swap(false, Ordering::Relaxed) {
131        out.extend(unfused_iter.by_ref());
132        // Tokio unbounded channel returns items in lenght-128 chunks, so we have to be careful
133        // that everything gets returned. That is why we yield here and loop.
134        tokio::task::yield_now().await;
135    }
136    out
137}
138
139/// Serialize a message to bytes using bincode.
140pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
141where
142    T: Serialize,
143{
144    bytes::Bytes::from(bincode::serialize(&msg).unwrap())
145}
146
147/// Serialize a message from bytes using bincode.
148pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
149where
150    T: DeserializeOwned,
151{
152    bincode::deserialize(msg.as_ref())
153}
154
155/// Resolve the `ipv4` [`SocketAddr`] from an IP or hostname string.
156pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
157    use std::net::ToSocketAddrs;
158    let mut addrs = addr.to_socket_addrs()?;
159    let result = addrs.find(|addr| addr.is_ipv4());
160    match result {
161        Some(addr) => Ok(addr),
162        None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
163    }
164}
165
166/// Returns a length-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
167/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
168#[cfg(not(target_arch = "wasm32"))]
169pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
170    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
171    udp_bytes(socket)
172}
173
174/// Returns a newline-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
175/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
176#[cfg(not(target_arch = "wasm32"))]
177pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
178    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
179    udp_lines(socket)
180}
181
182/// Returns a newline-delimited bytes `Sender`, `Receiver`, and `SocketAddr` bound to the given address.
183///
184/// The input `addr` may have a port of `0`, the returned `SocketAddr` will be the address of the newly bound endpoint.
185/// The inbound connections can be used in full duplex mode. When a `(T, SocketAddr)` pair is fed to the `Sender`
186/// returned by this function, the `SocketAddr` will be looked up against the currently existing connections.
187/// If a match is found then the data will be sent on that connection. If no match is found then the data is silently dropped.
188#[cfg(not(target_arch = "wasm32"))]
189pub async fn bind_tcp_bytes(
190    addr: SocketAddr,
191) -> (
192    unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
193    unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
194    SocketAddr,
195) {
196    bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
197        .await
198        .unwrap()
199}
200
201/// This is the same thing as `bind_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
202#[cfg(not(target_arch = "wasm32"))]
203pub async fn bind_tcp_lines(
204    addr: SocketAddr,
205) -> (
206    unsync::mpsc::Sender<(String, SocketAddr)>,
207    unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
208    SocketAddr,
209) {
210    bind_tcp(addr, tokio_util::codec::LinesCodec::new())
211        .await
212        .unwrap()
213}
214
215/// The inverse of [`bind_tcp_bytes`].
216///
217/// `(Bytes, SocketAddr)` pairs fed to the returned `Sender` will initiate new tcp connections to the specified `SocketAddr`.
218/// These connections will be cached and reused, so that there will only be one connection per destination endpoint. When the endpoint sends data back it will be available via the returned `Receiver`
219#[cfg(not(target_arch = "wasm32"))]
220pub fn connect_tcp_bytes() -> (
221    TcpFramedSink<bytes::Bytes>,
222    TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
223) {
224    connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
225}
226
227/// This is the same thing as `connect_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
228#[cfg(not(target_arch = "wasm32"))]
229pub fn connect_tcp_lines() -> (
230    TcpFramedSink<String>,
231    TcpFramedStream<tokio_util::codec::LinesCodec>,
232) {
233    connect_tcp(tokio_util::codec::LinesCodec::new())
234}
235
236/// Sort a slice using a key fn which returns references.
237///
238/// From addendum in
239/// <https://stackoverflow.com/questions/56105305/how-to-sort-a-vec-of-structs-by-a-string-field>
240pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
241where
242    F: for<'a> Fn(&'a T) -> &'a K,
243    K: Ord,
244{
245    slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
246}
247
248/// Waits for a specific process output before returning.
249///
250/// When a child process is spawned often you want to wait until the child process is ready before
251/// moving on. One way to do that synchronization is by waiting for the child process to output
252/// something and match regex against that output. For example, you could wait until the child
253/// process outputs "Client live!" which would indicate that it is ready to receive input now on
254/// stdin.
255pub fn wait_for_process_output(
256    output_so_far: &mut String,
257    output: &mut ChildStdout,
258    wait_for: &str,
259) {
260    let re = regex::Regex::new(wait_for).unwrap();
261
262    while !re.is_match(output_so_far) {
263        println!("waiting: {}", output_so_far);
264        let mut buffer = [0u8; 1024];
265        let bytes_read = output.read(&mut buffer).unwrap();
266
267        if bytes_read == 0 {
268            panic!();
269        }
270
271        output_so_far.push_str(&String::from_utf8_lossy(&buffer[0..bytes_read]));
272
273        println!("XXX {}", output_so_far);
274    }
275}
276
277/// Terminates the inner [`Child`] process when dropped.
278///
279/// When a `Child` is dropped normally nothing happens but in unit tests you usually want to
280/// terminate the child and wait for it to terminate. `DroppableChild` does that for us.
281pub struct DroppableChild(Child);
282
283impl Drop for DroppableChild {
284    fn drop(&mut self) {
285        #[cfg(target_family = "windows")]
286        let _ = self.0.kill(); // Windows throws `PermissionDenied` if the process has already exited.
287        #[cfg(not(target_family = "windows"))]
288        self.0.kill().unwrap();
289
290        self.0.wait().unwrap();
291    }
292}
293
294/// Run a rust example as a test.
295///
296/// Rust examples are meant to be run by people and have a natural interface for that. This makes
297/// unit testing them cumbersome. This function wraps calling cargo run and piping the stdin/stdout
298/// of the example to easy to handle returned objects. The function also returns a `DroppableChild`
299/// which will ensure that the child processes will be cleaned up appropriately.
300pub fn run_cargo_example(test_name: &str, args: &str) -> (DroppableChild, ChildStdin, ChildStdout) {
301    let mut server = if args.is_empty() {
302        std::process::Command::new("cargo")
303            .args(["run", "-p", "dfir_rs", "--example"])
304            .arg(test_name)
305            .stdin(Stdio::piped())
306            .stdout(Stdio::piped())
307            .spawn()
308            .unwrap()
309    } else {
310        std::process::Command::new("cargo")
311            .args(["run", "-p", "dfir_rs", "--example"])
312            .arg(test_name)
313            .arg("--")
314            .args(args.split(' '))
315            .stdin(Stdio::piped())
316            .stdout(Stdio::piped())
317            .spawn()
318            .unwrap()
319    };
320
321    let stdin = server.stdin.take().unwrap();
322    let stdout = server.stdout.take().unwrap();
323
324    (DroppableChild(server), stdin, stdout)
325}
326
327/// Converts an iterator into a stream that emits `n` items at a time, yielding between each batch.
328///
329/// This is useful for breaking up a large iterator across several ticks: `source_iter(...)` always
330/// releases all items in the first tick. However using `iter_batches_stream` with `source_stream(...)`
331/// will cause `n` items to be released each tick. (Although more than that may be emitted if there
332/// are loops in the stratum).
333pub fn iter_batches_stream<I>(
334    iter: I,
335    n: usize,
336) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
337where
338    I: IntoIterator + Unpin,
339{
340    let mut count = 0;
341    let mut iter = iter.into_iter();
342    futures::stream::poll_fn(move |ctx| {
343        count += 1;
344        if n < count {
345            count = 0;
346            ctx.waker().wake_by_ref();
347            Poll::Pending
348        } else {
349            Poll::Ready(iter.next())
350        }
351    })
352}
353
354#[cfg(test)]
355mod test {
356    use super::*;
357
358    #[test]
359    pub fn test_collect_ready() {
360        let (send, mut recv) = unbounded_channel::<usize>();
361        for x in 0..1000 {
362            send.send(x).unwrap();
363        }
364        assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
365    }
366
367    #[crate::test]
368    pub async fn test_collect_ready_async() {
369        // Tokio unbounded channel returns items in 128 item long chunks, so we have to be careful that everything gets returned.
370        let (send, mut recv) = unbounded_channel::<usize>();
371        for x in 0..1000 {
372            send.send(x).unwrap();
373        }
374        assert_eq!(
375            1000,
376            collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
377        );
378    }
379}