dfir_lang/graph/ops/
state_by.rs

1use quote::{quote_spanned, ToTokens};
2
3use super::{
4    OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
5    Persistence, WriteContextArgs, RANGE_1,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// List state operator, but with a closure to map the input to the state lattice and a factory
10/// function to initialize the internal data structure.
11///
12/// The emitted outputs (both the referencable singleton and the optional pass-through stream) are
13/// of the same type as the inputs to the state_by operator and are not required to be a lattice
14/// type. This is useful receiving pass-through context information on the output side.
15///
16/// ```dfir
17/// use std::collections::HashSet;
18///
19///
20/// use lattices::set_union::{CartesianProductBimorphism, SetUnionHashSet, SetUnionSingletonSet};
21///
22/// my_state = source_iter(0..3)
23///     -> state_by::<SetUnionHashSet<usize>>(SetUnionSingletonSet::new_from, std::default::Default::default);
24/// ```
25/// The 2nd argument into `state_by` is a factory function that can be used to supply a custom
26/// initial value for the backing state. The initial value is still expected to be bottom (and will
27/// be checked). This is useful for doing things like pre-allocating buffers, etc. In the above
28/// example, it is just using `Default::default()`
29///
30/// An example of preallocating the capacity in a hashmap:
31///
32///```dfir
33/// use std::collections::HashSet;
34/// use lattices::set_union::{SetUnion, CartesianProductBimorphism, SetUnionHashSet, SetUnionSingletonSet};
35///
36/// my_state = source_iter(0..3)
37///     -> state_by::<SetUnionHashSet<usize>>(SetUnionSingletonSet::new_from, {|| SetUnion::new(HashSet::<usize>::with_capacity(1_000)) });
38///```
39///
40/// The `state` operator is equivalent to `state_by` used with an identity mapping operator with
41/// `Default::default` providing the factory function.
42pub const STATE_BY: OperatorConstraints = OperatorConstraints {
43    name: "state_by",
44    categories: &[OperatorCategory::Persistence],
45    hard_range_inn: RANGE_1,
46    soft_range_inn: RANGE_1,
47    hard_range_out: &(0..=1),
48    soft_range_out: &(0..=1),
49    num_args: 2,
50    persistence_args: &(0..=1),
51    type_args: &(0..=1),
52    is_external_input: false,
53    has_singleton_output: true,
54    flo_type: None,
55    ports_inn: None,
56    ports_out: None,
57    input_delaytype_fn: |_| None,
58    write_fn: |&WriteContextArgs {
59                   root,
60                   context,
61                   df_ident,
62                   op_span,
63                   ident,
64                   inputs,
65                   outputs,
66                   is_pull,
67                   singleton_output_ident,
68                   op_name,
69                   op_inst:
70                       OperatorInstance {
71                           generics:
72                               OpInstGenerics {
73                                   type_args,
74                                   persistence_args,
75                                   ..
76                               },
77                           ..
78                       },
79                   arguments,
80                   ..
81               },
82               diagnostics| {
83        let lattice_type = type_args
84            .first()
85            .map(ToTokens::to_token_stream)
86            .unwrap_or(quote_spanned!(op_span=> _));
87
88        let persistence = match persistence_args[..] {
89            [] => Persistence::Tick,
90            [Persistence::Mutable] => {
91                diagnostics.push(Diagnostic::spanned(
92                    op_span,
93                    Level::Error,
94                    format!("{} does not support `'mut`.", op_name),
95                ));
96                Persistence::Tick
97            }
98            [a] => a,
99            _ => unreachable!(),
100        };
101
102
103        let state_ident = singleton_output_ident;
104        let factory_fn = &arguments[1];
105
106        let mut write_prologue = quote_spanned! { op_span=>
107                    let #state_ident = {
108                        let data_struct : #lattice_type = (#factory_fn)();
109                        ::std::debug_assert!(::lattices::IsBot::is_bot(&data_struct));
110                        #df_ident.add_state(::std::cell::RefCell::new(data_struct))
111                    };
112        };
113        if Persistence::Tick == persistence {
114            write_prologue.extend(quote_spanned! {op_span=>
115                #df_ident.set_state_tick_hook(#state_ident, |rcell| { rcell.take(); }); // Resets state to `Default::default()`.
116            });
117        }
118
119        let by_fn = &arguments[0];
120
121        // TODO(mingwei): deduplicate codegen
122        let write_iterator = if is_pull {
123            let input = &inputs[0];
124            quote_spanned! {op_span=>
125                let #ident = {
126                    fn check_input<'a, Item, MappingFn, MappedItem, Iter, Lat>(
127                        iter: Iter,
128                        mapfn: MappingFn,
129                        state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
130                        context: &'a #root::scheduled::context::Context,
131                    ) -> impl 'a + ::std::iter::Iterator<Item = Item>
132                    where
133                        Item: ::std::clone::Clone,
134                        MappingFn: 'a + Fn(Item) -> MappedItem,
135                        Iter: 'a + ::std::iter::Iterator<Item = Item>,
136                        Lat: 'static + #root::lattices::Merge<MappedItem>,
137                    {
138                        iter.filter(move |item| {
139                                let state = unsafe {
140                                    // SAFETY: handle from `#df_ident.add_state(..)`.
141                                    context.state_ref_unchecked(state_handle)
142                                };
143                                let mut state = state.borrow_mut();
144                                #root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
145                            })
146                    }
147                    check_input::<_, _, _, _, #lattice_type>(#input, #by_fn, #state_ident, #context)
148                };
149            }
150        } else if let Some(output) = outputs.first() {
151            quote_spanned! {op_span=>
152                let #ident = {
153                    fn check_output<'a, Item, MappingFn, MappedItem, Push, Lat>(
154                        push: Push,
155                        mapfn: MappingFn,
156                        state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
157                        context: &'a #root::scheduled::context::Context,
158                    ) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
159                    where
160                        Item: 'a + ::std::clone::Clone,
161                        MappingFn: 'a + Fn(Item) -> MappedItem,
162                        Push: 'a + #root::pusherator::Pusherator<Item = Item>,
163                        Lat: 'static + #root::lattices::Merge<MappedItem>,
164                    {
165                        #root::pusherator::filter::Filter::new(move |item| {
166                            let state = unsafe {
167                                // SAFETY: handle from `#df_ident.add_state(..)`.
168                                context.state_ref_unchecked(state_handle)
169                            };
170                            let mut state = state.borrow_mut();
171                                #root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
172                        }, push)
173                    }
174                    check_output::<_, _, _, _, #lattice_type>(#output, #by_fn, #state_ident, #context)
175                };
176            }
177        } else {
178            quote_spanned! {op_span=>
179                let #ident = {
180                    fn check_output<'a, Item, MappingFn, MappedItem, Lat>(
181                        state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
182                        mapfn: MappingFn,
183                        context: &'a #root::scheduled::context::Context,
184                    ) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
185                    where
186                        Item: 'a,
187                        MappedItem: 'a,
188                        MappingFn: 'a + Fn(Item) -> MappedItem,
189                        Lat: 'static + #root::lattices::Merge<MappedItem>,
190                    {
191                        #root::pusherator::for_each::ForEach::new(move |item| {
192                            let state = unsafe {
193                                // SAFETY: handle from `#df_ident.add_state(..)`.
194                                context.state_ref_unchecked(state_handle)
195                            };
196                            let mut state = state.borrow_mut();
197                            #root::lattices::Merge::merge(&mut *state, (mapfn)(item));
198                        })
199                    }
200                    check_output::<_, _, _, #lattice_type>(#state_ident, #by_fn, #context)
201                };
202            }
203        };
204        Ok(OperatorWriteOutput {
205            write_prologue,
206            write_iterator,
207            ..Default::default()
208        })
209    },
210};