1#![warn(missing_docs)]
2pub mod clear;
5#[cfg(feature = "dfir_macro")]
6#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
7pub mod demux_enum;
8pub mod monotonic_map;
9pub mod multiset;
10pub mod priority_stack;
11pub mod slot_vec;
12pub mod sparse_vec;
13pub mod unsync;
14
15pub mod simulation;
16
17mod monotonic;
18pub use monotonic::*;
19
20mod udp;
21#[cfg(not(target_arch = "wasm32"))]
22pub use udp::*;
23
24mod tcp;
25#[cfg(not(target_arch = "wasm32"))]
26pub use tcp::*;
27
28#[cfg(unix)]
29mod socket;
30#[cfg(unix)]
31pub use socket::*;
32
33#[cfg(feature = "deploy_integration")]
34#[cfg_attr(docsrs, doc(cfg(feature = "deploy_integration")))]
35pub mod deploy;
36
37use std::io::Read;
38use std::net::SocketAddr;
39use std::num::NonZeroUsize;
40use std::process::{Child, ChildStdin, ChildStdout, Stdio};
41use std::task::{Context, Poll};
42
43use futures::Stream;
44use serde::de::DeserializeOwned;
45use serde::ser::Serialize;
46
47pub enum Persistence<T> {
49 Persist(T),
51 Delete(T),
53}
54
55pub enum PersistenceKeyed<K, V> {
57 Persist(K, V),
59 Delete(K),
61}
62
63pub fn unbounded_channel<T>() -> (
65 tokio::sync::mpsc::UnboundedSender<T>,
66 tokio_stream::wrappers::UnboundedReceiverStream<T>,
67) {
68 let (send, recv) = tokio::sync::mpsc::unbounded_channel();
69 let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
70 (send, recv)
71}
72
73pub fn unsync_channel<T>(
75 capacity: Option<NonZeroUsize>,
76) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
77 unsync::mpsc::channel(capacity)
78}
79
80pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
82where
83 S: Stream,
84{
85 let mut stream = Box::pin(stream);
86 std::iter::from_fn(move || {
87 match stream
88 .as_mut()
89 .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
90 {
91 Poll::Ready(opt) => opt,
92 Poll::Pending => None,
93 }
94 })
95}
96
97pub fn collect_ready<C, S>(stream: S) -> C
102where
103 C: FromIterator<S::Item>,
104 S: Stream,
105{
106 assert!(
107 tokio::runtime::Handle::try_current().is_err(),
108 "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
109 );
110 ready_iter(stream).collect()
111}
112
113pub async fn collect_ready_async<C, S>(stream: S) -> C
118where
119 C: Default + Extend<S::Item>,
120 S: Stream,
121{
122 use std::sync::atomic::Ordering;
123
124 tokio::task::yield_now().await;
126
127 let got_any_items = std::sync::atomic::AtomicBool::new(true);
128 let mut unfused_iter =
129 ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
130 let mut out = C::default();
131 while got_any_items.swap(false, Ordering::Relaxed) {
132 out.extend(unfused_iter.by_ref());
133 tokio::task::yield_now().await;
136 }
137 out
138}
139
140pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
142where
143 T: Serialize,
144{
145 bytes::Bytes::from(bincode::serialize(&msg).unwrap())
146}
147
148pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
150where
151 T: DeserializeOwned,
152{
153 bincode::deserialize(msg.as_ref())
154}
155
156pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
158 use std::net::ToSocketAddrs;
159 let mut addrs = addr.to_socket_addrs()?;
160 let result = addrs.find(|addr| addr.is_ipv4());
161 match result {
162 Some(addr) => Ok(addr),
163 None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
164 }
165}
166
167#[cfg(not(target_arch = "wasm32"))]
170pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
171 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
172 udp_bytes(socket)
173}
174
175#[cfg(not(target_arch = "wasm32"))]
178pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
179 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
180 udp_lines(socket)
181}
182
183#[cfg(not(target_arch = "wasm32"))]
190pub async fn bind_tcp_bytes(
191 addr: SocketAddr,
192) -> (
193 unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
194 unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
195 SocketAddr,
196) {
197 bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
198 .await
199 .unwrap()
200}
201
202#[cfg(not(target_arch = "wasm32"))]
204pub async fn bind_tcp_lines(
205 addr: SocketAddr,
206) -> (
207 unsync::mpsc::Sender<(String, SocketAddr)>,
208 unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
209 SocketAddr,
210) {
211 bind_tcp(addr, tokio_util::codec::LinesCodec::new())
212 .await
213 .unwrap()
214}
215
216#[cfg(not(target_arch = "wasm32"))]
221pub fn connect_tcp_bytes() -> (
222 TcpFramedSink<bytes::Bytes>,
223 TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
224) {
225 connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
226}
227
228#[cfg(not(target_arch = "wasm32"))]
230pub fn connect_tcp_lines() -> (
231 TcpFramedSink<String>,
232 TcpFramedStream<tokio_util::codec::LinesCodec>,
233) {
234 connect_tcp(tokio_util::codec::LinesCodec::new())
235}
236
237pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
242where
243 F: for<'a> Fn(&'a T) -> &'a K,
244 K: Ord,
245{
246 slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
247}
248
249pub fn wait_for_process_output(
257 output_so_far: &mut String,
258 output: &mut ChildStdout,
259 wait_for: &str,
260) {
261 let re = regex::Regex::new(wait_for).unwrap();
262
263 while !re.is_match(output_so_far) {
264 println!("waiting: {}", output_so_far);
265 let mut buffer = [0u8; 1024];
266 let bytes_read = output.read(&mut buffer).unwrap();
267
268 if bytes_read == 0 {
269 panic!();
270 }
271
272 output_so_far.push_str(&String::from_utf8_lossy(&buffer[0..bytes_read]));
273
274 println!("XXX {}", output_so_far);
275 }
276}
277
278pub struct DroppableChild(Child);
283
284impl Drop for DroppableChild {
285 fn drop(&mut self) {
286 #[cfg(target_family = "windows")]
287 let _ = self.0.kill(); #[cfg(not(target_family = "windows"))]
289 self.0.kill().unwrap();
290
291 self.0.wait().unwrap();
292 }
293}
294
295pub fn run_cargo_example(test_name: &str, args: &str) -> (DroppableChild, ChildStdin, ChildStdout) {
302 let mut server = if args.is_empty() {
303 std::process::Command::new("cargo")
304 .args(["run", "-p", "dfir_rs", "--example"])
305 .arg(test_name)
306 .stdin(Stdio::piped())
307 .stdout(Stdio::piped())
308 .spawn()
309 .unwrap()
310 } else {
311 std::process::Command::new("cargo")
312 .args(["run", "-p", "dfir_rs", "--example"])
313 .arg(test_name)
314 .arg("--")
315 .args(args.split(' '))
316 .stdin(Stdio::piped())
317 .stdout(Stdio::piped())
318 .spawn()
319 .unwrap()
320 };
321
322 let stdin = server.stdin.take().unwrap();
323 let stdout = server.stdout.take().unwrap();
324
325 (DroppableChild(server), stdin, stdout)
326}
327
328pub fn iter_batches_stream<I>(
335 iter: I,
336 n: usize,
337) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
338where
339 I: IntoIterator + Unpin,
340{
341 let mut count = 0;
342 let mut iter = iter.into_iter();
343 futures::stream::poll_fn(move |ctx| {
344 count += 1;
345 if n < count {
346 count = 0;
347 ctx.waker().wake_by_ref();
348 Poll::Pending
349 } else {
350 Poll::Ready(iter.next())
351 }
352 })
353}
354
355#[cfg(test)]
356mod test {
357 use super::*;
358
359 #[test]
360 pub fn test_collect_ready() {
361 let (send, mut recv) = unbounded_channel::<usize>();
362 for x in 0..1000 {
363 send.send(x).unwrap();
364 }
365 assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
366 }
367
368 #[crate::test]
369 pub async fn test_collect_ready_async() {
370 let (send, mut recv) = unbounded_channel::<usize>();
372 for x in 0..1000 {
373 send.send(x).unwrap();
374 }
375 assert_eq!(
376 1000,
377 collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
378 );
379 }
380}