hydro_test_local/local/
teed_join.rs

1use dfir_rs::futures::stream::Stream;
2use dfir_rs::tokio::sync::mpsc::UnboundedSender;
3use dfir_rs::tokio_stream::wrappers::UnboundedReceiverStream;
4use hydro_lang::deploy::MultiGraph;
5use hydro_lang::dfir_rs::scheduled::graph::Dfir;
6use hydro_lang::*;
7use stageleft::{Quoted, RuntimeData};
8
9struct N0 {}
10struct N1 {}
11
12#[stageleft::entry(UnboundedReceiverStream<u32>)]
13pub fn teed_join<'a, S: Stream<Item = u32> + Unpin + 'a>(
14    flow: FlowBuilder<'a>,
15    input_stream: RuntimeData<S>,
16    output: RuntimeData<&'a UnboundedSender<u32>>,
17    send_twice: bool,
18    subgraph_id: RuntimeData<usize>,
19) -> impl Quoted<'a, Dfir<'a>> {
20    let node_zero = flow.process::<N0>();
21    let node_one = flow.process::<N1>();
22    let n0_tick = node_zero.tick();
23
24    let source = unsafe {
25        // SAFETY: intentionally using ticks
26        node_zero.source_stream(input_stream).tick_batch(&n0_tick)
27    };
28    let map1 = source.clone().map(q!(|v| (v + 1, ())));
29    let map2 = source.map(q!(|v| (v - 1, ())));
30
31    let joined = map1.join(map2).map(q!(|t| t.0));
32
33    joined.clone().all_ticks().for_each(q!(|v| {
34        output.send(v).unwrap();
35    }));
36
37    if send_twice {
38        joined.all_ticks().for_each(q!(|v| {
39            output.send(v).unwrap();
40        }));
41    }
42
43    let source_node_id_1 = node_one.source_iter(q!(0..5));
44    source_node_id_1.for_each(q!(|v| {
45        output.send(v).unwrap();
46    }));
47
48    flow.compile_no_network::<MultiGraph>()
49        .with_dynamic_id(subgraph_id)
50}
51
52#[cfg(stageleft_runtime)]
53#[cfg(test)]
54mod tests {
55    use dfir_rs::assert_graphvis_snapshots;
56    use dfir_rs::util::collect_ready;
57
58    #[test]
59    fn test_teed_join() {
60        let (in_send, input) = dfir_rs::util::unbounded_channel();
61        let (out, mut out_recv) = dfir_rs::util::unbounded_channel();
62
63        let mut joined = super::teed_join!(input, &out, false, 0);
64        assert_graphvis_snapshots!(joined);
65
66        in_send.send(1).unwrap();
67        in_send.send(2).unwrap();
68        in_send.send(3).unwrap();
69        in_send.send(4).unwrap();
70
71        joined.run_tick();
72
73        assert_eq!(&*collect_ready::<Vec<_>, _>(&mut out_recv), &[2, 3]);
74    }
75
76    #[test]
77    fn test_teed_join_twice() {
78        let (in_send, input) = dfir_rs::util::unbounded_channel();
79        let (out, mut out_recv) = dfir_rs::util::unbounded_channel();
80
81        let mut joined = super::teed_join!(input, &out, true, 0);
82        assert_graphvis_snapshots!(joined);
83
84        in_send.send(1).unwrap();
85        in_send.send(2).unwrap();
86        in_send.send(3).unwrap();
87        in_send.send(4).unwrap();
88
89        joined.run_tick();
90
91        assert_eq!(&*collect_ready::<Vec<_>, _>(&mut out_recv), &[2, 2, 3, 3]);
92    }
93
94    #[test]
95    fn test_teed_join_multi_node() {
96        let (_, input) = dfir_rs::util::unbounded_channel();
97        let (out, mut out_recv) = dfir_rs::util::unbounded_channel();
98
99        let mut joined = super::teed_join!(input, &out, true, 1);
100        assert_graphvis_snapshots!(joined);
101
102        joined.run_tick();
103
104        assert_eq!(
105            &*collect_ready::<Vec<_>, _>(&mut out_recv),
106            &[0, 1, 2, 3, 4]
107        );
108    }
109}