dfir_lang/graph/ops/join.rs
1use quote::{ToTokens, quote_spanned};
2use syn::parse_quote;
3
4use super::{
5 OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
6 Persistence, RANGE_1, WriteContextArgs,
7};
8use crate::diagnostic::{Diagnostic, Level};
9
10/// > 2 input streams of type `<(K, V1)>` and `<(K, V2)>`, 1 output stream of type `<(K, (V1, V2))>`
11///
12/// Forms the equijoin of the tuples in the input streams by their first (key) attribute. Note that the result nests the 2nd input field (values) into a tuple in the 2nd output field.
13///
14/// ```dfir
15/// source_iter(vec![("hello", "world"), ("stay", "gold"), ("hello", "world")]) -> [0]my_join;
16/// source_iter(vec![("hello", "cleveland")]) -> [1]my_join;
17/// my_join = join()
18/// -> assert_eq([("hello", ("world", "cleveland"))]);
19/// ```
20///
21/// `join` can also be provided with one or two generic lifetime persistence arguments, either
22/// `'tick` or `'static`, to specify how join data persists. With `'tick`, pairs will only be
23/// joined with corresponding pairs within the same tick. With `'static`, pairs will be remembered
24/// across ticks and will be joined with pairs arriving in later ticks. When not explicitly
25/// specified persistence defaults to `tick.
26///
27/// When two persistence arguments are supplied the first maps to port `0` and the second maps to
28/// port `1`.
29/// When a single persistence argument is supplied, it is applied to both input ports.
30/// When no persistence arguments are applied it defaults to `'tick` for both.
31///
32/// The syntax is as follows:
33/// ```dfir,ignore
34/// join(); // Or
35/// join::<'static>();
36///
37/// join::<'tick>();
38///
39/// join::<'static, 'tick>();
40///
41/// join::<'tick, 'static>();
42/// // etc.
43/// ```
44///
45/// `join` is defined to treat its inputs as *sets*, meaning that it
46/// eliminates duplicated values in its inputs. If you do not want
47/// duplicates eliminated, use the [`join_multiset`](#join_multiset) operator.
48///
49/// ### Examples
50///
51/// ```rustbook
52/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<(&str, &str)>();
53/// let mut flow = dfir_rs::dfir_syntax! {
54/// source_iter([("hello", "world")]) -> [0]my_join;
55/// source_stream(input_recv) -> [1]my_join;
56/// my_join = join::<'tick>() -> for_each(|(k, (v1, v2))| println!("({}, ({}, {}))", k, v1, v2));
57/// };
58/// input_send.send(("hello", "oakland")).unwrap();
59/// flow.run_tick();
60/// input_send.send(("hello", "san francisco")).unwrap();
61/// flow.run_tick();
62/// ```
63/// Prints out `"(hello, (world, oakland))"` since `source_iter([("hello", "world")])` is only
64/// included in the first tick, then forgotten.
65///
66/// ---
67///
68/// ```rustbook
69/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<(&str, &str)>();
70/// let mut flow = dfir_rs::dfir_syntax! {
71/// source_iter([("hello", "world")]) -> [0]my_join;
72/// source_stream(input_recv) -> [1]my_join;
73/// my_join = join::<'static>() -> for_each(|(k, (v1, v2))| println!("({}, ({}, {}))", k, v1, v2));
74/// };
75/// input_send.send(("hello", "oakland")).unwrap();
76/// flow.run_tick();
77/// input_send.send(("hello", "san francisco")).unwrap();
78/// flow.run_tick();
79/// ```
80/// Prints out `"(hello, (world, oakland))"` and `"(hello, (world, san francisco))"` since the
81/// inputs are peristed across ticks.
82pub const JOIN: OperatorConstraints = OperatorConstraints {
83 name: "join",
84 categories: &[OperatorCategory::MultiIn],
85 hard_range_inn: &(2..=2),
86 soft_range_inn: &(2..=2),
87 hard_range_out: RANGE_1,
88 soft_range_out: RANGE_1,
89 num_args: 0,
90 persistence_args: &(0..=2),
91 type_args: &(0..=1),
92 is_external_input: false,
93 has_singleton_output: false,
94 flo_type: None,
95 ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
96 ports_out: None,
97 input_delaytype_fn: |_| None,
98 write_fn: |wc @ &WriteContextArgs {
99 root,
100 context,
101 df_ident,
102 loop_id,
103 op_span,
104 ident,
105 inputs,
106 work_fn,
107 op_inst:
108 OperatorInstance {
109 generics:
110 OpInstGenerics {
111 persistence_args,
112 type_args,
113 ..
114 },
115 ..
116 },
117 ..
118 },
119 diagnostics| {
120 let join_type =
121 type_args
122 .first()
123 .map(ToTokens::to_token_stream)
124 .unwrap_or(quote_spanned!(op_span=>
125 #root::compiled::pull::HalfSetJoinState
126 ));
127
128 // TODO: This is really bad.
129 // This will break if the user aliases HalfSetJoinState to something else. Temporary hacky solution.
130 // Note that cross_join() depends on the implementation here as well.
131 let additional_trait_bounds = if join_type.to_string().contains("HalfSetJoinState") {
132 quote_spanned!(op_span=>
133 + ::std::cmp::Eq
134 )
135 } else {
136 quote_spanned!(op_span=>)
137 };
138
139 let mut make_joindata = |persistence, in_loop, side| {
140 let joindata_ident = wc.make_ident(format!("joindata_{}", side));
141 let borrow_ident = wc.make_ident(format!("joindata_{}_borrow", side));
142 let reset = match (in_loop, persistence) {
143 (false, Persistence::None) => {
144 diagnostics.push(Diagnostic::spanned(
145 op_span,
146 Level::Error,
147 "`'none` is not allowed outside of loops, use `'tick` instead.",
148 ));
149 return Err(());
150 }
151 (true, Persistence::None) => Default::default(),
152 (false, Persistence::Tick) => quote_spanned! {op_span=>
153 #df_ident.set_state_tick_hook(#joindata_ident, |rcell| #work_fn(|| #root::util::clear::Clear::clear(rcell.get_mut())));
154 },
155 (true, Persistence::Tick) => Default::default(),
156 (false, Persistence::Static) => Default::default(),
157 (true, Persistence::Static) => {
158 diagnostics.push(Diagnostic::spanned(
159 op_span,
160 Level::Error,
161 "`'static` is not allowed within loops.",
162 ));
163 return Err(());
164 }
165 (_, Persistence::Mutable) => {
166 diagnostics.push(Diagnostic::spanned(
167 op_span,
168 Level::Error,
169 "An implementation of `'mutable` does not exist.",
170 ));
171 return Err(());
172 }
173 };
174 let (borrow, init) = if !in_loop {
175 (
176 quote_spanned! {op_span=>
177 unsafe {
178 // SAFETY: handle from `#df_ident.add_state(..)`.
179 #context.state_ref_unchecked(#joindata_ident)
180 }.borrow_mut()
181 },
182 quote_spanned! {op_span=>
183 let #joindata_ident = #df_ident.add_state(::std::cell::RefCell::new(
184 #join_type::default()
185 ));
186 #reset
187 },
188 )
189 } else {
190 (
191 quote_spanned! {op_span=>
192 #join_type::default()
193 },
194 Default::default(),
195 )
196 };
197 Ok((borrow, borrow_ident, init))
198 };
199
200 let persistences = match persistence_args[..] {
201 [] => {
202 let p = if loop_id.is_some() {
203 Persistence::None
204 } else {
205 Persistence::Tick
206 };
207 [p, p]
208 }
209 [a] => [a, a],
210 [a, b] => [a, b],
211 _ => panic!(),
212 };
213
214 let (lhs_borrow, lhs_borrow_ident, lhs_init) =
215 make_joindata(persistences[0], loop_id.is_some(), "lhs")?;
216 let (rhs_borrow, rhs_borrow_ident, rhs_init) =
217 make_joindata(persistences[1], loop_id.is_some(), "rhs")?;
218
219 let write_prologue = quote_spanned! {op_span=>
220 #lhs_init
221 #rhs_init
222 };
223
224 let lhs = &inputs[0];
225 let rhs = &inputs[1];
226 let write_iterator = if loop_id.is_none() {
227 quote_spanned! {op_span=>
228 let mut #lhs_borrow_ident = #lhs_borrow;
229 let mut #rhs_borrow_ident = #rhs_borrow;
230 let #ident = {
231 // Limit error propagation by bounding locally, erasing output iterator type.
232 #[inline(always)]
233 fn check_inputs<'a, K, I1, V1, I2, V2>(
234 lhs: I1,
235 rhs: I2,
236 lhs_state: &'a mut #join_type<K, V1, V2>,
237 rhs_state: &'a mut #join_type<K, V2, V1>,
238 is_new_tick: bool,
239 ) -> impl 'a + Iterator<Item = (K, (V1, V2))>
240 where
241 K: Eq + std::hash::Hash + Clone,
242 V1: Clone #additional_trait_bounds,
243 V2: Clone #additional_trait_bounds,
244 I1: 'a + Iterator<Item = (K, V1)>,
245 I2: 'a + Iterator<Item = (K, V2)>,
246 {
247 #work_fn(|| #root::compiled::pull::symmetric_hash_join_into_iter(lhs, rhs, lhs_state, rhs_state, is_new_tick))
248 }
249
250 check_inputs(#lhs, #rhs, &mut *#lhs_borrow_ident, &mut *#rhs_borrow_ident, #context.is_first_run_this_tick())
251 };
252 }
253 } else {
254 // TODO(mingwei): deduplicate this code with the above.
255 quote_spanned! {op_span=>
256 let mut #lhs_borrow_ident = ::std::default::Default::default();
257 let mut #rhs_borrow_ident = ::std::default::Default::default();
258 let #ident = {
259 // Limit error propagation by bounding locally, erasing output iterator type.
260 #[inline(always)]
261 fn check_inputs<'a, K, I1, V1, I2, V2>(
262 lhs: I1,
263 rhs: I2,
264 lhs_state: &'a mut #join_type<K, V1, V2>,
265 rhs_state: &'a mut #join_type<K, V2, V1>,
266 is_new_tick: bool,
267 ) -> impl 'a + Iterator<Item = (K, (V1, V2))>
268 where
269 K: Eq + std::hash::Hash + Clone,
270 V1: Clone #additional_trait_bounds,
271 V2: Clone #additional_trait_bounds,
272 I1: 'a + Iterator<Item = (K, V1)>,
273 I2: 'a + Iterator<Item = (K, V2)>,
274 {
275 #work_fn(|| #root::compiled::pull::symmetric_hash_join_into_iter(lhs, rhs, lhs_state, rhs_state, is_new_tick))
276 }
277
278 check_inputs(#lhs, #rhs, &mut #lhs_borrow_ident, &mut #rhs_borrow_ident, true)
279 };
280 }
281 };
282
283 let write_iterator_after =
284 if persistences[0] == Persistence::Static || persistences[1] == Persistence::Static {
285 quote_spanned! {op_span=>
286 // TODO: Probably only need to schedule if #*_borrow.len() > 0?
287 #context.schedule_subgraph(#context.current_subgraph(), false);
288 }
289 } else {
290 quote_spanned! {op_span=>}
291 };
292
293 Ok(OperatorWriteOutput {
294 write_prologue,
295 write_iterator,
296 write_iterator_after,
297 })
298 },
299};