dfir_lang/graph/ops/
reduce_keyed.rs

1use quote::{ToTokens, quote_spanned};
2
3use super::{
4    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5    OperatorWriteOutput, Persistence, RANGE_1, WriteContextArgs,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// > 1 input stream of type `(K, V)`, 1 output stream of type `(K, V)`.
10/// > The output will have one tuple for each distinct `K`, with an accumulated (reduced) value of
11/// > type `V`.
12///
13/// If you need the accumulated value to have a different type than the input, use [`fold_keyed`](#fold_keyed).
14///
15/// > Arguments: one Rust closures. The closure takes two arguments: an `&mut` 'accumulator', and
16/// > an element. Accumulator should be updated based on the element.
17///
18/// A special case of `reduce`, in the spirit of SQL's GROUP BY and aggregation constructs. The input
19/// is partitioned into groups by the first field, and for each group the values in the second
20/// field are accumulated via the closures in the arguments.
21///
22/// > Note: The closures have access to the [`context` object](surface_flows.mdx#the-context-object).
23///
24/// `reduce_keyed` can also be provided with one generic lifetime persistence argument, either
25/// `'tick` or `'static`, to specify how data persists. With `'tick`, values will only be collected
26/// within the same tick. With `'static`, values will be remembered across ticks and will be
27/// aggregated with pairs arriving in later ticks. When not explicitly specified persistence
28/// defaults to `'tick`.
29///
30/// `reduce_keyed` can also be provided with two type arguments, the key and value type. This is
31/// required when using `'static` persistence if the compiler cannot infer the types.
32///
33/// ```dfir
34/// source_iter([("toy", 1), ("toy", 2), ("shoe", 11), ("shoe", 35), ("haberdashery", 7)])
35///     -> reduce_keyed(|old: &mut u32, val: u32| *old += val)
36///     -> assert_eq([("toy", 3), ("shoe", 46), ("haberdashery", 7)]);
37/// ```
38///
39/// Example using `'tick` persistence and type arguments:
40/// ```rustbook
41/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<(&str, &str)>();
42/// let mut flow = dfir_rs::dfir_syntax! {
43///     source_stream(input_recv)
44///         -> reduce_keyed::<'tick, &str>(|old: &mut _, val| *old = std::cmp::max(*old, val))
45///         -> for_each(|(k, v)| println!("({:?}, {:?})", k, v));
46/// };
47///
48/// input_send.send(("hello", "oakland")).unwrap();
49/// input_send.send(("hello", "berkeley")).unwrap();
50/// input_send.send(("hello", "san francisco")).unwrap();
51/// flow.run_available();
52/// // ("hello", "oakland, berkeley, san francisco, ")
53///
54/// input_send.send(("hello", "palo alto")).unwrap();
55/// flow.run_available();
56/// // ("hello", "palo alto, ")
57/// ```
58pub const REDUCE_KEYED: OperatorConstraints = OperatorConstraints {
59    name: "reduce_keyed",
60    categories: &[OperatorCategory::KeyedFold],
61    hard_range_inn: RANGE_1,
62    soft_range_inn: RANGE_1,
63    hard_range_out: RANGE_1,
64    soft_range_out: RANGE_1,
65    num_args: 1,
66    persistence_args: &(0..=1),
67    type_args: &(0..=2),
68    is_external_input: false,
69    // If this is set to true, the state will need to be cleared using `#context.set_state_tick_hook`
70    // to prevent reading uncleared data if this subgraph doesn't run.
71    // https://github.com/hydro-project/hydro/issues/1298
72    has_singleton_output: false,
73    flo_type: None,
74    ports_inn: None,
75    ports_out: None,
76    input_delaytype_fn: |_| Some(DelayType::Stratum),
77    write_fn: |wc @ &WriteContextArgs {
78                   df_ident,
79                   context,
80                   op_span,
81                   ident,
82                   inputs,
83                   is_pull,
84                   work_fn,
85                   root,
86                   op_inst:
87                       OperatorInstance {
88                           generics:
89                               OpInstGenerics {
90                                   persistence_args,
91                                   type_args,
92                                   ..
93                               },
94                           ..
95                       },
96                   arguments,
97                   ..
98               },
99               diagnostics| {
100        assert!(is_pull);
101
102        let persistence = match persistence_args[..] {
103            [] => Persistence::Tick,
104            [a] => a,
105            _ => unreachable!(),
106        };
107
108        let generic_type_args = [
109            type_args
110                .first()
111                .map(ToTokens::to_token_stream)
112                .unwrap_or(quote_spanned!(op_span=> _)),
113            type_args
114                .get(1)
115                .map(ToTokens::to_token_stream)
116                .unwrap_or(quote_spanned!(op_span=> _)),
117        ];
118
119        let input = &inputs[0];
120        let aggfn = &arguments[0];
121
122        let hashtable_ident = wc.make_ident("hashtable");
123        let (write_prologue, write_iterator, write_iterator_after) = match persistence {
124            Persistence::None => {
125                (
126                    Default::default(),
127                    // TODO(mingwei): deduplicate with other persistence cases below.
128                    quote_spanned! {op_span=>
129                        let mut #hashtable_ident = #root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default();
130
131                        #work_fn(|| {
132                            #[inline(always)]
133                            fn check_input<Iter: ::std::iter::Iterator<Item = (A, B)>, A: ::std::clone::Clone, B: ::std::clone::Clone>(iter: Iter)
134                                -> impl ::std::iter::Iterator<Item = (A, B)> { iter }
135
136                            #[inline(always)]
137                            /// A: accumulator type
138                            /// O: output type
139                            fn call_comb_type<A, O>(acc: &mut A, item: A, f: impl Fn(&mut A, A) -> O) -> O {
140                                f(acc, item)
141                            }
142
143                            for kv in check_input(#input) {
144                                match #hashtable_ident.entry(kv.0) {
145                                    ::std::collections::hash_map::Entry::Vacant(vacant) => {
146                                        vacant.insert(kv.1);
147                                    }
148                                    ::std::collections::hash_map::Entry::Occupied(mut occupied) => {
149                                        #[allow(clippy::redundant_closure_call)] call_comb_type(occupied.get_mut(), kv.1, #aggfn);
150                                    }
151                                }
152                            }
153                        });
154
155                        let #ident = #hashtable_ident.drain();
156                    },
157                    Default::default(),
158                )
159            }
160            Persistence::Tick => {
161                let groupbydata_ident = wc.make_ident("groupbydata");
162
163                (
164                    quote_spanned! {op_span=>
165                        let #groupbydata_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
166                    },
167                    quote_spanned! {op_span=>
168                        let mut #hashtable_ident = unsafe {
169                            // SAFETY: handle from `#df_ident.add_state(..)`.
170                            #context.state_ref_unchecked(#groupbydata_ident)
171                        }.borrow_mut();
172
173                        #work_fn(|| {
174                            #[inline(always)]
175                            fn check_input<Iter: ::std::iter::Iterator<Item = (A, B)>, A: ::std::clone::Clone, B: ::std::clone::Clone>(iter: Iter)
176                                -> impl ::std::iter::Iterator<Item = (A, B)> { iter }
177
178                            #[inline(always)]
179                            /// A: accumulator type
180                            /// O: output type
181                            fn call_comb_type<A, O>(acc: &mut A, item: A, f: impl Fn(&mut A, A) -> O) -> O {
182                                f(acc, item)
183                            }
184
185                            for kv in check_input(#input) {
186                                match #hashtable_ident.entry(kv.0) {
187                                    ::std::collections::hash_map::Entry::Vacant(vacant) => {
188                                        vacant.insert(kv.1);
189                                    }
190                                    ::std::collections::hash_map::Entry::Occupied(mut occupied) => {
191                                        #[allow(clippy::redundant_closure_call)] call_comb_type(occupied.get_mut(), kv.1, #aggfn);
192                                    }
193                                }
194                            }
195                        });
196
197                        let #ident = #hashtable_ident.drain();
198                    },
199                    Default::default(),
200                )
201            }
202            Persistence::Static => {
203                let groupbydata_ident = wc.make_ident("groupbydata");
204
205                (
206                    quote_spanned! {op_span=>
207                        let #groupbydata_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
208                    },
209                    quote_spanned! {op_span=>
210                        let mut #hashtable_ident = unsafe {
211                            // SAFETY: handle from `#df_ident.add_state(..)`.
212                            #context.state_ref_unchecked(#groupbydata_ident)
213                        }.borrow_mut();
214
215                        #work_fn(|| {
216                            #[inline(always)]
217                            fn check_input<Iter: ::std::iter::Iterator<Item = (A, B)>, A: ::std::clone::Clone, B: ::std::clone::Clone>(iter: Iter)
218                                -> impl ::std::iter::Iterator<Item = (A, B)> { iter }
219
220                            #[inline(always)]
221                            /// A: accumulator type
222                            /// O: output type
223                            fn call_comb_type<A, O>(acc: &mut A, item: A, f: impl Fn(&mut A, A) -> O) -> O {
224                                f(acc, item)
225                            }
226
227                            for kv in check_input(#input) {
228                                match #hashtable_ident.entry(kv.0) {
229                                    ::std::collections::hash_map::Entry::Vacant(vacant) => {
230                                        vacant.insert(kv.1);
231                                    }
232                                    ::std::collections::hash_map::Entry::Occupied(mut occupied) => {
233                                        #[allow(clippy::redundant_closure_call)] call_comb_type(occupied.get_mut(), kv.1, #aggfn);
234                                    }
235                                }
236                            }
237                        });
238
239                        let #ident = #context.is_first_run_this_tick()
240                            .then_some(#hashtable_ident.iter())
241                            .into_iter()
242                            .flatten()
243                            .map(
244                                // TODO(mingwei): remove `unknown_lints` when `suspicious_double_ref_op` is stabilized.
245                                #[allow(unknown_lints, suspicious_double_ref_op, clippy::clone_on_copy)]
246                                |(k, v)| (
247                                    ::std::clone::Clone::clone(k),
248                                    ::std::clone::Clone::clone(v),
249                                )
250                            );
251                    },
252                    quote_spanned! {op_span=>
253                        #context.schedule_subgraph(#context.current_subgraph(), false);
254                    },
255                )
256            }
257
258            Persistence::Mutable => {
259                diagnostics.push(Diagnostic::spanned(
260                    op_span,
261                    Level::Error,
262                    "An implementation of 'mutable does not exist",
263                ));
264                return Err(());
265            }
266        };
267
268        Ok(OperatorWriteOutput {
269            write_prologue,
270            write_iterator,
271            write_iterator_after,
272        })
273    },
274};