dfir_lang/graph/ops/reduce_keyed.rs
1use quote::{ToTokens, quote_spanned};
2
3use super::{
4 DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5 OperatorWriteOutput, Persistence, RANGE_1, WriteContextArgs,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// > 1 input stream of type `(K, V)`, 1 output stream of type `(K, V)`.
10/// > The output will have one tuple for each distinct `K`, with an accumulated (reduced) value of
11/// > type `V`.
12///
13/// If you need the accumulated value to have a different type than the input, use [`fold_keyed`](#fold_keyed).
14///
15/// > Arguments: one Rust closures. The closure takes two arguments: an `&mut` 'accumulator', and
16/// > an element. Accumulator should be updated based on the element.
17///
18/// A special case of `reduce`, in the spirit of SQL's GROUP BY and aggregation constructs. The input
19/// is partitioned into groups by the first field, and for each group the values in the second
20/// field are accumulated via the closures in the arguments.
21///
22/// > Note: The closures have access to the [`context` object](surface_flows.mdx#the-context-object).
23///
24/// `reduce_keyed` can also be provided with one generic lifetime persistence argument, either
25/// `'tick` or `'static`, to specify how data persists. With `'tick`, values will only be collected
26/// within the same tick. With `'static`, values will be remembered across ticks and will be
27/// aggregated with pairs arriving in later ticks. When not explicitly specified persistence
28/// defaults to `'tick`.
29///
30/// `reduce_keyed` can also be provided with two type arguments, the key and value type. This is
31/// required when using `'static` persistence if the compiler cannot infer the types.
32///
33/// ```dfir
34/// source_iter([("toy", 1), ("toy", 2), ("shoe", 11), ("shoe", 35), ("haberdashery", 7)])
35/// -> reduce_keyed(|old: &mut u32, val: u32| *old += val)
36/// -> assert_eq([("toy", 3), ("shoe", 46), ("haberdashery", 7)]);
37/// ```
38///
39/// Example using `'tick` persistence and type arguments:
40/// ```rustbook
41/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<(&str, &str)>();
42/// let mut flow = dfir_rs::dfir_syntax! {
43/// source_stream(input_recv)
44/// -> reduce_keyed::<'tick, &str>(|old: &mut _, val| *old = std::cmp::max(*old, val))
45/// -> for_each(|(k, v)| println!("({:?}, {:?})", k, v));
46/// };
47///
48/// input_send.send(("hello", "oakland")).unwrap();
49/// input_send.send(("hello", "berkeley")).unwrap();
50/// input_send.send(("hello", "san francisco")).unwrap();
51/// flow.run_available();
52/// // ("hello", "oakland, berkeley, san francisco, ")
53///
54/// input_send.send(("hello", "palo alto")).unwrap();
55/// flow.run_available();
56/// // ("hello", "palo alto, ")
57/// ```
58pub const REDUCE_KEYED: OperatorConstraints = OperatorConstraints {
59 name: "reduce_keyed",
60 categories: &[OperatorCategory::KeyedFold],
61 hard_range_inn: RANGE_1,
62 soft_range_inn: RANGE_1,
63 hard_range_out: RANGE_1,
64 soft_range_out: RANGE_1,
65 num_args: 1,
66 persistence_args: &(0..=1),
67 type_args: &(0..=2),
68 is_external_input: false,
69 // If this is set to true, the state will need to be cleared using `#context.set_state_tick_hook`
70 // to prevent reading uncleared data if this subgraph doesn't run.
71 // https://github.com/hydro-project/hydro/issues/1298
72 has_singleton_output: false,
73 flo_type: None,
74 ports_inn: None,
75 ports_out: None,
76 input_delaytype_fn: |_| Some(DelayType::Stratum),
77 write_fn: |wc @ &WriteContextArgs {
78 df_ident,
79 context,
80 op_span,
81 ident,
82 inputs,
83 is_pull,
84 work_fn,
85 root,
86 op_inst:
87 OperatorInstance {
88 generics:
89 OpInstGenerics {
90 persistence_args,
91 type_args,
92 ..
93 },
94 ..
95 },
96 arguments,
97 ..
98 },
99 diagnostics| {
100 assert!(is_pull);
101
102 let persistence = match persistence_args[..] {
103 [] => Persistence::Tick,
104 [a] => a,
105 _ => unreachable!(),
106 };
107
108 let generic_type_args = [
109 type_args
110 .first()
111 .map(ToTokens::to_token_stream)
112 .unwrap_or(quote_spanned!(op_span=> _)),
113 type_args
114 .get(1)
115 .map(ToTokens::to_token_stream)
116 .unwrap_or(quote_spanned!(op_span=> _)),
117 ];
118
119 let input = &inputs[0];
120 let aggfn = &arguments[0];
121
122 let hashtable_ident = wc.make_ident("hashtable");
123 let (write_prologue, write_iterator, write_iterator_after) = match persistence {
124 Persistence::None => {
125 (
126 Default::default(),
127 // TODO(mingwei): deduplicate with other persistence cases below.
128 quote_spanned! {op_span=>
129 let mut #hashtable_ident = #root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default();
130
131 #work_fn(|| {
132 #[inline(always)]
133 fn check_input<Iter: ::std::iter::Iterator<Item = (A, B)>, A: ::std::clone::Clone, B: ::std::clone::Clone>(iter: Iter)
134 -> impl ::std::iter::Iterator<Item = (A, B)> { iter }
135
136 #[inline(always)]
137 /// A: accumulator type
138 /// O: output type
139 fn call_comb_type<A, O>(acc: &mut A, item: A, f: impl Fn(&mut A, A) -> O) -> O {
140 f(acc, item)
141 }
142
143 for kv in check_input(#input) {
144 match #hashtable_ident.entry(kv.0) {
145 ::std::collections::hash_map::Entry::Vacant(vacant) => {
146 vacant.insert(kv.1);
147 }
148 ::std::collections::hash_map::Entry::Occupied(mut occupied) => {
149 #[allow(clippy::redundant_closure_call)] call_comb_type(occupied.get_mut(), kv.1, #aggfn);
150 }
151 }
152 }
153 });
154
155 let #ident = #hashtable_ident.drain();
156 },
157 Default::default(),
158 )
159 }
160 Persistence::Tick => {
161 let groupbydata_ident = wc.make_ident("groupbydata");
162
163 (
164 quote_spanned! {op_span=>
165 let #groupbydata_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
166 },
167 quote_spanned! {op_span=>
168 let mut #hashtable_ident = unsafe {
169 // SAFETY: handle from `#df_ident.add_state(..)`.
170 #context.state_ref_unchecked(#groupbydata_ident)
171 }.borrow_mut();
172
173 #work_fn(|| {
174 #[inline(always)]
175 fn check_input<Iter: ::std::iter::Iterator<Item = (A, B)>, A: ::std::clone::Clone, B: ::std::clone::Clone>(iter: Iter)
176 -> impl ::std::iter::Iterator<Item = (A, B)> { iter }
177
178 #[inline(always)]
179 /// A: accumulator type
180 /// O: output type
181 fn call_comb_type<A, O>(acc: &mut A, item: A, f: impl Fn(&mut A, A) -> O) -> O {
182 f(acc, item)
183 }
184
185 for kv in check_input(#input) {
186 match #hashtable_ident.entry(kv.0) {
187 ::std::collections::hash_map::Entry::Vacant(vacant) => {
188 vacant.insert(kv.1);
189 }
190 ::std::collections::hash_map::Entry::Occupied(mut occupied) => {
191 #[allow(clippy::redundant_closure_call)] call_comb_type(occupied.get_mut(), kv.1, #aggfn);
192 }
193 }
194 }
195 });
196
197 let #ident = #hashtable_ident.drain();
198 },
199 Default::default(),
200 )
201 }
202 Persistence::Static => {
203 let groupbydata_ident = wc.make_ident("groupbydata");
204
205 (
206 quote_spanned! {op_span=>
207 let #groupbydata_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
208 },
209 quote_spanned! {op_span=>
210 let mut #hashtable_ident = unsafe {
211 // SAFETY: handle from `#df_ident.add_state(..)`.
212 #context.state_ref_unchecked(#groupbydata_ident)
213 }.borrow_mut();
214
215 #work_fn(|| {
216 #[inline(always)]
217 fn check_input<Iter: ::std::iter::Iterator<Item = (A, B)>, A: ::std::clone::Clone, B: ::std::clone::Clone>(iter: Iter)
218 -> impl ::std::iter::Iterator<Item = (A, B)> { iter }
219
220 #[inline(always)]
221 /// A: accumulator type
222 /// O: output type
223 fn call_comb_type<A, O>(acc: &mut A, item: A, f: impl Fn(&mut A, A) -> O) -> O {
224 f(acc, item)
225 }
226
227 for kv in check_input(#input) {
228 match #hashtable_ident.entry(kv.0) {
229 ::std::collections::hash_map::Entry::Vacant(vacant) => {
230 vacant.insert(kv.1);
231 }
232 ::std::collections::hash_map::Entry::Occupied(mut occupied) => {
233 #[allow(clippy::redundant_closure_call)] call_comb_type(occupied.get_mut(), kv.1, #aggfn);
234 }
235 }
236 }
237 });
238
239 let #ident = #context.is_first_run_this_tick()
240 .then_some(#hashtable_ident.iter())
241 .into_iter()
242 .flatten()
243 .map(
244 // TODO(mingwei): remove `unknown_lints` when `suspicious_double_ref_op` is stabilized.
245 #[allow(unknown_lints, suspicious_double_ref_op, clippy::clone_on_copy)]
246 |(k, v)| (
247 ::std::clone::Clone::clone(k),
248 ::std::clone::Clone::clone(v),
249 )
250 );
251 },
252 quote_spanned! {op_span=>
253 #context.schedule_subgraph(#context.current_subgraph(), false);
254 },
255 )
256 }
257
258 Persistence::Mutable => {
259 diagnostics.push(Diagnostic::spanned(
260 op_span,
261 Level::Error,
262 "An implementation of 'mutable does not exist",
263 ));
264 return Err(());
265 }
266 };
267
268 Ok(OperatorWriteOutput {
269 write_prologue,
270 write_iterator,
271 write_iterator_after,
272 })
273 },
274};