dfir_lang/graph/ops/
multiset_delta.rs

1use quote::quote_spanned;
2
3use super::{
4    OperatorCategory, OperatorConstraints, OperatorWriteOutput, RANGE_0, RANGE_1, WriteContextArgs,
5};
6
7// TODO(mingwei): more doc
8/// Multiset delta from the previous tick.
9///
10/// ```rustbook
11/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<u32>();
12/// let mut flow = dfir_rs::dfir_syntax! {
13///     source_stream(input_recv)
14///         -> multiset_delta()
15///         -> for_each(|n| println!("{}", n));
16/// };
17///
18/// input_send.send(3).unwrap();
19/// input_send.send(4).unwrap();
20/// input_send.send(3).unwrap();
21/// flow.run_tick();
22/// // 3, 4,
23///
24/// input_send.send(3).unwrap();
25/// input_send.send(5).unwrap();
26/// input_send.send(3).unwrap();
27/// input_send.send(3).unwrap();
28/// flow.run_tick();
29/// // 5, 3
30/// // First two "3"s are removed due to previous tick.
31/// ```
32pub const MULTISET_DELTA: OperatorConstraints = OperatorConstraints {
33    name: "multiset_delta",
34    categories: &[OperatorCategory::Persistence],
35    hard_range_inn: RANGE_1,
36    soft_range_inn: RANGE_1,
37    hard_range_out: RANGE_1,
38    soft_range_out: RANGE_1,
39    num_args: 0,
40    persistence_args: RANGE_0,
41    type_args: RANGE_0,
42    is_external_input: false,
43    // If this is set to true, the state will need to be cleared using `#context.set_state_tick_hook`
44    // to prevent reading uncleared data if this subgraph doesn't run.
45    // https://github.com/hydro-project/hydro/issues/1298
46    // If `'tick` lifetimes are added.
47    has_singleton_output: false,
48    flo_type: None,
49    ports_inn: None,
50    ports_out: None,
51    input_delaytype_fn: |_| None,
52    write_fn: |wc @ &WriteContextArgs {
53                   root,
54                   op_span,
55                   context,
56                   df_ident,
57                   ident,
58                   inputs,
59                   outputs,
60                   is_pull,
61                   work_fn,
62                   ..
63               },
64               _| {
65        let input = &inputs[0];
66        let output = &outputs[0];
67
68        let prev_data = wc.make_ident("prev_data");
69        let curr_data = wc.make_ident("curr_data");
70
71        let write_prologue = quote_spanned! {op_span=>
72            let #prev_data = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::default()));
73            let #curr_data = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::default()));
74        };
75
76        let tick_swap = quote_spanned! {op_span=>
77            {
78                if context.is_first_run_this_tick() {
79                    let (mut prev_map, mut curr_map) = unsafe {
80                        // SAFETY: handle from `#df_ident.add_state(..)`.
81                        (
82                            #context.state_ref_unchecked(#prev_data).borrow_mut(),
83                            #context.state_ref_unchecked(#curr_data).borrow_mut(),
84                        )
85                    };
86                    ::std::mem::swap(::std::ops::DerefMut::deref_mut(&mut prev_map), ::std::ops::DerefMut::deref_mut(&mut curr_map));
87                    curr_map.clear();
88                }
89            }
90        };
91
92        let filter_fn = quote_spanned! {op_span=>
93            |item| {
94                let (mut prev_map, mut curr_map) = unsafe {
95                    // SAFETY: handle from `#df_ident.add_state(..)`.
96                    (
97                        #context.state_ref_unchecked(#prev_data).borrow_mut(),
98                        #context.state_ref_unchecked(#curr_data).borrow_mut(),
99                    )
100                };
101
102                *curr_map.entry(#[allow(clippy::clone_on_copy)] item.clone()).or_insert(0_usize) += 1;
103                if let Some(old_count) = prev_map.get_mut(item) {
104                    #[allow(clippy::absurd_extreme_comparisons)] // Usize cannot be less than zero.
105                    if *old_count <= 0 {
106                        true
107                    } else {
108                        *old_count -= 1;
109                        false
110                    }
111                } else {
112                    true
113                }
114            }
115        };
116        let write_iterator = if is_pull {
117            quote_spanned! {op_span=>
118                #work_fn(|| #tick_swap);
119                let #ident = #input.filter(#filter_fn);
120            }
121        } else {
122            quote_spanned! {op_span=>
123                #work_fn(|| #tick_swap);
124                let #ident = #root::pusherator::filter::Filter::new(#filter_fn, #output);
125            }
126        };
127
128        Ok(OperatorWriteOutput {
129            write_prologue,
130            write_iterator,
131            ..Default::default()
132        })
133    },
134};