dfir_rs/util/
mod.rs

1#![warn(missing_docs)]
2//! Helper utilities for the DFIR syntax.
3
4pub mod clear;
5#[cfg(feature = "dfir_macro")]
6#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
7pub mod demux_enum;
8pub mod multiset;
9pub mod priority_stack;
10pub mod slot_vec;
11pub mod sparse_vec;
12pub mod unsync;
13
14pub mod simulation;
15
16mod monotonic;
17pub use monotonic::*;
18
19mod udp;
20#[cfg(not(target_arch = "wasm32"))]
21pub use udp::*;
22
23mod tcp;
24#[cfg(not(target_arch = "wasm32"))]
25pub use tcp::*;
26
27#[cfg(unix)]
28mod socket;
29#[cfg(unix)]
30pub use socket::*;
31
32#[cfg(feature = "deploy_integration")]
33#[cfg_attr(docsrs, doc(cfg(feature = "deploy_integration")))]
34pub mod deploy;
35
36use std::net::SocketAddr;
37use std::num::NonZeroUsize;
38use std::task::{Context, Poll};
39
40use futures::Stream;
41use serde::de::DeserializeOwned;
42use serde::ser::Serialize;
43
44/// Persit or delete tuples
45pub enum Persistence<T> {
46    /// Persist T values
47    Persist(T),
48    /// Delete all values that exactly match
49    Delete(T),
50}
51
52/// Persit or delete key-value pairs
53pub enum PersistenceKeyed<K, V> {
54    /// Persist key-value pairs
55    Persist(K, V),
56    /// Delete all tuples that have the key K
57    Delete(K),
58}
59
60/// Returns a channel as a (1) unbounded sender and (2) unbounded receiver `Stream` for use in DFIR.
61pub fn unbounded_channel<T>() -> (
62    tokio::sync::mpsc::UnboundedSender<T>,
63    tokio_stream::wrappers::UnboundedReceiverStream<T>,
64) {
65    let (send, recv) = tokio::sync::mpsc::unbounded_channel();
66    let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
67    (send, recv)
68}
69
70/// Returns an unsync channel as a (1) sender and (2) receiver `Stream` for use in DFIR.
71pub fn unsync_channel<T>(
72    capacity: Option<NonZeroUsize>,
73) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
74    unsync::mpsc::channel(capacity)
75}
76
77/// Returns an [`Iterator`] of any immediately available items from the [`Stream`].
78pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
79where
80    S: Stream,
81{
82    let mut stream = Box::pin(stream);
83    std::iter::from_fn(move || {
84        match stream
85            .as_mut()
86            .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
87        {
88            Poll::Ready(opt) => opt,
89            Poll::Pending => None,
90        }
91    })
92}
93
94/// Collects the immediately available items from the `Stream` into a `FromIterator` collection.
95///
96/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
97/// to retain ownership of your stream.
98pub fn collect_ready<C, S>(stream: S) -> C
99where
100    C: FromIterator<S::Item>,
101    S: Stream,
102{
103    assert!(
104        tokio::runtime::Handle::try_current().is_err(),
105        "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
106    );
107    ready_iter(stream).collect()
108}
109
110/// Collects the immediately available items from the `Stream` into a collection (`Default` + `Extend`).
111///
112/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
113/// to retain ownership of your stream.
114pub async fn collect_ready_async<C, S>(stream: S) -> C
115where
116    C: Default + Extend<S::Item>,
117    S: Stream,
118{
119    use std::sync::atomic::Ordering;
120
121    // Yield to let any background async tasks send to the stream.
122    tokio::task::yield_now().await;
123
124    let got_any_items = std::sync::atomic::AtomicBool::new(true);
125    let mut unfused_iter =
126        ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
127    let mut out = C::default();
128    while got_any_items.swap(false, Ordering::Relaxed) {
129        out.extend(unfused_iter.by_ref());
130        // Tokio unbounded channel returns items in lenght-128 chunks, so we have to be careful
131        // that everything gets returned. That is why we yield here and loop.
132        tokio::task::yield_now().await;
133    }
134    out
135}
136
137/// Serialize a message to bytes using bincode.
138pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
139where
140    T: Serialize,
141{
142    bytes::Bytes::from(bincode::serialize(&msg).unwrap())
143}
144
145/// Serialize a message from bytes using bincode.
146pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
147where
148    T: DeserializeOwned,
149{
150    bincode::deserialize(msg.as_ref())
151}
152
153/// Resolve the `ipv4` [`SocketAddr`] from an IP or hostname string.
154pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
155    use std::net::ToSocketAddrs;
156    let mut addrs = addr.to_socket_addrs()?;
157    let result = addrs.find(|addr| addr.is_ipv4());
158    match result {
159        Some(addr) => Ok(addr),
160        None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
161    }
162}
163
164/// Returns a length-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
165/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
166#[cfg(not(target_arch = "wasm32"))]
167pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
168    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
169    udp_bytes(socket)
170}
171
172/// Returns a newline-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
173/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
174#[cfg(not(target_arch = "wasm32"))]
175pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
176    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
177    udp_lines(socket)
178}
179
180/// Returns a newline-delimited bytes `Sender`, `Receiver`, and `SocketAddr` bound to the given address.
181///
182/// The input `addr` may have a port of `0`, the returned `SocketAddr` will be the address of the newly bound endpoint.
183/// The inbound connections can be used in full duplex mode. When a `(T, SocketAddr)` pair is fed to the `Sender`
184/// returned by this function, the `SocketAddr` will be looked up against the currently existing connections.
185/// If a match is found then the data will be sent on that connection. If no match is found then the data is silently dropped.
186#[cfg(not(target_arch = "wasm32"))]
187pub async fn bind_tcp_bytes(
188    addr: SocketAddr,
189) -> (
190    unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
191    unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
192    SocketAddr,
193) {
194    bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
195        .await
196        .unwrap()
197}
198
199/// This is the same thing as `bind_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
200#[cfg(not(target_arch = "wasm32"))]
201pub async fn bind_tcp_lines(
202    addr: SocketAddr,
203) -> (
204    unsync::mpsc::Sender<(String, SocketAddr)>,
205    unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
206    SocketAddr,
207) {
208    bind_tcp(addr, tokio_util::codec::LinesCodec::new())
209        .await
210        .unwrap()
211}
212
213/// The inverse of [`bind_tcp_bytes`].
214///
215/// `(Bytes, SocketAddr)` pairs fed to the returned `Sender` will initiate new tcp connections to the specified `SocketAddr`.
216/// These connections will be cached and reused, so that there will only be one connection per destination endpoint. When the endpoint sends data back it will be available via the returned `Receiver`
217#[cfg(not(target_arch = "wasm32"))]
218pub fn connect_tcp_bytes() -> (
219    TcpFramedSink<bytes::Bytes>,
220    TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
221) {
222    connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
223}
224
225/// This is the same thing as `connect_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
226#[cfg(not(target_arch = "wasm32"))]
227pub fn connect_tcp_lines() -> (
228    TcpFramedSink<String>,
229    TcpFramedStream<tokio_util::codec::LinesCodec>,
230) {
231    connect_tcp(tokio_util::codec::LinesCodec::new())
232}
233
234/// Sort a slice using a key fn which returns references.
235///
236/// From addendum in
237/// <https://stackoverflow.com/questions/56105305/how-to-sort-a-vec-of-structs-by-a-string-field>
238pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
239where
240    F: for<'a> Fn(&'a T) -> &'a K,
241    K: Ord,
242{
243    slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
244}
245
246/// Converts an iterator into a stream that emits `n` items at a time, yielding between each batch.
247///
248/// This is useful for breaking up a large iterator across several ticks: `source_iter(...)` always
249/// releases all items in the first tick. However using `iter_batches_stream` with `source_stream(...)`
250/// will cause `n` items to be released each tick. (Although more than that may be emitted if there
251/// are loops in the stratum).
252pub fn iter_batches_stream<I>(
253    iter: I,
254    n: usize,
255) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
256where
257    I: IntoIterator + Unpin,
258{
259    let mut count = 0;
260    let mut iter = iter.into_iter();
261    futures::stream::poll_fn(move |ctx| {
262        count += 1;
263        if n < count {
264            count = 0;
265            ctx.waker().wake_by_ref();
266            Poll::Pending
267        } else {
268            Poll::Ready(iter.next())
269        }
270    })
271}
272
273#[cfg(test)]
274mod test {
275    use super::*;
276
277    #[test]
278    pub fn test_collect_ready() {
279        let (send, mut recv) = unbounded_channel::<usize>();
280        for x in 0..1000 {
281            send.send(x).unwrap();
282        }
283        assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
284    }
285
286    #[crate::test]
287    pub async fn test_collect_ready_async() {
288        // Tokio unbounded channel returns items in 128 item long chunks, so we have to be careful that everything gets returned.
289        let (send, mut recv) = unbounded_channel::<usize>();
290        for x in 0..1000 {
291            send.send(x).unwrap();
292        }
293        assert_eq!(
294            1000,
295            collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
296        );
297    }
298}