1#![warn(missing_docs)]
2pub mod clear;
5#[cfg(feature = "dfir_macro")]
6#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
7pub mod demux_enum;
8pub mod multiset;
9pub mod priority_stack;
10pub mod slot_vec;
11pub mod sparse_vec;
12pub mod unsync;
13
14pub mod simulation;
15
16mod monotonic;
17pub use monotonic::*;
18
19mod udp;
20#[cfg(not(target_arch = "wasm32"))]
21pub use udp::*;
22
23mod tcp;
24#[cfg(not(target_arch = "wasm32"))]
25pub use tcp::*;
26
27#[cfg(unix)]
28mod socket;
29#[cfg(unix)]
30pub use socket::*;
31
32#[cfg(feature = "deploy_integration")]
33#[cfg_attr(docsrs, doc(cfg(feature = "deploy_integration")))]
34pub mod deploy;
35
36use std::io::Read;
37use std::net::SocketAddr;
38use std::num::NonZeroUsize;
39use std::process::{Child, ChildStdin, ChildStdout, Stdio};
40use std::task::{Context, Poll};
41
42use futures::Stream;
43use serde::de::DeserializeOwned;
44use serde::ser::Serialize;
45
46pub enum Persistence<T> {
48 Persist(T),
50 Delete(T),
52}
53
54pub enum PersistenceKeyed<K, V> {
56 Persist(K, V),
58 Delete(K),
60}
61
62pub fn unbounded_channel<T>() -> (
64 tokio::sync::mpsc::UnboundedSender<T>,
65 tokio_stream::wrappers::UnboundedReceiverStream<T>,
66) {
67 let (send, recv) = tokio::sync::mpsc::unbounded_channel();
68 let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
69 (send, recv)
70}
71
72pub fn unsync_channel<T>(
74 capacity: Option<NonZeroUsize>,
75) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
76 unsync::mpsc::channel(capacity)
77}
78
79pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
81where
82 S: Stream,
83{
84 let mut stream = Box::pin(stream);
85 std::iter::from_fn(move || {
86 match stream
87 .as_mut()
88 .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
89 {
90 Poll::Ready(opt) => opt,
91 Poll::Pending => None,
92 }
93 })
94}
95
96pub fn collect_ready<C, S>(stream: S) -> C
101where
102 C: FromIterator<S::Item>,
103 S: Stream,
104{
105 assert!(
106 tokio::runtime::Handle::try_current().is_err(),
107 "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
108 );
109 ready_iter(stream).collect()
110}
111
112pub async fn collect_ready_async<C, S>(stream: S) -> C
117where
118 C: Default + Extend<S::Item>,
119 S: Stream,
120{
121 use std::sync::atomic::Ordering;
122
123 tokio::task::yield_now().await;
125
126 let got_any_items = std::sync::atomic::AtomicBool::new(true);
127 let mut unfused_iter =
128 ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
129 let mut out = C::default();
130 while got_any_items.swap(false, Ordering::Relaxed) {
131 out.extend(unfused_iter.by_ref());
132 tokio::task::yield_now().await;
135 }
136 out
137}
138
139pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
141where
142 T: Serialize,
143{
144 bytes::Bytes::from(bincode::serialize(&msg).unwrap())
145}
146
147pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
149where
150 T: DeserializeOwned,
151{
152 bincode::deserialize(msg.as_ref())
153}
154
155pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
157 use std::net::ToSocketAddrs;
158 let mut addrs = addr.to_socket_addrs()?;
159 let result = addrs.find(|addr| addr.is_ipv4());
160 match result {
161 Some(addr) => Ok(addr),
162 None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
163 }
164}
165
166#[cfg(not(target_arch = "wasm32"))]
169pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
170 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
171 udp_bytes(socket)
172}
173
174#[cfg(not(target_arch = "wasm32"))]
177pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
178 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
179 udp_lines(socket)
180}
181
182#[cfg(not(target_arch = "wasm32"))]
189pub async fn bind_tcp_bytes(
190 addr: SocketAddr,
191) -> (
192 unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
193 unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
194 SocketAddr,
195) {
196 bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
197 .await
198 .unwrap()
199}
200
201#[cfg(not(target_arch = "wasm32"))]
203pub async fn bind_tcp_lines(
204 addr: SocketAddr,
205) -> (
206 unsync::mpsc::Sender<(String, SocketAddr)>,
207 unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
208 SocketAddr,
209) {
210 bind_tcp(addr, tokio_util::codec::LinesCodec::new())
211 .await
212 .unwrap()
213}
214
215#[cfg(not(target_arch = "wasm32"))]
220pub fn connect_tcp_bytes() -> (
221 TcpFramedSink<bytes::Bytes>,
222 TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
223) {
224 connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
225}
226
227#[cfg(not(target_arch = "wasm32"))]
229pub fn connect_tcp_lines() -> (
230 TcpFramedSink<String>,
231 TcpFramedStream<tokio_util::codec::LinesCodec>,
232) {
233 connect_tcp(tokio_util::codec::LinesCodec::new())
234}
235
236pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
241where
242 F: for<'a> Fn(&'a T) -> &'a K,
243 K: Ord,
244{
245 slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
246}
247
248pub fn wait_for_process_output(
256 output_so_far: &mut String,
257 output: &mut ChildStdout,
258 wait_for: &str,
259) {
260 let re = regex::Regex::new(wait_for).unwrap();
261
262 while !re.is_match(output_so_far) {
263 println!("waiting: {}", output_so_far);
264 let mut buffer = [0u8; 1024];
265 let bytes_read = output.read(&mut buffer).unwrap();
266
267 if bytes_read == 0 {
268 panic!();
269 }
270
271 output_so_far.push_str(&String::from_utf8_lossy(&buffer[0..bytes_read]));
272
273 println!("XXX {}", output_so_far);
274 }
275}
276
277pub struct DroppableChild(Child);
282
283impl Drop for DroppableChild {
284 fn drop(&mut self) {
285 #[cfg(target_family = "windows")]
286 let _ = self.0.kill(); #[cfg(not(target_family = "windows"))]
288 self.0.kill().unwrap();
289
290 self.0.wait().unwrap();
291 }
292}
293
294pub fn run_cargo_example(test_name: &str, args: &str) -> (DroppableChild, ChildStdin, ChildStdout) {
301 let mut server = if args.is_empty() {
302 std::process::Command::new("cargo")
303 .args(["run", "-p", "dfir_rs", "--example"])
304 .arg(test_name)
305 .stdin(Stdio::piped())
306 .stdout(Stdio::piped())
307 .spawn()
308 .unwrap()
309 } else {
310 std::process::Command::new("cargo")
311 .args(["run", "-p", "dfir_rs", "--example"])
312 .arg(test_name)
313 .arg("--")
314 .args(args.split(' '))
315 .stdin(Stdio::piped())
316 .stdout(Stdio::piped())
317 .spawn()
318 .unwrap()
319 };
320
321 let stdin = server.stdin.take().unwrap();
322 let stdout = server.stdout.take().unwrap();
323
324 (DroppableChild(server), stdin, stdout)
325}
326
327pub fn iter_batches_stream<I>(
334 iter: I,
335 n: usize,
336) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
337where
338 I: IntoIterator + Unpin,
339{
340 let mut count = 0;
341 let mut iter = iter.into_iter();
342 futures::stream::poll_fn(move |ctx| {
343 count += 1;
344 if n < count {
345 count = 0;
346 ctx.waker().wake_by_ref();
347 Poll::Pending
348 } else {
349 Poll::Ready(iter.next())
350 }
351 })
352}
353
354#[cfg(test)]
355mod test {
356 use super::*;
357
358 #[test]
359 pub fn test_collect_ready() {
360 let (send, mut recv) = unbounded_channel::<usize>();
361 for x in 0..1000 {
362 send.send(x).unwrap();
363 }
364 assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
365 }
366
367 #[crate::test]
368 pub async fn test_collect_ready_async() {
369 let (send, mut recv) = unbounded_channel::<usize>();
371 for x in 0..1000 {
372 send.send(x).unwrap();
373 }
374 assert_eq!(
375 1000,
376 collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
377 );
378 }
379}