dfir_lang/graph/ops/
lattice_bimorphism.rs

1use quote::quote_spanned;
2use syn::parse_quote;
3
4use super::{
5    OperatorCategory, OperatorConstraints, OperatorWriteOutput, WriteContextArgs,
6    RANGE_0, RANGE_1,
7};
8
9/// An operator representing a [lattice bimorphism](https://hydro.run/docs/dfir/lattices_crate/lattice_math#lattice-bimorphism).
10///
11/// > 2 input streams, of type `LhsItem` and `RhsItem`.
12///
13/// > Three argument, one `LatticeBimorphism` function `Func`, an `LhsState` singleton reference, and an `RhsState` singleton reference.
14///
15/// > 1 output stream of the output type of the `LatticeBimorphism` function.
16///
17/// The function must be a lattice bimorphism for both `(LhsState, RhsItem)` and `(RhsState, LhsItem)`.
18///
19/// ```dfir
20/// use std::collections::HashSet;
21/// use lattices::set_union::{CartesianProductBimorphism, SetUnionHashSet, SetUnionSingletonSet};
22///
23/// lhs = source_iter(0..3)
24///     -> map(SetUnionSingletonSet::new_from)
25///     -> state::<'static, SetUnionHashSet<u32>>();
26/// rhs = source_iter(3..5)
27///     -> map(SetUnionSingletonSet::new_from)
28///     -> state::<'static, SetUnionHashSet<u32>>();
29///
30/// lhs -> [0]my_join;
31/// rhs -> [1]my_join;
32///
33/// my_join = lattice_bimorphism(CartesianProductBimorphism::<HashSet<_>>::default(), #lhs, #rhs)
34///     -> assert_eq([SetUnionHashSet::new(HashSet::from_iter([
35///        (0, 3),
36///        (0, 4),
37///        (1, 3),
38///        (1, 4),
39///        (2, 3),
40///        (2, 4),
41///    ]))]);
42/// ```
43pub const LATTICE_BIMORPHISM: OperatorConstraints = OperatorConstraints {
44    name: "lattice_bimorphism",
45    categories: &[OperatorCategory::MultiIn],
46    hard_range_inn: &(2..=2),
47    soft_range_inn: &(2..=2),
48    hard_range_out: RANGE_1,
49    soft_range_out: RANGE_1,
50    num_args: 3,
51    persistence_args: RANGE_0,
52    type_args: RANGE_0,
53    is_external_input: false,
54    has_singleton_output: false,
55    flo_type: None,
56    ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
57    ports_out: None,
58    input_delaytype_fn: |_| None,
59    write_fn: |&WriteContextArgs {
60                   root,
61                   context,
62                   op_span,
63                   is_pull,
64                   ident,
65                   inputs,
66                   arguments,
67                   arguments_handles,
68                   ..
69               },
70               _| {
71        assert!(is_pull);
72
73        let func = &arguments[0];
74        let lhs_state_handle = &arguments_handles[1];
75        let rhs_state_handle = &arguments_handles[2];
76
77        let lhs_items = &inputs[0];
78        let rhs_items = &inputs[1];
79
80        let write_iterator = quote_spanned! {op_span=>
81            let #ident = {
82                #[inline(always)]
83                fn check_inputs<'a, Func, LhsIter, RhsIter, LhsState, RhsState, Output>(
84                    mut func: Func,
85                    mut lhs_iter: LhsIter,
86                    mut rhs_iter: RhsIter,
87                    lhs_state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<LhsState>>,
88                    rhs_state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<RhsState>>,
89                    context: &'a #root::scheduled::context::Context,
90                ) -> Option<Output>
91                where
92                    Func: 'a
93                        + #root::lattices::LatticeBimorphism<LhsState, RhsIter::Item, Output = Output>
94                        + #root::lattices::LatticeBimorphism<LhsIter::Item, RhsState, Output = Output>,
95                    LhsIter: 'a + ::std::iter::Iterator,
96                    RhsIter: 'a + ::std::iter::Iterator,
97                    LhsState: 'static + ::std::clone::Clone,
98                    RhsState: 'static + ::std::clone::Clone,
99                    Output: #root::lattices::Merge<Output>,
100                {
101                    let (lhs_state, rhs_state) = unsafe {
102                        // SAFETY: handle from `#df_ident.add_state(..)`.
103                        (
104                            context.state_ref_unchecked(lhs_state_handle),
105                            context.state_ref_unchecked(rhs_state_handle),
106                        )
107                    };
108
109                    let iter = ::std::iter::from_fn(move || {
110                        // Use `from_fn` instead of `chain` to dodge multiple ownership of `func`.
111                        if let Some(lhs_item) = lhs_iter.next() {
112                            Some(func.call(lhs_item, (*rhs_state.borrow()).clone()))
113                        } else {
114                            let rhs_item = rhs_iter.next()?;
115                            Some(func.call((*lhs_state.borrow()).clone(), rhs_item))
116                        }
117                    });
118                    iter.reduce(|a, b| #root::lattices::Merge::merge_owned(a, b))
119                }
120                check_inputs(
121                    #func,
122                    #lhs_items,
123                    #rhs_items,
124                    #lhs_state_handle,
125                    #rhs_state_handle,
126                    &#context,
127                ).into_iter()
128            };
129        };
130
131        Ok(OperatorWriteOutput {
132            write_iterator,
133            ..Default::default()
134        })
135    },
136};