hydro_lang/rewrites/
persist_pullup.rs

1use std::cell::RefCell;
2use std::collections::HashSet;
3
4use crate::ir::*;
5
6fn persist_pullup_node(
7    node: &mut HydroNode,
8    persist_pulled_tees: &mut HashSet<*const RefCell<HydroNode>>,
9) {
10    *node = match_box::match_box! {
11        match std::mem::replace(node, HydroNode::Placeholder) {
12            HydroNode::Unpersist { inner: mb!(* HydroNode::Persist { inner: mb!(* behind_persist), .. }), .. } => behind_persist,
13
14            HydroNode::Delta { inner: mb!(* HydroNode::Persist { inner: mb!(* behind_persist), .. }), .. } => behind_persist,
15
16            // TODO: Figure out if persist needs to copy its metadata or can just use original metadata here. If it can just use original, figure out where that is
17            HydroNode::Tee { inner, metadata } => {
18                if persist_pulled_tees.contains(&(inner.0.as_ref() as *const RefCell<HydroNode>)) {
19                    HydroNode::Persist {
20                        inner: Box::new(HydroNode::Tee {
21                            inner: TeeNode(inner.0.clone()),
22                            metadata: metadata.clone(),
23                        }),
24                        metadata: metadata.clone(),
25                    }
26                } else if matches!(*inner.0.borrow(), HydroNode::Persist { .. }) {
27                    persist_pulled_tees.insert(inner.0.as_ref() as *const RefCell<HydroNode>);
28                    if let HydroNode::Persist { inner: behind_persist, .. } =
29                        inner.0.replace(HydroNode::Placeholder)
30                    {
31                        *inner.0.borrow_mut() = *behind_persist;
32                    } else {
33                        unreachable!()
34                    }
35
36                    HydroNode::Persist {
37                        inner: Box::new(HydroNode::Tee {
38                            inner: TeeNode(inner.0.clone()),
39                            metadata: metadata.clone(),
40                        }),
41                        metadata: metadata.clone(),
42                    }
43                } else {
44                    HydroNode::Tee { inner, metadata }
45                }
46            }
47
48            HydroNode::Map {
49                f,
50                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
51                metadata,
52            } => HydroNode::Persist {
53                inner: Box::new(HydroNode::Map {
54                    f,
55                    input: behind_persist,
56                    metadata: metadata.clone(),
57                }),
58                metadata: metadata.clone(),
59            },
60
61            HydroNode::FilterMap {
62                f,
63                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
64                metadata,
65            } => HydroNode::Persist {
66                inner: Box::new(HydroNode::FilterMap {
67                    f,
68                    input: behind_persist,
69                    metadata: metadata.clone(),
70                }),
71                metadata: metadata.clone()
72            },
73
74            HydroNode::FlatMap {
75                f,
76                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
77                metadata,
78            } => HydroNode::Persist {
79                inner: Box::new(HydroNode::FlatMap {
80                    f,
81                    input: behind_persist,
82                    metadata: metadata.clone(),
83                }),
84                metadata: metadata.clone()
85            },
86
87            HydroNode::Filter {
88                f,
89                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
90                metadata,
91            } => HydroNode::Persist {
92                inner: Box::new(HydroNode::Filter {
93                    f,
94                    input: behind_persist,
95                    metadata: metadata.clone(),
96                }),
97                metadata: metadata.clone()
98            },
99
100            HydroNode::Network {
101                from_key,
102                to_location,
103                to_key,
104                serialize_fn,
105                instantiate_fn,
106                deserialize_fn,
107                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
108                metadata,
109            } => HydroNode::Persist {
110                inner: Box::new(HydroNode::Network {
111                    from_key,
112                    to_location,
113                    to_key,
114                    serialize_fn,
115                    instantiate_fn,
116                    deserialize_fn,
117                    input: behind_persist,
118                    metadata: metadata.clone()
119                }),
120                metadata: metadata.clone(),
121            },
122
123            HydroNode::Chain {
124                first: mb!(* HydroNode::Persist { inner: first, metadata: persist_metadata }),
125                second: mb!(* HydroNode::Persist { inner: second, .. }),
126                metadata
127            } => HydroNode::Persist {
128                inner: Box::new(HydroNode::Chain { first, second, metadata }),
129                metadata: persist_metadata
130            },
131
132            HydroNode::CrossProduct {
133                left: mb!(* HydroNode::Persist { inner: left, metadata: left_metadata }),
134                right: mb!(* HydroNode::Persist { inner: right, metadata: right_metadata }),
135                metadata
136            } => HydroNode::Persist {
137                inner: Box::new(HydroNode::Delta {
138                    inner: Box::new(HydroNode::CrossProduct {
139                        left: Box::new(HydroNode::Persist { inner: left, metadata: left_metadata }),
140                        right: Box::new(HydroNode::Persist { inner: right, metadata: right_metadata }),
141                        metadata: metadata.clone()
142                    }),
143                    metadata: metadata.clone(),
144                }),
145                metadata: metadata.clone(),
146            },
147            HydroNode::Join {
148                left: mb!(* HydroNode::Persist { inner: left, metadata: left_metadata }),
149                right: mb!(* HydroNode::Persist { inner: right, metadata: right_metadata }),
150                metadata
151             } => HydroNode::Persist {
152                inner: Box::new(HydroNode::Delta {
153                    inner: Box::new(HydroNode::Join {
154                        left: Box::new(HydroNode::Persist { inner: left, metadata: left_metadata }),
155                        right: Box::new(HydroNode::Persist { inner: right, metadata: right_metadata }),
156                        metadata: metadata.clone()
157                    }),
158                    metadata: metadata.clone(),
159                }),
160                metadata: metadata.clone(),
161            },
162
163            HydroNode::Unique { input: mb!(* HydroNode::Persist {inner, metadata: persist_metadata } ), metadata } => HydroNode::Persist {
164                inner: Box::new(HydroNode::Delta {
165                    inner: Box::new(HydroNode::Unique {
166                        input: Box::new(HydroNode::Persist { inner, metadata: persist_metadata }),
167                        metadata: metadata.clone()
168                    }),
169                    metadata: metadata.clone(),
170                }),
171                metadata: metadata.clone()
172            },
173
174            node => node,
175        }
176    };
177}
178
179pub fn persist_pullup(ir: &mut [HydroLeaf]) {
180    let mut persist_pulled_tees = Default::default();
181    transform_bottom_up(ir, &mut |_| (), &mut |node| {
182        persist_pullup_node(node, &mut persist_pulled_tees)
183    });
184}
185
186#[cfg(test)]
187mod tests {
188    use stageleft::*;
189
190    use crate::deploy::MultiGraph;
191    use crate::location::Location;
192
193    #[test]
194    fn persist_pullup_through_map() {
195        let flow = crate::builder::FlowBuilder::new();
196        let process = flow.process::<()>();
197
198        process
199            .source_iter(q!(0..10))
200            .map(q!(|v| v + 1))
201            .for_each(q!(|n| println!("{}", n)));
202
203        let built = flow.finalize();
204
205        insta::assert_debug_snapshot!(built.ir());
206
207        let optimized = built.optimize_with(super::persist_pullup);
208
209        insta::assert_debug_snapshot!(optimized.ir());
210        for (id, graph) in optimized.compile_no_network::<MultiGraph>().all_dfir() {
211            insta::with_settings!({snapshot_suffix => format!("surface_graph_{id}")}, {
212                insta::assert_snapshot!(graph.surface_syntax_string());
213            });
214        }
215    }
216
217    #[test]
218    fn persist_pullup_behind_tee() {
219        let flow = crate::builder::FlowBuilder::new();
220        let process = flow.process::<()>();
221
222        let tick = process.tick();
223        let before_tee = unsafe { process.source_iter(q!(0..10)).tick_batch(&tick).persist() };
224
225        before_tee
226            .clone()
227            .map(q!(|v| v + 1))
228            .all_ticks()
229            .for_each(q!(|n| println!("{}", n)));
230
231        before_tee
232            .clone()
233            .map(q!(|v| v + 1))
234            .all_ticks()
235            .for_each(q!(|n| println!("{}", n)));
236
237        let built = flow.finalize();
238
239        insta::assert_debug_snapshot!(built.ir());
240
241        let optimized = built.optimize_with(super::persist_pullup);
242
243        insta::assert_debug_snapshot!(optimized.ir());
244
245        for (id, graph) in optimized.compile_no_network::<MultiGraph>().all_dfir() {
246            insta::with_settings!({snapshot_suffix => format!("surface_graph_{id}")}, {
247                insta::assert_snapshot!(graph.surface_syntax_string());
248            });
249        }
250    }
251}