1#![warn(missing_docs)]
3
4#[cfg(feature = "dfir_macro")]
5#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
6pub mod demux_enum;
7pub mod multiset;
8pub mod priority_stack;
9pub mod slot_vec;
10pub mod sparse_vec;
11pub mod unsync;
12
13pub mod simulation;
14
15mod monotonic;
16pub use monotonic::*;
17
18mod udp;
19#[cfg(not(target_arch = "wasm32"))]
20pub use udp::*;
21
22mod tcp;
23#[cfg(not(target_arch = "wasm32"))]
24pub use tcp::*;
25
26#[cfg(unix)]
27mod socket;
28use std::net::SocketAddr;
29use std::num::NonZeroUsize;
30use std::task::{Context, Poll};
31
32use futures::Stream;
33use serde::de::DeserializeOwned;
34use serde::ser::Serialize;
35#[cfg(unix)]
36pub use socket::*;
37
38pub enum Persistence<T> {
40 Persist(T),
42 Delete(T),
44}
45
46pub enum PersistenceKeyed<K, V> {
48 Persist(K, V),
50 Delete(K),
52}
53
54pub fn unbounded_channel<T>() -> (
56 tokio::sync::mpsc::UnboundedSender<T>,
57 tokio_stream::wrappers::UnboundedReceiverStream<T>,
58) {
59 let (send, recv) = tokio::sync::mpsc::unbounded_channel();
60 let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
61 (send, recv)
62}
63
64pub fn unsync_channel<T>(
66 capacity: Option<NonZeroUsize>,
67) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
68 unsync::mpsc::channel(capacity)
69}
70
71pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
73where
74 S: Stream,
75{
76 let mut stream = Box::pin(stream);
77 std::iter::from_fn(move || {
78 match stream
79 .as_mut()
80 .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
81 {
82 Poll::Ready(opt) => opt,
83 Poll::Pending => None,
84 }
85 })
86}
87
88pub fn collect_ready<C, S>(stream: S) -> C
93where
94 C: FromIterator<S::Item>,
95 S: Stream,
96{
97 assert!(
98 tokio::runtime::Handle::try_current().is_err(),
99 "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
100 );
101 ready_iter(stream).collect()
102}
103
104pub async fn collect_ready_async<C, S>(stream: S) -> C
109where
110 C: Default + Extend<S::Item>,
111 S: Stream,
112{
113 use std::sync::atomic::Ordering;
114
115 tokio::task::yield_now().await;
117
118 let got_any_items = std::sync::atomic::AtomicBool::new(true);
119 let mut unfused_iter =
120 ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
121 let mut out = C::default();
122 while got_any_items.swap(false, Ordering::Relaxed) {
123 out.extend(unfused_iter.by_ref());
124 tokio::task::yield_now().await;
127 }
128 out
129}
130
131pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
133where
134 T: Serialize,
135{
136 bytes::Bytes::from(bincode::serialize(&msg).unwrap())
137}
138
139pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
141where
142 T: DeserializeOwned,
143{
144 bincode::deserialize(msg.as_ref())
145}
146
147pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
149 use std::net::ToSocketAddrs;
150 let mut addrs = addr.to_socket_addrs()?;
151 let result = addrs.find(|addr| addr.is_ipv4());
152 match result {
153 Some(addr) => Ok(addr),
154 None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
155 }
156}
157
158#[cfg(not(target_arch = "wasm32"))]
161pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
162 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
163 udp_bytes(socket)
164}
165
166#[cfg(not(target_arch = "wasm32"))]
169pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
170 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
171 udp_lines(socket)
172}
173
174#[cfg(not(target_arch = "wasm32"))]
181pub async fn bind_tcp_bytes(
182 addr: SocketAddr,
183) -> (
184 unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
185 unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
186 SocketAddr,
187) {
188 bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
189 .await
190 .unwrap()
191}
192
193#[cfg(not(target_arch = "wasm32"))]
195pub async fn bind_tcp_lines(
196 addr: SocketAddr,
197) -> (
198 unsync::mpsc::Sender<(String, SocketAddr)>,
199 unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
200 SocketAddr,
201) {
202 bind_tcp(addr, tokio_util::codec::LinesCodec::new())
203 .await
204 .unwrap()
205}
206
207#[cfg(not(target_arch = "wasm32"))]
212pub fn connect_tcp_bytes() -> (
213 TcpFramedSink<bytes::Bytes>,
214 TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
215) {
216 connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
217}
218
219#[cfg(not(target_arch = "wasm32"))]
221pub fn connect_tcp_lines() -> (
222 TcpFramedSink<String>,
223 TcpFramedStream<tokio_util::codec::LinesCodec>,
224) {
225 connect_tcp(tokio_util::codec::LinesCodec::new())
226}
227
228pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
233where
234 F: for<'a> Fn(&'a T) -> &'a K,
235 K: Ord,
236{
237 slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
238}
239
240pub fn iter_batches_stream<I>(
247 iter: I,
248 n: usize,
249) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
250where
251 I: IntoIterator + Unpin,
252{
253 let mut count = 0;
254 let mut iter = iter.into_iter();
255 futures::stream::poll_fn(move |ctx| {
256 count += 1;
257 if n < count {
258 count = 0;
259 ctx.waker().wake_by_ref();
260 Poll::Pending
261 } else {
262 Poll::Ready(iter.next())
263 }
264 })
265}
266
267#[cfg(test)]
268mod test {
269 use super::*;
270
271 #[test]
272 pub fn test_collect_ready() {
273 let (send, mut recv) = unbounded_channel::<usize>();
274 for x in 0..1000 {
275 send.send(x).unwrap();
276 }
277 assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
278 }
279
280 #[crate::test]
281 pub async fn test_collect_ready_async() {
282 let (send, mut recv) = unbounded_channel::<usize>();
284 for x in 0..1000 {
285 send.send(x).unwrap();
286 }
287 assert_eq!(
288 1000,
289 collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
290 );
291 }
292}