dfir_lang/graph/ops/
join_fused_lhs.rs

1use quote::{quote_spanned, ToTokens};
2use syn::parse_quote;
3use syn::spanned::Spanned;
4
5use super::join_fused::{make_joindata, parse_argument, parse_persistences, JoinOptions};
6use super::{
7    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
8    OperatorWriteOutput, Persistence, PortIndexValue, WriteContextArgs, RANGE_0, RANGE_1,
9};
10use crate::diagnostic::{Diagnostic, Level};
11
12/// See `join_fused`
13///
14/// This operator is identical to `join_fused` except that the right hand side input `1` is a regular `join_multiset` input.
15///
16/// This means that `join_fused_lhs` only takes one argument input, which is the reducing/folding operation for the left hand side only.
17///
18/// For example:
19/// ```dfir
20/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)]) -> [0]my_join;
21/// source_iter(vec![("key", 2), ("key", 3)]) -> [1]my_join;
22/// my_join = join_fused_lhs(Reduce(|x, y| *x += y))
23///     -> assert_eq([("key", (3, 2)), ("key", (3, 3))]);
24/// ```
25pub const JOIN_FUSED_LHS: OperatorConstraints = OperatorConstraints {
26    name: "join_fused_lhs",
27    categories: &[OperatorCategory::MultiIn],
28    hard_range_inn: &(2..=2),
29    soft_range_inn: &(2..=2),
30    hard_range_out: RANGE_1,
31    soft_range_out: RANGE_1,
32    num_args: 1,
33    persistence_args: &(0..=2),
34    type_args: RANGE_0,
35    is_external_input: false,
36    has_singleton_output: false,
37    flo_type: None,
38    ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
39    ports_out: None,
40    input_delaytype_fn: |idx| match idx {
41        PortIndexValue::Int(path) if "0" == path.to_token_stream().to_string() => {
42            Some(DelayType::Stratum)
43        }
44        _ => None,
45    },
46    write_fn: |wc @ &WriteContextArgs {
47                   context,
48                   df_ident,
49                   op_span,
50                   ident,
51                   inputs,
52                   is_pull,
53                   op_inst:
54                       OperatorInstance {
55                           generics:
56                               OpInstGenerics {
57                                   persistence_args, ..
58                               },
59                           ..
60                       },
61                   arguments,
62                   ..
63               },
64               diagnostics| {
65        assert!(is_pull);
66
67        let arg0_span = arguments[0].span();
68
69        let persistences = parse_persistences(persistence_args);
70
71        let lhs_join_options =
72            parse_argument(&arguments[0]).map_err(|err| diagnostics.push(err))?;
73
74        let (lhs_prologue, lhs_pre_write_iter, lhs_borrow) =
75            make_joindata(wc, persistences[0], &lhs_join_options, "lhs")
76                .map_err(|err| diagnostics.push(err))?;
77
78        let rhs_joindata_ident = wc.make_ident("rhs_joindata");
79        let rhs_borrow_ident = wc.make_ident("rhs_joindata_borrow_ident");
80
81        let write_prologue_rhs = match persistences[1] {
82            Persistence::None | Persistence::Tick => quote_spanned! {op_span=>},
83            Persistence::Static => quote_spanned! {op_span=>
84                let #rhs_joindata_ident = #df_ident.add_state(::std::cell::RefCell::new(
85                    ::std::vec::Vec::new()
86                ));
87            },
88            Persistence::Mutable => {
89                diagnostics.push(Diagnostic::spanned(
90                    op_span,
91                    Level::Error,
92                    "An implementation of 'mutable does not exist",
93                ));
94                return Err(());
95            }
96        };
97
98        let write_prologue = quote_spanned! {op_span=>
99            #lhs_prologue
100            #write_prologue_rhs
101        };
102
103        let lhs = &inputs[0];
104        let rhs = &inputs[1];
105
106        let lhs_fold_or_reduce_into_from = match lhs_join_options {
107            JoinOptions::FoldFrom(lhs_from, lhs_fold) => quote_spanned! {arg0_span=>
108                #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_from);
109            },
110            JoinOptions::Fold(lhs_default, lhs_fold) => quote_spanned! {arg0_span=>
111                #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_default);
112            },
113            JoinOptions::Reduce(lhs_reduce) => quote_spanned! {arg0_span=>
114                #lhs_borrow.reduce_into(#lhs, #lhs_reduce);
115            },
116        };
117
118        let write_iterator = match persistences[1] {
119            Persistence::None | Persistence::Tick => quote_spanned! {op_span=>
120                #lhs_pre_write_iter
121
122                let #ident = {
123                    #lhs_fold_or_reduce_into_from
124
125                    #[allow(clippy::clone_on_copy)]
126                    #rhs.filter_map(|(k, v2)| #lhs_borrow.table.get(&k).map(|v1| (k, (v1.clone(), v2.clone()))))
127                };
128            },
129            Persistence::Static => quote_spanned! {op_span=>
130                #lhs_pre_write_iter
131                let mut #rhs_borrow_ident = unsafe {
132                    // SAFETY: handle from `#df_ident.add_state(..)`.
133                    #context.state_ref_unchecked(#rhs_joindata_ident)
134                }.borrow_mut();
135
136                let #ident = {
137                    #lhs_fold_or_reduce_into_from
138
139                    #[allow(clippy::clone_on_copy)]
140                    #[allow(suspicious_double_ref_op)]
141                    if #context.is_first_run_this_tick() {
142                        #rhs_borrow_ident.extend(#rhs);
143                        #rhs_borrow_ident.iter()
144                    } else {
145                        let len = #rhs_borrow_ident.len();
146                        #rhs_borrow_ident.extend(#rhs);
147                        #rhs_borrow_ident[len..].iter()
148                    }
149                    .filter_map(|(k, v2)| #lhs_borrow.table.get(k).map(|v1| (k.clone(), (v1.clone(), v2.clone()))))
150                };
151            },
152            Persistence::Mutable => {
153                diagnostics.push(Diagnostic::spanned(
154                    op_span,
155                    Level::Error,
156                    "An implementation of 'mutable does not exist",
157                ));
158                return Err(());
159            }
160        };
161
162        let write_iterator_after =
163            if persistences[0] == Persistence::Static || persistences[1] == Persistence::Static {
164                quote_spanned! {op_span=>
165                    // TODO: Probably only need to schedule if #*_borrow.len() > 0?
166                    #context.schedule_subgraph(#context.current_subgraph(), false);
167                }
168            } else {
169                quote_spanned! {op_span=>}
170            };
171
172        Ok(OperatorWriteOutput {
173            write_prologue,
174            write_iterator,
175            write_iterator_after,
176        })
177    },
178};