dfir_lang/graph/ops/_lattice_join_fused_join.rs
1use quote::quote_spanned;
2use syn::parse_quote;
3use syn::spanned::Spanned;
4
5use super::{
6 DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
7 OperatorWriteOutput, RANGE_1, WriteContextArgs,
8};
9
10/// > 2 input streams of type `(K, V1)` and `(K, V2)`, 1 output stream of type `(K, (V1', V2'))` where `V1`, `V2`, `V1'`, `V2'` are lattice types
11///
12/// Performs a [`fold_keyed`](#fold_keyed) with lattice-merge aggregate function on each input and then forms the
13/// equijoin of the resulting key/value pairs in the input streams by their first (key) attribute.
14/// Unlike [`join`](#join), the result is not a stream of tuples, it's a stream of MapUnionSingletonMap
15/// lattices. You can (non-monotonically) "reveal" these as tuples if desired via [`map`](#map); see the examples below.
16///
17/// You must specify the the accumulating lattice types, they cannot be inferred. The first type argument corresponds to the `[0]` input of the join, and the second to the `[1]` input.
18/// Type arguments are specified in dfir using the rust turbofish syntax `::<>`, for example `_lattice_join_fused_join::<Min<_>, Max<_>>()`
19/// The accumulating lattice type is not necessarily the same type as the input, see the below example involving SetUnion for such a case.
20///
21/// Like [`join`](#join), `_lattice_join_fused_join` can also be provided with one or two generic lifetime persistence arguments, either
22/// `'tick` or `'static`, to specify how join data persists. With `'tick`, pairs will only be
23/// joined with corresponding pairs within the same tick. With `'static`, pairs will be remembered
24/// across ticks and will be joined with pairs arriving in later ticks. When not explicitly
25/// specified persistence defaults to `tick.
26///
27/// Like [`join`](#join), when two persistence arguments are supplied the first maps to port `0` and the second maps to
28/// port `1`.
29/// When a single persistence argument is supplied, it is applied to both input ports.
30/// When no persistence arguments are applied it defaults to `'tick` for both.
31/// It is important to specify all persistence arguments before any type arguments, otherwise the persistence arguments will be ignored.
32///
33/// The syntax is as follows:
34/// ```dfir,ignore
35/// _lattice_join_fused_join::<MaxRepr<usize>, MaxRepr<usize>>(); // Or
36/// _lattice_join_fused_join::<'static, MaxRepr<usize>, MaxRepr<usize>>();
37///
38/// _lattice_join_fused_join::<'tick, MaxRepr<usize>, MaxRepr<usize>>();
39///
40/// _lattice_join_fused_join::<'static, 'tick, MaxRepr<usize>, MaxRepr<usize>>();
41///
42/// _lattice_join_fused_join::<'tick, 'static, MaxRepr<usize>, MaxRepr<usize>>();
43/// // etc.
44/// ```
45///
46/// ### Examples
47///
48/// ```dfir
49/// use dfir_rs::lattices::Min;
50/// use dfir_rs::lattices::Max;
51///
52/// source_iter([("key", Min::new(1)), ("key", Min::new(2))]) -> [0]my_join;
53/// source_iter([("key", Max::new(1)), ("key", Max::new(2))]) -> [1]my_join;
54///
55/// my_join = _lattice_join_fused_join::<'tick, Min<usize>, Max<usize>>()
56/// -> map(|singleton_map| {
57/// let lattices::collections::SingletonMap(k, v) = singleton_map.into_reveal();
58/// (k, (v.into_reveal()))
59/// })
60/// -> assert_eq([("key", (Min::new(1), Max::new(2)))]);
61/// ```
62///
63/// ```dfir
64/// use dfir_rs::lattices::set_union::SetUnionSingletonSet;
65/// use dfir_rs::lattices::set_union::SetUnionHashSet;
66///
67/// source_iter([("key", SetUnionSingletonSet::new_from(0)), ("key", SetUnionSingletonSet::new_from(1))]) -> [0]my_join;
68/// source_iter([("key", SetUnionHashSet::new_from([0])), ("key", SetUnionHashSet::new_from([1]))]) -> [1]my_join;
69///
70/// my_join = _lattice_join_fused_join::<'tick, SetUnionHashSet<usize>, SetUnionHashSet<usize>>()
71/// -> map(|singleton_map| {
72/// let lattices::collections::SingletonMap(k, v) = singleton_map.into_reveal();
73/// (k, (v.into_reveal()))
74/// })
75/// -> assert_eq([("key", (SetUnionHashSet::new_from([0, 1]), SetUnionHashSet::new_from([0, 1])))]);
76/// ```
77pub const _LATTICE_JOIN_FUSED_JOIN: OperatorConstraints = OperatorConstraints {
78 name: "_lattice_join_fused_join",
79 categories: &[OperatorCategory::CompilerFusionOperator],
80 hard_range_inn: &(2..=2),
81 soft_range_inn: &(2..=2),
82 hard_range_out: RANGE_1,
83 soft_range_out: RANGE_1,
84 num_args: 0,
85 persistence_args: &(0..=2),
86 type_args: &(2..=2),
87 is_external_input: false,
88 has_singleton_output: false,
89 flo_type: None,
90 ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
91 ports_out: None,
92 input_delaytype_fn: |_| Some(DelayType::MonotoneAccum),
93 write_fn: |wc @ &WriteContextArgs {
94 root,
95 op_span,
96 ident,
97 inputs,
98 is_pull,
99 op_inst:
100 OperatorInstance {
101 generics:
102 OpInstGenerics {
103 type_args,
104 persistence_args,
105 ..
106 },
107 ..
108 },
109 ..
110 },
111 diagnostics| {
112 let lhs_type = &type_args[0];
113 let rhs_type = &type_args[1];
114
115 let wc = WriteContextArgs {
116 arguments: &parse_quote! {
117 FoldFrom(<#lhs_type as #root::lattices::LatticeFrom::<_>>::lattice_from, #root::lattices::Merge::merge),
118 FoldFrom(<#rhs_type as #root::lattices::LatticeFrom::<_>>::lattice_from, #root::lattices::Merge::merge)
119 },
120 ..wc.clone()
121 };
122
123 // initialize write_prologue and write_iterator_after via join_fused, but specialize the write_iterator
124 let OperatorWriteOutput {
125 write_prologue,
126 write_iterator: _,
127 write_iterator_after,
128 } = (super::join_fused::JOIN_FUSED.write_fn)(&wc, diagnostics).unwrap();
129
130 assert!(is_pull);
131 let persistences = super::join_fused::parse_persistences(persistence_args);
132
133 let lhs_join_options = super::join_fused::parse_argument(&wc.arguments[0])
134 .map_err(|err| diagnostics.push(err))?;
135 let rhs_join_options = super::join_fused::parse_argument(&wc.arguments[1])
136 .map_err(|err| diagnostics.push(err))?;
137 let (_lhs_prologue, lhs_pre_write_iter, lhs_borrow) =
138 super::join_fused::make_joindata(&wc, persistences[0], &lhs_join_options, "lhs")
139 .map_err(|err| diagnostics.push(err))?;
140
141 let (_rhs_prologue, rhs_pre_write_iter, rhs_borrow) =
142 super::join_fused::make_joindata(&wc, persistences[1], &rhs_join_options, "rhs")
143 .map_err(|err| diagnostics.push(err))?;
144
145 let lhs = &inputs[0];
146 let rhs = &inputs[1];
147
148 let arg0_span = wc.arguments[0].span();
149 let arg1_span = wc.arguments[1].span();
150
151 let lhs_tokens = quote_spanned! {arg0_span=>
152 #lhs_borrow.fold_into(#lhs, #root::lattices::Merge::merge,
153 <#lhs_type as #root::lattices::LatticeFrom::<_>>::lattice_from);
154 };
155
156 let rhs_tokens = quote_spanned! {arg1_span=>
157 #rhs_borrow.fold_into(#rhs, #root::lattices::Merge::merge,
158 <#rhs_type as #root::lattices::LatticeFrom::<_>>::lattice_from);
159 };
160
161 let write_iterator = quote_spanned! {op_span=>
162 #lhs_pre_write_iter
163 #rhs_pre_write_iter
164
165 let #ident = {
166 #lhs_tokens
167 #rhs_tokens
168
169 // TODO: start the iterator with the smallest len() table rather than always picking rhs.
170 #[allow(clippy::clone_on_copy)]
171 #[allow(suspicious_double_ref_op)]
172 #rhs_borrow
173 .table
174 .iter()
175 .filter_map(|(k, v2)| #lhs_borrow.table.get(k).map(|v1| (k.clone(), lattices::Pair::<#lhs_type, #rhs_type>::new_from(v1.clone(), v2.clone()))))
176 .map(|(key, p)| #root::lattices::map_union::MapUnionSingletonMap::new_from((key, p)))
177 };
178 };
179
180 Ok(OperatorWriteOutput {
181 write_prologue,
182 write_iterator,
183 write_iterator_after,
184 })
185 },
186};