dfir_lang/graph/ops/
join_fused.rs

1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote_spanned};
3use syn::spanned::Spanned;
4use syn::{Expr, ExprCall, parse_quote};
5
6use super::{
7    DelayType, OperatorCategory, OperatorConstraints, OperatorWriteOutput, Persistence, RANGE_0,
8    RANGE_1, WriteContextArgs,
9};
10use crate::diagnostic::{Diagnostic, Level};
11
12/// > 2 input streams of type `<(K, V1)>` and `<(K, V2)>`, 1 output stream of type `<(K, (V1, V2))>`
13///
14/// `join_fused` takes two arguments, they are the configuration options for the left hand side and right hand side inputs respectively.
15/// There are three available configuration options, they are `Reduce`: if the input type is the same as the accumulator type,
16/// `Fold`: if the input type is different from the accumulator type, and the accumulator type has a sensible default value, and
17/// `FoldFrom`: if the input type is different from the accumulator type, and the accumulator needs to be derived from the first input value.
18/// Examples of all three configuration options are below:
19/// ```dfir,ignore
20/// // Left hand side input will use fold, right hand side input will use reduce,
21/// join_fused(Fold(|| "default value", |x, y| *x += y), Reduce(|x, y| *x -= y))
22///
23/// // Left hand side input will use FoldFrom, and the right hand side input will use Reduce again
24/// join_fused(FoldFrom(|x| "conversion function", |x, y| *x += y), Reduce(|x, y| *x *= y))
25/// ```
26/// The three currently supported fused operator types are `Fold(Fn() -> A, Fn(A, T) -> A)`, `Reduce(Fn(A, A) -> A)`, and `FoldFrom(Fn(T) -> A, Fn(A, T) -> A)`
27///
28/// `join_fused` first performs a fold_keyed/reduce_keyed operation on each input stream before performing joining. See `join()`. There is currently no equivalent for `FoldFrom` in dfir operators.
29///
30/// For example, the following two dfir programs are equivalent, the former would optimize into the latter:
31///
32/// ```dfir
33/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)])
34///     -> reduce_keyed(|x: &mut _, y| *x += y)
35///     -> [0]my_join;
36/// source_iter(vec![("key", 2), ("key", 3)])
37///     -> fold_keyed(|| 1, |x: &mut _, y| *x *= y)
38///     -> [1]my_join;
39/// my_join = join_multiset()
40///     -> assert_eq([("key", (3, 6))]);
41/// ```
42///
43/// ```dfir
44/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)])
45///     -> [0]my_join;
46/// source_iter(vec![("key", 2), ("key", 3)])
47///     -> [1]my_join;
48/// my_join = join_fused(Reduce(|x, y| *x += y), Fold(|| 1, |x, y| *x *= y))
49///     -> assert_eq([("key", (3, 6))]);
50/// ```
51///
52/// Here is an example of using FoldFrom to derive the accumulator from the first value:
53///
54/// ```dfir
55/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)])
56///     -> [0]my_join;
57/// source_iter(vec![("key", 2), ("key", 3)])
58///     -> [1]my_join;
59/// my_join = join_fused(FoldFrom(|x| x + 3, |x, y| *x += y), Fold(|| 1, |x, y| *x *= y))
60///     -> assert_eq([("key", (6, 6))]);
61/// ```
62///
63/// The benefit of this is that the state between the reducing/folding operator and the join is merged together.
64///
65/// `join_fused` follows the same persistence rules as `join` and all other operators. By default, both the left hand side and right hand side are `'tick` persistence. They can be set to `'static` persistence
66/// by specifying `'static` in the type arguments of the operator.
67///
68/// for `join_fused::<'static>`, the operator will replay all _keys_ that the join has ever seen each tick, and not only the new matches from that specific tick.
69/// This means that it behaves identically to if `persist::<'static>()` were placed before the inputs and the persistence of
70/// for example, the two following examples have identical behavior:
71///
72/// ```dfir
73/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)]) -> persist::<'static>() -> [0]my_join;
74/// source_iter(vec![("key", 2)]) -> my_union;
75/// source_iter(vec![("key", 3)]) -> defer_tick() -> my_union;
76/// my_union = union() -> persist::<'static>() -> [1]my_join;
77///
78/// my_join = join_fused(Reduce(|x, y| *x += y), Fold(|| 1, |x, y| *x *= y))
79///     -> assert_eq([("key", (3, 2)), ("key", (3, 6))]);
80/// ```
81///
82/// ```dfir
83/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)]) -> [0]my_join;
84/// source_iter(vec![("key", 2)]) -> my_union;
85/// source_iter(vec![("key", 3)]) -> defer_tick() -> my_union;
86/// my_union = union() -> [1]my_join;
87///
88/// my_join = join_fused::<'static>(Reduce(|x, y| *x += y), Fold(|| 1, |x, y| *x *= y))
89///     -> assert_eq([("key", (3, 2)), ("key", (3, 6))]);
90/// ```
91pub const JOIN_FUSED: OperatorConstraints = OperatorConstraints {
92    name: "join_fused",
93    categories: &[OperatorCategory::MultiIn],
94    hard_range_inn: &(2..=2),
95    soft_range_inn: &(2..=2),
96    hard_range_out: RANGE_1,
97    soft_range_out: RANGE_1,
98    num_args: 2,
99    persistence_args: &(0..=2),
100    type_args: RANGE_0,
101    is_external_input: false,
102    has_singleton_output: false,
103    flo_type: None,
104    ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
105    ports_out: None,
106    input_delaytype_fn: |_| Some(DelayType::Stratum),
107    write_fn: |wc @ &WriteContextArgs {
108                   context,
109                   op_span,
110                   ident,
111                   inputs,
112                   is_pull,
113                   arguments,
114                   ..
115               },
116               diagnostics| {
117        assert!(is_pull);
118
119        let persistences: [_; 2] = wc.persistence_args_disallow_mutable(diagnostics);
120
121        let lhs_join_options =
122            parse_argument(&arguments[0]).map_err(|err| diagnostics.push(err))?;
123        let rhs_join_options =
124            parse_argument(&arguments[1]).map_err(|err| diagnostics.push(err))?;
125
126        let (lhs_prologue, lhs_prologue_after, lhs_pre_write_iter, lhs_borrow) =
127            make_joindata(wc, persistences[0], &lhs_join_options, "lhs")
128                .map_err(|err| diagnostics.push(err))?;
129
130        let (rhs_prologue, rhs_prologue_after, rhs_pre_write_iter, rhs_borrow) =
131            make_joindata(wc, persistences[1], &rhs_join_options, "rhs")
132                .map_err(|err| diagnostics.push(err))?;
133
134        let lhs = &inputs[0];
135        let rhs = &inputs[1];
136
137        let arg0_span = arguments[0].span();
138        let arg1_span = arguments[1].span();
139
140        let lhs_tokens = match lhs_join_options {
141            JoinOptions::FoldFrom(lhs_from, lhs_fold) => quote_spanned! {arg0_span=>
142                #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_from);
143            },
144            JoinOptions::Fold(lhs_default, lhs_fold) => quote_spanned! {arg0_span=>
145                #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_default);
146            },
147            JoinOptions::Reduce(lhs_reduce) => quote_spanned! {arg0_span=>
148                #lhs_borrow.reduce_into(#lhs, #lhs_reduce);
149            },
150        };
151
152        let rhs_tokens = match rhs_join_options {
153            JoinOptions::FoldFrom(rhs_from, rhs_fold) => quote_spanned! {arg0_span=>
154                #rhs_borrow.fold_into(#rhs, #rhs_fold, #rhs_from);
155            },
156            JoinOptions::Fold(rhs_default, rhs_fold) => quote_spanned! {arg1_span=>
157                #rhs_borrow.fold_into(#rhs, #rhs_fold, #rhs_default);
158            },
159            JoinOptions::Reduce(rhs_reduce) => quote_spanned! {arg1_span=>
160                #rhs_borrow.reduce_into(#rhs, #rhs_reduce);
161            },
162        };
163
164        // Since both input arguments are stratum blocking then we don't need to keep track of ticks to avoid emitting the same thing twice in the same tick.
165        let write_iterator = quote_spanned! {op_span=>
166            #lhs_pre_write_iter
167            #rhs_pre_write_iter
168
169            let #ident = {
170                #lhs_tokens
171                #rhs_tokens
172
173                // TODO: start the iterator with the smallest len() table rather than always picking rhs.
174                #[allow(clippy::clone_on_copy)]
175                #[allow(suspicious_double_ref_op)]
176                #rhs_borrow
177                    .table
178                    .iter()
179                    .filter_map(|(k, v2)| #lhs_borrow.table.get(k).map(|v1| (k.clone(), (v1.clone(), v2.clone()))))
180            };
181        };
182
183        let write_iterator_after =
184            if persistences[0] == Persistence::Static || persistences[1] == Persistence::Static {
185                quote_spanned! {op_span=>
186                    // TODO: Probably only need to schedule if #*_borrow.len() > 0?
187                    #context.schedule_subgraph(#context.current_subgraph(), false);
188                }
189            } else {
190                quote_spanned! {op_span=>}
191            };
192
193        Ok(OperatorWriteOutput {
194            write_prologue: quote_spanned! {op_span=>
195                #lhs_prologue
196                #rhs_prologue
197            },
198            write_prologue_after: quote_spanned! {op_span=>
199                #lhs_prologue_after
200                #rhs_prologue_after
201            },
202            write_iterator,
203            write_iterator_after,
204        })
205    },
206};
207
208pub(crate) enum JoinOptions<'a> {
209    FoldFrom(&'a Expr, &'a Expr),
210    Fold(&'a Expr, &'a Expr),
211    Reduce(&'a Expr),
212}
213
214pub(crate) fn parse_argument(arg: &Expr) -> Result<JoinOptions, Diagnostic> {
215    let Expr::Call(ExprCall {
216        attrs: _,
217        func,
218        paren_token: _,
219        args,
220    }) = arg
221    else {
222        return Err(Diagnostic::spanned(
223            arg.span(),
224            Level::Error,
225            format!("Argument must be a function call: {arg:?}"),
226        ));
227    };
228
229    let mut elems = args.iter();
230    let func_name = func.to_token_stream().to_string();
231
232    match func_name.as_str() {
233        "Fold" => match (elems.next(), elems.next()) {
234            (Some(default), Some(fold)) => Ok(JoinOptions::Fold(default, fold)),
235            _ => Err(Diagnostic::spanned(
236                args.span(),
237                Level::Error,
238                format!(
239                    "Fold requires two arguments, first is the default function, second is the folding function: {func:?}"
240                ),
241            )),
242        },
243        "FoldFrom" => match (elems.next(), elems.next()) {
244            (Some(from), Some(fold)) => Ok(JoinOptions::FoldFrom(from, fold)),
245            _ => Err(Diagnostic::spanned(
246                args.span(),
247                Level::Error,
248                format!(
249                    "FoldFrom requires two arguments, first is the From function, second is the folding function: {func:?}"
250                ),
251            )),
252        },
253        "Reduce" => match elems.next() {
254            Some(reduce) => Ok(JoinOptions::Reduce(reduce)),
255            _ => Err(Diagnostic::spanned(
256                args.span(),
257                Level::Error,
258                format!("Reduce requires one argument, the reducing function: {func:?}"),
259            )),
260        },
261        _ => Err(Diagnostic::spanned(
262            func.span(),
263            Level::Error,
264            format!("Unknown summarizing function: {func:?}"),
265        )),
266    }
267}
268
269/// Returns `(prologue, prologue_after, pre_write_iter, borrow)`.
270pub(crate) fn make_joindata(
271    wc: &WriteContextArgs,
272    persistence: Persistence,
273    join_options: &JoinOptions<'_>,
274    side: &str,
275) -> Result<(TokenStream, TokenStream, TokenStream, TokenStream), Diagnostic> {
276    let joindata_ident = wc.make_ident(format!("joindata_{}", side));
277    let borrow_ident = wc.make_ident(format!("joindata_{}_borrow", side));
278
279    let &WriteContextArgs {
280        context,
281        df_ident,
282        root,
283        op_span,
284        ..
285    } = wc;
286
287    let join_type = match *join_options {
288        JoinOptions::FoldFrom(_, _) => {
289            quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateFoldFrom)
290        }
291        JoinOptions::Fold(_, _) => {
292            quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateFold)
293        }
294        JoinOptions::Reduce(_) => {
295            quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateReduce)
296        }
297    };
298
299    Ok(match persistence {
300        Persistence::None => (
301            Default::default(),
302            Default::default(),
303            quote_spanned! {op_span=>
304                let mut #borrow_ident = #join_type::default();
305            },
306            quote_spanned! {op_span=>
307                #borrow_ident
308            },
309        ),
310        Persistence::Tick | Persistence::Loop | Persistence::Static => {
311            let lifespan = wc.persistence_as_state_lifespan(persistence);
312            (
313                quote_spanned! {op_span=>
314                    let #joindata_ident = #df_ident.add_state(::std::cell::RefCell::new(#join_type::default()));
315                },
316                lifespan.map(|lifespan| quote_spanned! {op_span=>
317                    // Reset the value to the initializer fn at the end of each tick/loop execution.
318                    #df_ident.set_state_lifespan_hook(#joindata_ident, #lifespan, |rcell| { rcell.take(); });
319                }).unwrap_or_default(),
320                quote_spanned! {op_span=>
321                    let mut #borrow_ident = unsafe {
322                        // SAFETY: handles from `#df_ident`.
323                        #context.state_ref_unchecked(#joindata_ident)
324                    }.borrow_mut();
325                },
326                quote_spanned! {op_span=>
327                    #borrow_ident
328                },
329            )
330        }
331        Persistence::Mutable => panic!(),
332    })
333}