dfir_lang/graph/ops/
fold_keyed.rs

1use quote::{ToTokens, quote_spanned};
2
3use super::{
4    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5    OperatorWriteOutput, Persistence, RANGE_1, WriteContextArgs,
6};
7
8/// > 1 input stream of type `(K, V1)`, 1 output stream of type `(K, V2)`.
9/// > The output will have one tuple for each distinct `K`, with an accumulated value of type `V2`.
10///
11/// If the input and output value types are the same and do not require initialization then use
12/// [`reduce_keyed`](#reduce_keyed).
13///
14/// > Arguments: two Rust closures. The first generates an initial value per group. The second
15/// > itself takes two arguments: an 'accumulator', and an element. The second closure returns the
16/// > value that the accumulator should have for the next iteration.
17///
18/// A special case of `fold`, in the spirit of SQL's GROUP BY and aggregation constructs. The input
19/// is partitioned into groups by the first field ("keys"), and for each group the values in the second
20/// field are accumulated via the closures in the arguments.
21///
22/// > Note: The closures have access to the [`context` object](surface_flows.mdx#the-context-object).
23///
24/// ```dfir
25/// source_iter([("toy", 1), ("toy", 2), ("shoe", 11), ("shoe", 35), ("haberdashery", 7)])
26///     -> fold_keyed(|| 0, |old: &mut u32, val: u32| *old += val)
27///     -> assert_eq([("toy", 3), ("shoe", 46), ("haberdashery", 7)]);
28/// ```
29///
30/// `fold_keyed` can be provided with one generic lifetime persistence argument, either
31/// `'tick` or `'static`, to specify how data persists. With `'tick`, values will only be collected
32/// within the same tick. With `'static`, values will be remembered across ticks and will be
33/// aggregated with pairs arriving in later ticks. When not explicitly specified persistence
34/// defaults to `'tick`.
35///
36/// `fold_keyed` can also be provided with two type arguments, the key type `K` and aggregated
37/// output value type `V2`. This is required when using `'static` persistence if the compiler
38/// cannot infer the types.
39///
40/// ```dfir
41/// source_iter([("toy", 1), ("toy", 2), ("shoe", 11), ("shoe", 35), ("haberdashery", 7)])
42///     -> fold_keyed(|| 0, |old: &mut u32, val: u32| *old += val)
43///     -> assert_eq([("toy", 3), ("shoe", 46), ("haberdashery", 7)]);
44/// ```
45///
46/// Example using `'tick` persistence:
47/// ```rustbook
48/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<(&str, &str)>();
49/// let mut flow = dfir_rs::dfir_syntax! {
50///     source_stream(input_recv)
51///         -> fold_keyed::<'tick, &str, String>(String::new, |old: &mut _, val| {
52///             *old += val;
53///             *old += ", ";
54///         })
55///         -> for_each(|(k, v)| println!("({:?}, {:?})", k, v));
56/// };
57///
58/// input_send.send(("hello", "oakland")).unwrap();
59/// input_send.send(("hello", "berkeley")).unwrap();
60/// input_send.send(("hello", "san francisco")).unwrap();
61/// flow.run_available();
62/// // ("hello", "oakland, berkeley, san francisco, ")
63///
64/// input_send.send(("hello", "palo alto")).unwrap();
65/// flow.run_available();
66/// // ("hello", "palo alto, ")
67/// ```
68pub const FOLD_KEYED: OperatorConstraints = OperatorConstraints {
69    name: "fold_keyed",
70    categories: &[OperatorCategory::KeyedFold],
71    hard_range_inn: RANGE_1,
72    soft_range_inn: RANGE_1,
73    hard_range_out: RANGE_1,
74    soft_range_out: RANGE_1,
75    num_args: 2,
76    persistence_args: &(0..=1),
77    type_args: &(0..=2),
78    is_external_input: false,
79    // If this is set to true, the state will need to be cleared using `#context.set_state_tick_hook`
80    // to prevent reading uncleared data if this subgraph doesn't run.
81    // https://github.com/hydro-project/hydro/issues/1298
82    has_singleton_output: false,
83    flo_type: None,
84    ports_inn: None,
85    ports_out: None,
86    input_delaytype_fn: |_| Some(DelayType::Stratum),
87    write_fn: |wc @ &WriteContextArgs {
88                   df_ident,
89                   context,
90                   op_span,
91                   ident,
92                   inputs,
93                   is_pull,
94                   work_fn,
95                   root,
96                   op_name,
97                   op_inst:
98                       OperatorInstance {
99                           generics:
100                               OpInstGenerics {
101                                   persistence_args,
102                                   type_args,
103                                   ..
104                               },
105                           ..
106                       },
107                   arguments,
108                   ..
109               },
110               _| {
111        assert!(is_pull, "TODO(mingwei): `{}` only supports pull.", op_name);
112
113        let persistence = match persistence_args[..] {
114            [] => Persistence::Tick,
115            [a] => a,
116            _ => unreachable!(),
117        };
118
119        let generic_type_args = [
120            type_args
121                .first()
122                .map(ToTokens::to_token_stream)
123                .unwrap_or(quote_spanned!(op_span=> _)),
124            type_args
125                .get(1)
126                .map(ToTokens::to_token_stream)
127                .unwrap_or(quote_spanned!(op_span=> _)),
128        ];
129
130        let input = &inputs[0];
131        let initfn = &arguments[0];
132        let aggfn = &arguments[1];
133
134        let groupbydata_ident = wc.make_ident("groupbydata");
135        let hashtable_ident = wc.make_ident("hashtable");
136
137        let (write_prologue, write_iterator, write_iterator_after) = match persistence {
138            Persistence::None => {
139                (
140                    Default::default(),
141                    // TODO(mingwei): deduplicate this code with the other persistence cases.
142                    quote_spanned! {op_span=>
143                        let mut #hashtable_ident = #root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default();
144
145                        #work_fn(|| {
146                            #[inline(always)]
147                            fn check_input<Iter, A, B>(iter: Iter) -> impl ::std::iter::Iterator<Item = (A, B)>
148                            where
149                                Iter: std::iter::Iterator<Item = (A, B)>,
150                                A: ::std::clone::Clone,
151                                B: ::std::clone::Clone
152                            {
153                                iter
154                            }
155
156                            /// A: accumulator type
157                            /// T: iterator item type
158                            /// O: output type
159                            #[inline(always)]
160                            fn call_comb_type<A, T, O>(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O {
161                                (f)(a, t)
162                            }
163
164                            for kv in check_input(#input) {
165                                // TODO(mingwei): remove `unknown_lints` when `clippy::unwrap_or_default` is stabilized.
166                                #[allow(unknown_lints, clippy::unwrap_or_default)]
167                                let entry = #hashtable_ident.entry(kv.0).or_insert_with(#initfn);
168                                #[allow(clippy::redundant_closure_call)] call_comb_type(entry, kv.1, #aggfn);
169                            }
170                        });
171
172                        let #ident = #hashtable_ident.drain();
173                    },
174                    Default::default(),
175                )
176            }
177            Persistence::Tick => {
178                (
179                    quote_spanned! {op_span=>
180                        let #groupbydata_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
181                    },
182                    quote_spanned! {op_span=>
183                        let mut #hashtable_ident = unsafe {
184                            // SAFETY: handle from `#df_ident.add_state(..)`.
185                            #context.state_ref_unchecked(#groupbydata_ident)
186                        }.borrow_mut();
187
188                        #work_fn(|| {
189                            #[inline(always)]
190                            fn check_input<Iter, A, B>(iter: Iter) -> impl ::std::iter::Iterator<Item = (A, B)>
191                            where
192                                Iter: std::iter::Iterator<Item = (A, B)>,
193                                A: ::std::clone::Clone,
194                                B: ::std::clone::Clone
195                            {
196                                iter
197                            }
198
199                            /// A: accumulator type
200                            /// T: iterator item type
201                            /// O: output type
202                            #[inline(always)]
203                            fn call_comb_type<A, T, O>(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O {
204                                (f)(a, t)
205                            }
206
207                            for kv in check_input(#input) {
208                                // TODO(mingwei): remove `unknown_lints` when `clippy::unwrap_or_default` is stabilized.
209                                #[allow(unknown_lints, clippy::unwrap_or_default)]
210                                let entry = #hashtable_ident.entry(kv.0).or_insert_with(#initfn);
211                                #[allow(clippy::redundant_closure_call)] call_comb_type(entry, kv.1, #aggfn);
212                            }
213                        });
214
215                        let #ident = #hashtable_ident.drain();
216                    },
217                    Default::default(),
218                )
219            }
220            Persistence::Static => {
221                (
222                    quote_spanned! {op_span=>
223                        let #groupbydata_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
224                    },
225                    quote_spanned! {op_span=>
226                        let mut #hashtable_ident = unsafe {
227                            // SAFETY: handle from `#df_ident.add_state(..)`.
228                            #context.state_ref_unchecked(#groupbydata_ident)
229                        }.borrow_mut();
230
231                        #work_fn(|| {
232                            #[inline(always)]
233                            fn check_input<Iter, A, B>(iter: Iter) -> impl ::std::iter::Iterator<Item = (A, B)>
234                            where
235                                Iter: std::iter::Iterator<Item = (A, B)>,
236                                A: ::std::clone::Clone,
237                                B: ::std::clone::Clone
238                            {
239                                iter
240                            }
241
242                            /// A: accumulator type
243                            /// T: iterator item type
244                            /// O: output type
245                            #[inline(always)]
246                            fn call_comb_type<A, T, O>(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O {
247                                (f)(a, t)
248                            }
249
250                            for kv in check_input(#input) {
251                                // TODO(mingwei): remove `unknown_lints` when `clippy::unwrap_or_default` is stabilized.
252                                #[allow(unknown_lints, clippy::unwrap_or_default)]
253                                let entry = #hashtable_ident.entry(kv.0).or_insert_with(#initfn);
254                                #[allow(clippy::redundant_closure_call)] call_comb_type(entry, kv.1, #aggfn);
255                            }
256                        });
257
258                        // Play everything but only on the first run of this tick/stratum.
259                        // (We know we won't have any more inputs, so it is fine to only play once.
260                        // Because of the `DelayType::Stratum` or `DelayType::MonotoneAccum`).
261                        let #ident = #context.is_first_run_this_tick()
262                            .then_some(#hashtable_ident.iter())
263                            .into_iter()
264                            .flatten()
265                            .map(
266                                // TODO(mingwei): remove `unknown_lints` when `suspicious_double_ref_op` is stabilized.
267                                #[allow(unknown_lints, suspicious_double_ref_op, clippy::clone_on_copy)]
268                                |(k, v)| (
269                                    ::std::clone::Clone::clone(k),
270                                    ::std::clone::Clone::clone(v),
271                                )
272                            );
273                    },
274                    quote_spanned! {op_span=>
275                        #context.schedule_subgraph(#context.current_subgraph(), false);
276                    },
277                )
278            }
279            Persistence::Mutable => {
280                (
281                    quote_spanned! {op_span=>
282                        let #groupbydata_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
283                    },
284                    quote_spanned! {op_span=>
285                        let mut #hashtable_ident = unsafe {
286                            // SAFETY: handle from `#df_ident.add_state(..)`.
287                            #context.state_ref_unchecked(#groupbydata_ident)
288                        }.borrow_mut();
289
290                        #work_fn(|| {
291                            #[inline(always)]
292                            fn check_input<Iter: ::std::iter::Iterator<Item = #root::util::PersistenceKeyed::<K, V>>, K: ::std::clone::Clone, V: ::std::clone::Clone>(iter: Iter)
293                                -> impl ::std::iter::Iterator<Item = #root::util::PersistenceKeyed::<K, V>> { iter }
294
295                            #[inline(always)]
296                            /// A: accumulator type
297                            /// T: iterator item type
298                            /// O: output type
299                            fn call_comb_type<A, T, O>(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O {
300                                f(a, t)
301                            }
302
303                            for item in check_input(#input) {
304                                match item {
305                                    Persist(k, v) => {
306                                        let entry = #hashtable_ident.entry(k).or_insert_with(#initfn);
307                                        #[allow(clippy::redundant_closure_call)] call_comb_type(entry, v, #aggfn);
308                                    },
309                                    Delete(k) => {
310                                        #hashtable_ident.remove(&k);
311                                    },
312                                }
313                            }
314                        });
315
316                        let #ident = #hashtable_ident
317                            .iter()
318                            .map(#[allow(suspicious_double_ref_op, clippy::clone_on_copy)] |(k, v)| (k.clone(), v.clone()));
319                    },
320                    quote_spanned! {op_span=>
321                        #context.schedule_subgraph(#context.current_subgraph(), false);
322                    },
323                )
324            }
325        };
326
327        Ok(OperatorWriteOutput {
328            write_prologue,
329            write_iterator,
330            write_iterator_after,
331        })
332    },
333};