dfir_lang/graph/ops/
join_fused.rs

1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote_spanned};
3use syn::spanned::Spanned;
4use syn::{Expr, ExprCall, parse_quote};
5
6use super::{
7    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
8    OperatorWriteOutput, Persistence, RANGE_0, RANGE_1, WriteContextArgs,
9};
10use crate::diagnostic::{Diagnostic, Level};
11
12/// > 2 input streams of type `<(K, V1)>` and `<(K, V2)>`, 1 output stream of type `<(K, (V1, V2))>`
13///
14/// `join_fused` takes two arguments, they are the configuration options for the left hand side and right hand side inputs respectively.
15/// There are three available configuration options, they are `Reduce`: if the input type is the same as the accumulator type,
16/// `Fold`: if the input type is different from the accumulator type, and the accumulator type has a sensible default value, and
17/// `FoldFrom`: if the input type is different from the accumulator type, and the accumulator needs to be derived from the first input value.
18/// Examples of all three configuration options are below:
19/// ```dfir,ignore
20/// // Left hand side input will use fold, right hand side input will use reduce,
21/// join_fused(Fold(|| "default value", |x, y| *x += y), Reduce(|x, y| *x -= y))
22///
23/// // Left hand side input will use FoldFrom, and the right hand side input will use Reduce again
24/// join_fused(FoldFrom(|x| "conversion function", |x, y| *x += y), Reduce(|x, y| *x *= y))
25/// ```
26/// The three currently supported fused operator types are `Fold(Fn() -> A, Fn(A, T) -> A)`, `Reduce(Fn(A, A) -> A)`, and `FoldFrom(Fn(T) -> A, Fn(A, T) -> A)`
27///
28/// `join_fused` first performs a fold_keyed/reduce_keyed operation on each input stream before performing joining. See `join()`. There is currently no equivalent for `FoldFrom` in dfir operators.
29///
30/// For example, the following two dfir programs are equivalent, the former would optimize into the latter:
31///
32/// ```dfir
33/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)])
34///     -> reduce_keyed(|x: &mut _, y| *x += y)
35///     -> [0]my_join;
36/// source_iter(vec![("key", 2), ("key", 3)])
37///     -> fold_keyed(|| 1, |x: &mut _, y| *x *= y)
38///     -> [1]my_join;
39/// my_join = join_multiset()
40///     -> assert_eq([("key", (3, 6))]);
41/// ```
42///
43/// ```dfir
44/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)])
45///     -> [0]my_join;
46/// source_iter(vec![("key", 2), ("key", 3)])
47///     -> [1]my_join;
48/// my_join = join_fused(Reduce(|x, y| *x += y), Fold(|| 1, |x, y| *x *= y))
49///     -> assert_eq([("key", (3, 6))]);
50/// ```
51///
52/// Here is an example of using FoldFrom to derive the accumulator from the first value:
53///
54/// ```dfir
55/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)])
56///     -> [0]my_join;
57/// source_iter(vec![("key", 2), ("key", 3)])
58///     -> [1]my_join;
59/// my_join = join_fused(FoldFrom(|x| x + 3, |x, y| *x += y), Fold(|| 1, |x, y| *x *= y))
60///     -> assert_eq([("key", (6, 6))]);
61/// ```
62///
63/// The benefit of this is that the state between the reducing/folding operator and the join is merged together.
64///
65/// `join_fused` follows the same persistence rules as `join` and all other operators. By default, both the left hand side and right hand side are `'tick` persistence. They can be set to `'static` persistence
66/// by specifying `'static` in the type arguments of the operator.
67///
68/// for `join_fused::<'static>`, the operator will replay all _keys_ that the join has ever seen each tick, and not only the new matches from that specific tick.
69/// This means that it behaves identically to if `persist::<'static>()` were placed before the inputs and the persistence of
70/// for example, the two following examples have identical behavior:
71///
72/// ```dfir
73/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)]) -> persist::<'static>() -> [0]my_join;
74/// source_iter(vec![("key", 2)]) -> my_union;
75/// source_iter(vec![("key", 3)]) -> defer_tick() -> my_union;
76/// my_union = union() -> persist::<'static>() -> [1]my_join;
77///
78/// my_join = join_fused(Reduce(|x, y| *x += y), Fold(|| 1, |x, y| *x *= y))
79///     -> assert_eq([("key", (3, 2)), ("key", (3, 6))]);
80/// ```
81///
82/// ```dfir
83/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)]) -> [0]my_join;
84/// source_iter(vec![("key", 2)]) -> my_union;
85/// source_iter(vec![("key", 3)]) -> defer_tick() -> my_union;
86/// my_union = union() -> [1]my_join;
87///
88/// my_join = join_fused::<'static>(Reduce(|x, y| *x += y), Fold(|| 1, |x, y| *x *= y))
89///     -> assert_eq([("key", (3, 2)), ("key", (3, 6))]);
90/// ```
91pub const JOIN_FUSED: OperatorConstraints = OperatorConstraints {
92    name: "join_fused",
93    categories: &[OperatorCategory::MultiIn],
94    hard_range_inn: &(2..=2),
95    soft_range_inn: &(2..=2),
96    hard_range_out: RANGE_1,
97    soft_range_out: RANGE_1,
98    num_args: 2,
99    persistence_args: &(0..=2),
100    type_args: RANGE_0,
101    is_external_input: false,
102    has_singleton_output: false,
103    flo_type: None,
104    ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
105    ports_out: None,
106    input_delaytype_fn: |_| Some(DelayType::Stratum),
107    write_fn: |wc @ &WriteContextArgs {
108                   context,
109                   op_span,
110                   ident,
111                   inputs,
112                   is_pull,
113                   op_inst:
114                       OperatorInstance {
115                           generics:
116                               OpInstGenerics {
117                                   persistence_args, ..
118                               },
119                           ..
120                       },
121                   arguments,
122                   ..
123               },
124               diagnostics| {
125        assert!(is_pull);
126
127        let persistences = parse_persistences(persistence_args);
128
129        let lhs_join_options =
130            parse_argument(&arguments[0]).map_err(|err| diagnostics.push(err))?;
131        let rhs_join_options =
132            parse_argument(&arguments[1]).map_err(|err| diagnostics.push(err))?;
133
134        let (lhs_prologue, lhs_pre_write_iter, lhs_borrow) =
135            make_joindata(wc, persistences[0], &lhs_join_options, "lhs")
136                .map_err(|err| diagnostics.push(err))?;
137
138        let (rhs_prologue, rhs_pre_write_iter, rhs_borrow) =
139            make_joindata(wc, persistences[1], &rhs_join_options, "rhs")
140                .map_err(|err| diagnostics.push(err))?;
141
142        let write_prologue = quote_spanned! {op_span=>
143            #lhs_prologue
144            #rhs_prologue
145        };
146
147        let lhs = &inputs[0];
148        let rhs = &inputs[1];
149
150        let arg0_span = arguments[0].span();
151        let arg1_span = arguments[1].span();
152
153        let lhs_tokens = match lhs_join_options {
154            JoinOptions::FoldFrom(lhs_from, lhs_fold) => quote_spanned! {arg0_span=>
155                #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_from);
156            },
157            JoinOptions::Fold(lhs_default, lhs_fold) => quote_spanned! {arg0_span=>
158                #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_default);
159            },
160            JoinOptions::Reduce(lhs_reduce) => quote_spanned! {arg0_span=>
161                #lhs_borrow.reduce_into(#lhs, #lhs_reduce);
162            },
163        };
164
165        let rhs_tokens = match rhs_join_options {
166            JoinOptions::FoldFrom(rhs_from, rhs_fold) => quote_spanned! {arg0_span=>
167                #rhs_borrow.fold_into(#rhs, #rhs_fold, #rhs_from);
168            },
169            JoinOptions::Fold(rhs_default, rhs_fold) => quote_spanned! {arg1_span=>
170                #rhs_borrow.fold_into(#rhs, #rhs_fold, #rhs_default);
171            },
172            JoinOptions::Reduce(rhs_reduce) => quote_spanned! {arg1_span=>
173                #rhs_borrow.reduce_into(#rhs, #rhs_reduce);
174            },
175        };
176
177        // Since both input arguments are stratum blocking then we don't need to keep track of ticks to avoid emitting the same thing twice in the same tick.
178        let write_iterator = quote_spanned! {op_span=>
179            #lhs_pre_write_iter
180            #rhs_pre_write_iter
181
182            let #ident = {
183                #lhs_tokens
184                #rhs_tokens
185
186                // TODO: start the iterator with the smallest len() table rather than always picking rhs.
187                #[allow(clippy::clone_on_copy)]
188                #[allow(suspicious_double_ref_op)]
189                #rhs_borrow
190                    .table
191                    .iter()
192                    .filter_map(|(k, v2)| #lhs_borrow.table.get(k).map(|v1| (k.clone(), (v1.clone(), v2.clone()))))
193            };
194        };
195
196        let write_iterator_after =
197            if persistences[0] == Persistence::Static || persistences[1] == Persistence::Static {
198                quote_spanned! {op_span=>
199                    // TODO: Probably only need to schedule if #*_borrow.len() > 0?
200                    #context.schedule_subgraph(#context.current_subgraph(), false);
201                }
202            } else {
203                quote_spanned! {op_span=>}
204            };
205
206        Ok(OperatorWriteOutput {
207            write_prologue,
208            write_iterator,
209            write_iterator_after,
210        })
211    },
212};
213
214pub(crate) enum JoinOptions<'a> {
215    FoldFrom(&'a Expr, &'a Expr),
216    Fold(&'a Expr, &'a Expr),
217    Reduce(&'a Expr),
218}
219
220pub(crate) fn parse_argument(arg: &Expr) -> Result<JoinOptions, Diagnostic> {
221    let Expr::Call(ExprCall {
222        attrs: _,
223        func,
224        paren_token: _,
225        args,
226    }) = arg
227    else {
228        return Err(Diagnostic::spanned(
229            arg.span(),
230            Level::Error,
231            format!("Argument must be a function call: {arg:?}"),
232        ));
233    };
234
235    let mut elems = args.iter();
236    let func_name = func.to_token_stream().to_string();
237
238    match func_name.as_str() {
239        "Fold" => match (elems.next(), elems.next()) {
240            (Some(default), Some(fold)) => Ok(JoinOptions::Fold(default, fold)),
241            _ => Err(Diagnostic::spanned(
242                args.span(),
243                Level::Error,
244                format!(
245                    "Fold requires two arguments, first is the default function, second is the folding function: {func:?}"
246                ),
247            )),
248        },
249        "FoldFrom" => match (elems.next(), elems.next()) {
250            (Some(from), Some(fold)) => Ok(JoinOptions::FoldFrom(from, fold)),
251            _ => Err(Diagnostic::spanned(
252                args.span(),
253                Level::Error,
254                format!(
255                    "FoldFrom requires two arguments, first is the From function, second is the folding function: {func:?}"
256                ),
257            )),
258        },
259        "Reduce" => match elems.next() {
260            Some(reduce) => Ok(JoinOptions::Reduce(reduce)),
261            _ => Err(Diagnostic::spanned(
262                args.span(),
263                Level::Error,
264                format!("Reduce requires one argument, the reducing function: {func:?}"),
265            )),
266        },
267        _ => Err(Diagnostic::spanned(
268            func.span(),
269            Level::Error,
270            format!("Unknown summarizing function: {func:?}"),
271        )),
272    }
273}
274
275pub(crate) fn make_joindata(
276    wc: &WriteContextArgs,
277    persistence: Persistence,
278    join_options: &JoinOptions<'_>,
279    side: &str,
280) -> Result<(TokenStream, TokenStream, TokenStream), Diagnostic> {
281    let joindata_ident = wc.make_ident(format!("joindata_{}", side));
282    let borrow_ident = wc.make_ident(format!("joindata_{}_borrow", side));
283
284    let &WriteContextArgs {
285        context,
286        df_ident,
287        root,
288        op_span,
289        ..
290    } = wc;
291
292    let join_type = match *join_options {
293        JoinOptions::FoldFrom(_, _) => {
294            quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateFoldFrom)
295        }
296        JoinOptions::Fold(_, _) => {
297            quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateFold)
298        }
299        JoinOptions::Reduce(_) => {
300            quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateReduce)
301        }
302    };
303
304    let (prologue, pre_write_iter, borrow) = match persistence {
305        Persistence::None => (
306            Default::default(),
307            quote_spanned! {op_span=>
308                let mut #borrow_ident = #join_type::default();
309            },
310            quote_spanned! {op_span=>
311                #borrow_ident
312            },
313        ),
314        Persistence::Tick => (
315            quote_spanned! {op_span=>
316                let #joindata_ident = #df_ident.add_state(std::cell::RefCell::new(
317                    #root::util::monotonic_map::MonotonicMap::new_init(
318                        #join_type::default()
319                    )
320                ));
321            },
322            quote_spanned! {op_span=>
323                let mut #borrow_ident = unsafe {
324                    // SAFETY: handles from `#df_ident`.
325                    #context.state_ref_unchecked(#joindata_ident)
326                }.borrow_mut();
327            },
328            quote_spanned! {op_span=>
329                #borrow_ident.get_mut_clear(#context.current_tick())
330            },
331        ),
332        Persistence::Static => (
333            quote_spanned! {op_span=>
334                let #joindata_ident = #df_ident.add_state(std::cell::RefCell::new(
335                    #join_type::default()
336                ));
337            },
338            quote_spanned! {op_span=>
339                let mut #borrow_ident = unsafe {
340                    // SAFETY: handles from `#df_ident`.
341                    #context.state_ref_unchecked(#joindata_ident)
342                }.borrow_mut();
343            },
344            quote_spanned! {op_span=>
345                #borrow_ident
346            },
347        ),
348        Persistence::Mutable => {
349            return Err(Diagnostic::spanned(
350                op_span,
351                Level::Error,
352                "An implementation of 'mutable does not exist",
353            ));
354        }
355    };
356    Ok((prologue, pre_write_iter, borrow))
357}
358
359pub(crate) fn parse_persistences(persistences: &[Persistence]) -> [Persistence; 2] {
360    match persistences {
361        [] => [Persistence::Tick, Persistence::Tick],
362        [a] => [*a, *a],
363        [a, b] => [*a, *b],
364        _ => panic!("Too many persistences: {persistences:?}"),
365    }
366}