dfir_lang/graph/ops/
state_by.rs

1use quote::{ToTokens, quote_spanned};
2
3use super::{
4    OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
5    Persistence, RANGE_1, WriteContextArgs,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// List state operator, but with a closure to map the input to the state lattice and a factory
10/// function to initialize the internal data structure.
11///
12/// The emitted outputs (both the referencable singleton and the optional pass-through stream) are
13/// of the same type as the inputs to the state_by operator and are not required to be a lattice
14/// type. This is useful receiving pass-through context information on the output side.
15///
16/// ```dfir
17/// use std::collections::HashSet;
18///
19///
20/// use lattices::set_union::{CartesianProductBimorphism, SetUnionHashSet, SetUnionSingletonSet};
21///
22/// my_state = source_iter(0..3)
23///     -> state_by::<SetUnionHashSet<usize>>(SetUnionSingletonSet::new_from, std::default::Default::default);
24/// ```
25/// The 2nd argument into `state_by` is a factory function that can be used to supply a custom
26/// initial value for the backing state. The initial value is still expected to be bottom (and will
27/// be checked). This is useful for doing things like pre-allocating buffers, etc. In the above
28/// example, it is just using `Default::default()`
29///
30/// An example of preallocating the capacity in a hashmap:
31///
32/// ```dfir
33/// use std::collections::HashSet;
34/// use lattices::set_union::{SetUnion, CartesianProductBimorphism, SetUnionHashSet, SetUnionSingletonSet};
35///
36/// my_state = source_iter(0..3)
37///     -> state_by::<SetUnionHashSet<usize>>(SetUnionSingletonSet::new_from, {|| SetUnion::new(HashSet::<usize>::with_capacity(1_000)) });
38/// ```
39///
40/// The `state` operator is equivalent to `state_by` used with an identity mapping operator with
41/// `Default::default` providing the factory function.
42pub const STATE_BY: OperatorConstraints = OperatorConstraints {
43    name: "state_by",
44    categories: &[OperatorCategory::Persistence],
45    hard_range_inn: RANGE_1,
46    soft_range_inn: RANGE_1,
47    hard_range_out: &(0..=1),
48    soft_range_out: &(0..=1),
49    num_args: 2,
50    persistence_args: &(0..=1),
51    type_args: &(0..=1),
52    is_external_input: false,
53    has_singleton_output: true,
54    flo_type: None,
55    ports_inn: None,
56    ports_out: None,
57    input_delaytype_fn: |_| None,
58    write_fn: |wc @ &WriteContextArgs {
59                   root,
60                   context,
61                   df_ident,
62                   op_span,
63                   ident,
64                   inputs,
65                   outputs,
66                   is_pull,
67                   singleton_output_ident,
68                   op_name,
69                   op_inst:
70                       OperatorInstance {
71                           generics:
72                               OpInstGenerics {
73                                   type_args,
74                                   persistence_args,
75                                   ..
76                               },
77                           ..
78                       },
79                   arguments,
80                   ..
81               },
82               diagnostics| {
83        let lattice_type = type_args
84            .first()
85            .map(ToTokens::to_token_stream)
86            .unwrap_or(quote_spanned!(op_span=> _));
87
88        let persistence = match persistence_args[..] {
89            [] => Persistence::Tick,
90            [Persistence::Mutable] => {
91                diagnostics.push(Diagnostic::spanned(
92                    op_span,
93                    Level::Error,
94                    format!("{} does not support `'mut`.", op_name),
95                ));
96                Persistence::Tick
97            }
98            [a] => a,
99            _ => unreachable!(),
100        };
101
102        let state_ident = singleton_output_ident;
103        let factory_fn = &arguments[1];
104
105        let write_prologue = quote_spanned! {op_span=>
106            let #state_ident = {
107                let data_struct: #lattice_type = (#factory_fn)();
108                ::std::debug_assert!(::lattices::IsBot::is_bot(&data_struct));
109                #df_ident.add_state(::std::cell::RefCell::new(data_struct))
110            };
111        };
112        let write_prologue_after = wc
113            .persistence_as_state_lifespan(persistence)
114            .map(|lifespan| quote_spanned! {op_span=>
115                #df_ident.set_state_lifespan_hook(#state_ident, #lifespan, |rcell| { rcell.take(); });
116            }).unwrap_or_default();
117
118        let by_fn = &arguments[0];
119
120        // TODO(mingwei): deduplicate codegen
121        let write_iterator = if is_pull {
122            let input = &inputs[0];
123            quote_spanned! {op_span=>
124                let #ident = {
125                    fn check_input<'a, Item, MappingFn, MappedItem, Iter, Lat>(
126                        iter: Iter,
127                        mapfn: MappingFn,
128                        state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
129                        context: &'a #root::scheduled::context::Context,
130                    ) -> impl 'a + ::std::iter::Iterator<Item = Item>
131                    where
132                        Item: ::std::clone::Clone,
133                        MappingFn: 'a + Fn(Item) -> MappedItem,
134                        Iter: 'a + ::std::iter::Iterator<Item = Item>,
135                        Lat: 'static + #root::lattices::Merge<MappedItem>,
136                    {
137                        iter.filter(move |item| {
138                                let state = unsafe {
139                                    // SAFETY: handle from `#df_ident.add_state(..)`.
140                                    context.state_ref_unchecked(state_handle)
141                                };
142                                let mut state = state.borrow_mut();
143                                #root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
144                            })
145                    }
146                    check_input::<_, _, _, _, #lattice_type>(#input, #by_fn, #state_ident, #context)
147                };
148            }
149        } else if let Some(output) = outputs.first() {
150            quote_spanned! {op_span=>
151                let #ident = {
152                    fn check_output<'a, Item, MappingFn, MappedItem, Push, Lat>(
153                        push: Push,
154                        mapfn: MappingFn,
155                        state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
156                        context: &'a #root::scheduled::context::Context,
157                    ) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
158                    where
159                        Item: 'a + ::std::clone::Clone,
160                        MappingFn: 'a + Fn(Item) -> MappedItem,
161                        Push: 'a + #root::pusherator::Pusherator<Item = Item>,
162                        Lat: 'static + #root::lattices::Merge<MappedItem>,
163                    {
164                        #root::pusherator::filter::Filter::new(move |item| {
165                            let state = unsafe {
166                                // SAFETY: handle from `#df_ident.add_state(..)`.
167                                context.state_ref_unchecked(state_handle)
168                            };
169                            let mut state = state.borrow_mut();
170                                #root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
171                        }, push)
172                    }
173                    check_output::<_, _, _, _, #lattice_type>(#output, #by_fn, #state_ident, #context)
174                };
175            }
176        } else {
177            quote_spanned! {op_span=>
178                let #ident = {
179                    fn check_output<'a, Item, MappingFn, MappedItem, Lat>(
180                        state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
181                        mapfn: MappingFn,
182                        context: &'a #root::scheduled::context::Context,
183                    ) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
184                    where
185                        Item: 'a,
186                        MappedItem: 'a,
187                        MappingFn: 'a + Fn(Item) -> MappedItem,
188                        Lat: 'static + #root::lattices::Merge<MappedItem>,
189                    {
190                        #root::pusherator::for_each::ForEach::new(move |item| {
191                            let state = unsafe {
192                                // SAFETY: handle from `#df_ident.add_state(..)`.
193                                context.state_ref_unchecked(state_handle)
194                            };
195                            let mut state = state.borrow_mut();
196                            #root::lattices::Merge::merge(&mut *state, (mapfn)(item));
197                        })
198                    }
199                    check_output::<_, _, _, #lattice_type>(#state_ident, #by_fn, #context)
200                };
201            }
202        };
203        Ok(OperatorWriteOutput {
204            write_prologue,
205            write_prologue_after,
206            write_iterator,
207            ..Default::default()
208        })
209    },
210};