1use std::cell::RefCell;
2use std::collections::HashSet;
3
4use crate::ir::*;
5
6fn persist_pullup_node(
7 node: &mut HydroNode,
8 persist_pulled_tees: &mut HashSet<*const RefCell<HydroNode>>,
9) {
10 *node = match_box::match_box! {
11 match std::mem::replace(node, HydroNode::Placeholder) {
12 HydroNode::Unpersist { inner: mb!(* HydroNode::Persist { inner: mb!(* behind_persist), .. }), .. } => behind_persist,
13
14 HydroNode::Delta { inner: mb!(* HydroNode::Persist { inner: mb!(* behind_persist), .. }), .. } => behind_persist,
15
16 HydroNode::Tee { inner, metadata } => {
18 if persist_pulled_tees.contains(&(inner.0.as_ref() as *const RefCell<HydroNode>)) {
19 HydroNode::Persist {
20 inner: Box::new(HydroNode::Tee {
21 inner: TeeNode(inner.0.clone()),
22 metadata: metadata.clone(),
23 }),
24 metadata: metadata.clone(),
25 }
26 } else if matches!(*inner.0.borrow(), HydroNode::Persist { .. }) {
27 persist_pulled_tees.insert(inner.0.as_ref() as *const RefCell<HydroNode>);
28 if let HydroNode::Persist { inner: behind_persist, .. } =
29 inner.0.replace(HydroNode::Placeholder)
30 {
31 *inner.0.borrow_mut() = *behind_persist;
32 } else {
33 unreachable!()
34 }
35
36 HydroNode::Persist {
37 inner: Box::new(HydroNode::Tee {
38 inner: TeeNode(inner.0.clone()),
39 metadata: metadata.clone(),
40 }),
41 metadata: metadata.clone(),
42 }
43 } else {
44 HydroNode::Tee { inner, metadata }
45 }
46 }
47
48 HydroNode::Map {
49 f,
50 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
51 metadata,
52 } => HydroNode::Persist {
53 inner: Box::new(HydroNode::Map {
54 f,
55 input: behind_persist,
56 metadata: metadata.clone(),
57 }),
58 metadata: metadata.clone(),
59 },
60
61 HydroNode::FilterMap {
62 f,
63 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
64 metadata,
65 } => HydroNode::Persist {
66 inner: Box::new(HydroNode::FilterMap {
67 f,
68 input: behind_persist,
69 metadata: metadata.clone(),
70 }),
71 metadata: metadata.clone()
72 },
73
74 HydroNode::FlatMap {
75 f,
76 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
77 metadata,
78 } => HydroNode::Persist {
79 inner: Box::new(HydroNode::FlatMap {
80 f,
81 input: behind_persist,
82 metadata: metadata.clone(),
83 }),
84 metadata: metadata.clone()
85 },
86
87 HydroNode::Filter {
88 f,
89 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
90 metadata,
91 } => HydroNode::Persist {
92 inner: Box::new(HydroNode::Filter {
93 f,
94 input: behind_persist,
95 metadata: metadata.clone(),
96 }),
97 metadata: metadata.clone()
98 },
99
100 HydroNode::Network {
101 from_key,
102 to_location,
103 to_key,
104 serialize_fn,
105 instantiate_fn,
106 deserialize_fn,
107 input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
108 metadata,
109 } => HydroNode::Persist {
110 inner: Box::new(HydroNode::Network {
111 from_key,
112 to_location,
113 to_key,
114 serialize_fn,
115 instantiate_fn,
116 deserialize_fn,
117 input: behind_persist,
118 metadata: metadata.clone()
119 }),
120 metadata: metadata.clone(),
121 },
122
123 HydroNode::Chain {
124 first: mb!(* HydroNode::Persist { inner: first, metadata: persist_metadata }),
125 second: mb!(* HydroNode::Persist { inner: second, .. }),
126 metadata
127 } => HydroNode::Persist {
128 inner: Box::new(HydroNode::Chain { first, second, metadata }),
129 metadata: persist_metadata
130 },
131
132 HydroNode::CrossProduct {
133 left: mb!(* HydroNode::Persist { inner: left, metadata: left_metadata }),
134 right: mb!(* HydroNode::Persist { inner: right, metadata: right_metadata }),
135 metadata
136 } => HydroNode::Persist {
137 inner: Box::new(HydroNode::Delta {
138 inner: Box::new(HydroNode::CrossProduct {
139 left: Box::new(HydroNode::Persist { inner: left, metadata: left_metadata }),
140 right: Box::new(HydroNode::Persist { inner: right, metadata: right_metadata }),
141 metadata: metadata.clone()
142 }),
143 metadata: metadata.clone(),
144 }),
145 metadata: metadata.clone(),
146 },
147 HydroNode::Join {
148 left: mb!(* HydroNode::Persist { inner: left, metadata: left_metadata }),
149 right: mb!(* HydroNode::Persist { inner: right, metadata: right_metadata }),
150 metadata
151 } => HydroNode::Persist {
152 inner: Box::new(HydroNode::Delta {
153 inner: Box::new(HydroNode::Join {
154 left: Box::new(HydroNode::Persist { inner: left, metadata: left_metadata }),
155 right: Box::new(HydroNode::Persist { inner: right, metadata: right_metadata }),
156 metadata: metadata.clone()
157 }),
158 metadata: metadata.clone(),
159 }),
160 metadata: metadata.clone(),
161 },
162
163 HydroNode::Unique { input: mb!(* HydroNode::Persist {inner, metadata: persist_metadata } ), metadata } => HydroNode::Persist {
164 inner: Box::new(HydroNode::Delta {
165 inner: Box::new(HydroNode::Unique {
166 input: Box::new(HydroNode::Persist { inner, metadata: persist_metadata }),
167 metadata: metadata.clone()
168 }),
169 metadata: metadata.clone(),
170 }),
171 metadata: metadata.clone()
172 },
173
174 node => node,
175 }
176 };
177}
178
179pub fn persist_pullup(ir: &mut [HydroLeaf]) {
180 let mut persist_pulled_tees = Default::default();
181 transform_bottom_up(ir, &mut |_| (), &mut |node| {
182 persist_pullup_node(node, &mut persist_pulled_tees)
183 });
184}
185
186#[cfg(test)]
187mod tests {
188 use stageleft::*;
189
190 use crate::deploy::MultiGraph;
191 use crate::location::Location;
192
193 #[test]
194 fn persist_pullup_through_map() {
195 let flow = crate::builder::FlowBuilder::new();
196 let process = flow.process::<()>();
197
198 process
199 .source_iter(q!(0..10))
200 .map(q!(|v| v + 1))
201 .for_each(q!(|n| println!("{}", n)));
202
203 let built = flow.finalize();
204
205 insta::assert_debug_snapshot!(built.ir());
206
207 let optimized = built.optimize_with(super::persist_pullup);
208
209 insta::assert_debug_snapshot!(optimized.ir());
210 for (id, graph) in optimized.compile_no_network::<MultiGraph>().all_dfir() {
211 insta::with_settings!({snapshot_suffix => format!("surface_graph_{id}")}, {
212 insta::assert_snapshot!(graph.surface_syntax_string());
213 });
214 }
215 }
216
217 #[test]
218 fn persist_pullup_behind_tee() {
219 let flow = crate::builder::FlowBuilder::new();
220 let process = flow.process::<()>();
221
222 let tick = process.tick();
223 let before_tee = unsafe { process.source_iter(q!(0..10)).tick_batch(&tick).persist() };
224
225 before_tee
226 .clone()
227 .map(q!(|v| v + 1))
228 .all_ticks()
229 .for_each(q!(|n| println!("{}", n)));
230
231 before_tee
232 .clone()
233 .map(q!(|v| v + 1))
234 .all_ticks()
235 .for_each(q!(|n| println!("{}", n)));
236
237 let built = flow.finalize();
238
239 insta::assert_debug_snapshot!(built.ir());
240
241 let optimized = built.optimize_with(super::persist_pullup);
242
243 insta::assert_debug_snapshot!(optimized.ir());
244
245 for (id, graph) in optimized.compile_no_network::<MultiGraph>().all_dfir() {
246 insta::with_settings!({snapshot_suffix => format!("surface_graph_{id}")}, {
247 insta::assert_snapshot!(graph.surface_syntax_string());
248 });
249 }
250 }
251}