1#![warn(missing_docs)]
3
4#[cfg(feature = "dfir_macro")]
5#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
6pub mod demux_enum;
7pub mod multiset;
8pub mod sparse_vec;
9pub mod unsync;
10
11mod monotonic;
12pub use monotonic::*;
13
14mod udp;
15#[cfg(not(target_arch = "wasm32"))]
16pub use udp::*;
17
18mod tcp;
19#[cfg(not(target_arch = "wasm32"))]
20pub use tcp::*;
21
22#[cfg(unix)]
23mod socket;
24use std::net::SocketAddr;
25use std::num::NonZeroUsize;
26use std::task::{Context, Poll};
27
28use futures::Stream;
29use serde::de::DeserializeOwned;
30use serde::ser::Serialize;
31#[cfg(unix)]
32pub use socket::*;
33
34pub enum Persistence<T> {
36 Persist(T),
38 Delete(T),
40}
41
42pub enum PersistenceKeyed<K, V> {
44 Persist(K, V),
46 Delete(K),
48}
49
50pub fn unbounded_channel<T>() -> (
52 tokio::sync::mpsc::UnboundedSender<T>,
53 tokio_stream::wrappers::UnboundedReceiverStream<T>,
54) {
55 let (send, recv) = tokio::sync::mpsc::unbounded_channel();
56 let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
57 (send, recv)
58}
59
60pub fn unsync_channel<T>(
62 capacity: Option<NonZeroUsize>,
63) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
64 unsync::mpsc::channel(capacity)
65}
66
67pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
69where
70 S: Stream,
71{
72 let mut stream = Box::pin(stream);
73 std::iter::from_fn(move || {
74 match stream
75 .as_mut()
76 .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
77 {
78 Poll::Ready(opt) => opt,
79 Poll::Pending => None,
80 }
81 })
82}
83
84pub fn collect_ready<C, S>(stream: S) -> C
89where
90 C: FromIterator<S::Item>,
91 S: Stream,
92{
93 assert!(
94 tokio::runtime::Handle::try_current().is_err(),
95 "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
96 );
97 ready_iter(stream).collect()
98}
99
100pub async fn collect_ready_async<C, S>(stream: S) -> C
105where
106 C: Default + Extend<S::Item>,
107 S: Stream,
108{
109 use std::sync::atomic::Ordering;
110
111 tokio::task::yield_now().await;
113
114 let got_any_items = std::sync::atomic::AtomicBool::new(true);
115 let mut unfused_iter =
116 ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
117 let mut out = C::default();
118 while got_any_items.swap(false, Ordering::Relaxed) {
119 out.extend(unfused_iter.by_ref());
120 tokio::task::yield_now().await;
123 }
124 out
125}
126
127pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
129where
130 T: Serialize,
131{
132 bytes::Bytes::from(bincode::serialize(&msg).unwrap())
133}
134
135pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
137where
138 T: DeserializeOwned,
139{
140 bincode::deserialize(msg.as_ref())
141}
142
143pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
145 use std::net::ToSocketAddrs;
146 let mut addrs = addr.to_socket_addrs()?;
147 let result = addrs.find(|addr| addr.is_ipv4());
148 match result {
149 Some(addr) => Ok(addr),
150 None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
151 }
152}
153
154#[cfg(not(target_arch = "wasm32"))]
157pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
158 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
159 udp_bytes(socket)
160}
161
162#[cfg(not(target_arch = "wasm32"))]
165pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
166 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
167 udp_lines(socket)
168}
169
170#[cfg(not(target_arch = "wasm32"))]
177pub async fn bind_tcp_bytes(
178 addr: SocketAddr,
179) -> (
180 unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
181 unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
182 SocketAddr,
183) {
184 bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
185 .await
186 .unwrap()
187}
188
189#[cfg(not(target_arch = "wasm32"))]
191pub async fn bind_tcp_lines(
192 addr: SocketAddr,
193) -> (
194 unsync::mpsc::Sender<(String, SocketAddr)>,
195 unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
196 SocketAddr,
197) {
198 bind_tcp(addr, tokio_util::codec::LinesCodec::new())
199 .await
200 .unwrap()
201}
202
203#[cfg(not(target_arch = "wasm32"))]
208pub fn connect_tcp_bytes() -> (
209 TcpFramedSink<bytes::Bytes>,
210 TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
211) {
212 connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
213}
214
215#[cfg(not(target_arch = "wasm32"))]
217pub fn connect_tcp_lines() -> (
218 TcpFramedSink<String>,
219 TcpFramedStream<tokio_util::codec::LinesCodec>,
220) {
221 connect_tcp(tokio_util::codec::LinesCodec::new())
222}
223
224pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
229where
230 F: for<'a> Fn(&'a T) -> &'a K,
231 K: Ord,
232{
233 slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
234}
235
236pub fn iter_batches_stream<I>(
243 iter: I,
244 n: usize,
245) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
246where
247 I: IntoIterator + Unpin,
248{
249 let mut count = 0;
250 let mut iter = iter.into_iter();
251 futures::stream::poll_fn(move |ctx| {
252 count += 1;
253 if n < count {
254 count = 0;
255 ctx.waker().wake_by_ref();
256 Poll::Pending
257 } else {
258 Poll::Ready(iter.next())
259 }
260 })
261}
262
263#[cfg(test)]
264mod test {
265 use super::*;
266
267 #[test]
268 pub fn test_collect_ready() {
269 let (send, mut recv) = unbounded_channel::<usize>();
270 for x in 0..1000 {
271 send.send(x).unwrap();
272 }
273 assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
274 }
275
276 #[crate::test]
277 pub async fn test_collect_ready_async() {
278 let (send, mut recv) = unbounded_channel::<usize>();
280 for x in 0..1000 {
281 send.send(x).unwrap();
282 }
283 assert_eq!(
284 1000,
285 collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
286 );
287 }
288}