dfir_lang/graph/ops/
join_fused_lhs.rs
1use quote::{quote_spanned, ToTokens};
2use syn::parse_quote;
3use syn::spanned::Spanned;
4
5use super::join_fused::{make_joindata, parse_argument, parse_persistences, JoinOptions};
6use super::{
7 DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
8 OperatorWriteOutput, Persistence, PortIndexValue, WriteContextArgs, RANGE_0, RANGE_1,
9};
10use crate::diagnostic::{Diagnostic, Level};
11
12pub const JOIN_FUSED_LHS: OperatorConstraints = OperatorConstraints {
26 name: "join_fused_lhs",
27 categories: &[OperatorCategory::MultiIn],
28 hard_range_inn: &(2..=2),
29 soft_range_inn: &(2..=2),
30 hard_range_out: RANGE_1,
31 soft_range_out: RANGE_1,
32 num_args: 1,
33 persistence_args: &(0..=2),
34 type_args: RANGE_0,
35 is_external_input: false,
36 has_singleton_output: false,
37 flo_type: None,
38 ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
39 ports_out: None,
40 input_delaytype_fn: |idx| match idx {
41 PortIndexValue::Int(path) if "0" == path.to_token_stream().to_string() => {
42 Some(DelayType::Stratum)
43 }
44 _ => None,
45 },
46 write_fn: |wc @ &WriteContextArgs {
47 context,
48 df_ident,
49 op_span,
50 ident,
51 inputs,
52 is_pull,
53 op_inst:
54 OperatorInstance {
55 generics:
56 OpInstGenerics {
57 persistence_args, ..
58 },
59 ..
60 },
61 arguments,
62 ..
63 },
64 diagnostics| {
65 assert!(is_pull);
66
67 let arg0_span = arguments[0].span();
68
69 let persistences = parse_persistences(persistence_args);
70
71 let lhs_join_options =
72 parse_argument(&arguments[0]).map_err(|err| diagnostics.push(err))?;
73
74 let (lhs_prologue, lhs_pre_write_iter, lhs_borrow) =
75 make_joindata(wc, persistences[0], &lhs_join_options, "lhs")
76 .map_err(|err| diagnostics.push(err))?;
77
78 let rhs_joindata_ident = wc.make_ident("rhs_joindata");
79 let rhs_borrow_ident = wc.make_ident("rhs_joindata_borrow_ident");
80
81 let write_prologue_rhs = match persistences[1] {
82 Persistence::None | Persistence::Tick => quote_spanned! {op_span=>},
83 Persistence::Static => quote_spanned! {op_span=>
84 let #rhs_joindata_ident = #df_ident.add_state(::std::cell::RefCell::new(
85 ::std::vec::Vec::new()
86 ));
87 },
88 Persistence::Mutable => {
89 diagnostics.push(Diagnostic::spanned(
90 op_span,
91 Level::Error,
92 "An implementation of 'mutable does not exist",
93 ));
94 return Err(());
95 }
96 };
97
98 let write_prologue = quote_spanned! {op_span=>
99 #lhs_prologue
100 #write_prologue_rhs
101 };
102
103 let lhs = &inputs[0];
104 let rhs = &inputs[1];
105
106 let lhs_fold_or_reduce_into_from = match lhs_join_options {
107 JoinOptions::FoldFrom(lhs_from, lhs_fold) => quote_spanned! {arg0_span=>
108 #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_from);
109 },
110 JoinOptions::Fold(lhs_default, lhs_fold) => quote_spanned! {arg0_span=>
111 #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_default);
112 },
113 JoinOptions::Reduce(lhs_reduce) => quote_spanned! {arg0_span=>
114 #lhs_borrow.reduce_into(#lhs, #lhs_reduce);
115 },
116 };
117
118 let write_iterator = match persistences[1] {
119 Persistence::None | Persistence::Tick => quote_spanned! {op_span=>
120 #lhs_pre_write_iter
121
122 let #ident = {
123 #lhs_fold_or_reduce_into_from
124
125 #[allow(clippy::clone_on_copy)]
126 #rhs.filter_map(|(k, v2)| #lhs_borrow.table.get(&k).map(|v1| (k, (v1.clone(), v2.clone()))))
127 };
128 },
129 Persistence::Static => quote_spanned! {op_span=>
130 #lhs_pre_write_iter
131 let mut #rhs_borrow_ident = unsafe {
132 #context.state_ref_unchecked(#rhs_joindata_ident)
134 }.borrow_mut();
135
136 let #ident = {
137 #lhs_fold_or_reduce_into_from
138
139 #[allow(clippy::clone_on_copy)]
140 #[allow(suspicious_double_ref_op)]
141 if #context.is_first_run_this_tick() {
142 #rhs_borrow_ident.extend(#rhs);
143 #rhs_borrow_ident.iter()
144 } else {
145 let len = #rhs_borrow_ident.len();
146 #rhs_borrow_ident.extend(#rhs);
147 #rhs_borrow_ident[len..].iter()
148 }
149 .filter_map(|(k, v2)| #lhs_borrow.table.get(k).map(|v1| (k.clone(), (v1.clone(), v2.clone()))))
150 };
151 },
152 Persistence::Mutable => {
153 diagnostics.push(Diagnostic::spanned(
154 op_span,
155 Level::Error,
156 "An implementation of 'mutable does not exist",
157 ));
158 return Err(());
159 }
160 };
161
162 let write_iterator_after =
163 if persistences[0] == Persistence::Static || persistences[1] == Persistence::Static {
164 quote_spanned! {op_span=>
165 #context.schedule_subgraph(#context.current_subgraph(), false);
167 }
168 } else {
169 quote_spanned! {op_span=>}
170 };
171
172 Ok(OperatorWriteOutput {
173 write_prologue,
174 write_iterator,
175 write_iterator_after,
176 })
177 },
178};