hydro_lang/rewrites/
persist_pullup.rs

1use std::cell::RefCell;
2use std::collections::HashSet;
3
4use crate::ir::*;
5
6fn persist_pullup_node(
7    node: &mut HydroNode,
8    persist_pulled_tees: &mut HashSet<*const RefCell<HydroNode>>,
9) {
10    *node = match_box::match_box! {
11        match std::mem::replace(node, HydroNode::Placeholder) {
12            HydroNode::Unpersist { inner: mb!(* HydroNode::Persist { inner: mb!(* behind_persist), .. }), .. } => behind_persist,
13
14            HydroNode::Delta { inner: mb!(* HydroNode::Persist { inner: mb!(* behind_persist), .. }), .. } => behind_persist,
15
16            // TODO: Figure out if persist needs to copy its metadata or can just use original metadata here. If it can just use original, figure out where that is
17            HydroNode::Tee { inner, metadata } => {
18                if persist_pulled_tees.contains(&(inner.0.as_ref() as *const RefCell<HydroNode>)) {
19                    HydroNode::Persist {
20                        inner: Box::new(HydroNode::Tee {
21                            inner: TeeNode(inner.0.clone()),
22                            metadata: metadata.clone(),
23                        }),
24                        metadata: metadata.clone(),
25                    }
26                } else if matches!(*inner.0.borrow(), HydroNode::Persist { .. }) {
27                    persist_pulled_tees.insert(inner.0.as_ref() as *const RefCell<HydroNode>);
28                    if let HydroNode::Persist { inner: behind_persist, .. } =
29                        inner.0.replace(HydroNode::Placeholder)
30                    {
31                        *inner.0.borrow_mut() = *behind_persist;
32                    } else {
33                        unreachable!()
34                    }
35
36                    HydroNode::Persist {
37                        inner: Box::new(HydroNode::Tee {
38                            inner: TeeNode(inner.0.clone()),
39                            metadata: metadata.clone(),
40                        }),
41                        metadata: metadata.clone(),
42                    }
43                } else {
44                    HydroNode::Tee { inner, metadata }
45                }
46            }
47
48            HydroNode::ResolveFutures {
49                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
50                metadata,
51            } => HydroNode::Persist {
52                inner: Box::new(HydroNode::ResolveFutures {
53                    input: behind_persist,
54                    metadata: metadata.clone(),
55                }),
56                metadata: metadata.clone(),
57            },
58
59            HydroNode::ResolveFuturesOrdered {
60                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
61                metadata,
62            } => HydroNode::Persist {
63                inner: Box::new(HydroNode::ResolveFuturesOrdered {
64                    input: behind_persist,
65                    metadata: metadata.clone(),
66                }),
67                metadata: metadata.clone(),
68            },
69
70            HydroNode::Map {
71                f,
72                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
73                metadata,
74            } => HydroNode::Persist {
75                inner: Box::new(HydroNode::Map {
76                    f,
77                    input: behind_persist,
78                    metadata: metadata.clone(),
79                }),
80                metadata: metadata.clone(),
81            },
82
83            HydroNode::FilterMap {
84                f,
85                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
86                metadata,
87            } => HydroNode::Persist {
88                inner: Box::new(HydroNode::FilterMap {
89                    f,
90                    input: behind_persist,
91                    metadata: metadata.clone(),
92                }),
93                metadata: metadata.clone()
94            },
95
96            HydroNode::FlatMap {
97                f,
98                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
99                metadata,
100            } => HydroNode::Persist {
101                inner: Box::new(HydroNode::FlatMap {
102                    f,
103                    input: behind_persist,
104                    metadata: metadata.clone(),
105                }),
106                metadata: metadata.clone()
107            },
108
109            HydroNode::Filter {
110                f,
111                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
112                metadata,
113            } => HydroNode::Persist {
114                inner: Box::new(HydroNode::Filter {
115                    f,
116                    input: behind_persist,
117                    metadata: metadata.clone(),
118                }),
119                metadata: metadata.clone()
120            },
121
122            HydroNode::Network {
123                serialize_fn,
124                instantiate_fn,
125                deserialize_fn,
126                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
127                metadata,
128            } => HydroNode::Persist {
129                inner: Box::new(HydroNode::Network {
130                    serialize_fn,
131                    instantiate_fn,
132                    deserialize_fn,
133                    input: behind_persist,
134                    metadata: metadata.clone()
135                }),
136                metadata: metadata.clone(),
137            },
138
139            HydroNode::Chain {
140                first: mb!(* HydroNode::Persist { inner: first, metadata: persist_metadata }),
141                second: mb!(* HydroNode::Persist { inner: second, .. }),
142                metadata
143            } => HydroNode::Persist {
144                inner: Box::new(HydroNode::Chain { first, second, metadata }),
145                metadata: persist_metadata
146            },
147
148            HydroNode::CrossProduct {
149                left: mb!(* HydroNode::Persist { inner: left, metadata: left_metadata }),
150                right: mb!(* HydroNode::Persist { inner: right, metadata: right_metadata }),
151                metadata
152            } => HydroNode::Persist {
153                inner: Box::new(HydroNode::Delta {
154                    inner: Box::new(HydroNode::CrossProduct {
155                        left: Box::new(HydroNode::Persist { inner: left, metadata: left_metadata }),
156                        right: Box::new(HydroNode::Persist { inner: right, metadata: right_metadata }),
157                        metadata: metadata.clone()
158                    }),
159                    metadata: metadata.clone(),
160                }),
161                metadata: metadata.clone(),
162            },
163            HydroNode::Join {
164                left: mb!(* HydroNode::Persist { inner: left, metadata: left_metadata }),
165                right: mb!(* HydroNode::Persist { inner: right, metadata: right_metadata }),
166                metadata
167             } => HydroNode::Persist {
168                inner: Box::new(HydroNode::Delta {
169                    inner: Box::new(HydroNode::Join {
170                        left: Box::new(HydroNode::Persist { inner: left, metadata: left_metadata }),
171                        right: Box::new(HydroNode::Persist { inner: right, metadata: right_metadata }),
172                        metadata: metadata.clone()
173                    }),
174                    metadata: metadata.clone(),
175                }),
176                metadata: metadata.clone(),
177            },
178
179            HydroNode::Unique { input: mb!(* HydroNode::Persist {inner, metadata: persist_metadata } ), metadata } => HydroNode::Persist {
180                inner: Box::new(HydroNode::Delta {
181                    inner: Box::new(HydroNode::Unique {
182                        input: Box::new(HydroNode::Persist { inner, metadata: persist_metadata }),
183                        metadata: metadata.clone()
184                    }),
185                    metadata: metadata.clone(),
186                }),
187                metadata: metadata.clone()
188            },
189
190            node => node,
191        }
192    };
193}
194
195pub fn persist_pullup(ir: &mut [HydroRoot]) {
196    let mut persist_pulled_tees = Default::default();
197    transform_bottom_up(
198        ir,
199        &mut |_| (),
200        &mut |node| persist_pullup_node(node, &mut persist_pulled_tees),
201        false,
202    );
203}
204
205#[cfg(stageleft_runtime)]
206#[cfg(test)]
207mod tests {
208    use stageleft::*;
209
210    use crate::deploy::HydroDeploy;
211    use crate::location::Location;
212    use crate::nondet;
213
214    #[test]
215    fn persist_pullup_through_map() {
216        let flow = crate::builder::FlowBuilder::new();
217        let process = flow.process::<()>();
218
219        process
220            .source_iter(q!(0..10))
221            .map(q!(|v| v + 1))
222            .for_each(q!(|n| println!("{}", n)));
223
224        let built = flow.finalize();
225
226        hydro_build_utils::assert_debug_snapshot!(built.ir());
227
228        let optimized = built.optimize_with(super::persist_pullup);
229
230        hydro_build_utils::assert_debug_snapshot!(optimized.ir());
231        for (id, graph) in optimized
232            .into_deploy::<HydroDeploy>()
233            .preview_compile()
234            .all_dfir()
235        {
236            hydro_build_utils::insta::with_settings!({
237                snapshot_suffix => format!("surface_graph_{id}"),
238            }, {
239                hydro_build_utils::assert_snapshot!(graph.surface_syntax_string());
240            });
241        }
242    }
243
244    #[test]
245    fn persist_pullup_behind_tee() {
246        let flow = crate::builder::FlowBuilder::new();
247        let process = flow.process::<()>();
248
249        let tick = process.tick();
250        let before_tee = process
251            .source_iter(q!(0..10))
252            .batch(&tick, nondet!(/** test */))
253            .persist();
254
255        before_tee
256            .clone()
257            .map(q!(|v| v + 1))
258            .all_ticks()
259            .for_each(q!(|n| println!("{}", n)));
260
261        before_tee
262            .clone()
263            .map(q!(|v| v + 1))
264            .all_ticks()
265            .for_each(q!(|n| println!("{}", n)));
266
267        let built = flow.finalize();
268
269        hydro_build_utils::assert_debug_snapshot!(built.ir());
270
271        let optimized = built.optimize_with(super::persist_pullup);
272
273        hydro_build_utils::assert_debug_snapshot!(optimized.ir());
274
275        for (id, graph) in optimized
276            .into_deploy::<HydroDeploy>()
277            .preview_compile()
278            .all_dfir()
279        {
280            hydro_build_utils::insta::with_settings!({
281                snapshot_suffix => format!("surface_graph_{id}"),
282            }, {
283                hydro_build_utils::assert_snapshot!(graph.surface_syntax_string());
284            });
285        }
286    }
287}