dfir_lang/graph/ops/
reduce.rs

1use quote::quote_spanned;
2
3use super::{
4    DelayType, OperatorCategory, OperatorConstraints, OperatorWriteOutput, Persistence, RANGE_0,
5    RANGE_1, WriteContextArgs,
6};
7
8/// > 1 input stream, 1 output stream
9///
10/// > Arguments: a closure which itself takes two arguments:
11/// > an `&mut Accum` accumulator mutable reference, and an `Item`. The closure should merge the item
12/// > into the accumulator.
13///
14/// Akin to Rust's built-in [`reduce`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.reduce)
15/// operator, except that it takes the accumulator by `&mut` instead of by value. Reduces every
16/// item into an accumulator by applying a closure, returning the final result.
17///
18/// > Note: The closure has access to the [`context` object](surface_flows.mdx#the-context-object).
19///
20/// `reduce` can also be provided with one generic lifetime persistence argument, either
21/// `'tick` or `'static`, to specify how data persists. With `'tick`, values will only be collected
22/// within the same tick. With `'static`, the accumulated value will be remembered across ticks and
23/// items are aggregated with items arriving in later ticks. When not explicitly specified
24/// persistence defaults to `'tick`.
25///
26/// ```dfir
27/// source_iter([1,2,3,4,5])
28///     -> reduce::<'tick>(|accum: &mut _, elem| {
29///         *accum *= elem;
30///     })
31///     -> assert_eq([120]);
32/// ```
33pub const REDUCE: OperatorConstraints = OperatorConstraints {
34    name: "reduce",
35    categories: &[OperatorCategory::Fold],
36    hard_range_inn: RANGE_1,
37    soft_range_inn: RANGE_1,
38    hard_range_out: &(0..=1),
39    soft_range_out: &(0..=1),
40    num_args: 1,
41    persistence_args: &(0..=1),
42    type_args: RANGE_0,
43    is_external_input: false,
44    has_singleton_output: true,
45    flo_type: None,
46    ports_inn: None,
47    ports_out: None,
48    input_delaytype_fn: |_| Some(DelayType::Stratum),
49    write_fn: |wc @ &WriteContextArgs {
50                   root,
51                   context,
52                   df_ident,
53                   op_span,
54                   ident,
55                   inputs,
56                   is_pull,
57                   singleton_output_ident,
58                   work_fn,
59                   arguments,
60                   ..
61               },
62               diagnostics| {
63        let [persistence] = wc.persistence_args_disallow_mutable(diagnostics);
64
65        let write_prologue = quote_spanned! {op_span=>
66            let #singleton_output_ident = #df_ident.add_state(::std::cell::RefCell::new(::std::option::Option::None));
67        };
68        let write_prologue_after = wc
69            .persistence_as_state_lifespan(persistence)
70            .map(|lifespan| quote_spanned! {op_span=>
71                #df_ident.set_state_lifespan_hook(#singleton_output_ident, #lifespan, move |rcell| { rcell.replace(::std::option::Option::None); });
72            }).unwrap_or_default();
73
74        let func = &arguments[0];
75        let accumulator_ident = wc.make_ident("accumulator");
76        let iterator_item_ident = wc.make_ident("iterator_item");
77
78        let iterator_foreach = quote_spanned! {op_span=>
79            #[inline(always)]
80            fn call_comb_type<Item>(
81                accum: &mut Option<Item>,
82                item: Item,
83                func: impl Fn(&mut Item, Item),
84            ) {
85                match accum {
86                    accum @ None => *accum = Some(item),
87                    Some(accum) => (func)(accum, item),
88                }
89            }
90            #[allow(clippy::redundant_closure_call)]
91            call_comb_type(&mut *#accumulator_ident, #iterator_item_ident, #func);
92        };
93
94        let assign_accum_ident = quote_spanned! {op_span=>
95            #[allow(unused_mut)]
96            let mut #accumulator_ident = unsafe {
97                // SAFETY: handle from `#df_ident.add_state(..)`.
98                #context.state_ref_unchecked(#singleton_output_ident)
99            }.borrow_mut();
100        };
101
102        let write_iterator = if is_pull {
103            let input = &inputs[0];
104            quote_spanned! {op_span=>
105                let #ident = {
106                    #assign_accum_ident
107
108                    #work_fn(|| #input.for_each(|#iterator_item_ident| {
109                        #iterator_foreach
110                    }));
111
112                    #[allow(clippy::clone_on_copy)]
113                    {
114                        ::std::iter::IntoIterator::into_iter(#work_fn(|| ::std::clone::Clone::clone(&*#accumulator_ident)))
115                    }
116                };
117            }
118        } else {
119            // Is only push when used as a singleton, so no need to push to `outputs[0]`.
120            quote_spanned! {op_span=>
121                let #ident = #root::pusherator::for_each::ForEach::new(|#iterator_item_ident| {
122                    #assign_accum_ident
123
124                    #iterator_foreach
125                });
126            }
127        };
128
129        let write_iterator_after = if Persistence::Static == persistence {
130            quote_spanned! {op_span=>
131                #context.schedule_subgraph(#context.current_subgraph(), false);
132            }
133        } else {
134            Default::default()
135        };
136
137        Ok(OperatorWriteOutput {
138            write_prologue,
139            write_prologue_after,
140            write_iterator,
141            write_iterator_after,
142        })
143    },
144};