dfir_lang/graph/ops/
fold.rs

1use quote::quote_spanned;
2
3use super::{
4    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5    OperatorWriteOutput, Persistence, RANGE_0, RANGE_1, WriteContextArgs,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// > 1 input stream, 1 output stream
10///
11/// > Arguments: two arguments, both closures. The first closure is used to create the initial
12/// > value for the accumulator, and the second is used to combine new items with the existing
13/// > accumulator value. The second closure takes two two arguments: an `&mut Accum` accumulated
14/// > value, and an `Item`.
15///
16/// Akin to Rust's built-in [`fold`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.fold)
17/// operator, except that it takes the accumulator by `&mut` instead of by value. Folds every item
18/// into an accumulator by applying a closure, returning the final result.
19///
20/// > Note: The closures have access to the [`context` object](surface_flows.mdx#the-context-object).
21///
22/// `fold` can also be provided with one generic lifetime persistence argument, either
23/// `'tick` or `'static`, to specify how data persists. With `'tick`, Items will only be collected
24/// within the same tick. With `'static`, the accumulated value will be remembered across ticks and
25/// will be aggregated with items arriving in later ticks. When not explicitly specified
26/// persistence defaults to `'tick`.
27///
28/// ```dfir
29/// // should print `Reassembled vector [1,2,3,4,5]`
30/// source_iter([1,2,3,4,5])
31///     -> fold::<'tick>(Vec::new, |accum: &mut Vec<_>, elem| {
32///         accum.push(elem);
33///     })
34///     -> assert_eq([vec![1, 2, 3, 4, 5]]);
35/// ```
36pub const FOLD: OperatorConstraints = OperatorConstraints {
37    name: "fold",
38    categories: &[OperatorCategory::Fold],
39    hard_range_inn: RANGE_1,
40    soft_range_inn: RANGE_1,
41    hard_range_out: &(0..=1),
42    soft_range_out: &(0..=1),
43    num_args: 2,
44    persistence_args: &(0..=1),
45    type_args: RANGE_0,
46    is_external_input: false,
47    has_singleton_output: true,
48    flo_type: None,
49    ports_inn: None,
50    ports_out: None,
51    input_delaytype_fn: |_| Some(DelayType::Stratum),
52    write_fn: |wc @ &WriteContextArgs {
53                   root,
54                   context,
55                   df_ident,
56                   loop_id,
57                   op_span,
58                   ident,
59                   is_pull,
60                   inputs,
61                   singleton_output_ident,
62                   work_fn,
63                   op_inst:
64                       OperatorInstance {
65                           generics:
66                               OpInstGenerics {
67                                   persistence_args, ..
68                               },
69                           ..
70                       },
71                   arguments,
72                   ..
73               },
74               diagnostics| {
75
76        let persistence = persistence_args.first().copied().unwrap_or_else(|| {
77            if loop_id.is_some() {
78                Persistence::None
79            } else {
80                Persistence::Tick
81            }
82        });
83        if Persistence::Mutable == persistence {
84            diagnostics.push(Diagnostic::spanned(
85                op_span,
86                Level::Error,
87                "An implementation of 'mutable does not exist",
88            ));
89            return Err(());
90        }
91
92        let input = &inputs[0];
93        let init = &arguments[0];
94        let func = &arguments[1];
95        let initializer_func_ident = wc.make_ident("initializer_func");
96        let accumulator_ident = wc.make_ident("accumulator");
97        let iterator_item_ident = wc.make_ident("iterator_item");
98
99        let iterator_foreach = quote_spanned! {op_span=>
100            #[inline(always)]
101            fn call_comb_type<Accum, Item>(
102                accum: &mut Accum,
103                item: Item,
104                func: impl Fn(&mut Accum, Item),
105            ) {
106                (func)(accum, item);
107            }
108            #[allow(clippy::redundant_closure_call)]
109            call_comb_type(&mut *#accumulator_ident, #iterator_item_ident, #func);
110        };
111
112        let mut write_prologue = quote_spanned! {op_span=>
113            #[allow(unused_mut)]
114            let mut #initializer_func_ident = #init;
115
116            #[allow(clippy::redundant_closure_call)]
117            let #singleton_output_ident = #df_ident.add_state(
118                ::std::cell::RefCell::new((#initializer_func_ident)())
119            );
120        };
121        if Persistence::Tick == persistence {
122            write_prologue.extend(quote_spanned! {op_span=>
123                // Reset the value to the initializer fn if it is a new tick.
124                #df_ident.set_state_tick_hook(#singleton_output_ident, move |rcell| { rcell.replace((#initializer_func_ident)()); });
125            });
126        }
127        let write_iterator = if is_pull {
128            quote_spanned! {op_span=>
129                let #ident = {
130                    let mut #accumulator_ident = unsafe {
131                        // SAFETY: handle from `#df_ident.add_state(..)`.
132                        #context.state_ref_unchecked(#singleton_output_ident)
133                    }.borrow_mut();
134
135                    #work_fn(|| #input.for_each(|#iterator_item_ident| {
136                        #iterator_foreach
137                    }));
138
139                    #[allow(clippy::clone_on_copy)]
140                    {
141                        ::std::iter::once(#work_fn(|| ::std::clone::Clone::clone(&*#accumulator_ident)))
142                    }
143                };
144            }
145        } else {
146            quote_spanned! {op_span=>
147                let #ident = {
148                    #root::pusherator::for_each::ForEach::new(|#iterator_item_ident| {
149                        let mut #accumulator_ident = unsafe {
150                            // SAFETY: handle from `#df_ident.add_state(..)`.
151                            #context.state_ref_unchecked(#singleton_output_ident)
152                        }.borrow_mut();
153                        #iterator_foreach
154                    })
155                };
156            }
157        };
158        let write_iterator_after = quote_spanned! {op_span=>
159            #context.schedule_subgraph(#context.current_subgraph(), false);
160        };
161
162        Ok(OperatorWriteOutput {
163            write_prologue,
164            write_iterator,
165            write_iterator_after,
166        })
167    },
168};