1use std::cell::RefCell;
2use std::collections::HashSet;
3
4use crate::ir::*;
5
6fn persist_pullup_node(
7 node: &mut HydroNode,
8 persist_pulled_tees: &mut HashSet<*const RefCell<HydroNode>>,
9) {
10 *node = match_box::match_box! {
11 match std::mem::replace(node, HydroNode::Placeholder) {
12 HydroNode::Unpersist { inner: mb!(* HydroNode::Persist { inner: mb!(* behind_persist), .. }), .. } => behind_persist,
13
14 HydroNode::Delta { inner: mb!(* HydroNode::Persist { inner: mb!(* behind_persist), .. }), .. } => behind_persist,
15
16 HydroNode::Tee { inner, metadata } => {
18 if persist_pulled_tees.contains(&(inner.0.as_ref() as *const RefCell<HydroNode>)) {
19 HydroNode::Persist {
20 inner: Box::new(HydroNode::Tee {
21 inner: TeeNode(inner.0.clone()),
22 metadata: metadata.clone(),
23 }),
24 metadata: metadata.clone(),
25 }
26 } else if matches!(*inner.0.borrow(), HydroNode::Persist { .. }) {
27 persist_pulled_tees.insert(inner.0.as_ref() as *const RefCell<HydroNode>);
28 if let HydroNode::Persist { inner: behind_persist, .. } =
29 inner.0.replace(HydroNode::Placeholder)
30 {
31 *inner.0.borrow_mut() = *behind_persist;
32 } else {
33 unreachable!()
34 }
35
36 HydroNode::Persist {
37 inner: Box::new(HydroNode::Tee {
38 inner: TeeNode(inner.0.clone()),
39 metadata: metadata.clone(),
40 }),
41 metadata: metadata.clone(),
42 }
43 } else {
44 HydroNode::Tee { inner, metadata }
45 }
46 }
47
48 HydroNode::ResolveFutures {
49 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
50 metadata,
51 } => HydroNode::Persist {
52 inner: Box::new(HydroNode::ResolveFutures {
53 input: behind_persist,
54 metadata: metadata.clone(),
55 }),
56 metadata: metadata.clone(),
57 },
58
59 HydroNode::ResolveFuturesOrdered {
60 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
61 metadata,
62 } => HydroNode::Persist {
63 inner: Box::new(HydroNode::ResolveFuturesOrdered {
64 input: behind_persist,
65 metadata: metadata.clone(),
66 }),
67 metadata: metadata.clone(),
68 },
69
70 HydroNode::Map {
71 f,
72 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
73 metadata,
74 } => HydroNode::Persist {
75 inner: Box::new(HydroNode::Map {
76 f,
77 input: behind_persist,
78 metadata: metadata.clone(),
79 }),
80 metadata: metadata.clone(),
81 },
82
83 HydroNode::FilterMap {
84 f,
85 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
86 metadata,
87 } => HydroNode::Persist {
88 inner: Box::new(HydroNode::FilterMap {
89 f,
90 input: behind_persist,
91 metadata: metadata.clone(),
92 }),
93 metadata: metadata.clone()
94 },
95
96 HydroNode::FlatMap {
97 f,
98 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
99 metadata,
100 } => HydroNode::Persist {
101 inner: Box::new(HydroNode::FlatMap {
102 f,
103 input: behind_persist,
104 metadata: metadata.clone(),
105 }),
106 metadata: metadata.clone()
107 },
108
109 HydroNode::Filter {
110 f,
111 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
112 metadata,
113 } => HydroNode::Persist {
114 inner: Box::new(HydroNode::Filter {
115 f,
116 input: behind_persist,
117 metadata: metadata.clone(),
118 }),
119 metadata: metadata.clone()
120 },
121
122 HydroNode::Network {
123 from_key,
124 to_location,
125 to_key,
126 serialize_fn,
127 instantiate_fn,
128 deserialize_fn,
129 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
130 metadata,
131 } => HydroNode::Persist {
132 inner: Box::new(HydroNode::Network {
133 from_key,
134 to_location,
135 to_key,
136 serialize_fn,
137 instantiate_fn,
138 deserialize_fn,
139 input: behind_persist,
140 metadata: metadata.clone()
141 }),
142 metadata: metadata.clone(),
143 },
144
145 HydroNode::Chain {
146 first: mb!(* HydroNode::Persist { inner: first, metadata: persist_metadata }),
147 second: mb!(* HydroNode::Persist { inner: second, .. }),
148 metadata
149 } => HydroNode::Persist {
150 inner: Box::new(HydroNode::Chain { first, second, metadata }),
151 metadata: persist_metadata
152 },
153
154 HydroNode::CrossProduct {
155 left: mb!(* HydroNode::Persist { inner: left, metadata: left_metadata }),
156 right: mb!(* HydroNode::Persist { inner: right, metadata: right_metadata }),
157 metadata
158 } => HydroNode::Persist {
159 inner: Box::new(HydroNode::Delta {
160 inner: Box::new(HydroNode::CrossProduct {
161 left: Box::new(HydroNode::Persist { inner: left, metadata: left_metadata }),
162 right: Box::new(HydroNode::Persist { inner: right, metadata: right_metadata }),
163 metadata: metadata.clone()
164 }),
165 metadata: metadata.clone(),
166 }),
167 metadata: metadata.clone(),
168 },
169 HydroNode::Join {
170 left: mb!(* HydroNode::Persist { inner: left, metadata: left_metadata }),
171 right: mb!(* HydroNode::Persist { inner: right, metadata: right_metadata }),
172 metadata
173 } => HydroNode::Persist {
174 inner: Box::new(HydroNode::Delta {
175 inner: Box::new(HydroNode::Join {
176 left: Box::new(HydroNode::Persist { inner: left, metadata: left_metadata }),
177 right: Box::new(HydroNode::Persist { inner: right, metadata: right_metadata }),
178 metadata: metadata.clone()
179 }),
180 metadata: metadata.clone(),
181 }),
182 metadata: metadata.clone(),
183 },
184
185 HydroNode::Unique { input: mb!(* HydroNode::Persist {inner, metadata: persist_metadata } ), metadata } => HydroNode::Persist {
186 inner: Box::new(HydroNode::Delta {
187 inner: Box::new(HydroNode::Unique {
188 input: Box::new(HydroNode::Persist { inner, metadata: persist_metadata }),
189 metadata: metadata.clone()
190 }),
191 metadata: metadata.clone(),
192 }),
193 metadata: metadata.clone()
194 },
195
196 node => node,
197 }
198 };
199}
200
201pub fn persist_pullup(ir: &mut [HydroLeaf]) {
202 let mut persist_pulled_tees = Default::default();
203 transform_bottom_up(ir, &mut |_| (), &mut |node| {
204 persist_pullup_node(node, &mut persist_pulled_tees)
205 });
206}
207
208#[cfg(test)]
209mod tests {
210 use stageleft::*;
211
212 use crate::deploy::MultiGraph;
213 use crate::location::Location;
214
215 #[test]
216 fn persist_pullup_through_map() {
217 let flow = crate::builder::FlowBuilder::new();
218 let process = flow.process::<()>();
219
220 process
221 .source_iter(q!(0..10))
222 .map(q!(|v| v + 1))
223 .for_each(q!(|n| println!("{}", n)));
224
225 let built = flow.finalize();
226
227 insta::assert_debug_snapshot!(built.ir());
228
229 let optimized = built.optimize_with(super::persist_pullup);
230
231 insta::assert_debug_snapshot!(optimized.ir());
232 for (id, graph) in optimized.compile_no_network::<MultiGraph>().all_dfir() {
233 insta::with_settings!({snapshot_suffix => format!("surface_graph_{id}")}, {
234 insta::assert_snapshot!(graph.surface_syntax_string());
235 });
236 }
237 }
238
239 #[test]
240 fn persist_pullup_behind_tee() {
241 let flow = crate::builder::FlowBuilder::new();
242 let process = flow.process::<()>();
243
244 let tick = process.tick();
245 let before_tee = unsafe { process.source_iter(q!(0..10)).tick_batch(&tick).persist() };
246
247 before_tee
248 .clone()
249 .map(q!(|v| v + 1))
250 .all_ticks()
251 .for_each(q!(|n| println!("{}", n)));
252
253 before_tee
254 .clone()
255 .map(q!(|v| v + 1))
256 .all_ticks()
257 .for_each(q!(|n| println!("{}", n)));
258
259 let built = flow.finalize();
260
261 insta::assert_debug_snapshot!(built.ir());
262
263 let optimized = built.optimize_with(super::persist_pullup);
264
265 insta::assert_debug_snapshot!(optimized.ir());
266
267 for (id, graph) in optimized.compile_no_network::<MultiGraph>().all_dfir() {
268 insta::with_settings!({snapshot_suffix => format!("surface_graph_{id}")}, {
269 insta::assert_snapshot!(graph.surface_syntax_string());
270 });
271 }
272 }
273}