dfir_lang/graph/ops/fold_keyed.rs
1use quote::{ToTokens, quote_spanned};
2
3use super::{
4 DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5 OperatorWriteOutput, Persistence, RANGE_1, WriteContextArgs,
6};
7
8/// > 1 input stream of type `(K, V1)`, 1 output stream of type `(K, V2)`.
9/// > The output will have one tuple for each distinct `K`, with an accumulated value of type `V2`.
10///
11/// If the input and output value types are the same and do not require initialization then use
12/// [`reduce_keyed`](#reduce_keyed).
13///
14/// > Arguments: two Rust closures. The first generates an initial value per group. The second
15/// > itself takes two arguments: an 'accumulator', and an element. The second closure returns the
16/// > value that the accumulator should have for the next iteration.
17///
18/// A special case of `fold`, in the spirit of SQL's GROUP BY and aggregation constructs. The input
19/// is partitioned into groups by the first field ("keys"), and for each group the values in the second
20/// field are accumulated via the closures in the arguments.
21///
22/// > Note: The closures have access to the [`context` object](surface_flows.mdx#the-context-object).
23///
24/// ```dfir
25/// source_iter([("toy", 1), ("toy", 2), ("shoe", 11), ("shoe", 35), ("haberdashery", 7)])
26/// -> fold_keyed(|| 0, |old: &mut u32, val: u32| *old += val)
27/// -> assert_eq([("toy", 3), ("shoe", 46), ("haberdashery", 7)]);
28/// ```
29///
30/// `fold_keyed` can be provided with one generic lifetime persistence argument, either
31/// `'tick` or `'static`, to specify how data persists. With `'tick`, values will only be collected
32/// within the same tick. With `'static`, values will be remembered across ticks and will be
33/// aggregated with pairs arriving in later ticks. When not explicitly specified persistence
34/// defaults to `'tick`.
35///
36/// `fold_keyed` can also be provided with two type arguments, the key type `K` and aggregated
37/// output value type `V2`. This is required when using `'static` persistence if the compiler
38/// cannot infer the types.
39///
40/// ```dfir
41/// source_iter([("toy", 1), ("toy", 2), ("shoe", 11), ("shoe", 35), ("haberdashery", 7)])
42/// -> fold_keyed(|| 0, |old: &mut u32, val: u32| *old += val)
43/// -> assert_eq([("toy", 3), ("shoe", 46), ("haberdashery", 7)]);
44/// ```
45///
46/// Example using `'tick` persistence:
47/// ```rustbook
48/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<(&str, &str)>();
49/// let mut flow = dfir_rs::dfir_syntax! {
50/// source_stream(input_recv)
51/// -> fold_keyed::<'tick, &str, String>(String::new, |old: &mut _, val| {
52/// *old += val;
53/// *old += ", ";
54/// })
55/// -> for_each(|(k, v)| println!("({:?}, {:?})", k, v));
56/// };
57///
58/// input_send.send(("hello", "oakland")).unwrap();
59/// input_send.send(("hello", "berkeley")).unwrap();
60/// input_send.send(("hello", "san francisco")).unwrap();
61/// flow.run_available();
62/// // ("hello", "oakland, berkeley, san francisco, ")
63///
64/// input_send.send(("hello", "palo alto")).unwrap();
65/// flow.run_available();
66/// // ("hello", "palo alto, ")
67/// ```
68pub const FOLD_KEYED: OperatorConstraints = OperatorConstraints {
69 name: "fold_keyed",
70 categories: &[OperatorCategory::KeyedFold],
71 hard_range_inn: RANGE_1,
72 soft_range_inn: RANGE_1,
73 hard_range_out: RANGE_1,
74 soft_range_out: RANGE_1,
75 num_args: 2,
76 persistence_args: &(0..=1),
77 type_args: &(0..=2),
78 is_external_input: false,
79 has_singleton_output: true,
80 flo_type: None,
81 ports_inn: None,
82 ports_out: None,
83 input_delaytype_fn: |_| Some(DelayType::Stratum),
84 write_fn: |wc @ &WriteContextArgs {
85 df_ident,
86 context,
87 op_span,
88 ident,
89 inputs,
90 singleton_output_ident,
91 is_pull,
92 work_fn,
93 root,
94 op_name,
95 op_inst:
96 OperatorInstance {
97 generics:
98 OpInstGenerics {
99 persistence_args,
100 type_args,
101 ..
102 },
103 ..
104 },
105 arguments,
106 ..
107 },
108 _| {
109 assert!(is_pull, "TODO(mingwei): `{}` only supports pull.", op_name);
110
111 let persistence = match persistence_args[..] {
112 [] => Persistence::Tick,
113 [a] => a,
114 _ => unreachable!(),
115 };
116
117 let generic_type_args = [
118 type_args
119 .first()
120 .map(ToTokens::to_token_stream)
121 .unwrap_or(quote_spanned!(op_span=> _)),
122 type_args
123 .get(1)
124 .map(ToTokens::to_token_stream)
125 .unwrap_or(quote_spanned!(op_span=> _)),
126 ];
127
128 let input = &inputs[0];
129 let initfn = &arguments[0];
130 let aggfn = &arguments[1];
131
132 let hashtable_ident = wc.make_ident("hashtable");
133
134 let write_prologue = quote_spanned! {op_span=>
135 let #singleton_output_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
136 };
137 let write_prologue_after =wc
138 .persistence_as_state_lifespan(persistence)
139 .map(|lifespan| quote_spanned! {op_span=>
140 #[allow(clippy::redundant_closure_call)]
141 #df_ident.set_state_lifespan_hook(#singleton_output_ident, #lifespan, move |rcell| { rcell.take(); });
142 }).unwrap_or_default();
143
144 let assign_hashtable_ident = quote_spanned! {op_span=>
145 let mut #hashtable_ident = unsafe {
146 // SAFETY: handle from `#df_ident.add_state(..)`.
147 #context.state_ref_unchecked(#singleton_output_ident)
148 }.borrow_mut();
149 };
150
151 let write_iterator = if Persistence::Mutable == persistence {
152 quote_spanned! {op_span=>
153 #assign_hashtable_ident
154
155 #work_fn(|| {
156 #[inline(always)]
157 fn check_input<Iter, K, V>(iter: Iter) -> impl ::std::iter::Iterator<Item = #root::util::PersistenceKeyed::<K, V>>
158 where
159 Iter: ::std::iter::Iterator<Item = #root::util::PersistenceKeyed::<K, V>>,
160 K: ::std::clone::Clone,
161 V: ::std::clone::Clone,
162 {
163 iter
164 }
165
166 /// A: accumulator type
167 /// T: iterator item type
168 /// O: output type
169 #[inline(always)]
170 fn call_comb_type<A, T, O>(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O {
171 (f)(a, t)
172 }
173
174 for item in check_input(#input) {
175 match item {
176 Persist(k, v) => {
177 let entry = #hashtable_ident.entry(k).or_insert_with(#initfn);
178 call_comb_type(entry, v, #aggfn);
179 },
180 Delete(k) => {
181 #hashtable_ident.remove(&k);
182 },
183 }
184 }
185 });
186
187 let #ident = #hashtable_ident
188 .iter()
189 .map(#[allow(suspicious_double_ref_op, clippy::clone_on_copy)] |(k, v)| (k.clone(), v.clone()));
190 }
191 } else {
192 let iter_expr = match persistence {
193 Persistence::None | Persistence::Tick => quote_spanned! {op_span=>
194 #hashtable_ident.drain()
195 },
196 Persistence::Loop => quote_spanned! {op_span=>
197 #hashtable_ident.iter().map(
198 #[allow(suspicious_double_ref_op, clippy::clone_on_copy)]
199 |(k, v)| (
200 ::std::clone::Clone::clone(k),
201 ::std::clone::Clone::clone(v),
202 )
203 )
204 },
205 Persistence::Static => quote_spanned! {op_span=>
206 // Play everything but only on the first run of this tick/stratum.
207 // (We know we won't have any more inputs, so it is fine to only play once.
208 // Because of the `DelayType::Stratum` or `DelayType::MonotoneAccum`).
209 #context.is_first_run_this_tick()
210 .then_some(#hashtable_ident.iter())
211 .into_iter()
212 .flatten()
213 .map(
214 #[allow(suspicious_double_ref_op, clippy::clone_on_copy)]
215 |(k, v)| (
216 ::std::clone::Clone::clone(k),
217 ::std::clone::Clone::clone(v),
218 )
219 )
220 },
221 Persistence::Mutable => unreachable!(),
222 };
223
224 quote_spanned! {op_span=>
225 #assign_hashtable_ident
226
227 #work_fn(|| {
228 #[inline(always)]
229 fn check_input<Iter, K, V>(iter: Iter) -> impl ::std::iter::Iterator<Item = (K, V)>
230 where
231 Iter: std::iter::Iterator<Item = (K, V)>,
232 K: ::std::clone::Clone,
233 V: ::std::clone::Clone
234 {
235 iter
236 }
237
238 /// A: accumulator type
239 /// T: iterator item type
240 /// O: output type
241 #[inline(always)]
242 fn call_comb_type<A, T, O>(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O {
243 (f)(a, t)
244 }
245
246 for kv in check_input(#input) {
247 // TODO(mingwei): remove `unknown_lints` when `clippy::unwrap_or_default` is stabilized.
248 #[allow(unknown_lints, clippy::unwrap_or_default)]
249 let entry = #hashtable_ident.entry(kv.0).or_insert_with(#initfn);
250 call_comb_type(entry, kv.1, #aggfn);
251 }
252 });
253
254 let #ident = #iter_expr;
255 }
256 };
257
258 let write_iterator_after = match persistence {
259 Persistence::None | Persistence::Tick | Persistence::Loop => Default::default(),
260 Persistence::Static | Persistence::Mutable => quote_spanned! {op_span=>
261 // Reschedule the subgraph lazily to ensure replay on later ticks.
262 #context.schedule_subgraph(#context.current_subgraph(), false);
263 },
264 };
265
266 Ok(OperatorWriteOutput {
267 write_prologue,
268 write_prologue_after,
269 write_iterator,
270 write_iterator_after,
271 })
272 },
273};