dfir_lang/graph/ops/
reduce.rs

1use quote::quote_spanned;
2
3use super::{
4    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5    OperatorWriteOutput, Persistence, RANGE_0, RANGE_1, WriteContextArgs,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// > 1 input stream, 1 output stream
10///
11/// > Arguments: a closure which itself takes two arguments:
12/// > an `&mut Accum` accumulator mutable reference, and an `Item`. The closure should merge the item
13/// > into the accumulator.
14///
15/// Akin to Rust's built-in [`reduce`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.reduce)
16/// operator, except that it takes the accumulator by `&mut` instead of by value. Reduces every
17/// item into an accumulator by applying a closure, returning the final result.
18///
19/// > Note: The closure has access to the [`context` object](surface_flows.mdx#the-context-object).
20///
21/// `reduce` can also be provided with one generic lifetime persistence argument, either
22/// `'tick` or `'static`, to specify how data persists. With `'tick`, values will only be collected
23/// within the same tick. With `'static`, the accumulated value will be remembered across ticks and
24/// items are aggregated with items arriving in later ticks. When not explicitly specified
25/// persistence defaults to `'tick`.
26///
27/// ```dfir
28/// source_iter([1,2,3,4,5])
29///     -> reduce::<'tick>(|accum: &mut _, elem| {
30///         *accum *= elem;
31///     })
32///     -> assert_eq([120]);
33/// ```
34pub const REDUCE: OperatorConstraints = OperatorConstraints {
35    name: "reduce",
36    categories: &[OperatorCategory::Fold],
37    hard_range_inn: RANGE_1,
38    soft_range_inn: RANGE_1,
39    hard_range_out: &(0..=1),
40    soft_range_out: &(0..=1),
41    num_args: 1,
42    persistence_args: &(0..=1),
43    type_args: RANGE_0,
44    is_external_input: false,
45    has_singleton_output: true,
46    flo_type: None,
47    ports_inn: None,
48    ports_out: None,
49    input_delaytype_fn: |_| Some(DelayType::Stratum),
50    write_fn: |wc @ &WriteContextArgs {
51                   root,
52                   context,
53                   df_ident,
54                   op_span,
55                   ident,
56                   inputs,
57                   is_pull,
58                   singleton_output_ident,
59                   work_fn,
60                   op_inst:
61                       OperatorInstance {
62                           generics:
63                               OpInstGenerics {
64                                   persistence_args, ..
65                               },
66                           ..
67                       },
68                   arguments,
69                   ..
70               },
71               diagnostics| {
72        let persistence = match persistence_args[..] {
73            [] => Persistence::Tick,
74            [a] => a,
75            _ => unreachable!(),
76        };
77        if Persistence::Mutable == persistence {
78            diagnostics.push(Diagnostic::spanned(
79                op_span,
80                Level::Error,
81                "An implementation of 'mutable does not exist",
82            ));
83            return Err(());
84        }
85
86        let func = &arguments[0];
87        let accumulator_ident = wc.make_ident("accumulator");
88        let iterator_item_ident = wc.make_ident("iterator_item");
89
90        let iterator_foreach = quote_spanned! {op_span=>
91            #[inline(always)]
92            fn call_comb_type<Item>(
93                accum: &mut Option<Item>,
94                item: Item,
95                func: impl Fn(&mut Item, Item),
96            ) {
97                match accum {
98                    accum @ None => *accum = Some(item),
99                    Some(accum) => (func)(accum, item),
100                }
101            }
102            #[allow(clippy::redundant_closure_call)]
103            call_comb_type(&mut *#accumulator_ident, #iterator_item_ident, #func);
104        };
105
106        let mut write_prologue = quote_spanned! {op_span=>
107            #[allow(clippy::redundant_closure_call)]
108            let #singleton_output_ident = #df_ident.add_state(
109                ::std::cell::RefCell::new(::std::option::Option::None)
110            );
111        };
112        if Persistence::Tick == persistence {
113            write_prologue.extend(quote_spanned! {op_span=>
114                // Reset the value to the initializer fn at the end of each tick.
115                #df_ident.set_state_tick_hook(#singleton_output_ident, |rcell| { rcell.take(); });
116            });
117        }
118
119        let write_iterator = if is_pull {
120            let input = &inputs[0];
121            quote_spanned! {op_span=>
122                let #ident = {
123                    let mut #accumulator_ident = unsafe {
124                        // SAFETY: handle from `#df_ident.add_state(..)`.
125                        #context.state_ref_unchecked(#singleton_output_ident)
126                    }.borrow_mut();
127
128                    #work_fn(|| #input.for_each(|#iterator_item_ident| {
129                        #iterator_foreach
130                    }));
131
132                    #[allow(clippy::clone_on_copy)]
133                    {
134                        ::std::iter::IntoIterator::into_iter(#work_fn(|| ::std::clone::Clone::clone(&*#accumulator_ident)))
135                    }
136                };
137            }
138        } else {
139            // Is only push when used as a singleton, so no need to push to `outputs[0]`.
140            quote_spanned! {op_span=>
141                let #ident = {
142                    #root::pusherator::for_each::ForEach::new(|#iterator_item_ident| {
143                        let mut #accumulator_ident = unsafe {
144                            // SAFETY: handle from `#df_ident.add_state(..)`.
145                            #context.state_ref_unchecked(#singleton_output_ident)
146                        }.borrow_mut();
147                        #iterator_foreach
148                    })
149                };
150            }
151        };
152        let write_iterator_after = if Persistence::Static == persistence {
153            quote_spanned! {op_span=>
154                #context.schedule_subgraph(#context.current_subgraph(), false);
155            }
156        } else {
157            Default::default()
158        };
159
160        Ok(OperatorWriteOutput {
161            write_prologue,
162            write_iterator,
163            write_iterator_after,
164        })
165    },
166};