dfir_lang/graph/ops/
anti_join_multiset.rs

1use quote::{ToTokens, quote_spanned};
2use syn::parse_quote;
3
4use super::{
5    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
6    OperatorWriteOutput, Persistence, PortIndexValue, RANGE_0, RANGE_1, WriteContextArgs,
7};
8use crate::diagnostic::{Diagnostic, Level};
9
10// This implementation is largely redundant to ANTI_JOIN and should be DRY'ed
11/// > 2 input streams the first of type (K, T), the second of type K,
12/// > with output type (K, T)
13///
14/// For a given tick, computes the anti-join of the items in the input
15/// streams, returning items in the `pos` input --that do not have matching keys
16/// in the `neg` input. NOTE this uses multiset semantics on the positive side,
17/// so duplicated positive inputs will appear in the output either 0 times (if matched in `neg`)
18/// or as many times as they appear in the input (if not matched in `neg`)
19///
20/// ```dfir
21/// source_iter(vec![("cat", 2), ("cat", 2), ("elephant", 3), ("elephant", 3)]) -> [pos]diff;
22/// source_iter(vec!["dog", "cat", "gorilla"]) -> [neg]diff;
23/// diff = anti_join_multiset() -> assert_eq([("elephant", 3), ("elephant", 3)]);
24/// ```
25pub const ANTI_JOIN_MULTISET: OperatorConstraints = OperatorConstraints {
26    name: "anti_join_multiset",
27    categories: &[OperatorCategory::MultiIn],
28    hard_range_inn: &(2..=2),
29    soft_range_inn: &(2..=2),
30    hard_range_out: RANGE_1,
31    soft_range_out: RANGE_1,
32    num_args: 0,
33    persistence_args: &(0..=2),
34    type_args: RANGE_0,
35    is_external_input: false,
36    // If this is set to true, the state will need to be cleared using `#context.set_state_tick_hook`
37    // to prevent reading uncleared data if this subgraph doesn't run.
38    // https://github.com/hydro-project/hydro/issues/1298
39    has_singleton_output: false,
40    flo_type: None,
41    ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { pos, neg })),
42    ports_out: None,
43    input_delaytype_fn: |idx| match idx {
44        PortIndexValue::Path(path) if "neg" == path.to_token_stream().to_string() => {
45            Some(DelayType::Stratum)
46        }
47        _else => None,
48    },
49    write_fn: |wc @ &WriteContextArgs {
50                   root,
51                   context,
52                   df_ident,
53                   op_span,
54                   ident,
55                   inputs,
56                   work_fn,
57                   op_inst:
58                       OperatorInstance {
59                           generics:
60                               OpInstGenerics {
61                                   persistence_args, ..
62                               },
63                           ..
64                       },
65                   ..
66               },
67               diagnostics| {
68        let persistences = match persistence_args[..] {
69            [] => [Persistence::Tick, Persistence::Tick],
70            [a] => [a, a],
71            [a, b] => [a, b],
72            _ => unreachable!(),
73        };
74
75        let mut make_antijoindata = |persistence, side| {
76            let antijoindata_ident = wc.make_ident(format!("antijoindata_{}", side));
77            let borrow_ident = wc.make_ident(format!("antijoindata_{}_borrow", side));
78            let (init, borrow) = match persistence {
79                Persistence::None | Persistence::Tick => (
80                    quote_spanned! {op_span=>
81                        #root::util::monotonic_map::MonotonicMap::<_, #root::rustc_hash::FxHashSet<_>>::default()
82                    },
83                    quote_spanned! {op_span=>
84                        (&mut *#borrow_ident).get_mut_clear(#context.current_tick())
85                    },
86                ),
87                Persistence::Static => (
88                    quote_spanned! {op_span=>
89                        #root::rustc_hash::FxHashSet::default()
90                    },
91                    quote_spanned! {op_span=>
92                        (&mut *#borrow_ident)
93                    },
94                ),
95                Persistence::Mutable => {
96                    diagnostics.push(Diagnostic::spanned(
97                        op_span,
98                        Level::Error,
99                        "An implementation of 'mutable does not exist",
100                    ));
101                    return Err(());
102                }
103            };
104            Ok((antijoindata_ident, borrow_ident, init, borrow))
105        };
106
107        let (neg_antijoindata_ident, neg_borrow_ident, neg_init, neg_borrow) =
108            make_antijoindata(persistences[1], "neg")?;
109
110        // let vec_ident = wc.make_ident("persistvec");
111        let pos_antijoindata_ident = wc.make_ident("antijoindata_pos_ident");
112        let pos_borrow_ident = wc.make_ident("antijoindata_pos_borrow_ident");
113
114        let write_prologue = match persistences[0] {
115            Persistence::None => Default::default(),
116            Persistence::Tick => quote_spanned! {op_span=>
117                let #neg_antijoindata_ident = #df_ident.add_state(std::cell::RefCell::new(
118                    #neg_init
119                ));
120            },
121            Persistence::Static => quote_spanned! {op_span=>
122                let #pos_antijoindata_ident = #df_ident.add_state(std::cell::RefCell::new(
123                    ::std::vec::Vec::new()
124                ));
125                let #neg_antijoindata_ident = #df_ident.add_state(std::cell::RefCell::new(
126                    #neg_init
127                ));
128            },
129            Persistence::Mutable => {
130                diagnostics.push(Diagnostic::spanned(
131                    op_span,
132                    Level::Error,
133                    "An implementation of 'mutable does not exist",
134                ));
135                return Err(());
136            }
137        };
138
139        let input_neg = &inputs[0]; // N before P
140        let input_pos = &inputs[1];
141        let write_iterator = match persistences[0] {
142            Persistence::None => quote_spanned! {op_span=>
143                let mut #neg_borrow_ident = #neg_init;
144
145                #[allow(clippy::needless_borrow)]
146                #work_fn(|| #neg_borrow.extend(#input_neg));
147
148                let #ident = #input_pos.filter(|x: &(_,_)| {
149                    #[allow(clippy::needless_borrow)]
150                    #[allow(clippy::unnecessary_mut_passed)]
151                    !#neg_borrow.contains(&x.0)
152                });
153            },
154            Persistence::Tick => quote_spanned! {op_span=>
155                let mut #neg_borrow_ident = unsafe {
156                    // SAFETY: handle from `#df_ident.add_state(..)`.
157                    #context.state_ref_unchecked(#neg_antijoindata_ident)
158                }.borrow_mut();
159
160                #[allow(clippy::needless_borrow)]
161                #work_fn(|| #neg_borrow.extend(#input_neg));
162
163                let #ident = #input_pos.filter(|x: &(_,_)| {
164                    #[allow(clippy::needless_borrow)]
165                    #[allow(clippy::unnecessary_mut_passed)]
166                    !#neg_borrow.contains(&x.0)
167                });
168            },
169            Persistence::Static => quote_spanned! {op_span =>
170                let (mut #neg_borrow_ident, mut #pos_borrow_ident) = unsafe {
171                    // SAFETY: handles from `#df_ident`.
172                    (
173                        #context.state_ref_unchecked(#neg_antijoindata_ident).borrow_mut(),
174                        #context.state_ref_unchecked(#pos_antijoindata_ident).borrow_mut(),
175                    )
176                };
177
178                #[allow(clippy::needless_borrow)]
179                let #ident = {
180                    #[allow(clippy::clone_on_copy)]
181                    #[allow(suspicious_double_ref_op)]
182                    if context.is_first_run_this_tick() {
183                        // Start of new tick
184                        #work_fn(|| #neg_borrow.extend(#input_neg));
185
186                        #work_fn(|| #pos_borrow_ident.extend(#input_pos));
187                        #pos_borrow_ident.iter()
188                    } else {
189                        // Called second or later times on the same tick.
190                        let len = #pos_borrow_ident.len();
191                        #work_fn(|| #pos_borrow_ident.extend(#input_pos));
192                        #pos_borrow_ident[len..].iter()
193                    }
194                    .filter(|x: &&(_,_)| {
195                        #[allow(clippy::unnecessary_mut_passed)]
196                        !#neg_borrow.contains(&x.0)
197                    })
198                    .map(|(k, v)| (k.clone(), v.clone()))
199                };
200            },
201            Persistence::Mutable => quote_spanned! {op_span =>
202                diagnostics.push(Diagnostic::spanned(
203                    op_span,
204                    Level::Error,
205                    "An implementation of 'mutable does not exist",
206                ));
207                return Err(());
208            },
209        };
210
211        Ok(OperatorWriteOutput {
212            write_prologue,
213            write_iterator,
214            ..Default::default()
215        })
216    },
217};