dfir_lang/graph/ops/
reduce_keyed.rs

1use quote::{ToTokens, quote_spanned};
2
3use super::{
4    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5    OperatorWriteOutput, Persistence, RANGE_1, WriteContextArgs,
6};
7
8/// > 1 input stream of type `(K, V)`, 1 output stream of type `(K, V)`.
9/// > The output will have one tuple for each distinct `K`, with an accumulated (reduced) value of
10/// > type `V`.
11///
12/// If you need the accumulated value to have a different type than the input, use [`fold_keyed`](#fold_keyed).
13///
14/// > Arguments: one Rust closures. The closure takes two arguments: an `&mut` 'accumulator', and
15/// > an element. Accumulator should be updated based on the element.
16///
17/// A special case of `reduce`, in the spirit of SQL's GROUP BY and aggregation constructs. The input
18/// is partitioned into groups by the first field, and for each group the values in the second
19/// field are accumulated via the closures in the arguments.
20///
21/// > Note: The closures have access to the [`context` object](surface_flows.mdx#the-context-object).
22///
23/// `reduce_keyed` can also be provided with one generic lifetime persistence argument, either
24/// `'tick` or `'static`, to specify how data persists. With `'tick`, values will only be collected
25/// within the same tick. With `'static`, values will be remembered across ticks and will be
26/// aggregated with pairs arriving in later ticks. When not explicitly specified persistence
27/// defaults to `'tick`.
28///
29/// `reduce_keyed` can also be provided with two type arguments, the key and value type. This is
30/// required when using `'static` persistence if the compiler cannot infer the types.
31///
32/// ```dfir
33/// source_iter([("toy", 1), ("toy", 2), ("shoe", 11), ("shoe", 35), ("haberdashery", 7)])
34///     -> reduce_keyed(|old: &mut u32, val: u32| *old += val)
35///     -> assert_eq([("toy", 3), ("shoe", 46), ("haberdashery", 7)]);
36/// ```
37///
38/// Example using `'tick` persistence and type arguments:
39/// ```rustbook
40/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<(&str, &str)>();
41/// let mut flow = dfir_rs::dfir_syntax! {
42///     source_stream(input_recv)
43///         -> reduce_keyed::<'tick, &str>(|old: &mut _, val| *old = std::cmp::max(*old, val))
44///         -> for_each(|(k, v)| println!("({:?}, {:?})", k, v));
45/// };
46///
47/// input_send.send(("hello", "oakland")).unwrap();
48/// input_send.send(("hello", "berkeley")).unwrap();
49/// input_send.send(("hello", "san francisco")).unwrap();
50/// flow.run_available();
51/// // ("hello", "oakland, berkeley, san francisco, ")
52///
53/// input_send.send(("hello", "palo alto")).unwrap();
54/// flow.run_available();
55/// // ("hello", "palo alto, ")
56/// ```
57pub const REDUCE_KEYED: OperatorConstraints = OperatorConstraints {
58    name: "reduce_keyed",
59    categories: &[OperatorCategory::KeyedFold],
60    hard_range_inn: RANGE_1,
61    soft_range_inn: RANGE_1,
62    hard_range_out: RANGE_1,
63    soft_range_out: RANGE_1,
64    num_args: 1,
65    persistence_args: &(0..=1),
66    type_args: &(0..=2),
67    is_external_input: false,
68    has_singleton_output: true,
69    flo_type: None,
70    ports_inn: None,
71    ports_out: None,
72    input_delaytype_fn: |_| Some(DelayType::Stratum),
73    write_fn: |wc @ &WriteContextArgs {
74                   df_ident,
75                   context,
76                   op_span,
77                   ident,
78                   inputs,
79                   singleton_output_ident,
80                   is_pull,
81                   work_fn,
82                   root,
83                   op_name,
84                   op_inst:
85                       OperatorInstance {
86                           generics: OpInstGenerics { type_args, .. },
87                           ..
88                       },
89                   arguments,
90                   ..
91               },
92               diagnostics| {
93        assert!(is_pull, "TODO(mingwei): `{}` only supports pull.", op_name);
94
95        let [persistence] = wc.persistence_args_disallow_mutable(diagnostics);
96
97        let generic_type_args = [
98            type_args
99                .first()
100                .map(ToTokens::to_token_stream)
101                .unwrap_or(quote_spanned!(op_span=> _)),
102            type_args
103                .get(1)
104                .map(ToTokens::to_token_stream)
105                .unwrap_or(quote_spanned!(op_span=> _)),
106        ];
107
108        let input = &inputs[0];
109        let aggfn = &arguments[0];
110
111        let hashtable_ident = wc.make_ident("hashtable");
112
113        let write_prologue = quote_spanned! {op_span=>
114            let #singleton_output_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
115        };
116        let write_prologue_after = wc
117            .persistence_as_state_lifespan(persistence)
118            .map(|lifespan| quote_spanned! {op_span=>
119                #df_ident.set_state_lifespan_hook(#singleton_output_ident, #lifespan, |rcell| { rcell.take(); });
120            }).unwrap_or_default();
121
122        let write_iterator = {
123            let iter_expr = match persistence {
124                Persistence::None | Persistence::Tick => quote_spanned! {op_span=>
125                    #hashtable_ident.drain()
126                },
127                Persistence::Loop => quote_spanned! {op_span=>
128                    #hashtable_ident.iter().map(
129                        #[allow(suspicious_double_ref_op, clippy::clone_on_copy)]
130                        |(k, v)| (
131                            ::std::clone::Clone::clone(k),
132                            ::std::clone::Clone::clone(v),
133                        )
134                    )
135                },
136                Persistence::Static => quote_spanned! {op_span=>
137                    // Play everything but only on the first run of this tick/stratum.
138                    // (We know we won't have any more inputs, so it is fine to only play once.
139                    // Because of the `DelayType::Stratum` or `DelayType::MonotoneAccum`).
140                    #context.is_first_run_this_tick()
141                        .then_some(#hashtable_ident.iter())
142                        .into_iter()
143                        .flatten()
144                        .map(
145                            #[allow(suspicious_double_ref_op, clippy::clone_on_copy)]
146                            |(k, v)| (
147                                ::std::clone::Clone::clone(k),
148                                ::std::clone::Clone::clone(v),
149                            )
150                        )
151                },
152                Persistence::Mutable => unreachable!(),
153            };
154
155            quote_spanned! {op_span=>
156                let mut #hashtable_ident = unsafe {
157                    // SAFETY: handle from `#df_ident.add_state(..)`.
158                    #context.state_ref_unchecked(#singleton_output_ident)
159                }.borrow_mut();
160
161                #work_fn(|| {
162                    #[inline(always)]
163                    fn check_input<Iter, K, V>(iter: Iter) -> impl ::std::iter::Iterator<Item = (K, V)>
164                    where
165                        Iter: std::iter::Iterator<Item = (K, V)>,
166                        K: ::std::clone::Clone,
167                        V: ::std::clone::Clone
168                    {
169                        iter
170                    }
171
172                    /// A: accumulator/item type
173                    /// O: output type
174                    #[inline(always)]
175                    fn call_comb_type<A, O>(acc: &mut A, item: A, f: impl Fn(&mut A, A) -> O) -> O {
176                        (f)(acc, item)
177                    }
178
179                    for kv in check_input(#input) {
180                        match #hashtable_ident.entry(kv.0) {
181                            ::std::collections::hash_map::Entry::Vacant(vacant) => {
182                                vacant.insert(kv.1);
183                            }
184                            ::std::collections::hash_map::Entry::Occupied(mut occupied) => {
185                                call_comb_type(occupied.get_mut(), kv.1, #aggfn);
186                            }
187                        }
188                    }
189                });
190
191                let #ident = #iter_expr;
192            }
193        };
194
195        let write_iterator_after = match persistence {
196            Persistence::None | Persistence::Tick | Persistence::Loop => Default::default(),
197            Persistence::Static | Persistence::Mutable => quote_spanned! {op_span=>
198                // Reschedule the subgraph lazily to ensure replay on later ticks.
199                #context.schedule_subgraph(#context.current_subgraph(), false);
200            },
201        };
202
203        Ok(OperatorWriteOutput {
204            write_prologue,
205            write_prologue_after,
206            write_iterator,
207            write_iterator_after,
208        })
209    },
210};