dfir_lang/graph/ops/fold_keyed.rs
1use quote::{ToTokens, quote_spanned};
2
3use super::{
4 DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5 OperatorWriteOutput, Persistence, RANGE_1, WriteContextArgs,
6};
7
8/// > 1 input stream of type `(K, V1)`, 1 output stream of type `(K, V2)`.
9/// > The output will have one tuple for each distinct `K`, with an accumulated value of type `V2`.
10///
11/// If the input and output value types are the same and do not require initialization then use
12/// [`reduce_keyed`](#reduce_keyed).
13///
14/// > Arguments: two Rust closures. The first generates an initial value per group. The second
15/// > itself takes two arguments: an 'accumulator', and an element. The second closure returns the
16/// > value that the accumulator should have for the next iteration.
17///
18/// A special case of `fold`, in the spirit of SQL's GROUP BY and aggregation constructs. The input
19/// is partitioned into groups by the first field ("keys"), and for each group the values in the second
20/// field are accumulated via the closures in the arguments.
21///
22/// > Note: The closures have access to the [`context` object](surface_flows.mdx#the-context-object).
23///
24/// ```dfir
25/// source_iter([("toy", 1), ("toy", 2), ("shoe", 11), ("shoe", 35), ("haberdashery", 7)])
26/// -> fold_keyed(|| 0, |old: &mut u32, val: u32| *old += val)
27/// -> assert_eq([("toy", 3), ("shoe", 46), ("haberdashery", 7)]);
28/// ```
29///
30/// `fold_keyed` can be provided with one generic lifetime persistence argument, either
31/// `'tick` or `'static`, to specify how data persists. With `'tick`, values will only be collected
32/// within the same tick. With `'static`, values will be remembered across ticks and will be
33/// aggregated with pairs arriving in later ticks. When not explicitly specified persistence
34/// defaults to `'tick`.
35///
36/// `fold_keyed` can also be provided with two type arguments, the key type `K` and aggregated
37/// output value type `V2`. This is required when using `'static` persistence if the compiler
38/// cannot infer the types.
39///
40/// ```dfir
41/// source_iter([("toy", 1), ("toy", 2), ("shoe", 11), ("shoe", 35), ("haberdashery", 7)])
42/// -> fold_keyed(|| 0, |old: &mut u32, val: u32| *old += val)
43/// -> assert_eq([("toy", 3), ("shoe", 46), ("haberdashery", 7)]);
44/// ```
45///
46/// Example using `'tick` persistence:
47/// ```rustbook
48/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<(&str, &str)>();
49/// let mut flow = dfir_rs::dfir_syntax! {
50/// source_stream(input_recv)
51/// -> fold_keyed::<'tick, &str, String>(String::new, |old: &mut _, val| {
52/// *old += val;
53/// *old += ", ";
54/// })
55/// -> for_each(|(k, v)| println!("({:?}, {:?})", k, v));
56/// };
57///
58/// input_send.send(("hello", "oakland")).unwrap();
59/// input_send.send(("hello", "berkeley")).unwrap();
60/// input_send.send(("hello", "san francisco")).unwrap();
61/// flow.run_available();
62/// // ("hello", "oakland, berkeley, san francisco, ")
63///
64/// input_send.send(("hello", "palo alto")).unwrap();
65/// flow.run_available();
66/// // ("hello", "palo alto, ")
67/// ```
68pub const FOLD_KEYED: OperatorConstraints = OperatorConstraints {
69 name: "fold_keyed",
70 categories: &[OperatorCategory::KeyedFold],
71 hard_range_inn: RANGE_1,
72 soft_range_inn: RANGE_1,
73 hard_range_out: RANGE_1,
74 soft_range_out: RANGE_1,
75 num_args: 2,
76 persistence_args: &(0..=1),
77 type_args: &(0..=2),
78 is_external_input: false,
79 // If this is set to true, the state will need to be cleared using `#context.set_state_tick_hook`
80 // to prevent reading uncleared data if this subgraph doesn't run.
81 // https://github.com/hydro-project/hydro/issues/1298
82 has_singleton_output: false,
83 flo_type: None,
84 ports_inn: None,
85 ports_out: None,
86 input_delaytype_fn: |_| Some(DelayType::Stratum),
87 write_fn: |wc @ &WriteContextArgs {
88 df_ident,
89 context,
90 op_span,
91 ident,
92 inputs,
93 is_pull,
94 work_fn,
95 root,
96 op_name,
97 op_inst:
98 OperatorInstance {
99 generics:
100 OpInstGenerics {
101 persistence_args,
102 type_args,
103 ..
104 },
105 ..
106 },
107 arguments,
108 ..
109 },
110 _| {
111 assert!(is_pull, "TODO(mingwei): `{}` only supports pull.", op_name);
112
113 let persistence = match persistence_args[..] {
114 [] => Persistence::Tick,
115 [a] => a,
116 _ => unreachable!(),
117 };
118
119 let generic_type_args = [
120 type_args
121 .first()
122 .map(ToTokens::to_token_stream)
123 .unwrap_or(quote_spanned!(op_span=> _)),
124 type_args
125 .get(1)
126 .map(ToTokens::to_token_stream)
127 .unwrap_or(quote_spanned!(op_span=> _)),
128 ];
129
130 let input = &inputs[0];
131 let initfn = &arguments[0];
132 let aggfn = &arguments[1];
133
134 let groupbydata_ident = wc.make_ident("groupbydata");
135 let hashtable_ident = wc.make_ident("hashtable");
136
137 let (write_prologue, write_iterator, write_iterator_after) = match persistence {
138 Persistence::None => {
139 (
140 Default::default(),
141 // TODO(mingwei): deduplicate this code with the other persistence cases.
142 quote_spanned! {op_span=>
143 let mut #hashtable_ident = #root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default();
144
145 #work_fn(|| {
146 #[inline(always)]
147 fn check_input<Iter, A, B>(iter: Iter) -> impl ::std::iter::Iterator<Item = (A, B)>
148 where
149 Iter: std::iter::Iterator<Item = (A, B)>,
150 A: ::std::clone::Clone,
151 B: ::std::clone::Clone
152 {
153 iter
154 }
155
156 /// A: accumulator type
157 /// T: iterator item type
158 /// O: output type
159 #[inline(always)]
160 fn call_comb_type<A, T, O>(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O {
161 (f)(a, t)
162 }
163
164 for kv in check_input(#input) {
165 // TODO(mingwei): remove `unknown_lints` when `clippy::unwrap_or_default` is stabilized.
166 #[allow(unknown_lints, clippy::unwrap_or_default)]
167 let entry = #hashtable_ident.entry(kv.0).or_insert_with(#initfn);
168 #[allow(clippy::redundant_closure_call)] call_comb_type(entry, kv.1, #aggfn);
169 }
170 });
171
172 let #ident = #hashtable_ident.drain();
173 },
174 Default::default(),
175 )
176 }
177 Persistence::Tick => {
178 (
179 quote_spanned! {op_span=>
180 let #groupbydata_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
181 },
182 quote_spanned! {op_span=>
183 let mut #hashtable_ident = unsafe {
184 // SAFETY: handle from `#df_ident.add_state(..)`.
185 #context.state_ref_unchecked(#groupbydata_ident)
186 }.borrow_mut();
187
188 #work_fn(|| {
189 #[inline(always)]
190 fn check_input<Iter, A, B>(iter: Iter) -> impl ::std::iter::Iterator<Item = (A, B)>
191 where
192 Iter: std::iter::Iterator<Item = (A, B)>,
193 A: ::std::clone::Clone,
194 B: ::std::clone::Clone
195 {
196 iter
197 }
198
199 /// A: accumulator type
200 /// T: iterator item type
201 /// O: output type
202 #[inline(always)]
203 fn call_comb_type<A, T, O>(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O {
204 (f)(a, t)
205 }
206
207 for kv in check_input(#input) {
208 // TODO(mingwei): remove `unknown_lints` when `clippy::unwrap_or_default` is stabilized.
209 #[allow(unknown_lints, clippy::unwrap_or_default)]
210 let entry = #hashtable_ident.entry(kv.0).or_insert_with(#initfn);
211 #[allow(clippy::redundant_closure_call)] call_comb_type(entry, kv.1, #aggfn);
212 }
213 });
214
215 let #ident = #hashtable_ident.drain();
216 },
217 Default::default(),
218 )
219 }
220 Persistence::Static => {
221 (
222 quote_spanned! {op_span=>
223 let #groupbydata_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
224 },
225 quote_spanned! {op_span=>
226 let mut #hashtable_ident = unsafe {
227 // SAFETY: handle from `#df_ident.add_state(..)`.
228 #context.state_ref_unchecked(#groupbydata_ident)
229 }.borrow_mut();
230
231 #work_fn(|| {
232 #[inline(always)]
233 fn check_input<Iter, A, B>(iter: Iter) -> impl ::std::iter::Iterator<Item = (A, B)>
234 where
235 Iter: std::iter::Iterator<Item = (A, B)>,
236 A: ::std::clone::Clone,
237 B: ::std::clone::Clone
238 {
239 iter
240 }
241
242 /// A: accumulator type
243 /// T: iterator item type
244 /// O: output type
245 #[inline(always)]
246 fn call_comb_type<A, T, O>(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O {
247 (f)(a, t)
248 }
249
250 for kv in check_input(#input) {
251 // TODO(mingwei): remove `unknown_lints` when `clippy::unwrap_or_default` is stabilized.
252 #[allow(unknown_lints, clippy::unwrap_or_default)]
253 let entry = #hashtable_ident.entry(kv.0).or_insert_with(#initfn);
254 #[allow(clippy::redundant_closure_call)] call_comb_type(entry, kv.1, #aggfn);
255 }
256 });
257
258 // Play everything but only on the first run of this tick/stratum.
259 // (We know we won't have any more inputs, so it is fine to only play once.
260 // Because of the `DelayType::Stratum` or `DelayType::MonotoneAccum`).
261 let #ident = #context.is_first_run_this_tick()
262 .then_some(#hashtable_ident.iter())
263 .into_iter()
264 .flatten()
265 .map(
266 // TODO(mingwei): remove `unknown_lints` when `suspicious_double_ref_op` is stabilized.
267 #[allow(unknown_lints, suspicious_double_ref_op, clippy::clone_on_copy)]
268 |(k, v)| (
269 ::std::clone::Clone::clone(k),
270 ::std::clone::Clone::clone(v),
271 )
272 );
273 },
274 quote_spanned! {op_span=>
275 #context.schedule_subgraph(#context.current_subgraph(), false);
276 },
277 )
278 }
279 Persistence::Mutable => {
280 (
281 quote_spanned! {op_span=>
282 let #groupbydata_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
283 },
284 quote_spanned! {op_span=>
285 let mut #hashtable_ident = unsafe {
286 // SAFETY: handle from `#df_ident.add_state(..)`.
287 #context.state_ref_unchecked(#groupbydata_ident)
288 }.borrow_mut();
289
290 #work_fn(|| {
291 #[inline(always)]
292 fn check_input<Iter: ::std::iter::Iterator<Item = #root::util::PersistenceKeyed::<K, V>>, K: ::std::clone::Clone, V: ::std::clone::Clone>(iter: Iter)
293 -> impl ::std::iter::Iterator<Item = #root::util::PersistenceKeyed::<K, V>> { iter }
294
295 #[inline(always)]
296 /// A: accumulator type
297 /// T: iterator item type
298 /// O: output type
299 fn call_comb_type<A, T, O>(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O {
300 f(a, t)
301 }
302
303 for item in check_input(#input) {
304 match item {
305 Persist(k, v) => {
306 let entry = #hashtable_ident.entry(k).or_insert_with(#initfn);
307 #[allow(clippy::redundant_closure_call)] call_comb_type(entry, v, #aggfn);
308 },
309 Delete(k) => {
310 #hashtable_ident.remove(&k);
311 },
312 }
313 }
314 });
315
316 let #ident = #hashtable_ident
317 .iter()
318 .map(#[allow(suspicious_double_ref_op, clippy::clone_on_copy)] |(k, v)| (k.clone(), v.clone()));
319 },
320 quote_spanned! {op_span=>
321 #context.schedule_subgraph(#context.current_subgraph(), false);
322 },
323 )
324 }
325 };
326
327 Ok(OperatorWriteOutput {
328 write_prologue,
329 write_iterator,
330 write_iterator_after,
331 })
332 },
333};