1use quote::{ToTokens, quote_spanned};
2
3use super::{
4 OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
5 Persistence, RANGE_1, WriteContextArgs,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9pub const STATE_BY: OperatorConstraints = OperatorConstraints {
43 name: "state_by",
44 categories: &[OperatorCategory::Persistence],
45 hard_range_inn: RANGE_1,
46 soft_range_inn: RANGE_1,
47 hard_range_out: &(0..=1),
48 soft_range_out: &(0..=1),
49 num_args: 2,
50 persistence_args: &(0..=1),
51 type_args: &(0..=1),
52 is_external_input: false,
53 has_singleton_output: true,
54 flo_type: None,
55 ports_inn: None,
56 ports_out: None,
57 input_delaytype_fn: |_| None,
58 write_fn: |wc @ &WriteContextArgs {
59 root,
60 context,
61 df_ident,
62 op_span,
63 ident,
64 inputs,
65 outputs,
66 is_pull,
67 singleton_output_ident,
68 op_name,
69 op_inst:
70 OperatorInstance {
71 generics:
72 OpInstGenerics {
73 type_args,
74 persistence_args,
75 ..
76 },
77 ..
78 },
79 arguments,
80 ..
81 },
82 diagnostics| {
83 let lattice_type = type_args
84 .first()
85 .map(ToTokens::to_token_stream)
86 .unwrap_or(quote_spanned!(op_span=> _));
87
88 let persistence = match persistence_args[..] {
89 [] => Persistence::Tick,
90 [Persistence::Mutable] => {
91 diagnostics.push(Diagnostic::spanned(
92 op_span,
93 Level::Error,
94 format!("{} does not support `'mut`.", op_name),
95 ));
96 Persistence::Tick
97 }
98 [a] => a,
99 _ => unreachable!(),
100 };
101
102 let state_ident = singleton_output_ident;
103 let factory_fn = &arguments[1];
104
105 let write_prologue = quote_spanned! {op_span=>
106 let #state_ident = {
107 let data_struct: #lattice_type = (#factory_fn)();
108 ::std::debug_assert!(::lattices::IsBot::is_bot(&data_struct));
109 #df_ident.add_state(::std::cell::RefCell::new(data_struct))
110 };
111 };
112 let write_prologue_after = wc
113 .persistence_as_state_lifespan(persistence)
114 .map(|lifespan| quote_spanned! {op_span=>
115 #df_ident.set_state_lifespan_hook(#state_ident, #lifespan, |rcell| { rcell.take(); });
116 }).unwrap_or_default();
117
118 let by_fn = &arguments[0];
119
120 let write_iterator = if is_pull {
122 let input = &inputs[0];
123 quote_spanned! {op_span=>
124 let #ident = {
125 fn check_input<'a, Item, MappingFn, MappedItem, Iter, Lat>(
126 iter: Iter,
127 mapfn: MappingFn,
128 state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
129 context: &'a #root::scheduled::context::Context,
130 ) -> impl 'a + ::std::iter::Iterator<Item = Item>
131 where
132 Item: ::std::clone::Clone,
133 MappingFn: 'a + Fn(Item) -> MappedItem,
134 Iter: 'a + ::std::iter::Iterator<Item = Item>,
135 Lat: 'static + #root::lattices::Merge<MappedItem>,
136 {
137 iter.filter(move |item| {
138 let state = unsafe {
139 context.state_ref_unchecked(state_handle)
141 };
142 let mut state = state.borrow_mut();
143 #root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
144 })
145 }
146 check_input::<_, _, _, _, #lattice_type>(#input, #by_fn, #state_ident, #context)
147 };
148 }
149 } else if let Some(output) = outputs.first() {
150 quote_spanned! {op_span=>
151 let #ident = {
152 fn check_output<'a, Item, MappingFn, MappedItem, Push, Lat>(
153 push: Push,
154 mapfn: MappingFn,
155 state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
156 context: &'a #root::scheduled::context::Context,
157 ) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
158 where
159 Item: 'a + ::std::clone::Clone,
160 MappingFn: 'a + Fn(Item) -> MappedItem,
161 Push: 'a + #root::pusherator::Pusherator<Item = Item>,
162 Lat: 'static + #root::lattices::Merge<MappedItem>,
163 {
164 #root::pusherator::filter::Filter::new(move |item| {
165 let state = unsafe {
166 context.state_ref_unchecked(state_handle)
168 };
169 let mut state = state.borrow_mut();
170 #root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
171 }, push)
172 }
173 check_output::<_, _, _, _, #lattice_type>(#output, #by_fn, #state_ident, #context)
174 };
175 }
176 } else {
177 quote_spanned! {op_span=>
178 let #ident = {
179 fn check_output<'a, Item, MappingFn, MappedItem, Lat>(
180 state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
181 mapfn: MappingFn,
182 context: &'a #root::scheduled::context::Context,
183 ) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
184 where
185 Item: 'a,
186 MappedItem: 'a,
187 MappingFn: 'a + Fn(Item) -> MappedItem,
188 Lat: 'static + #root::lattices::Merge<MappedItem>,
189 {
190 #root::pusherator::for_each::ForEach::new(move |item| {
191 let state = unsafe {
192 context.state_ref_unchecked(state_handle)
194 };
195 let mut state = state.borrow_mut();
196 #root::lattices::Merge::merge(&mut *state, (mapfn)(item));
197 })
198 }
199 check_output::<_, _, _, #lattice_type>(#state_ident, #by_fn, #context)
200 };
201 }
202 };
203 Ok(OperatorWriteOutput {
204 write_prologue,
205 write_prologue_after,
206 write_iterator,
207 ..Default::default()
208 })
209 },
210};