1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
#![warn(missing_docs)]
//! Helper utilities for the Hydroflow surface syntax.

pub mod clear;
#[cfg(feature = "dfir_macro")]
pub mod demux_enum;
pub mod monotonic_map;
pub mod multiset;
pub mod slot_vec;
pub mod sparse_vec;
pub mod unsync;

pub mod simulation;

mod monotonic;
pub use monotonic::*;

mod udp;
#[cfg(not(target_arch = "wasm32"))]
pub use udp::*;

mod tcp;
#[cfg(not(target_arch = "wasm32"))]
pub use tcp::*;

#[cfg(unix)]
mod socket;
#[cfg(unix)]
pub use socket::*;

#[cfg(feature = "deploy_integration")]
pub mod deploy;

use std::io::Read;
use std::net::SocketAddr;
use std::num::NonZeroUsize;
use std::process::{Child, ChildStdin, ChildStdout, Stdio};
use std::task::{Context, Poll};

use futures::Stream;
use serde::de::DeserializeOwned;
use serde::ser::Serialize;

/// Persit or delete tuples
pub enum Persistence<T> {
    /// Persist T values
    Persist(T),
    /// Delete all values that exactly match
    Delete(T),
}

/// Persit or delete key-value pairs
pub enum PersistenceKeyed<K, V> {
    /// Persist key-value pairs
    Persist(K, V),
    /// Delete all tuples that have the key K
    Delete(K),
}

/// Returns a channel as a (1) unbounded sender and (2) unbounded receiver `Stream` for use in Hydroflow.
pub fn unbounded_channel<T>() -> (
    tokio::sync::mpsc::UnboundedSender<T>,
    tokio_stream::wrappers::UnboundedReceiverStream<T>,
) {
    let (send, recv) = tokio::sync::mpsc::unbounded_channel();
    let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
    (send, recv)
}

/// Returns an unsync channel as a (1) sender and (2) receiver `Stream` for use in Hydroflow.
pub fn unsync_channel<T>(
    capacity: Option<NonZeroUsize>,
) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
    unsync::mpsc::channel(capacity)
}

/// Returns an [`Iterator`] of any immediately available items from the [`Stream`].
pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
where
    S: Stream,
{
    let mut stream = Box::pin(stream);
    std::iter::from_fn(move || {
        match stream
            .as_mut()
            .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
        {
            Poll::Ready(opt) => opt,
            Poll::Pending => None,
        }
    })
}

/// Collects the immediately available items from the `Stream` into a `FromIterator` collection.
///
/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
/// to retain ownership of your stream.
pub fn collect_ready<C, S>(stream: S) -> C
where
    C: FromIterator<S::Item>,
    S: Stream,
{
    assert!(tokio::runtime::Handle::try_current().is_err(), "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead.");
    ready_iter(stream).collect()
}

/// Collects the immediately available items from the `Stream` into a collection (`Default` + `Extend`).
///
/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
/// to retain ownership of your stream.
pub async fn collect_ready_async<C, S>(stream: S) -> C
where
    C: Default + Extend<S::Item>,
    S: Stream,
{
    use std::sync::atomic::Ordering;

    // Yield to let any background async tasks send to the stream.
    tokio::task::yield_now().await;

    let got_any_items = std::sync::atomic::AtomicBool::new(true);
    let mut unfused_iter =
        ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
    let mut out = C::default();
    while got_any_items.swap(false, Ordering::Relaxed) {
        out.extend(unfused_iter.by_ref());
        // Tokio unbounded channel returns items in lenght-128 chunks, so we have to be careful
        // that everything gets returned. That is why we yield here and loop.
        tokio::task::yield_now().await;
    }
    out
}

/// Serialize a message to bytes using bincode.
pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
where
    T: Serialize,
{
    bytes::Bytes::from(bincode::serialize(&msg).unwrap())
}

/// Serialize a message from bytes using bincode.
pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
where
    T: DeserializeOwned,
{
    bincode::deserialize(msg.as_ref())
}

/// Resolve the `ipv4` [`SocketAddr`] from an IP or hostname string.
pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
    use std::net::ToSocketAddrs;
    let mut addrs = addr.to_socket_addrs()?;
    let result = addrs.find(|addr| addr.is_ipv4());
    match result {
        Some(addr) => Ok(addr),
        None => Err(std::io::Error::new(
            std::io::ErrorKind::Other,
            "Unable to resolve IPv4 address",
        )),
    }
}

/// Returns a length-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
#[cfg(not(target_arch = "wasm32"))]
pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
    udp_bytes(socket)
}

/// Returns a newline-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
#[cfg(not(target_arch = "wasm32"))]
pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
    udp_lines(socket)
}

/// Returns a newline-delimited bytes `Sender`, `Receiver`, and `SocketAddr` bound to the given address.
///
/// The input `addr` may have a port of `0`, the returned `SocketAddr` will be the address of the newly bound endpoint.
/// The inbound connections can be used in full duplex mode. When a `(T, SocketAddr)` pair is fed to the `Sender`
/// returned by this function, the `SocketAddr` will be looked up against the currently existing connections.
/// If a match is found then the data will be sent on that connection. If no match is found then the data is silently dropped.
#[cfg(not(target_arch = "wasm32"))]
pub async fn bind_tcp_bytes(
    addr: SocketAddr,
) -> (
    unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
    unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
    SocketAddr,
) {
    bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
        .await
        .unwrap()
}

