dfir_lang/graph/ops/
join_fused_lhs.rs

1use quote::{ToTokens, quote_spanned};
2use syn::parse_quote;
3use syn::spanned::Spanned;
4
5use super::join_fused::{JoinOptions, make_joindata, parse_argument};
6use super::{
7    DelayType, OperatorCategory, OperatorConstraints, OperatorWriteOutput, Persistence,
8    PortIndexValue, RANGE_0, RANGE_1, WriteContextArgs,
9};
10
11/// See `join_fused`
12///
13/// This operator is identical to `join_fused` except that the right hand side input `1` is a regular `join_multiset` input.
14///
15/// This means that `join_fused_lhs` only takes one argument input, which is the reducing/folding operation for the left hand side only.
16///
17/// For example:
18/// ```dfir
19/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)]) -> [0]my_join;
20/// source_iter(vec![("key", 2), ("key", 3)]) -> [1]my_join;
21/// my_join = join_fused_lhs(Reduce(|x, y| *x += y))
22///     -> assert_eq([("key", (3, 2)), ("key", (3, 3))]);
23/// ```
24pub const JOIN_FUSED_LHS: OperatorConstraints = OperatorConstraints {
25    name: "join_fused_lhs",
26    categories: &[OperatorCategory::MultiIn],
27    hard_range_inn: &(2..=2),
28    soft_range_inn: &(2..=2),
29    hard_range_out: RANGE_1,
30    soft_range_out: RANGE_1,
31    num_args: 1,
32    persistence_args: &(0..=2),
33    type_args: RANGE_0,
34    is_external_input: false,
35    has_singleton_output: false,
36    flo_type: None,
37    ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
38    ports_out: None,
39    input_delaytype_fn: |idx| match idx {
40        PortIndexValue::Int(path) if "0" == path.to_token_stream().to_string() => {
41            Some(DelayType::Stratum)
42        }
43        _ => None,
44    },
45    write_fn: |wc @ &WriteContextArgs {
46                   context,
47                   df_ident,
48                   op_span,
49                   ident,
50                   inputs,
51                   is_pull,
52                   arguments,
53                   ..
54               },
55               diagnostics| {
56        assert!(is_pull);
57
58        let arg0_span = arguments[0].span();
59
60        let persistences: [_; 2] = wc.persistence_args_disallow_mutable(diagnostics);
61
62        let lhs_join_options =
63            parse_argument(&arguments[0]).map_err(|err| diagnostics.push(err))?;
64
65        let (lhs_prologue, lhs_prologue_after, lhs_pre_write_iter, lhs_borrow) =
66            make_joindata(wc, persistences[0], &lhs_join_options, "lhs")
67                .map_err(|err| diagnostics.push(err))?;
68
69        let rhs_joindata_ident = wc.make_ident("rhs_joindata");
70        let rhs_borrow_ident = wc.make_ident("rhs_joindata_borrow_ident");
71
72        let rhs_prologue = match persistences[1] {
73            Persistence::None | Persistence::Loop | Persistence::Tick => quote_spanned! {op_span=>},
74            Persistence::Static => quote_spanned! {op_span=>
75                let #rhs_joindata_ident = #df_ident.add_state(::std::cell::RefCell::new(
76                    ::std::vec::Vec::new()
77                ));
78            },
79            Persistence::Mutable => unreachable!(),
80        };
81
82        let lhs = &inputs[0];
83        let rhs = &inputs[1];
84
85        let lhs_fold_or_reduce_into_from = match lhs_join_options {
86            JoinOptions::FoldFrom(lhs_from, lhs_fold) => quote_spanned! {arg0_span=>
87                #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_from);
88            },
89            JoinOptions::Fold(lhs_default, lhs_fold) => quote_spanned! {arg0_span=>
90                #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_default);
91            },
92            JoinOptions::Reduce(lhs_reduce) => quote_spanned! {arg0_span=>
93                #lhs_borrow.reduce_into(#lhs, #lhs_reduce);
94            },
95        };
96
97        let write_iterator = match persistences[1] {
98            Persistence::None | Persistence::Loop | Persistence::Tick => quote_spanned! {op_span=>
99                #lhs_pre_write_iter
100
101                let #ident = {
102                    #lhs_fold_or_reduce_into_from
103
104                    #[allow(clippy::clone_on_copy)]
105                    #rhs.filter_map(|(k, v2)| #lhs_borrow.table.get(&k).map(|v1| (k, (v1.clone(), v2.clone()))))
106                };
107            },
108            Persistence::Static => quote_spanned! {op_span=>
109                #lhs_pre_write_iter
110                let mut #rhs_borrow_ident = unsafe {
111                    // SAFETY: handle from `#df_ident.add_state(..)`.
112                    #context.state_ref_unchecked(#rhs_joindata_ident)
113                }.borrow_mut();
114
115                let #ident = {
116                    #lhs_fold_or_reduce_into_from
117
118                    #[allow(clippy::clone_on_copy)]
119                    #[allow(suspicious_double_ref_op)]
120                    if #context.is_first_run_this_tick() {
121                        #rhs_borrow_ident.extend(#rhs);
122                        #rhs_borrow_ident.iter()
123                    } else {
124                        let len = #rhs_borrow_ident.len();
125                        #rhs_borrow_ident.extend(#rhs);
126                        #rhs_borrow_ident[len..].iter()
127                    }
128                    .filter_map(|(k, v2)| #lhs_borrow.table.get(k).map(|v1| (k.clone(), (v1.clone(), v2.clone()))))
129                };
130            },
131            Persistence::Mutable => unreachable!(),
132        };
133
134        let write_iterator_after =
135            if persistences[0] == Persistence::Static || persistences[1] == Persistence::Static {
136                quote_spanned! {op_span=>
137                    // TODO: Probably only need to schedule if #*_borrow.len() > 0?
138                    #context.schedule_subgraph(#context.current_subgraph(), false);
139                }
140            } else {
141                quote_spanned! {op_span=>}
142            };
143
144        Ok(OperatorWriteOutput {
145            write_prologue: quote_spanned! {op_span=>
146                #lhs_prologue
147                #rhs_prologue
148            },
149            write_prologue_after: lhs_prologue_after,
150            write_iterator,
151            write_iterator_after,
152        })
153    },
154};