1#![warn(missing_docs)]
3
4#[cfg(feature = "dfir_macro")]
5#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
6pub mod demux_enum;
7pub mod multiset;
8pub mod priority_stack;
9pub mod sparse_vec;
10pub mod unsync;
11
12mod monotonic;
13pub use monotonic::*;
14
15mod udp;
16#[cfg(not(target_arch = "wasm32"))]
17pub use udp::*;
18
19mod tcp;
20#[cfg(not(target_arch = "wasm32"))]
21pub use tcp::*;
22
23#[cfg(unix)]
24mod socket;
25use std::net::SocketAddr;
26use std::num::NonZeroUsize;
27use std::task::{Context, Poll};
28
29use futures::Stream;
30use serde::de::DeserializeOwned;
31use serde::ser::Serialize;
32#[cfg(unix)]
33pub use socket::*;
34
35pub enum Persistence<T> {
37 Persist(T),
39 Delete(T),
41}
42
43pub enum PersistenceKeyed<K, V> {
45 Persist(K, V),
47 Delete(K),
49}
50
51pub fn unbounded_channel<T>() -> (
53 tokio::sync::mpsc::UnboundedSender<T>,
54 tokio_stream::wrappers::UnboundedReceiverStream<T>,
55) {
56 let (send, recv) = tokio::sync::mpsc::unbounded_channel();
57 let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
58 (send, recv)
59}
60
61pub fn unsync_channel<T>(
63 capacity: Option<NonZeroUsize>,
64) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
65 unsync::mpsc::channel(capacity)
66}
67
68pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
70where
71 S: Stream,
72{
73 let mut stream = Box::pin(stream);
74 std::iter::from_fn(move || {
75 match stream
76 .as_mut()
77 .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
78 {
79 Poll::Ready(opt) => opt,
80 Poll::Pending => None,
81 }
82 })
83}
84
85pub fn collect_ready<C, S>(stream: S) -> C
90where
91 C: FromIterator<S::Item>,
92 S: Stream,
93{
94 assert!(
95 tokio::runtime::Handle::try_current().is_err(),
96 "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
97 );
98 ready_iter(stream).collect()
99}
100
101pub async fn collect_ready_async<C, S>(stream: S) -> C
106where
107 C: Default + Extend<S::Item>,
108 S: Stream,
109{
110 use std::sync::atomic::Ordering;
111
112 tokio::task::yield_now().await;
114
115 let got_any_items = std::sync::atomic::AtomicBool::new(true);
116 let mut unfused_iter =
117 ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
118 let mut out = C::default();
119 while got_any_items.swap(false, Ordering::Relaxed) {
120 out.extend(unfused_iter.by_ref());
121 tokio::task::yield_now().await;
124 }
125 out
126}
127
128pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
130where
131 T: Serialize,
132{
133 bytes::Bytes::from(bincode::serialize(&msg).unwrap())
134}
135
136pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
138where
139 T: DeserializeOwned,
140{
141 bincode::deserialize(msg.as_ref())
142}
143
144pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
146 use std::net::ToSocketAddrs;
147 let mut addrs = addr.to_socket_addrs()?;
148 let result = addrs.find(|addr| addr.is_ipv4());
149 match result {
150 Some(addr) => Ok(addr),
151 None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
152 }
153}
154
155#[cfg(not(target_arch = "wasm32"))]
158pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
159 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
160 udp_bytes(socket)
161}
162
163#[cfg(not(target_arch = "wasm32"))]
166pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
167 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
168 udp_lines(socket)
169}
170
171#[cfg(not(target_arch = "wasm32"))]
178pub async fn bind_tcp_bytes(
179 addr: SocketAddr,
180) -> (
181 unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
182 unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
183 SocketAddr,
184) {
185 bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
186 .await
187 .unwrap()
188}
189
190#[cfg(not(target_arch = "wasm32"))]
192pub async fn bind_tcp_lines(
193 addr: SocketAddr,
194) -> (
195 unsync::mpsc::Sender<(String, SocketAddr)>,
196 unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
197 SocketAddr,
198) {
199 bind_tcp(addr, tokio_util::codec::LinesCodec::new())
200 .await
201 .unwrap()
202}
203
204#[cfg(not(target_arch = "wasm32"))]
209pub fn connect_tcp_bytes() -> (
210 TcpFramedSink<bytes::Bytes>,
211 TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
212) {
213 connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
214}
215
216#[cfg(not(target_arch = "wasm32"))]
218pub fn connect_tcp_lines() -> (
219 TcpFramedSink<String>,
220 TcpFramedStream<tokio_util::codec::LinesCodec>,
221) {
222 connect_tcp(tokio_util::codec::LinesCodec::new())
223}
224
225pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
230where
231 F: for<'a> Fn(&'a T) -> &'a K,
232 K: Ord,
233{
234 slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
235}
236
237pub fn iter_batches_stream<I>(
244 iter: I,
245 n: usize,
246) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
247where
248 I: IntoIterator + Unpin,
249{
250 let mut count = 0;
251 let mut iter = iter.into_iter();
252 futures::stream::poll_fn(move |ctx| {
253 count += 1;
254 if n < count {
255 count = 0;
256 ctx.waker().wake_by_ref();
257 Poll::Pending
258 } else {
259 Poll::Ready(iter.next())
260 }
261 })
262}
263
264#[cfg(test)]
265mod test {
266 use super::*;
267
268 #[test]
269 pub fn test_collect_ready() {
270 let (send, mut recv) = unbounded_channel::<usize>();
271 for x in 0..1000 {
272 send.send(x).unwrap();
273 }
274 assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
275 }
276
277 #[crate::test]
278 pub async fn test_collect_ready_async() {
279 let (send, mut recv) = unbounded_channel::<usize>();
281 for x in 0..1000 {
282 send.send(x).unwrap();
283 }
284 assert_eq!(
285 1000,
286 collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
287 );
288 }
289}