/// This is the same thing as `bind_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
#[cfg(not(target_arch = "wasm32"))]
pub async fn bind_tcp_lines(
    addr: SocketAddr,
) -> (
    unsync::mpsc::Sender<(String, SocketAddr)>,
    unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
    SocketAddr,
) {
    bind_tcp(addr, tokio_util::codec::LinesCodec::new())
        .await
        .unwrap()
}

/// The inverse of [`bind_tcp_bytes`].
///
/// `(Bytes, SocketAddr)` pairs fed to the returned `Sender` will initiate new tcp connections to the specified `SocketAddr`.
/// These connections will be cached and reused, so that there will only be one connection per destination endpoint. When the endpoint sends data back it will be available via the returned `Receiver`
#[cfg(not(target_arch = "wasm32"))]
pub fn connect_tcp_bytes() -> (
    TcpFramedSink<bytes::Bytes>,
    TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
) {
    connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
}

/// This is the same thing as `connect_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
#[cfg(not(target_arch = "wasm32"))]
pub fn connect_tcp_lines() -> (
    TcpFramedSink<String>,
    TcpFramedStream<tokio_util::codec::LinesCodec>,
) {
    connect_tcp(tokio_util::codec::LinesCodec::new())
}

/// Sort a slice using a key fn which returns references.
///
/// From addendum in
/// <https://stackoverflow.com/questions/56105305/how-to-sort-a-vec-of-structs-by-a-string-field>
pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
where
    F: for<'a> Fn(&'a T) -> &'a K,
    K: Ord,
{
    slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
}

/// Waits for a specific process output before returning.
///
/// When a child process is spawned often you want to wait until the child process is ready before
/// moving on. One way to do that synchronization is by waiting for the child process to output
/// something and match regex against that output. For example, you could wait until the child
/// process outputs "Client live!" which would indicate that it is ready to receive input now on
/// stdin.
pub fn wait_for_process_output(
    output_so_far: &mut String,
    output: &mut ChildStdout,
    wait_for: &str,
) {
    let re = regex::Regex::new(wait_for).unwrap();

    while !re.is_match(output_so_far) {
        println!("waiting: {}", output_so_far);
        let mut buffer = [0u8; 1024];
        let bytes_read = output.read(&mut buffer).unwrap();

        if bytes_read == 0 {
            panic!();
        }

        output_so_far.push_str(&String::from_utf8_lossy(&buffer[0..bytes_read]));

        println!("XXX {}", output_so_far);
    }
}

/// Terminates the inner [`Child`] process when dropped.
///
/// When a `Child` is dropped normally nothing happens but in unit tests you usually want to
/// terminate the child and wait for it to terminate. `DroppableChild` does that for us.
pub struct DroppableChild(Child);

impl Drop for DroppableChild {
    fn drop(&mut self) {
        #[cfg(target_family = "windows")]
        let _ = self.0.kill(); // Windows throws `PermissionDenied` if the process has already exited.
        #[cfg(not(target_family = "windows"))]
        self.0.kill().unwrap();

        self.0.wait().unwrap();
    }
}

/// Run a rust example as a test.
///
/// Rust examples are meant to be run by people and have a natural interface for that. This makes
/// unit testing them cumbersome. This function wraps calling cargo run and piping the stdin/stdout
/// of the example to easy to handle returned objects. The function also returns a `DroppableChild`
/// which will ensure that the child processes will be cleaned up appropriately.
pub fn run_cargo_example(test_name: &str, args: &str) -> (DroppableChild, ChildStdin, ChildStdout) {
    let mut server = if args.is_empty() {
        std::process::Command::new("cargo")
            .args(["run", "-p", "dfir_rs", "--example"])
            .arg(test_name)
            .stdin(Stdio::piped())
            .stdout(Stdio::piped())
            .spawn()
            .unwrap()
    } else {
        std::process::Command::new("cargo")
            .args(["run", "-p", "dfir_rs", "--example"])
            .arg(test_name)
            .arg("--")
            .args(args.split(' '))
            .stdin(Stdio::piped())
            .stdout(Stdio::piped())
            .spawn()
            .unwrap()
    };

    let stdin = server.stdin.take().unwrap();
    let stdout = server.stdout.take().unwrap();

    (DroppableChild(server), stdin, stdout)
}

/// Converts an iterator into a stream that emits `n` items at a time, yielding between each batch.
///
/// This is useful for breaking up a large iterator across several ticks: `source_iter(...)` always
/// releases all items in the first tick. However using `iter_batches_stream` with `source_stream(...)`
/// will cause `n` items to be released each tick. (Although more than that may be emitted if there
/// are loops in the stratum).
pub fn iter_batches_stream<I>(
    iter: I,
    n: usize,
) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
where
    I: IntoIterator + Unpin,
{
    let mut count = 0;
    let mut iter = iter.into_iter();
    futures::stream::poll_fn(move |ctx| {
        count += 1;
        if n < count {
            count = 0;
            ctx.waker().wake_by_ref();
            Poll::Pending
        } else {
            Poll::Ready(iter.next())
        }
    })
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    pub fn test_collect_ready() {
        let (send, mut recv) = unbounded_channel::<usize>();
        for x in 0..1000 {
            send.send(x).unwrap();
        }
        assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
    }

    #[crate::test]
    pub async fn test_collect_ready_async() {
        // Tokio unbounded channel returns items in 128 item long chunks, so we have to be careful that everything gets returned.
        let (send, mut recv) = unbounded_channel::<usize>();
        for x in 0..1000 {
            send.send(x).unwrap();
        }
        assert_eq!(
            1000,
            collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
        );
    }
}