dfir_lang/graph/ops/
fold_keyed.rs

1use quote::{ToTokens, quote_spanned};
2
3use super::{
4    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5    OperatorWriteOutput, Persistence, RANGE_1, WriteContextArgs,
6};
7
8/// > 1 input stream of type `(K, V1)`, 1 output stream of type `(K, V2)`.
9/// > The output will have one tuple for each distinct `K`, with an accumulated value of type `V2`.
10///
11/// If the input and output value types are the same and do not require initialization then use
12/// [`reduce_keyed`](#reduce_keyed).
13///
14/// > Arguments: two Rust closures. The first generates an initial value per group. The second
15/// > itself takes two arguments: an 'accumulator', and an element. The second closure returns the
16/// > value that the accumulator should have for the next iteration.
17///
18/// A special case of `fold`, in the spirit of SQL's GROUP BY and aggregation constructs. The input
19/// is partitioned into groups by the first field ("keys"), and for each group the values in the second
20/// field are accumulated via the closures in the arguments.
21///
22/// > Note: The closures have access to the [`context` object](surface_flows.mdx#the-context-object).
23///
24/// ```dfir
25/// source_iter([("toy", 1), ("toy", 2), ("shoe", 11), ("shoe", 35), ("haberdashery", 7)])
26///     -> fold_keyed(|| 0, |old: &mut u32, val: u32| *old += val)
27///     -> assert_eq([("toy", 3), ("shoe", 46), ("haberdashery", 7)]);
28/// ```
29///
30/// `fold_keyed` can be provided with one generic lifetime persistence argument, either
31/// `'tick` or `'static`, to specify how data persists. With `'tick`, values will only be collected
32/// within the same tick. With `'static`, values will be remembered across ticks and will be
33/// aggregated with pairs arriving in later ticks. When not explicitly specified persistence
34/// defaults to `'tick`.
35///
36/// `fold_keyed` can also be provided with two type arguments, the key type `K` and aggregated
37/// output value type `V2`. This is required when using `'static` persistence if the compiler
38/// cannot infer the types.
39///
40/// ```dfir
41/// source_iter([("toy", 1), ("toy", 2), ("shoe", 11), ("shoe", 35), ("haberdashery", 7)])
42///     -> fold_keyed(|| 0, |old: &mut u32, val: u32| *old += val)
43///     -> assert_eq([("toy", 3), ("shoe", 46), ("haberdashery", 7)]);
44/// ```
45///
46/// Example using `'tick` persistence:
47/// ```rustbook
48/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<(&str, &str)>();
49/// let mut flow = dfir_rs::dfir_syntax! {
50///     source_stream(input_recv)
51///         -> fold_keyed::<'tick, &str, String>(String::new, |old: &mut _, val| {
52///             *old += val;
53///             *old += ", ";
54///         })
55///         -> for_each(|(k, v)| println!("({:?}, {:?})", k, v));
56/// };
57///
58/// input_send.send(("hello", "oakland")).unwrap();
59/// input_send.send(("hello", "berkeley")).unwrap();
60/// input_send.send(("hello", "san francisco")).unwrap();
61/// flow.run_available();
62/// // ("hello", "oakland, berkeley, san francisco, ")
63///
64/// input_send.send(("hello", "palo alto")).unwrap();
65/// flow.run_available();
66/// // ("hello", "palo alto, ")
67/// ```
68pub const FOLD_KEYED: OperatorConstraints = OperatorConstraints {
69    name: "fold_keyed",
70    categories: &[OperatorCategory::KeyedFold],
71    hard_range_inn: RANGE_1,
72    soft_range_inn: RANGE_1,
73    hard_range_out: RANGE_1,
74    soft_range_out: RANGE_1,
75    num_args: 2,
76    persistence_args: &(0..=1),
77    type_args: &(0..=2),
78    is_external_input: false,
79    has_singleton_output: true,
80    flo_type: None,
81    ports_inn: None,
82    ports_out: None,
83    input_delaytype_fn: |_| Some(DelayType::Stratum),
84    write_fn: |wc @ &WriteContextArgs {
85                   df_ident,
86                   context,
87                   op_span,
88                   ident,
89                   inputs,
90                   singleton_output_ident,
91                   is_pull,
92                   work_fn,
93                   root,
94                   op_name,
95                   op_inst:
96                       OperatorInstance {
97                           generics:
98                               OpInstGenerics {
99                                   persistence_args,
100                                   type_args,
101                                   ..
102                               },
103                           ..
104                       },
105                   arguments,
106                   ..
107               },
108               _| {
109        assert!(is_pull, "TODO(mingwei): `{}` only supports pull.", op_name);
110
111        let persistence = match persistence_args[..] {
112            [] => Persistence::Tick,
113            [a] => a,
114            _ => unreachable!(),
115        };
116
117        let generic_type_args = [
118            type_args
119                .first()
120                .map(ToTokens::to_token_stream)
121                .unwrap_or(quote_spanned!(op_span=> _)),
122            type_args
123                .get(1)
124                .map(ToTokens::to_token_stream)
125                .unwrap_or(quote_spanned!(op_span=> _)),
126        ];
127
128        let input = &inputs[0];
129        let initfn = &arguments[0];
130        let aggfn = &arguments[1];
131
132        let hashtable_ident = wc.make_ident("hashtable");
133
134        let write_prologue = quote_spanned! {op_span=>
135            let #singleton_output_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
136        };
137        let write_prologue_after =wc
138            .persistence_as_state_lifespan(persistence)
139            .map(|lifespan| quote_spanned! {op_span=>
140                #[allow(clippy::redundant_closure_call)]
141                #df_ident.set_state_lifespan_hook(#singleton_output_ident, #lifespan, move |rcell| { rcell.take(); });
142            }).unwrap_or_default();
143
144        let assign_hashtable_ident = quote_spanned! {op_span=>
145            let mut #hashtable_ident = unsafe {
146                // SAFETY: handle from `#df_ident.add_state(..)`.
147                #context.state_ref_unchecked(#singleton_output_ident)
148            }.borrow_mut();
149        };
150
151        let write_iterator = if Persistence::Mutable == persistence {
152            quote_spanned! {op_span=>
153                #assign_hashtable_ident
154
155                #work_fn(|| {
156                    #[inline(always)]
157                    fn check_input<Iter, K, V>(iter: Iter) -> impl ::std::iter::Iterator<Item = #root::util::PersistenceKeyed::<K, V>>
158                    where
159                        Iter: ::std::iter::Iterator<Item = #root::util::PersistenceKeyed::<K, V>>,
160                        K: ::std::clone::Clone,
161                        V: ::std::clone::Clone,
162                    {
163                        iter
164                    }
165
166                    /// A: accumulator type
167                    /// T: iterator item type
168                    /// O: output type
169                    #[inline(always)]
170                    fn call_comb_type<A, T, O>(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O {
171                        (f)(a, t)
172                    }
173
174                    for item in check_input(#input) {
175                        match item {
176                            Persist(k, v) => {
177                                let entry = #hashtable_ident.entry(k).or_insert_with(#initfn);
178                                call_comb_type(entry, v, #aggfn);
179                            },
180                            Delete(k) => {
181                                #hashtable_ident.remove(&k);
182                            },
183                        }
184                    }
185                });
186
187                let #ident = #hashtable_ident
188                    .iter()
189                    .map(#[allow(suspicious_double_ref_op, clippy::clone_on_copy)] |(k, v)| (k.clone(), v.clone()));
190            }
191        } else {
192            let iter_expr = match persistence {
193                Persistence::None | Persistence::Tick => quote_spanned! {op_span=>
194                    #hashtable_ident.drain()
195                },
196                Persistence::Loop => quote_spanned! {op_span=>
197                    #hashtable_ident.iter().map(
198                        #[allow(suspicious_double_ref_op, clippy::clone_on_copy)]
199                        |(k, v)| (
200                            ::std::clone::Clone::clone(k),
201                            ::std::clone::Clone::clone(v),
202                        )
203                    )
204                },
205                Persistence::Static => quote_spanned! {op_span=>
206                    // Play everything but only on the first run of this tick/stratum.
207                    // (We know we won't have any more inputs, so it is fine to only play once.
208                    // Because of the `DelayType::Stratum` or `DelayType::MonotoneAccum`).
209                    #context.is_first_run_this_tick()
210                        .then_some(#hashtable_ident.iter())
211                        .into_iter()
212                        .flatten()
213                        .map(
214                            #[allow(suspicious_double_ref_op, clippy::clone_on_copy)]
215                            |(k, v)| (
216                                ::std::clone::Clone::clone(k),
217                                ::std::clone::Clone::clone(v),
218                            )
219                        )
220                },
221                Persistence::Mutable => unreachable!(),
222            };
223
224            quote_spanned! {op_span=>
225                #assign_hashtable_ident
226
227                #work_fn(|| {
228                    #[inline(always)]
229                    fn check_input<Iter, K, V>(iter: Iter) -> impl ::std::iter::Iterator<Item = (K, V)>
230                    where
231                        Iter: std::iter::Iterator<Item = (K, V)>,
232                        K: ::std::clone::Clone,
233                        V: ::std::clone::Clone
234                    {
235                        iter
236                    }
237
238                    /// A: accumulator type
239                    /// T: iterator item type
240                    /// O: output type
241                    #[inline(always)]
242                    fn call_comb_type<A, T, O>(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O {
243                        (f)(a, t)
244                    }
245
246                    for kv in check_input(#input) {
247                        // TODO(mingwei): remove `unknown_lints` when `clippy::unwrap_or_default` is stabilized.
248                        #[allow(unknown_lints, clippy::unwrap_or_default)]
249                        let entry = #hashtable_ident.entry(kv.0).or_insert_with(#initfn);
250                        call_comb_type(entry, kv.1, #aggfn);
251                    }
252                });
253
254                let #ident = #iter_expr;
255            }
256        };
257
258        let write_iterator_after = match persistence {
259            Persistence::None | Persistence::Tick | Persistence::Loop => Default::default(),
260            Persistence::Static | Persistence::Mutable => quote_spanned! {op_span=>
261                // Reschedule the subgraph lazily to ensure replay on later ticks.
262                #context.schedule_subgraph(#context.current_subgraph(), false);
263            },
264        };
265
266        Ok(OperatorWriteOutput {
267            write_prologue,
268            write_prologue_after,
269            write_iterator,
270            write_iterator_after,
271        })
272    },
273};