dfir_lang/graph/ops/
anti_join.rs
1use quote::{ToTokens, quote_spanned};
2use syn::parse_quote;
3
4use super::{
5 DelayType, OperatorCategory, OperatorConstraints, OperatorWriteOutput, PortIndexValue, RANGE_0,
6 RANGE_1, WriteContextArgs,
7};
8
9pub const ANTI_JOIN: OperatorConstraints = OperatorConstraints {
24 name: "anti_join",
25 categories: &[OperatorCategory::MultiIn],
26 hard_range_inn: &(2..=2),
27 soft_range_inn: &(2..=2),
28 hard_range_out: RANGE_1,
29 soft_range_out: RANGE_1,
30 num_args: 0,
31 persistence_args: &(0..=2),
32 type_args: RANGE_0,
33 is_external_input: false,
34 has_singleton_output: false,
38 flo_type: None,
39 ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { pos, neg })),
40 ports_out: None,
41 input_delaytype_fn: |idx| match idx {
42 PortIndexValue::Path(path) if "neg" == path.to_token_stream().to_string() => {
43 Some(DelayType::Stratum)
44 }
45 _else => None,
46 },
47 write_fn: |wc @ &WriteContextArgs {
48 root,
49 context,
50 df_ident,
51 op_span,
52 ident,
53 inputs,
54 work_fn,
55 ..
56 },
57 diagnostics| {
58 let [pos_persistence, neg_persistence] = wc.persistence_args_disallow_mutable(diagnostics);
59
60 let make_antijoindata = |persistence, side| {
61 let antijoindata_ident = wc.make_ident(format!("antijoindata_{}", side));
62 let borrow_ident = wc.make_ident(format!("antijoindata_{}_borrow", side));
63 let lifespan = wc.persistence_as_state_lifespan(persistence);
64 (
65 quote_spanned! {op_span=>
66 let #antijoindata_ident = #df_ident.add_state(std::cell::RefCell::new(#root::rustc_hash::FxHashSet::default()));
67 },
68 lifespan.map(|lifespan| quote_spanned! {op_span=>
69 #df_ident.set_state_lifespan_hook(#antijoindata_ident, #lifespan, |rcell| { rcell.take(); });
70 }).unwrap_or_default(),
71 quote_spanned! {op_span=>
72 let mut #borrow_ident = unsafe {
73 #context.state_ref_unchecked(#antijoindata_ident)
75 }.borrow_mut();
76 },
77 quote_spanned! {op_span=>
78 &mut *#borrow_ident
79 },
80 )
81 };
82
83 let (pos_prologue, pos_prologue_after, pos_pre_write_iter, pos_borrow) =
84 make_antijoindata(pos_persistence, "pos");
85 let (neg_prologue, neg_prologue_after, neg_pre_write_iter, neg_borrow) =
86 make_antijoindata(neg_persistence, "neg");
87
88 let input_neg = &inputs[0]; let input_pos = &inputs[1];
90 let write_iterator = {
91 quote_spanned! {op_span=>
92 #pos_pre_write_iter
93 #neg_pre_write_iter
94 let #ident = {
95 #[inline(always)]
97 fn check_inputs<'a, K, I1, V, I2>(
98 input_neg: I1,
99 input_pos: I2,
100 neg_state: &'a mut #root::rustc_hash::FxHashSet<K>,
101 pos_state: &'a mut #root::rustc_hash::FxHashSet<(K, V)>,
102 is_new_tick: bool,
103 ) -> impl 'a + Iterator<Item = (K, V)>
104 where
105 K: Eq + ::std::hash::Hash + Clone,
106 V: Eq + ::std::hash::Hash + Clone,
107 I1: 'a + Iterator<Item = K>,
108 I2: 'a + Iterator<Item = (K, V)>,
109 {
110 #work_fn(|| neg_state.extend(input_neg));
111
112 #root::compiled::pull::anti_join_into_iter(input_pos, neg_state, pos_state, is_new_tick)
113 }
114
115 check_inputs(
116 #input_neg,
117 #input_pos,
118 #neg_borrow,
119 #pos_borrow,
120 context.is_first_run_this_tick(),
121 )
122 };
123 }
124 };
125
126 Ok(OperatorWriteOutput {
127 write_prologue: quote_spanned! {op_span=>
128 #pos_prologue
129 #neg_prologue
130 },
131 write_prologue_after: quote_spanned! {op_span=>
132 #pos_prologue_after
133 #neg_prologue_after
134 },
135 write_iterator,
136 ..Default::default()
137 })
138 },
139};