1#![warn(missing_docs)]
3
4pub mod accumulator;
5pub mod clear;
6#[cfg(feature = "dfir_macro")]
7#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
8pub mod demux_enum;
9pub mod multiset;
10pub mod priority_stack;
11pub mod slot_vec;
12pub mod sparse_vec;
13pub mod unsync;
14
15pub mod simulation;
16
17mod monotonic;
18pub use monotonic::*;
19
20mod udp;
21#[cfg(not(target_arch = "wasm32"))]
22pub use udp::*;
23
24mod tcp;
25#[cfg(not(target_arch = "wasm32"))]
26pub use tcp::*;
27
28#[cfg(unix)]
29mod socket;
30#[cfg(unix)]
31pub use socket::*;
32
33#[cfg(feature = "deploy_integration")]
34#[cfg_attr(docsrs, doc(cfg(feature = "deploy_integration")))]
35pub mod deploy;
36
37use std::net::SocketAddr;
38use std::num::NonZeroUsize;
39use std::task::{Context, Poll};
40
41use futures::Stream;
42use serde::de::DeserializeOwned;
43use serde::ser::Serialize;
44
45pub enum Persistence<T> {
47 Persist(T),
49 Delete(T),
51}
52
53pub enum PersistenceKeyed<K, V> {
55 Persist(K, V),
57 Delete(K),
59}
60
61pub fn unbounded_channel<T>() -> (
63 tokio::sync::mpsc::UnboundedSender<T>,
64 tokio_stream::wrappers::UnboundedReceiverStream<T>,
65) {
66 let (send, recv) = tokio::sync::mpsc::unbounded_channel();
67 let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
68 (send, recv)
69}
70
71pub fn unsync_channel<T>(
73 capacity: Option<NonZeroUsize>,
74) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
75 unsync::mpsc::channel(capacity)
76}
77
78pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
80where
81 S: Stream,
82{
83 let mut stream = Box::pin(stream);
84 std::iter::from_fn(move || {
85 match stream
86 .as_mut()
87 .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
88 {
89 Poll::Ready(opt) => opt,
90 Poll::Pending => None,
91 }
92 })
93}
94
95pub fn collect_ready<C, S>(stream: S) -> C
100where
101 C: FromIterator<S::Item>,
102 S: Stream,
103{
104 assert!(
105 tokio::runtime::Handle::try_current().is_err(),
106 "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
107 );
108 ready_iter(stream).collect()
109}
110
111pub async fn collect_ready_async<C, S>(stream: S) -> C
116where
117 C: Default + Extend<S::Item>,
118 S: Stream,
119{
120 use std::sync::atomic::Ordering;
121
122 tokio::task::yield_now().await;
124
125 let got_any_items = std::sync::atomic::AtomicBool::new(true);
126 let mut unfused_iter =
127 ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
128 let mut out = C::default();
129 while got_any_items.swap(false, Ordering::Relaxed) {
130 out.extend(unfused_iter.by_ref());
131 tokio::task::yield_now().await;
134 }
135 out
136}
137
138pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
140where
141 T: Serialize,
142{
143 bytes::Bytes::from(bincode::serialize(&msg).unwrap())
144}
145
146pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
148where
149 T: DeserializeOwned,
150{
151 bincode::deserialize(msg.as_ref())
152}
153
154pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
156 use std::net::ToSocketAddrs;
157 let mut addrs = addr.to_socket_addrs()?;
158 let result = addrs.find(|addr| addr.is_ipv4());
159 match result {
160 Some(addr) => Ok(addr),
161 None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
162 }
163}
164
165#[cfg(not(target_arch = "wasm32"))]
168pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
169 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
170 udp_bytes(socket)
171}
172
173#[cfg(not(target_arch = "wasm32"))]
176pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
177 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
178 udp_lines(socket)
179}
180
181#[cfg(not(target_arch = "wasm32"))]
188pub async fn bind_tcp_bytes(
189 addr: SocketAddr,
190) -> (
191 unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
192 unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
193 SocketAddr,
194) {
195 bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
196 .await
197 .unwrap()
198}
199
200#[cfg(not(target_arch = "wasm32"))]
202pub async fn bind_tcp_lines(
203 addr: SocketAddr,
204) -> (
205 unsync::mpsc::Sender<(String, SocketAddr)>,
206 unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
207 SocketAddr,
208) {
209 bind_tcp(addr, tokio_util::codec::LinesCodec::new())
210 .await
211 .unwrap()
212}
213
214#[cfg(not(target_arch = "wasm32"))]
219pub fn connect_tcp_bytes() -> (
220 TcpFramedSink<bytes::Bytes>,
221 TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
222) {
223 connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
224}
225
226#[cfg(not(target_arch = "wasm32"))]
228pub fn connect_tcp_lines() -> (
229 TcpFramedSink<String>,
230 TcpFramedStream<tokio_util::codec::LinesCodec>,
231) {
232 connect_tcp(tokio_util::codec::LinesCodec::new())
233}
234
235pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
240where
241 F: for<'a> Fn(&'a T) -> &'a K,
242 K: Ord,
243{
244 slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
245}
246
247pub fn iter_batches_stream<I>(
254 iter: I,
255 n: usize,
256) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
257where
258 I: IntoIterator + Unpin,
259{
260 let mut count = 0;
261 let mut iter = iter.into_iter();
262 futures::stream::poll_fn(move |ctx| {
263 count += 1;
264 if n < count {
265 count = 0;
266 ctx.waker().wake_by_ref();
267 Poll::Pending
268 } else {
269 Poll::Ready(iter.next())
270 }
271 })
272}
273
274#[cfg(test)]
275mod test {
276 use super::*;
277
278 #[test]
279 pub fn test_collect_ready() {
280 let (send, mut recv) = unbounded_channel::<usize>();
281 for x in 0..1000 {
282 send.send(x).unwrap();
283 }
284 assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
285 }
286
287 #[crate::test]
288 pub async fn test_collect_ready_async() {
289 let (send, mut recv) = unbounded_channel::<usize>();
291 for x in 0..1000 {
292 send.send(x).unwrap();
293 }
294 assert_eq!(
295 1000,
296 collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
297 );
298 }
299}