1use std::cell::RefCell;
2use std::collections::HashSet;
3
4use crate::ir::*;
5
6fn persist_pullup_node(
7 node: &mut HydroNode,
8 persist_pulled_tees: &mut HashSet<*const RefCell<HydroNode>>,
9) {
10 *node = match_box::match_box! {
11 match std::mem::replace(node, HydroNode::Placeholder) {
12 HydroNode::Unpersist { inner: mb!(* HydroNode::Persist { inner: mb!(* behind_persist), .. }), .. } => behind_persist,
13
14 HydroNode::Delta { inner: mb!(* HydroNode::Persist { inner: mb!(* behind_persist), .. }), .. } => behind_persist,
15
16 HydroNode::Tee { inner, metadata } => {
18 if persist_pulled_tees.contains(&(inner.0.as_ref() as *const RefCell<HydroNode>)) {
19 HydroNode::Persist {
20 inner: Box::new(HydroNode::Tee {
21 inner: TeeNode(inner.0.clone()),
22 metadata: metadata.clone(),
23 }),
24 metadata: metadata.clone(),
25 }
26 } else if matches!(*inner.0.borrow(), HydroNode::Persist { .. }) {
27 persist_pulled_tees.insert(inner.0.as_ref() as *const RefCell<HydroNode>);
28 if let HydroNode::Persist { inner: behind_persist, .. } =
29 inner.0.replace(HydroNode::Placeholder)
30 {
31 *inner.0.borrow_mut() = *behind_persist;
32 } else {
33 unreachable!()
34 }
35
36 HydroNode::Persist {
37 inner: Box::new(HydroNode::Tee {
38 inner: TeeNode(inner.0.clone()),
39 metadata: metadata.clone(),
40 }),
41 metadata: metadata.clone(),
42 }
43 } else {
44 HydroNode::Tee { inner, metadata }
45 }
46 }
47
48 HydroNode::ResolveFutures {
49 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
50 metadata,
51 } => HydroNode::Persist {
52 inner: Box::new(HydroNode::ResolveFutures {
53 input: behind_persist,
54 metadata: metadata.clone(),
55 }),
56 metadata: metadata.clone(),
57 },
58
59 HydroNode::ResolveFuturesOrdered {
60 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
61 metadata,
62 } => HydroNode::Persist {
63 inner: Box::new(HydroNode::ResolveFuturesOrdered {
64 input: behind_persist,
65 metadata: metadata.clone(),
66 }),
67 metadata: metadata.clone(),
68 },
69
70 HydroNode::Map {
71 f,
72 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
73 metadata,
74 } => HydroNode::Persist {
75 inner: Box::new(HydroNode::Map {
76 f,
77 input: behind_persist,
78 metadata: metadata.clone(),
79 }),
80 metadata: metadata.clone(),
81 },
82
83 HydroNode::FilterMap {
84 f,
85 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
86 metadata,
87 } => HydroNode::Persist {
88 inner: Box::new(HydroNode::FilterMap {
89 f,
90 input: behind_persist,
91 metadata: metadata.clone(),
92 }),
93 metadata: metadata.clone()
94 },
95
96 HydroNode::FlatMap {
97 f,
98 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
99 metadata,
100 } => HydroNode::Persist {
101 inner: Box::new(HydroNode::FlatMap {
102 f,
103 input: behind_persist,
104 metadata: metadata.clone(),
105 }),
106 metadata: metadata.clone()
107 },
108
109 HydroNode::Filter {
110 f,
111 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
112 metadata,
113 } => HydroNode::Persist {
114 inner: Box::new(HydroNode::Filter {
115 f,
116 input: behind_persist,
117 metadata: metadata.clone(),
118 }),
119 metadata: metadata.clone()
120 },
121
122 HydroNode::Network {
123 serialize_fn,
124 instantiate_fn,
125 deserialize_fn,
126 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
127 metadata,
128 } => HydroNode::Persist {
129 inner: Box::new(HydroNode::Network {
130 serialize_fn,
131 instantiate_fn,
132 deserialize_fn,
133 input: behind_persist,
134 metadata: metadata.clone()
135 }),
136 metadata: metadata.clone(),
137 },
138
139 HydroNode::Chain {
140 first: mb!(* HydroNode::Persist { inner: first, metadata: persist_metadata }),
141 second: mb!(* HydroNode::Persist { inner: second, .. }),
142 metadata
143 } => HydroNode::Persist {
144 inner: Box::new(HydroNode::Chain { first, second, metadata }),
145 metadata: persist_metadata
146 },
147
148 HydroNode::CrossProduct {
149 left: mb!(* HydroNode::Persist { inner: left, metadata: left_metadata }),
150 right: mb!(* HydroNode::Persist { inner: right, metadata: right_metadata }),
151 metadata
152 } => HydroNode::Persist {
153 inner: Box::new(HydroNode::Delta {
154 inner: Box::new(HydroNode::CrossProduct {
155 left: Box::new(HydroNode::Persist { inner: left, metadata: left_metadata }),
156 right: Box::new(HydroNode::Persist { inner: right, metadata: right_metadata }),
157 metadata: metadata.clone()
158 }),
159 metadata: metadata.clone(),
160 }),
161 metadata: metadata.clone(),
162 },
163 HydroNode::Join {
164 left: mb!(* HydroNode::Persist { inner: left, metadata: left_metadata }),
165 right: mb!(* HydroNode::Persist { inner: right, metadata: right_metadata }),
166 metadata
167 } => HydroNode::Persist {
168 inner: Box::new(HydroNode::Delta {
169 inner: Box::new(HydroNode::Join {
170 left: Box::new(HydroNode::Persist { inner: left, metadata: left_metadata }),
171 right: Box::new(HydroNode::Persist { inner: right, metadata: right_metadata }),
172 metadata: metadata.clone()
173 }),
174 metadata: metadata.clone(),
175 }),
176 metadata: metadata.clone(),
177 },
178
179 HydroNode::Unique { input: mb!(* HydroNode::Persist {inner, metadata: persist_metadata } ), metadata } => HydroNode::Persist {
180 inner: Box::new(HydroNode::Delta {
181 inner: Box::new(HydroNode::Unique {
182 input: Box::new(HydroNode::Persist { inner, metadata: persist_metadata }),
183 metadata: metadata.clone()
184 }),
185 metadata: metadata.clone(),
186 }),
187 metadata: metadata.clone()
188 },
189
190 node => node,
191 }
192 };
193}
194
195pub fn persist_pullup(ir: &mut [HydroRoot]) {
196 let mut persist_pulled_tees = Default::default();
197 transform_bottom_up(
198 ir,
199 &mut |_| (),
200 &mut |node| persist_pullup_node(node, &mut persist_pulled_tees),
201 false,
202 );
203}
204
205#[cfg(stageleft_runtime)]
206#[cfg(test)]
207mod tests {
208 use stageleft::*;
209
210 use crate::deploy::HydroDeploy;
211 use crate::location::Location;
212 use crate::nondet;
213
214 #[test]
215 fn persist_pullup_through_map() {
216 let flow = crate::builder::FlowBuilder::new();
217 let process = flow.process::<()>();
218
219 process
220 .source_iter(q!(0..10))
221 .map(q!(|v| v + 1))
222 .for_each(q!(|n| println!("{}", n)));
223
224 let built = flow.finalize();
225
226 hydro_build_utils::assert_debug_snapshot!(built.ir());
227
228 let optimized = built.optimize_with(super::persist_pullup);
229
230 hydro_build_utils::assert_debug_snapshot!(optimized.ir());
231 for (id, graph) in optimized
232 .into_deploy::<HydroDeploy>()
233 .preview_compile()
234 .all_dfir()
235 {
236 hydro_build_utils::insta::with_settings!({
237 snapshot_suffix => format!("surface_graph_{id}"),
238 }, {
239 hydro_build_utils::assert_snapshot!(graph.surface_syntax_string());
240 });
241 }
242 }
243
244 #[test]
245 fn persist_pullup_behind_tee() {
246 let flow = crate::builder::FlowBuilder::new();
247 let process = flow.process::<()>();
248
249 let tick = process.tick();
250 let before_tee = process
251 .source_iter(q!(0..10))
252 .batch(&tick, nondet!())
253 .persist();
254
255 before_tee
256 .clone()
257 .map(q!(|v| v + 1))
258 .all_ticks()
259 .for_each(q!(|n| println!("{}", n)));
260
261 before_tee
262 .clone()
263 .map(q!(|v| v + 1))
264 .all_ticks()
265 .for_each(q!(|n| println!("{}", n)));
266
267 let built = flow.finalize();
268
269 hydro_build_utils::assert_debug_snapshot!(built.ir());
270
271 let optimized = built.optimize_with(super::persist_pullup);
272
273 hydro_build_utils::assert_debug_snapshot!(optimized.ir());
274
275 for (id, graph) in optimized
276 .into_deploy::<HydroDeploy>()
277 .preview_compile()
278 .all_dfir()
279 {
280 hydro_build_utils::insta::with_settings!({
281 snapshot_suffix => format!("surface_graph_{id}"),
282 }, {
283 hydro_build_utils::assert_snapshot!(graph.surface_syntax_string());
284 });
285 }
286 }
287}