dfir_lang/graph/ops/
anti_join.rs

1use quote::{ToTokens, quote_spanned};
2use syn::parse_quote;
3
4use super::{
5    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
6    OperatorWriteOutput, Persistence, PortIndexValue, RANGE_0, RANGE_1, WriteContextArgs,
7};
8use crate::diagnostic::{Diagnostic, Level};
9
10/// > 2 input streams the first of type (K, T), the second of type K,
11/// > with output type (K, T)
12///
13/// For a given tick, computes the anti-join of the items in the input
14/// streams, returning unique items in the `pos` input that do not have matching keys
15/// in the `neg` input. Note this is set semantics only for the `neg element`. Order
16/// is preserved for new elements in a given tick, but not for elements processed
17/// in a previous tick with `'static`.
18///
19/// ```dfir
20/// source_iter(vec![("dog", 1), ("cat", 2), ("elephant", 3)]) -> [pos]diff;
21/// source_iter(vec!["dog", "cat", "gorilla"]) -> [neg]diff;
22/// diff = anti_join() -> assert_eq([("elephant", 3)]);
23/// ```
24pub const ANTI_JOIN: OperatorConstraints = OperatorConstraints {
25    name: "anti_join",
26    categories: &[OperatorCategory::MultiIn],
27    hard_range_inn: &(2..=2),
28    soft_range_inn: &(2..=2),
29    hard_range_out: RANGE_1,
30    soft_range_out: RANGE_1,
31    num_args: 0,
32    persistence_args: &(0..=2),
33    type_args: RANGE_0,
34    is_external_input: false,
35    // If this is set to true, the state will need to be cleared using `#context.set_state_tick_hook`
36    // to prevent reading uncleared data if this subgraph doesn't run.
37    // https://github.com/hydro-project/hydro/issues/1298
38    has_singleton_output: false,
39    flo_type: None,
40    ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { pos, neg })),
41    ports_out: None,
42    input_delaytype_fn: |idx| match idx {
43        PortIndexValue::Path(path) if "neg" == path.to_token_stream().to_string() => {
44            Some(DelayType::Stratum)
45        }
46        _else => None,
47    },
48    write_fn: |wc @ &WriteContextArgs {
49                   root,
50                   context,
51                   df_ident,
52                   loop_id,
53                   op_span,
54                   ident,
55                   inputs,
56                   work_fn,
57                   op_inst:
58                       OperatorInstance {
59                           generics:
60                               OpInstGenerics {
61                                   persistence_args, ..
62                               },
63                           ..
64                       },
65                   ..
66               },
67               diagnostics| {
68        let persistences = match persistence_args[..] {
69            [] => {
70                let p = if loop_id.is_some() {
71                    Persistence::None
72                } else {
73                    Persistence::Tick
74                };
75                [p, p]
76            }
77            [a] => [a, a],
78            [a, b] => [a, b],
79            _ => unreachable!(),
80        };
81
82        let mut make_antijoindata = |persistence, side| {
83            let antijoindata_ident = wc.make_ident(format!("antijoindata_{}", side));
84            let borrow_ident = wc.make_ident(format!("antijoindata_{}_borrow", side));
85            let (init, pre_write_iter, borrow) = match persistence {
86                Persistence::None => (
87                    Default::default(),
88                    quote_spanned! {op_span=>
89                        let #borrow_ident = &mut #root::rustc_hash::FxHashSet::default();
90                    },
91                    quote_spanned! {op_span=>
92                        &mut *#borrow_ident
93                    },
94                ),
95                Persistence::Tick => (
96                    quote_spanned! {op_span=>
97                        let #antijoindata_ident = #df_ident.add_state(std::cell::RefCell::new(
98                            #root::util::monotonic_map::MonotonicMap::<_, #root::rustc_hash::FxHashSet<_>>::default()
99                        ));
100                    },
101                    quote_spanned! {op_span=>
102                        let mut #borrow_ident = unsafe {
103                            // SAFETY: handle from `#df_ident.add_state(..)`.
104                            #context.state_ref_unchecked(#antijoindata_ident)
105                        }.borrow_mut();
106                    },
107                    quote_spanned! {op_span=>
108                        &mut *#borrow_ident.get_mut_clear(#context.current_tick())
109                    },
110                ),
111                Persistence::Static => (
112                    quote_spanned! {op_span=>
113                        let #antijoindata_ident = #df_ident.add_state(std::cell::RefCell::new(
114                            #root::rustc_hash::FxHashSet::default()
115                        ));
116                    },
117                    quote_spanned! {op_span=>
118                        let mut #borrow_ident = unsafe {
119                            // SAFETY: handle from `#df_ident.add_state(..)`.
120                            #context.state_ref_unchecked(#antijoindata_ident)
121                        }.borrow_mut();
122                    },
123                    quote_spanned! {op_span=>
124                        &mut *#borrow_ident
125                    },
126                ),
127                Persistence::Mutable => {
128                    diagnostics.push(Diagnostic::spanned(
129                        op_span,
130                        Level::Error,
131                        "An implementation of 'mutable does not exist",
132                    ));
133                    return Err(());
134                }
135            };
136            Ok((init, pre_write_iter, borrow))
137        };
138
139        let (pos_init, pos_pre_write_iter, pos_borrow) = make_antijoindata(persistences[0], "pos")?;
140        let (neg_init, neg_pre_write_iter, neg_borrow) = make_antijoindata(persistences[1], "neg")?;
141
142        let write_prologue = quote_spanned! {op_span=>
143            #pos_init
144            #neg_init
145        };
146
147        let input_neg = &inputs[0]; // N before P
148        let input_pos = &inputs[1];
149        let write_iterator = {
150            quote_spanned! {op_span=>
151                #pos_pre_write_iter
152                #neg_pre_write_iter
153                let #ident = {
154                    /// Limit error propagation by bounding locally, erasing output iterator type.
155                    #[inline(always)]
156                    fn check_inputs<'a, K, I1, V, I2>(
157                        input_neg: I1,
158                        input_pos: I2,
159                        neg_state: &'a mut #root::rustc_hash::FxHashSet<K>,
160                        pos_state: &'a mut #root::rustc_hash::FxHashSet<(K, V)>,
161                        is_new_tick: bool,
162                    ) -> impl 'a + Iterator<Item = (K, V)>
163                    where
164                        K: Eq + ::std::hash::Hash + Clone,
165                        V: Eq + ::std::hash::Hash + Clone,
166                        I1: 'a + Iterator<Item = K>,
167                        I2: 'a + Iterator<Item = (K, V)>,
168                    {
169                        #work_fn(|| neg_state.extend(input_neg));
170
171                        #root::compiled::pull::anti_join_into_iter(input_pos, neg_state, pos_state, is_new_tick)
172                    }
173
174                    check_inputs(
175                        #input_neg,
176                        #input_pos,
177                        #neg_borrow,
178                        #pos_borrow,
179                        context.is_first_run_this_tick(),
180                    )
181                };
182            }
183        };
184
185        Ok(OperatorWriteOutput {
186            write_prologue,
187            write_iterator,
188            ..Default::default()
189        })
190    },
191};