dfir_lang/graph/ops/
anti_join.rs

1use quote::{ToTokens, quote_spanned};
2use syn::parse_quote;
3
4use super::{
5    DelayType, OperatorCategory, OperatorConstraints, OperatorWriteOutput, PortIndexValue, RANGE_0,
6    RANGE_1, WriteContextArgs,
7};
8
9/// > 2 input streams the first of type (K, T), the second of type K,
10/// > with output type (K, T)
11///
12/// For a given tick, computes the anti-join of the items in the input
13/// streams, returning unique items in the `pos` input that do not have matching keys
14/// in the `neg` input. Note this is set semantics only for the `neg element`. Order
15/// is preserved for new elements in a given tick, but not for elements processed
16/// in a previous tick with `'static`.
17///
18/// ```dfir
19/// source_iter(vec![("dog", 1), ("cat", 2), ("elephant", 3)]) -> [pos]diff;
20/// source_iter(vec!["dog", "cat", "gorilla"]) -> [neg]diff;
21/// diff = anti_join() -> assert_eq([("elephant", 3)]);
22/// ```
23pub const ANTI_JOIN: OperatorConstraints = OperatorConstraints {
24    name: "anti_join",
25    categories: &[OperatorCategory::MultiIn],
26    hard_range_inn: &(2..=2),
27    soft_range_inn: &(2..=2),
28    hard_range_out: RANGE_1,
29    soft_range_out: RANGE_1,
30    num_args: 0,
31    persistence_args: &(0..=2),
32    type_args: RANGE_0,
33    is_external_input: false,
34    // If this is set to true, the state will need to be cleared using `#context.set_state_lifespan_hook`
35    // to prevent reading uncleared data if this subgraph doesn't run.
36    // https://github.com/hydro-project/hydro/issues/1298
37    has_singleton_output: false,
38    flo_type: None,
39    ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { pos, neg })),
40    ports_out: None,
41    input_delaytype_fn: |idx| match idx {
42        PortIndexValue::Path(path) if "neg" == path.to_token_stream().to_string() => {
43            Some(DelayType::Stratum)
44        }
45        _else => None,
46    },
47    write_fn: |wc @ &WriteContextArgs {
48                   root,
49                   context,
50                   df_ident,
51                   op_span,
52                   ident,
53                   inputs,
54                   work_fn,
55                   ..
56               },
57               diagnostics| {
58        let [pos_persistence, neg_persistence] = wc.persistence_args_disallow_mutable(diagnostics);
59
60        let make_antijoindata = |persistence, side| {
61            let antijoindata_ident = wc.make_ident(format!("antijoindata_{}", side));
62            let borrow_ident = wc.make_ident(format!("antijoindata_{}_borrow", side));
63            let lifespan = wc.persistence_as_state_lifespan(persistence);
64            (
65                quote_spanned! {op_span=>
66                    let #antijoindata_ident = #df_ident.add_state(std::cell::RefCell::new(#root::rustc_hash::FxHashSet::default()));
67                },
68                lifespan.map(|lifespan| quote_spanned! {op_span=>
69                    #df_ident.set_state_lifespan_hook(#antijoindata_ident, #lifespan, |rcell| { rcell.take(); });
70                }).unwrap_or_default(),
71                quote_spanned! {op_span=>
72                    let mut #borrow_ident = unsafe {
73                        // SAFETY: handle from `#df_ident.add_state(..)`.
74                        #context.state_ref_unchecked(#antijoindata_ident)
75                    }.borrow_mut();
76                },
77                quote_spanned! {op_span=>
78                    &mut *#borrow_ident
79                },
80            )
81        };
82
83        let (pos_prologue, pos_prologue_after, pos_pre_write_iter, pos_borrow) =
84            make_antijoindata(pos_persistence, "pos");
85        let (neg_prologue, neg_prologue_after, neg_pre_write_iter, neg_borrow) =
86            make_antijoindata(neg_persistence, "neg");
87
88        let input_neg = &inputs[0]; // N before P
89        let input_pos = &inputs[1];
90        let write_iterator = {
91            quote_spanned! {op_span=>
92                #pos_pre_write_iter
93                #neg_pre_write_iter
94                let #ident = {
95                    /// Limit error propagation by bounding locally, erasing output iterator type.
96                    #[inline(always)]
97                    fn check_inputs<'a, K, I1, V, I2>(
98                        input_neg: I1,
99                        input_pos: I2,
100                        neg_state: &'a mut #root::rustc_hash::FxHashSet<K>,
101                        pos_state: &'a mut #root::rustc_hash::FxHashSet<(K, V)>,
102                        is_new_tick: bool,
103                    ) -> impl 'a + Iterator<Item = (K, V)>
104                    where
105                        K: Eq + ::std::hash::Hash + Clone,
106                        V: Eq + ::std::hash::Hash + Clone,
107                        I1: 'a + Iterator<Item = K>,
108                        I2: 'a + Iterator<Item = (K, V)>,
109                    {
110                        #work_fn(|| neg_state.extend(input_neg));
111
112                        #root::compiled::pull::anti_join_into_iter(input_pos, neg_state, pos_state, is_new_tick)
113                    }
114
115                    check_inputs(
116                        #input_neg,
117                        #input_pos,
118                        #neg_borrow,
119                        #pos_borrow,
120                        context.is_first_run_this_tick(),
121                    )
122                };
123            }
124        };
125
126        Ok(OperatorWriteOutput {
127            write_prologue: quote_spanned! {op_span=>
128                #pos_prologue
129                #neg_prologue
130            },
131            write_prologue_after: quote_spanned! {op_span=>
132                #pos_prologue_after
133                #neg_prologue_after
134            },
135            write_iterator,
136            ..Default::default()
137        })
138    },
139};