dfir_lang/graph/ops/
_lattice_join_fused_join.rs

1use quote::quote_spanned;
2use syn::parse_quote;
3use syn::spanned::Spanned;
4
5use super::{
6    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
7    OperatorWriteOutput, RANGE_1, WriteContextArgs,
8};
9
10/// > 2 input streams of type `(K, V1)` and `(K, V2)`, 1 output stream of type `(K, (V1', V2'))` where `V1`, `V2`, `V1'`, `V2'` are lattice types
11///
12/// Performs a [`fold_keyed`](#fold_keyed) with lattice-merge aggregate function on each input and then forms the
13/// equijoin of the resulting key/value pairs in the input streams by their first (key) attribute.
14/// Unlike [`join`](#join), the result is not a stream of tuples, it's a stream of MapUnionSingletonMap
15/// lattices. You can (non-monotonically) "reveal" these as tuples if desired via [`map`](#map); see the examples below.
16///
17/// You must specify the the accumulating lattice types, they cannot be inferred. The first type argument corresponds to the `[0]` input of the join, and the second to the `[1]` input.
18/// Type arguments are specified in dfir using the rust turbofish syntax `::<>`, for example `_lattice_join_fused_join::<Min<_>, Max<_>>()`
19/// The accumulating lattice type is not necessarily the same type as the input, see the below example involving SetUnion for such a case.
20///
21/// Like [`join`](#join), `_lattice_join_fused_join` can also be provided with one or two generic lifetime persistence arguments, either
22/// `'tick` or `'static`, to specify how join data persists. With `'tick`, pairs will only be
23/// joined with corresponding pairs within the same tick. With `'static`, pairs will be remembered
24/// across ticks and will be joined with pairs arriving in later ticks. When not explicitly
25/// specified persistence defaults to `tick.
26///
27/// Like [`join`](#join), when two persistence arguments are supplied the first maps to port `0` and the second maps to
28/// port `1`.
29/// When a single persistence argument is supplied, it is applied to both input ports.
30/// When no persistence arguments are applied it defaults to `'tick` for both.
31/// It is important to specify all persistence arguments before any type arguments, otherwise the persistence arguments will be ignored.
32///
33/// The syntax is as follows:
34/// ```dfir,ignore
35/// _lattice_join_fused_join::<MaxRepr<usize>, MaxRepr<usize>>(); // Or
36/// _lattice_join_fused_join::<'static, MaxRepr<usize>, MaxRepr<usize>>();
37///
38/// _lattice_join_fused_join::<'tick, MaxRepr<usize>, MaxRepr<usize>>();
39///
40/// _lattice_join_fused_join::<'static, 'tick, MaxRepr<usize>, MaxRepr<usize>>();
41///
42/// _lattice_join_fused_join::<'tick, 'static, MaxRepr<usize>, MaxRepr<usize>>();
43/// // etc.
44/// ```
45///
46/// ### Examples
47///
48/// ```dfir
49/// use dfir_rs::lattices::Min;
50/// use dfir_rs::lattices::Max;
51///
52/// source_iter([("key", Min::new(1)), ("key", Min::new(2))]) -> [0]my_join;
53/// source_iter([("key", Max::new(1)), ("key", Max::new(2))]) -> [1]my_join;
54///
55/// my_join = _lattice_join_fused_join::<'tick, Min<usize>, Max<usize>>()
56///     -> map(|singleton_map| {
57///         let lattices::collections::SingletonMap(k, v) = singleton_map.into_reveal();
58///         (k, (v.into_reveal()))
59///     })
60///     -> assert_eq([("key", (Min::new(1), Max::new(2)))]);
61/// ```
62///
63/// ```dfir
64/// use dfir_rs::lattices::set_union::SetUnionSingletonSet;
65/// use dfir_rs::lattices::set_union::SetUnionHashSet;
66///
67/// source_iter([("key", SetUnionSingletonSet::new_from(0)), ("key", SetUnionSingletonSet::new_from(1))]) -> [0]my_join;
68/// source_iter([("key", SetUnionHashSet::new_from([0])), ("key", SetUnionHashSet::new_from([1]))]) -> [1]my_join;
69///
70/// my_join = _lattice_join_fused_join::<'tick, SetUnionHashSet<usize>, SetUnionHashSet<usize>>()
71///     -> map(|singleton_map| {
72///         let lattices::collections::SingletonMap(k, v) = singleton_map.into_reveal();
73///         (k, (v.into_reveal()))
74///     })
75///     -> assert_eq([("key", (SetUnionHashSet::new_from([0, 1]), SetUnionHashSet::new_from([0, 1])))]);
76/// ```
77pub const _LATTICE_JOIN_FUSED_JOIN: OperatorConstraints = OperatorConstraints {
78    name: "_lattice_join_fused_join",
79    categories: &[OperatorCategory::CompilerFusionOperator],
80    hard_range_inn: &(2..=2),
81    soft_range_inn: &(2..=2),
82    hard_range_out: RANGE_1,
83    soft_range_out: RANGE_1,
84    num_args: 0,
85    persistence_args: &(0..=2),
86    type_args: &(2..=2),
87    is_external_input: false,
88    has_singleton_output: false,
89    flo_type: None,
90    ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
91    ports_out: None,
92    input_delaytype_fn: |_| Some(DelayType::MonotoneAccum),
93    write_fn: |wc @ &WriteContextArgs {
94                   root,
95                   op_span,
96                   ident,
97                   inputs,
98                   is_pull,
99                   op_inst:
100                       OperatorInstance {
101                           generics:
102                               OpInstGenerics {
103                                   type_args,
104                                   persistence_args,
105                                   ..
106                               },
107                           ..
108                       },
109                   ..
110               },
111               diagnostics| {
112        let lhs_type = &type_args[0];
113        let rhs_type = &type_args[1];
114
115        let wc = WriteContextArgs {
116            arguments: &parse_quote! {
117                FoldFrom(<#lhs_type as #root::lattices::LatticeFrom::<_>>::lattice_from, #root::lattices::Merge::merge),
118                FoldFrom(<#rhs_type as #root::lattices::LatticeFrom::<_>>::lattice_from, #root::lattices::Merge::merge)
119            },
120            ..wc.clone()
121        };
122
123        // initialize write_prologue and write_iterator_after via join_fused, but specialize the write_iterator
124        let OperatorWriteOutput {
125            write_prologue,
126            write_iterator: _,
127            write_iterator_after,
128        } = (super::join_fused::JOIN_FUSED.write_fn)(&wc, diagnostics).unwrap();
129
130        assert!(is_pull);
131        let persistences = super::join_fused::parse_persistences(persistence_args);
132
133        let lhs_join_options = super::join_fused::parse_argument(&wc.arguments[0])
134            .map_err(|err| diagnostics.push(err))?;
135        let rhs_join_options = super::join_fused::parse_argument(&wc.arguments[1])
136            .map_err(|err| diagnostics.push(err))?;
137        let (_lhs_prologue, lhs_pre_write_iter, lhs_borrow) =
138            super::join_fused::make_joindata(&wc, persistences[0], &lhs_join_options, "lhs")
139                .map_err(|err| diagnostics.push(err))?;
140
141        let (_rhs_prologue, rhs_pre_write_iter, rhs_borrow) =
142            super::join_fused::make_joindata(&wc, persistences[1], &rhs_join_options, "rhs")
143                .map_err(|err| diagnostics.push(err))?;
144
145        let lhs = &inputs[0];
146        let rhs = &inputs[1];
147
148        let arg0_span = wc.arguments[0].span();
149        let arg1_span = wc.arguments[1].span();
150
151        let lhs_tokens = quote_spanned! {arg0_span=>
152            #lhs_borrow.fold_into(#lhs, #root::lattices::Merge::merge,
153                <#lhs_type as #root::lattices::LatticeFrom::<_>>::lattice_from);
154        };
155
156        let rhs_tokens = quote_spanned! {arg1_span=>
157            #rhs_borrow.fold_into(#rhs, #root::lattices::Merge::merge,
158                <#rhs_type as #root::lattices::LatticeFrom::<_>>::lattice_from);
159        };
160
161        let write_iterator = quote_spanned! {op_span=>
162            #lhs_pre_write_iter
163            #rhs_pre_write_iter
164
165            let #ident = {
166                #lhs_tokens
167                #rhs_tokens
168
169                // TODO: start the iterator with the smallest len() table rather than always picking rhs.
170                #[allow(clippy::clone_on_copy)]
171                #[allow(suspicious_double_ref_op)]
172                #rhs_borrow
173                    .table
174                    .iter()
175                    .filter_map(|(k, v2)| #lhs_borrow.table.get(k).map(|v1| (k.clone(), lattices::Pair::<#lhs_type, #rhs_type>::new_from(v1.clone(), v2.clone()))))
176                    .map(|(key, p)| #root::lattices::map_union::MapUnionSingletonMap::new_from((key, p)))
177            };
178        };
179
180        Ok(OperatorWriteOutput {
181            write_prologue,
182            write_iterator,
183            write_iterator_after,
184        })
185    },
186};