dfir_lang/graph/ops/
anti_join_multiset.rs

1use quote::{ToTokens, quote_spanned};
2use syn::parse_quote;
3
4use super::{
5    DelayType, OperatorCategory, OperatorConstraints, OperatorWriteOutput, PortIndexValue, RANGE_0,
6    RANGE_1, WriteContextArgs,
7};
8
9// This implementation is largely redundant to ANTI_JOIN and should be DRY'ed
10/// > 2 input streams the first of type (K, T), the second of type K,
11/// > with output type (K, T)
12///
13/// For a given tick, computes the anti-join of the items in the input
14/// streams, returning items in the `pos` input that do not have matching keys
15/// in the `neg` input. NOTE this uses multiset semantics only on the positive side,
16/// so duplicated positive inputs will appear in the output either 0 times (if matched in `neg`)
17/// or as many times as they appear in the input (if not matched in `neg`)
18///
19/// ```dfir
20/// source_iter(vec![("cat", 2), ("cat", 2), ("elephant", 3), ("elephant", 3)]) -> [pos]diff;
21/// source_iter(vec!["dog", "cat", "gorilla"]) -> [neg]diff;
22/// diff = anti_join_multiset() -> assert_eq([("elephant", 3), ("elephant", 3)]);
23/// ```
24pub const ANTI_JOIN_MULTISET: OperatorConstraints = OperatorConstraints {
25    name: "anti_join_multiset",
26    categories: &[OperatorCategory::MultiIn],
27    hard_range_inn: &(2..=2),
28    soft_range_inn: &(2..=2),
29    hard_range_out: RANGE_1,
30    soft_range_out: RANGE_1,
31    num_args: 0,
32    persistence_args: &(0..=2),
33    type_args: RANGE_0,
34    is_external_input: false,
35    // If this is set to true, the state will need to be cleared using `#context.set_state_tick_hook`
36    // to prevent reading uncleared data if this subgraph doesn't run.
37    // https://github.com/hydro-project/hydro/issues/1298
38    has_singleton_output: false,
39    flo_type: None,
40    ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { pos, neg })),
41    ports_out: None,
42    input_delaytype_fn: |idx| match idx {
43        PortIndexValue::Path(path) if "neg" == path.to_token_stream().to_string() => {
44            Some(DelayType::Stratum)
45        }
46        _else => None,
47    },
48    write_fn: |wc @ &WriteContextArgs {
49                   root,
50                   context,
51                   df_ident,
52                   op_span,
53                   ident,
54                   inputs,
55                   work_fn,
56                   ..
57               },
58               diagnostics| {
59        let persistences: [_; 2] = wc.persistence_args_disallow_mutable(diagnostics);
60
61        let pos_antijoindata_ident = wc.make_ident("antijoindata_pos");
62        let neg_antijoindata_ident = wc.make_ident("antijoindata_neg");
63
64        let write_prologue_pos = quote_spanned! {op_span=>
65            let #pos_antijoindata_ident = #df_ident.add_state(std::cell::RefCell::new(
66                ::std::vec::Vec::new()
67            ));
68        };
69        let write_prologue_after_pos = wc
70            .persistence_as_state_lifespan(persistences[0])
71            .map(|lifespan| quote_spanned! {op_span=>
72                #[allow(clippy::redundant_closure_call)]
73                #df_ident.set_state_lifespan_hook(
74                    #pos_antijoindata_ident, #lifespan, move |rcell| { rcell.borrow_mut().clear(); },
75                );
76            }).unwrap_or_default();
77
78        let write_prologue_neg = quote_spanned! {op_span=>
79            let #neg_antijoindata_ident = #df_ident.add_state(std::cell::RefCell::new(
80                #root::rustc_hash::FxHashSet::default()
81            ));
82        };
83        let write_prologue_after_neg = wc
84            .persistence_as_state_lifespan(persistences[1])
85            .map(|lifespan| quote_spanned! {op_span=>
86                #[allow(clippy::redundant_closure_call)]
87                #df_ident.set_state_lifespan_hook(
88                    #neg_antijoindata_ident, #lifespan, move |rcell| { rcell.borrow_mut().clear(); },
89                );
90            }).unwrap_or_default();
91
92        let input_neg = &inputs[0]; // N before P
93        let input_pos = &inputs[1];
94        let write_iterator = quote_spanned! {op_span =>
95            let (mut neg_borrow, mut pos_borrow) = unsafe {
96                // SAFETY: handles from `#df_ident`.
97                (
98                    #context.state_ref_unchecked(#neg_antijoindata_ident).borrow_mut(),
99                    #context.state_ref_unchecked(#pos_antijoindata_ident).borrow_mut(),
100                )
101            };
102
103            #[allow(clippy::needless_borrow)]
104            let #ident = {
105                #[allow(clippy::clone_on_copy)]
106                #[allow(suspicious_double_ref_op)]
107                if context.is_first_run_this_tick() {
108                    // Start of new tick
109                    #work_fn(|| neg_borrow.extend(#input_neg));
110                    #work_fn(|| pos_borrow.extend(#input_pos));
111                    pos_borrow.iter()
112                } else {
113                    // Called second or later times on the same tick.
114                    let len = pos_borrow.len();
115                    #work_fn(|| pos_borrow.extend(#input_pos));
116                    pos_borrow[len..].iter()
117                }
118                .filter(|x: &&(_,_)| {
119                    #[allow(clippy::unnecessary_mut_passed)]
120                    !neg_borrow.contains(&x.0)
121                })
122                .map(|(k, v)| (k.clone(), v.clone()))
123            };
124        };
125
126        Ok(OperatorWriteOutput {
127            write_prologue: quote_spanned! {op_span=>
128                #write_prologue_pos
129                #write_prologue_neg
130            },
131            write_prologue_after: quote_spanned! {op_span=>
132                #write_prologue_after_pos
133                #write_prologue_after_neg
134            },
135            write_iterator,
136            ..Default::default()
137        })
138    },
139};