hydro_lang/rewrites/
persist_pullup.rs

1use std::cell::RefCell;
2use std::collections::HashSet;
3
4use crate::ir::*;
5
6fn persist_pullup_node(
7    node: &mut HydroNode,
8    persist_pulled_tees: &mut HashSet<*const RefCell<HydroNode>>,
9) {
10    *node = match_box::match_box! {
11        match std::mem::replace(node, HydroNode::Placeholder) {
12            HydroNode::Unpersist { inner: mb!(* HydroNode::Persist { inner: mb!(* behind_persist), .. }), .. } => behind_persist,
13
14            HydroNode::Delta { inner: mb!(* HydroNode::Persist { inner: mb!(* behind_persist), .. }), .. } => behind_persist,
15
16            // TODO: Figure out if persist needs to copy its metadata or can just use original metadata here. If it can just use original, figure out where that is
17            HydroNode::Tee { inner, metadata } => {
18                if persist_pulled_tees.contains(&(inner.0.as_ref() as *const RefCell<HydroNode>)) {
19                    HydroNode::Persist {
20                        inner: Box::new(HydroNode::Tee {
21                            inner: TeeNode(inner.0.clone()),
22                            metadata: metadata.clone(),
23                        }),
24                        metadata: metadata.clone(),
25                    }
26                } else if matches!(*inner.0.borrow(), HydroNode::Persist { .. }) {
27                    persist_pulled_tees.insert(inner.0.as_ref() as *const RefCell<HydroNode>);
28                    if let HydroNode::Persist { inner: behind_persist, .. } =
29                        inner.0.replace(HydroNode::Placeholder)
30                    {
31                        *inner.0.borrow_mut() = *behind_persist;
32                    } else {
33                        unreachable!()
34                    }
35
36                    HydroNode::Persist {
37                        inner: Box::new(HydroNode::Tee {
38                            inner: TeeNode(inner.0.clone()),
39                            metadata: metadata.clone(),
40                        }),
41                        metadata: metadata.clone(),
42                    }
43                } else {
44                    HydroNode::Tee { inner, metadata }
45                }
46            }
47
48            HydroNode::ResolveFutures {
49                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
50                metadata,
51            } => HydroNode::Persist {
52                inner: Box::new(HydroNode::ResolveFutures {
53                    input: behind_persist,
54                    metadata: metadata.clone(),
55                }),
56                metadata: metadata.clone(),
57            },
58
59            HydroNode::ResolveFuturesOrdered {
60                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
61                metadata,
62            } => HydroNode::Persist {
63                inner: Box::new(HydroNode::ResolveFuturesOrdered {
64                    input: behind_persist,
65                    metadata: metadata.clone(),
66                }),
67                metadata: metadata.clone(),
68            },
69
70            HydroNode::Map {
71                f,
72                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
73                metadata,
74            } => HydroNode::Persist {
75                inner: Box::new(HydroNode::Map {
76                    f,
77                    input: behind_persist,
78                    metadata: metadata.clone(),
79                }),
80                metadata: metadata.clone(),
81            },
82
83            HydroNode::FilterMap {
84                f,
85                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
86                metadata,
87            } => HydroNode::Persist {
88                inner: Box::new(HydroNode::FilterMap {
89                    f,
90                    input: behind_persist,
91                    metadata: metadata.clone(),
92                }),
93                metadata: metadata.clone()
94            },
95
96            HydroNode::FlatMap {
97                f,
98                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
99                metadata,
100            } => HydroNode::Persist {
101                inner: Box::new(HydroNode::FlatMap {
102                    f,
103                    input: behind_persist,
104                    metadata: metadata.clone(),
105                }),
106                metadata: metadata.clone()
107            },
108
109            HydroNode::Filter {
110                f,
111                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
112                metadata,
113            } => HydroNode::Persist {
114                inner: Box::new(HydroNode::Filter {
115                    f,
116                    input: behind_persist,
117                    metadata: metadata.clone(),
118                }),
119                metadata: metadata.clone()
120            },
121
122            HydroNode::Network {
123                from_key,
124                to_location,
125                to_key,
126                serialize_fn,
127                instantiate_fn,
128                deserialize_fn,
129                input: mb!(* HydroNode::Persist { inner: behind_persist, .. }),
130                metadata,
131            } => HydroNode::Persist {
132                inner: Box::new(HydroNode::Network {
133                    from_key,
134                    to_location,
135                    to_key,
136                    serialize_fn,
137                    instantiate_fn,
138                    deserialize_fn,
139                    input: behind_persist,
140                    metadata: metadata.clone()
141                }),
142                metadata: metadata.clone(),
143            },
144
145            HydroNode::Chain {
146                first: mb!(* HydroNode::Persist { inner: first, metadata: persist_metadata }),
147                second: mb!(* HydroNode::Persist { inner: second, .. }),
148                metadata
149            } => HydroNode::Persist {
150                inner: Box::new(HydroNode::Chain { first, second, metadata }),
151                metadata: persist_metadata
152            },
153
154            HydroNode::CrossProduct {
155                left: mb!(* HydroNode::Persist { inner: left, metadata: left_metadata }),
156                right: mb!(* HydroNode::Persist { inner: right, metadata: right_metadata }),
157                metadata
158            } => HydroNode::Persist {
159                inner: Box::new(HydroNode::Delta {
160                    inner: Box::new(HydroNode::CrossProduct {
161                        left: Box::new(HydroNode::Persist { inner: left, metadata: left_metadata }),
162                        right: Box::new(HydroNode::Persist { inner: right, metadata: right_metadata }),
163                        metadata: metadata.clone()
164                    }),
165                    metadata: metadata.clone(),
166                }),
167                metadata: metadata.clone(),
168            },
169            HydroNode::Join {
170                left: mb!(* HydroNode::Persist { inner: left, metadata: left_metadata }),
171                right: mb!(* HydroNode::Persist { inner: right, metadata: right_metadata }),
172                metadata
173             } => HydroNode::Persist {
174                inner: Box::new(HydroNode::Delta {
175                    inner: Box::new(HydroNode::Join {
176                        left: Box::new(HydroNode::Persist { inner: left, metadata: left_metadata }),
177                        right: Box::new(HydroNode::Persist { inner: right, metadata: right_metadata }),
178                        metadata: metadata.clone()
179                    }),
180                    metadata: metadata.clone(),
181                }),
182                metadata: metadata.clone(),
183            },
184
185            HydroNode::Unique { input: mb!(* HydroNode::Persist {inner, metadata: persist_metadata } ), metadata } => HydroNode::Persist {
186                inner: Box::new(HydroNode::Delta {
187                    inner: Box::new(HydroNode::Unique {
188                        input: Box::new(HydroNode::Persist { inner, metadata: persist_metadata }),
189                        metadata: metadata.clone()
190                    }),
191                    metadata: metadata.clone(),
192                }),
193                metadata: metadata.clone()
194            },
195
196            node => node,
197        }
198    };
199}
200
201pub fn persist_pullup(ir: &mut [HydroLeaf]) {
202    let mut persist_pulled_tees = Default::default();
203    transform_bottom_up(ir, &mut |_| (), &mut |node| {
204        persist_pullup_node(node, &mut persist_pulled_tees)
205    });
206}
207
208#[cfg(test)]
209mod tests {
210    use stageleft::*;
211
212    use crate::deploy::MultiGraph;
213    use crate::location::Location;
214
215    #[test]
216    fn persist_pullup_through_map() {
217        let flow = crate::builder::FlowBuilder::new();
218        let process = flow.process::<()>();
219
220        process
221            .source_iter(q!(0..10))
222            .map(q!(|v| v + 1))
223            .for_each(q!(|n| println!("{}", n)));
224
225        let built = flow.finalize();
226
227        insta::assert_debug_snapshot!(built.ir());
228
229        let optimized = built.optimize_with(super::persist_pullup);
230
231        insta::assert_debug_snapshot!(optimized.ir());
232        for (id, graph) in optimized.compile_no_network::<MultiGraph>().all_dfir() {
233            insta::with_settings!({snapshot_suffix => format!("surface_graph_{id}")}, {
234                insta::assert_snapshot!(graph.surface_syntax_string());
235            });
236        }
237    }
238
239    #[test]
240    fn persist_pullup_behind_tee() {
241        let flow = crate::builder::FlowBuilder::new();
242        let process = flow.process::<()>();
243
244        let tick = process.tick();
245        let before_tee = unsafe { process.source_iter(q!(0..10)).tick_batch(&tick).persist() };
246
247        before_tee
248            .clone()
249            .map(q!(|v| v + 1))
250            .all_ticks()
251            .for_each(q!(|n| println!("{}", n)));
252
253        before_tee
254            .clone()
255            .map(q!(|v| v + 1))
256            .all_ticks()
257            .for_each(q!(|n| println!("{}", n)));
258
259        let built = flow.finalize();
260
261        insta::assert_debug_snapshot!(built.ir());
262
263        let optimized = built.optimize_with(super::persist_pullup);
264
265        insta::assert_debug_snapshot!(optimized.ir());
266
267        for (id, graph) in optimized.compile_no_network::<MultiGraph>().all_dfir() {
268            insta::with_settings!({snapshot_suffix => format!("surface_graph_{id}")}, {
269                insta::assert_snapshot!(graph.surface_syntax_string());
270            });
271        }
272    }
273}