1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote_spanned};
3use syn::spanned::Spanned;
4use syn::{Expr, ExprCall, parse_quote};
5
6use super::{
7 DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
8 OperatorWriteOutput, Persistence, RANGE_0, RANGE_1, WriteContextArgs,
9};
10use crate::diagnostic::{Diagnostic, Level};
11
12pub const JOIN_FUSED: OperatorConstraints = OperatorConstraints {
92 name: "join_fused",
93 categories: &[OperatorCategory::MultiIn],
94 hard_range_inn: &(2..=2),
95 soft_range_inn: &(2..=2),
96 hard_range_out: RANGE_1,
97 soft_range_out: RANGE_1,
98 num_args: 2,
99 persistence_args: &(0..=2),
100 type_args: RANGE_0,
101 is_external_input: false,
102 has_singleton_output: false,
103 flo_type: None,
104 ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
105 ports_out: None,
106 input_delaytype_fn: |_| Some(DelayType::Stratum),
107 write_fn: |wc @ &WriteContextArgs {
108 context,
109 op_span,
110 ident,
111 inputs,
112 is_pull,
113 op_inst:
114 OperatorInstance {
115 generics:
116 OpInstGenerics {
117 persistence_args, ..
118 },
119 ..
120 },
121 arguments,
122 ..
123 },
124 diagnostics| {
125 assert!(is_pull);
126
127 let persistences = parse_persistences(persistence_args);
128
129 let lhs_join_options =
130 parse_argument(&arguments[0]).map_err(|err| diagnostics.push(err))?;
131 let rhs_join_options =
132 parse_argument(&arguments[1]).map_err(|err| diagnostics.push(err))?;
133
134 let (lhs_prologue, lhs_pre_write_iter, lhs_borrow) =
135 make_joindata(wc, persistences[0], &lhs_join_options, "lhs")
136 .map_err(|err| diagnostics.push(err))?;
137
138 let (rhs_prologue, rhs_pre_write_iter, rhs_borrow) =
139 make_joindata(wc, persistences[1], &rhs_join_options, "rhs")
140 .map_err(|err| diagnostics.push(err))?;
141
142 let write_prologue = quote_spanned! {op_span=>
143 #lhs_prologue
144 #rhs_prologue
145 };
146
147 let lhs = &inputs[0];
148 let rhs = &inputs[1];
149
150 let arg0_span = arguments[0].span();
151 let arg1_span = arguments[1].span();
152
153 let lhs_tokens = match lhs_join_options {
154 JoinOptions::FoldFrom(lhs_from, lhs_fold) => quote_spanned! {arg0_span=>
155 #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_from);
156 },
157 JoinOptions::Fold(lhs_default, lhs_fold) => quote_spanned! {arg0_span=>
158 #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_default);
159 },
160 JoinOptions::Reduce(lhs_reduce) => quote_spanned! {arg0_span=>
161 #lhs_borrow.reduce_into(#lhs, #lhs_reduce);
162 },
163 };
164
165 let rhs_tokens = match rhs_join_options {
166 JoinOptions::FoldFrom(rhs_from, rhs_fold) => quote_spanned! {arg0_span=>
167 #rhs_borrow.fold_into(#rhs, #rhs_fold, #rhs_from);
168 },
169 JoinOptions::Fold(rhs_default, rhs_fold) => quote_spanned! {arg1_span=>
170 #rhs_borrow.fold_into(#rhs, #rhs_fold, #rhs_default);
171 },
172 JoinOptions::Reduce(rhs_reduce) => quote_spanned! {arg1_span=>
173 #rhs_borrow.reduce_into(#rhs, #rhs_reduce);
174 },
175 };
176
177 let write_iterator = quote_spanned! {op_span=>
179 #lhs_pre_write_iter
180 #rhs_pre_write_iter
181
182 let #ident = {
183 #lhs_tokens
184 #rhs_tokens
185
186 #[allow(clippy::clone_on_copy)]
188 #[allow(suspicious_double_ref_op)]
189 #rhs_borrow
190 .table
191 .iter()
192 .filter_map(|(k, v2)| #lhs_borrow.table.get(k).map(|v1| (k.clone(), (v1.clone(), v2.clone()))))
193 };
194 };
195
196 let write_iterator_after =
197 if persistences[0] == Persistence::Static || persistences[1] == Persistence::Static {
198 quote_spanned! {op_span=>
199 #context.schedule_subgraph(#context.current_subgraph(), false);
201 }
202 } else {
203 quote_spanned! {op_span=>}
204 };
205
206 Ok(OperatorWriteOutput {
207 write_prologue,
208 write_iterator,
209 write_iterator_after,
210 })
211 },
212};
213
214pub(crate) enum JoinOptions<'a> {
215 FoldFrom(&'a Expr, &'a Expr),
216 Fold(&'a Expr, &'a Expr),
217 Reduce(&'a Expr),
218}
219
220pub(crate) fn parse_argument(arg: &Expr) -> Result<JoinOptions, Diagnostic> {
221 let Expr::Call(ExprCall {
222 attrs: _,
223 func,
224 paren_token: _,
225 args,
226 }) = arg
227 else {
228 return Err(Diagnostic::spanned(
229 arg.span(),
230 Level::Error,
231 format!("Argument must be a function call: {arg:?}"),
232 ));
233 };
234
235 let mut elems = args.iter();
236 let func_name = func.to_token_stream().to_string();
237
238 match func_name.as_str() {
239 "Fold" => match (elems.next(), elems.next()) {
240 (Some(default), Some(fold)) => Ok(JoinOptions::Fold(default, fold)),
241 _ => Err(Diagnostic::spanned(
242 args.span(),
243 Level::Error,
244 format!(
245 "Fold requires two arguments, first is the default function, second is the folding function: {func:?}"
246 ),
247 )),
248 },
249 "FoldFrom" => match (elems.next(), elems.next()) {
250 (Some(from), Some(fold)) => Ok(JoinOptions::FoldFrom(from, fold)),
251 _ => Err(Diagnostic::spanned(
252 args.span(),
253 Level::Error,
254 format!(
255 "FoldFrom requires two arguments, first is the From function, second is the folding function: {func:?}"
256 ),
257 )),
258 },
259 "Reduce" => match elems.next() {
260 Some(reduce) => Ok(JoinOptions::Reduce(reduce)),
261 _ => Err(Diagnostic::spanned(
262 args.span(),
263 Level::Error,
264 format!("Reduce requires one argument, the reducing function: {func:?}"),
265 )),
266 },
267 _ => Err(Diagnostic::spanned(
268 func.span(),
269 Level::Error,
270 format!("Unknown summarizing function: {func:?}"),
271 )),
272 }
273}
274
275pub(crate) fn make_joindata(
276 wc: &WriteContextArgs,
277 persistence: Persistence,
278 join_options: &JoinOptions<'_>,
279 side: &str,
280) -> Result<(TokenStream, TokenStream, TokenStream), Diagnostic> {
281 let joindata_ident = wc.make_ident(format!("joindata_{}", side));
282 let borrow_ident = wc.make_ident(format!("joindata_{}_borrow", side));
283
284 let &WriteContextArgs {
285 context,
286 df_ident,
287 root,
288 op_span,
289 ..
290 } = wc;
291
292 let join_type = match *join_options {
293 JoinOptions::FoldFrom(_, _) => {
294 quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateFoldFrom)
295 }
296 JoinOptions::Fold(_, _) => {
297 quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateFold)
298 }
299 JoinOptions::Reduce(_) => {
300 quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateReduce)
301 }
302 };
303
304 let (prologue, pre_write_iter, borrow) = match persistence {
305 Persistence::None => (
306 Default::default(),
307 quote_spanned! {op_span=>
308 let mut #borrow_ident = #join_type::default();
309 },
310 quote_spanned! {op_span=>
311 #borrow_ident
312 },
313 ),
314 Persistence::Tick => (
315 quote_spanned! {op_span=>
316 let #joindata_ident = #df_ident.add_state(std::cell::RefCell::new(
317 #root::util::monotonic_map::MonotonicMap::new_init(
318 #join_type::default()
319 )
320 ));
321 },
322 quote_spanned! {op_span=>
323 let mut #borrow_ident = unsafe {
324 #context.state_ref_unchecked(#joindata_ident)
326 }.borrow_mut();
327 },
328 quote_spanned! {op_span=>
329 #borrow_ident.get_mut_clear(#context.current_tick())
330 },
331 ),
332 Persistence::Static => (
333 quote_spanned! {op_span=>
334 let #joindata_ident = #df_ident.add_state(std::cell::RefCell::new(
335 #join_type::default()
336 ));
337 },
338 quote_spanned! {op_span=>
339 let mut #borrow_ident = unsafe {
340 #context.state_ref_unchecked(#joindata_ident)
342 }.borrow_mut();
343 },
344 quote_spanned! {op_span=>
345 #borrow_ident
346 },
347 ),
348 Persistence::Mutable => {
349 return Err(Diagnostic::spanned(
350 op_span,
351 Level::Error,
352 "An implementation of 'mutable does not exist",
353 ));
354 }
355 };
356 Ok((prologue, pre_write_iter, borrow))
357}
358
359pub(crate) fn parse_persistences(persistences: &[Persistence]) -> [Persistence; 2] {
360 match persistences {
361 [] => [Persistence::Tick, Persistence::Tick],
362 [a] => [*a, *a],
363 [a, b] => [*a, *b],
364 _ => panic!("Too many persistences: {persistences:?}"),
365 }
366}