1use quote::{quote_spanned, ToTokens};
2
3use super::{
4 OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
5 Persistence, WriteContextArgs, RANGE_1,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9pub const STATE_BY: OperatorConstraints = OperatorConstraints {
43 name: "state_by",
44 categories: &[OperatorCategory::Persistence],
45 hard_range_inn: RANGE_1,
46 soft_range_inn: RANGE_1,
47 hard_range_out: &(0..=1),
48 soft_range_out: &(0..=1),
49 num_args: 2,
50 persistence_args: &(0..=1),
51 type_args: &(0..=1),
52 is_external_input: false,
53 has_singleton_output: true,
54 flo_type: None,
55 ports_inn: None,
56 ports_out: None,
57 input_delaytype_fn: |_| None,
58 write_fn: |&WriteContextArgs {
59 root,
60 context,
61 df_ident,
62 op_span,
63 ident,
64 inputs,
65 outputs,
66 is_pull,
67 singleton_output_ident,
68 op_name,
69 op_inst:
70 OperatorInstance {
71 generics:
72 OpInstGenerics {
73 type_args,
74 persistence_args,
75 ..
76 },
77 ..
78 },
79 arguments,
80 ..
81 },
82 diagnostics| {
83 let lattice_type = type_args
84 .first()
85 .map(ToTokens::to_token_stream)
86 .unwrap_or(quote_spanned!(op_span=> _));
87
88 let persistence = match persistence_args[..] {
89 [] => Persistence::Tick,
90 [Persistence::Mutable] => {
91 diagnostics.push(Diagnostic::spanned(
92 op_span,
93 Level::Error,
94 format!("{} does not support `'mut`.", op_name),
95 ));
96 Persistence::Tick
97 }
98 [a] => a,
99 _ => unreachable!(),
100 };
101
102
103 let state_ident = singleton_output_ident;
104 let factory_fn = &arguments[1];
105
106 let mut write_prologue = quote_spanned! { op_span=>
107 let #state_ident = {
108 let data_struct : #lattice_type = (#factory_fn)();
109 ::std::debug_assert!(::lattices::IsBot::is_bot(&data_struct));
110 #df_ident.add_state(::std::cell::RefCell::new(data_struct))
111 };
112 };
113 if Persistence::Tick == persistence {
114 write_prologue.extend(quote_spanned! {op_span=>
115 #df_ident.set_state_tick_hook(#state_ident, |rcell| { rcell.take(); }); });
117 }
118
119 let by_fn = &arguments[0];
120
121 let write_iterator = if is_pull {
123 let input = &inputs[0];
124 quote_spanned! {op_span=>
125 let #ident = {
126 fn check_input<'a, Item, MappingFn, MappedItem, Iter, Lat>(
127 iter: Iter,
128 mapfn: MappingFn,
129 state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
130 context: &'a #root::scheduled::context::Context,
131 ) -> impl 'a + ::std::iter::Iterator<Item = Item>
132 where
133 Item: ::std::clone::Clone,
134 MappingFn: 'a + Fn(Item) -> MappedItem,
135 Iter: 'a + ::std::iter::Iterator<Item = Item>,
136 Lat: 'static + #root::lattices::Merge<MappedItem>,
137 {
138 iter.filter(move |item| {
139 let state = unsafe {
140 context.state_ref_unchecked(state_handle)
142 };
143 let mut state = state.borrow_mut();
144 #root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
145 })
146 }
147 check_input::<_, _, _, _, #lattice_type>(#input, #by_fn, #state_ident, #context)
148 };
149 }
150 } else if let Some(output) = outputs.first() {
151 quote_spanned! {op_span=>
152 let #ident = {
153 fn check_output<'a, Item, MappingFn, MappedItem, Push, Lat>(
154 push: Push,
155 mapfn: MappingFn,
156 state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
157 context: &'a #root::scheduled::context::Context,
158 ) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
159 where
160 Item: 'a + ::std::clone::Clone,
161 MappingFn: 'a + Fn(Item) -> MappedItem,
162 Push: 'a + #root::pusherator::Pusherator<Item = Item>,
163 Lat: 'static + #root::lattices::Merge<MappedItem>,
164 {
165 #root::pusherator::filter::Filter::new(move |item| {
166 let state = unsafe {
167 context.state_ref_unchecked(state_handle)
169 };
170 let mut state = state.borrow_mut();
171 #root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
172 }, push)
173 }
174 check_output::<_, _, _, _, #lattice_type>(#output, #by_fn, #state_ident, #context)
175 };
176 }
177 } else {
178 quote_spanned! {op_span=>
179 let #ident = {
180 fn check_output<'a, Item, MappingFn, MappedItem, Lat>(
181 state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
182 mapfn: MappingFn,
183 context: &'a #root::scheduled::context::Context,
184 ) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
185 where
186 Item: 'a,
187 MappedItem: 'a,
188 MappingFn: 'a + Fn(Item) -> MappedItem,
189 Lat: 'static + #root::lattices::Merge<MappedItem>,
190 {
191 #root::pusherator::for_each::ForEach::new(move |item| {
192 let state = unsafe {
193 context.state_ref_unchecked(state_handle)
195 };
196 let mut state = state.borrow_mut();
197 #root::lattices::Merge::merge(&mut *state, (mapfn)(item));
198 })
199 }
200 check_output::<_, _, _, #lattice_type>(#state_ident, #by_fn, #context)
201 };
202 }
203 };
204 Ok(OperatorWriteOutput {
205 write_prologue,
206 write_iterator,
207 ..Default::default()
208 })
209 },
210};