dfir_lang/graph/ops/
fold.rs
1use quote::quote_spanned;
2
3use super::{
4 DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5 OperatorWriteOutput, Persistence, RANGE_0, RANGE_1, WriteContextArgs,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9pub const FOLD: OperatorConstraints = OperatorConstraints {
37 name: "fold",
38 categories: &[OperatorCategory::Fold],
39 hard_range_inn: RANGE_1,
40 soft_range_inn: RANGE_1,
41 hard_range_out: &(0..=1),
42 soft_range_out: &(0..=1),
43 num_args: 2,
44 persistence_args: &(0..=1),
45 type_args: RANGE_0,
46 is_external_input: false,
47 has_singleton_output: true,
48 flo_type: None,
49 ports_inn: None,
50 ports_out: None,
51 input_delaytype_fn: |_| Some(DelayType::Stratum),
52 write_fn: |wc @ &WriteContextArgs {
53 root,
54 context,
55 df_ident,
56 loop_id,
57 op_span,
58 ident,
59 is_pull,
60 inputs,
61 singleton_output_ident,
62 work_fn,
63 op_inst:
64 OperatorInstance {
65 generics:
66 OpInstGenerics {
67 persistence_args, ..
68 },
69 ..
70 },
71 arguments,
72 ..
73 },
74 diagnostics| {
75
76 let persistence = persistence_args.first().copied().unwrap_or_else(|| {
77 if loop_id.is_some() {
78 Persistence::None
79 } else {
80 Persistence::Tick
81 }
82 });
83 if Persistence::Mutable == persistence {
84 diagnostics.push(Diagnostic::spanned(
85 op_span,
86 Level::Error,
87 "An implementation of 'mutable does not exist",
88 ));
89 return Err(());
90 }
91
92 let input = &inputs[0];
93 let init = &arguments[0];
94 let func = &arguments[1];
95 let initializer_func_ident = wc.make_ident("initializer_func");
96 let accumulator_ident = wc.make_ident("accumulator");
97 let iterator_item_ident = wc.make_ident("iterator_item");
98
99 let iterator_foreach = quote_spanned! {op_span=>
100 #[inline(always)]
101 fn call_comb_type<Accum, Item>(
102 accum: &mut Accum,
103 item: Item,
104 func: impl Fn(&mut Accum, Item),
105 ) {
106 (func)(accum, item);
107 }
108 #[allow(clippy::redundant_closure_call)]
109 call_comb_type(&mut *#accumulator_ident, #iterator_item_ident, #func);
110 };
111
112 let mut write_prologue = quote_spanned! {op_span=>
113 #[allow(unused_mut)]
114 let mut #initializer_func_ident = #init;
115
116 #[allow(clippy::redundant_closure_call)]
117 let #singleton_output_ident = #df_ident.add_state(
118 ::std::cell::RefCell::new((#initializer_func_ident)())
119 );
120 };
121 if Persistence::Tick == persistence {
122 write_prologue.extend(quote_spanned! {op_span=>
123 #df_ident.set_state_tick_hook(#singleton_output_ident, move |rcell| { rcell.replace((#initializer_func_ident)()); });
125 });
126 }
127 let write_iterator = if is_pull {
128 quote_spanned! {op_span=>
129 let #ident = {
130 let mut #accumulator_ident = unsafe {
131 #context.state_ref_unchecked(#singleton_output_ident)
133 }.borrow_mut();
134
135 #work_fn(|| #input.for_each(|#iterator_item_ident| {
136 #iterator_foreach
137 }));
138
139 #[allow(clippy::clone_on_copy)]
140 {
141 ::std::iter::once(#work_fn(|| ::std::clone::Clone::clone(&*#accumulator_ident)))
142 }
143 };
144 }
145 } else {
146 quote_spanned! {op_span=>
147 let #ident = {
148 #root::pusherator::for_each::ForEach::new(|#iterator_item_ident| {
149 let mut #accumulator_ident = unsafe {
150 #context.state_ref_unchecked(#singleton_output_ident)
152 }.borrow_mut();
153 #iterator_foreach
154 })
155 };
156 }
157 };
158 let write_iterator_after = quote_spanned! {op_span=>
159 #context.schedule_subgraph(#context.current_subgraph(), false);
160 };
161
162 Ok(OperatorWriteOutput {
163 write_prologue,
164 write_iterator,
165 write_iterator_after,
166 })
167 },
168};