1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote_spanned};
3use syn::spanned::Spanned;
4use syn::{Expr, ExprCall, parse_quote};
5
6use super::{
7 DelayType, OperatorCategory, OperatorConstraints, OperatorWriteOutput, Persistence, RANGE_0,
8 RANGE_1, WriteContextArgs,
9};
10use crate::diagnostic::{Diagnostic, Level};
11
12pub const JOIN_FUSED: OperatorConstraints = OperatorConstraints {
92 name: "join_fused",
93 categories: &[OperatorCategory::MultiIn],
94 hard_range_inn: &(2..=2),
95 soft_range_inn: &(2..=2),
96 hard_range_out: RANGE_1,
97 soft_range_out: RANGE_1,
98 num_args: 2,
99 persistence_args: &(0..=2),
100 type_args: RANGE_0,
101 is_external_input: false,
102 has_singleton_output: false,
103 flo_type: None,
104 ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
105 ports_out: None,
106 input_delaytype_fn: |_| Some(DelayType::Stratum),
107 write_fn: |wc @ &WriteContextArgs {
108 context,
109 op_span,
110 ident,
111 inputs,
112 is_pull,
113 arguments,
114 ..
115 },
116 diagnostics| {
117 assert!(is_pull);
118
119 let persistences: [_; 2] = wc.persistence_args_disallow_mutable(diagnostics);
120
121 let lhs_join_options =
122 parse_argument(&arguments[0]).map_err(|err| diagnostics.push(err))?;
123 let rhs_join_options =
124 parse_argument(&arguments[1]).map_err(|err| diagnostics.push(err))?;
125
126 let (lhs_prologue, lhs_prologue_after, lhs_pre_write_iter, lhs_borrow) =
127 make_joindata(wc, persistences[0], &lhs_join_options, "lhs")
128 .map_err(|err| diagnostics.push(err))?;
129
130 let (rhs_prologue, rhs_prologue_after, rhs_pre_write_iter, rhs_borrow) =
131 make_joindata(wc, persistences[1], &rhs_join_options, "rhs")
132 .map_err(|err| diagnostics.push(err))?;
133
134 let lhs = &inputs[0];
135 let rhs = &inputs[1];
136
137 let arg0_span = arguments[0].span();
138 let arg1_span = arguments[1].span();
139
140 let lhs_tokens = match lhs_join_options {
141 JoinOptions::FoldFrom(lhs_from, lhs_fold) => quote_spanned! {arg0_span=>
142 #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_from);
143 },
144 JoinOptions::Fold(lhs_default, lhs_fold) => quote_spanned! {arg0_span=>
145 #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_default);
146 },
147 JoinOptions::Reduce(lhs_reduce) => quote_spanned! {arg0_span=>
148 #lhs_borrow.reduce_into(#lhs, #lhs_reduce);
149 },
150 };
151
152 let rhs_tokens = match rhs_join_options {
153 JoinOptions::FoldFrom(rhs_from, rhs_fold) => quote_spanned! {arg0_span=>
154 #rhs_borrow.fold_into(#rhs, #rhs_fold, #rhs_from);
155 },
156 JoinOptions::Fold(rhs_default, rhs_fold) => quote_spanned! {arg1_span=>
157 #rhs_borrow.fold_into(#rhs, #rhs_fold, #rhs_default);
158 },
159 JoinOptions::Reduce(rhs_reduce) => quote_spanned! {arg1_span=>
160 #rhs_borrow.reduce_into(#rhs, #rhs_reduce);
161 },
162 };
163
164 let write_iterator = quote_spanned! {op_span=>
166 #lhs_pre_write_iter
167 #rhs_pre_write_iter
168
169 let #ident = {
170 #lhs_tokens
171 #rhs_tokens
172
173 #[allow(clippy::clone_on_copy)]
175 #[allow(suspicious_double_ref_op)]
176 #rhs_borrow
177 .table
178 .iter()
179 .filter_map(|(k, v2)| #lhs_borrow.table.get(k).map(|v1| (k.clone(), (v1.clone(), v2.clone()))))
180 };
181 };
182
183 let write_iterator_after =
184 if persistences[0] == Persistence::Static || persistences[1] == Persistence::Static {
185 quote_spanned! {op_span=>
186 #context.schedule_subgraph(#context.current_subgraph(), false);
188 }
189 } else {
190 quote_spanned! {op_span=>}
191 };
192
193 Ok(OperatorWriteOutput {
194 write_prologue: quote_spanned! {op_span=>
195 #lhs_prologue
196 #rhs_prologue
197 },
198 write_prologue_after: quote_spanned! {op_span=>
199 #lhs_prologue_after
200 #rhs_prologue_after
201 },
202 write_iterator,
203 write_iterator_after,
204 })
205 },
206};
207
208pub(crate) enum JoinOptions<'a> {
209 FoldFrom(&'a Expr, &'a Expr),
210 Fold(&'a Expr, &'a Expr),
211 Reduce(&'a Expr),
212}
213
214pub(crate) fn parse_argument(arg: &Expr) -> Result<JoinOptions, Diagnostic> {
215 let Expr::Call(ExprCall {
216 attrs: _,
217 func,
218 paren_token: _,
219 args,
220 }) = arg
221 else {
222 return Err(Diagnostic::spanned(
223 arg.span(),
224 Level::Error,
225 format!("Argument must be a function call: {arg:?}"),
226 ));
227 };
228
229 let mut elems = args.iter();
230 let func_name = func.to_token_stream().to_string();
231
232 match func_name.as_str() {
233 "Fold" => match (elems.next(), elems.next()) {
234 (Some(default), Some(fold)) => Ok(JoinOptions::Fold(default, fold)),
235 _ => Err(Diagnostic::spanned(
236 args.span(),
237 Level::Error,
238 format!(
239 "Fold requires two arguments, first is the default function, second is the folding function: {func:?}"
240 ),
241 )),
242 },
243 "FoldFrom" => match (elems.next(), elems.next()) {
244 (Some(from), Some(fold)) => Ok(JoinOptions::FoldFrom(from, fold)),
245 _ => Err(Diagnostic::spanned(
246 args.span(),
247 Level::Error,
248 format!(
249 "FoldFrom requires two arguments, first is the From function, second is the folding function: {func:?}"
250 ),
251 )),
252 },
253 "Reduce" => match elems.next() {
254 Some(reduce) => Ok(JoinOptions::Reduce(reduce)),
255 _ => Err(Diagnostic::spanned(
256 args.span(),
257 Level::Error,
258 format!("Reduce requires one argument, the reducing function: {func:?}"),
259 )),
260 },
261 _ => Err(Diagnostic::spanned(
262 func.span(),
263 Level::Error,
264 format!("Unknown summarizing function: {func:?}"),
265 )),
266 }
267}
268
269pub(crate) fn make_joindata(
271 wc: &WriteContextArgs,
272 persistence: Persistence,
273 join_options: &JoinOptions<'_>,
274 side: &str,
275) -> Result<(TokenStream, TokenStream, TokenStream, TokenStream), Diagnostic> {
276 let joindata_ident = wc.make_ident(format!("joindata_{}", side));
277 let borrow_ident = wc.make_ident(format!("joindata_{}_borrow", side));
278
279 let &WriteContextArgs {
280 context,
281 df_ident,
282 root,
283 op_span,
284 ..
285 } = wc;
286
287 let join_type = match *join_options {
288 JoinOptions::FoldFrom(_, _) => {
289 quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateFoldFrom)
290 }
291 JoinOptions::Fold(_, _) => {
292 quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateFold)
293 }
294 JoinOptions::Reduce(_) => {
295 quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateReduce)
296 }
297 };
298
299 Ok(match persistence {
300 Persistence::None => (
301 Default::default(),
302 Default::default(),
303 quote_spanned! {op_span=>
304 let mut #borrow_ident = #join_type::default();
305 },
306 quote_spanned! {op_span=>
307 #borrow_ident
308 },
309 ),
310 Persistence::Tick | Persistence::Loop | Persistence::Static => {
311 let lifespan = wc.persistence_as_state_lifespan(persistence);
312 (
313 quote_spanned! {op_span=>
314 let #joindata_ident = #df_ident.add_state(::std::cell::RefCell::new(#join_type::default()));
315 },
316 lifespan.map(|lifespan| quote_spanned! {op_span=>
317 #df_ident.set_state_lifespan_hook(#joindata_ident, #lifespan, |rcell| { rcell.take(); });
319 }).unwrap_or_default(),
320 quote_spanned! {op_span=>
321 let mut #borrow_ident = unsafe {
322 #context.state_ref_unchecked(#joindata_ident)
324 }.borrow_mut();
325 },
326 quote_spanned! {op_span=>
327 #borrow_ident
328 },
329 )
330 }
331 Persistence::Mutable => panic!(),
332 })
333}