hydro_lang/builder/
deploy.rs

1use std::cell::UnsafeCell;
2use std::collections::{BTreeMap, HashMap};
3use std::io::Error;
4use std::marker::PhantomData;
5use std::pin::Pin;
6
7use bytes::Bytes;
8use futures::{Sink, Stream};
9use proc_macro2::Span;
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use stageleft::QuotedWithContext;
13
14use super::built::build_inner;
15use super::compiled::CompiledFlow;
16use crate::deploy::{
17    ClusterSpec, Deploy, ExternalSpec, IntoProcessSpec, LocalDeploy, Node, ProcessSpec,
18    RegisterPort,
19};
20use crate::ir::HydroLeaf;
21use crate::location::external_process::{
22    ExternalBincodeSink, ExternalBincodeStream, ExternalBytesPort,
23};
24use crate::location::{Cluster, ExternalProcess, Location, LocationId, Process};
25use crate::staging_util::Invariant;
26
27pub struct DeployFlow<'a, D>
28where
29    D: LocalDeploy<'a>,
30{
31    // We need to grab an `&mut` reference to the IR in `preview_compile` even though
32    // that function does not modify the IR. Using an `UnsafeCell` allows us to do this
33    // while still being able to lend out immutable references to the IR.
34    pub(super) ir: UnsafeCell<Vec<HydroLeaf>>,
35
36    /// Deployed instances of each process in the flow
37    pub(super) processes: HashMap<usize, D::Process>,
38
39    /// Lists all the processes that were created in the flow, same ID as `processes`
40    /// but with the type name of the tag.
41    pub(super) process_id_name: Vec<(usize, String)>,
42
43    pub(super) externals: HashMap<usize, D::ExternalProcess>,
44    pub(super) external_id_name: Vec<(usize, String)>,
45
46    pub(super) clusters: HashMap<usize, D::Cluster>,
47    pub(super) cluster_id_name: Vec<(usize, String)>,
48    pub(super) used: bool,
49
50    pub(super) _phantom: Invariant<'a, D>,
51}
52
53impl<'a, D: LocalDeploy<'a>> Drop for DeployFlow<'a, D> {
54    fn drop(&mut self) {
55        if !self.used {
56            panic!(
57                "Dropped DeployFlow without instantiating, you may have forgotten to call `compile` or `deploy`."
58            );
59        }
60    }
61}
62
63impl<'a, D: LocalDeploy<'a>> DeployFlow<'a, D> {
64    pub fn ir(&self) -> &Vec<HydroLeaf> {
65        unsafe {
66            // SAFETY: even when we grab this as mutable in `preview_compile`, we do not modify it
67            &*self.ir.get()
68        }
69    }
70
71    pub fn with_process<P>(
72        mut self,
73        process: &Process<P>,
74        spec: impl IntoProcessSpec<'a, D>,
75    ) -> Self {
76        let tag_name = std::any::type_name::<P>().to_string();
77        self.processes.insert(
78            process.id,
79            spec.into_process_spec().build(process.id, &tag_name),
80        );
81        self
82    }
83
84    pub fn with_remaining_processes<S: IntoProcessSpec<'a, D> + 'a>(
85        mut self,
86        spec: impl Fn() -> S,
87    ) -> Self {
88        for (id, name) in &self.process_id_name {
89            self.processes
90                .insert(*id, spec().into_process_spec().build(*id, name));
91        }
92
93        self
94    }
95
96    pub fn with_external<P>(
97        mut self,
98        process: &ExternalProcess<P>,
99        spec: impl ExternalSpec<'a, D>,
100    ) -> Self {
101        let tag_name = std::any::type_name::<P>().to_string();
102        self.externals
103            .insert(process.id, spec.build(process.id, &tag_name));
104        self
105    }
106
107    pub fn with_remaining_externals<S: ExternalSpec<'a, D> + 'a>(
108        mut self,
109        spec: impl Fn() -> S,
110    ) -> Self {
111        for (id, name) in &self.external_id_name {
112            self.externals.insert(*id, spec().build(*id, name));
113        }
114
115        self
116    }
117
118    pub fn with_cluster<C>(mut self, cluster: &Cluster<C>, spec: impl ClusterSpec<'a, D>) -> Self {
119        let tag_name = std::any::type_name::<C>().to_string();
120        self.clusters
121            .insert(cluster.id, spec.build(cluster.id, &tag_name));
122        self
123    }
124
125    pub fn with_remaining_clusters<S: ClusterSpec<'a, D> + 'a>(
126        mut self,
127        spec: impl Fn() -> S,
128    ) -> Self {
129        for (id, name) in &self.cluster_id_name {
130            self.clusters.insert(*id, spec().build(*id, name));
131        }
132
133        self
134    }
135
136    /// Compiles the flow into DFIR using placeholders for the network.
137    /// Useful for generating Mermaid diagrams of the DFIR.
138    pub fn preview_compile(&self) -> CompiledFlow<'a, ()> {
139        CompiledFlow {
140            dfir: build_inner(unsafe {
141                // SAFETY: `build_inner` does not mutate the IR, &mut is required
142                // only because the shared traversal logic requires it
143                &mut *self.ir.get()
144            }),
145            #[cfg(feature = "staged_macro")]
146            extra_stmts: BTreeMap::new(),
147            _phantom: PhantomData,
148        }
149    }
150
151    pub fn compile_no_network(mut self) -> CompiledFlow<'a, D::GraphId> {
152        self.used = true;
153
154        CompiledFlow {
155            dfir: build_inner(self.ir.get_mut()),
156            #[cfg(feature = "staged_macro")]
157            extra_stmts: BTreeMap::new(),
158            _phantom: PhantomData,
159        }
160    }
161}
162
163impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
164    pub fn compile(mut self, env: &D::CompileEnv) -> CompiledFlow<'a, D::GraphId> {
165        self.used = true;
166
167        let mut seen_tees: HashMap<_, _> = HashMap::new();
168        let mut seen_tee_locations: HashMap<_, _> = HashMap::new();
169        self.ir.get_mut().iter_mut().for_each(|leaf| {
170            leaf.compile_network::<D>(
171                env,
172                &mut seen_tees,
173                &mut seen_tee_locations,
174                &self.processes,
175                &self.clusters,
176                &self.externals,
177            );
178        });
179
180        #[cfg(feature = "staged_macro")]
181        let extra_stmts = self.extra_stmts(env);
182
183        CompiledFlow {
184            dfir: build_inner(self.ir.get_mut()),
185            #[cfg(feature = "staged_macro")]
186            extra_stmts,
187            _phantom: PhantomData,
188        }
189    }
190
191    fn extra_stmts(&self, env: &<D as Deploy<'a>>::CompileEnv) -> BTreeMap<usize, Vec<syn::Stmt>> {
192        let mut extra_stmts: BTreeMap<usize, Vec<syn::Stmt>> = BTreeMap::new();
193
194        let mut all_clusters_sorted = self.clusters.keys().collect::<Vec<_>>();
195        all_clusters_sorted.sort();
196
197        for &c_id in all_clusters_sorted {
198            let self_id_ident = syn::Ident::new(
199                &format!("__hydro_lang_cluster_self_id_{}", c_id),
200                Span::call_site(),
201            );
202            let self_id_expr = D::cluster_self_id(env).splice_untyped();
203            extra_stmts
204                .entry(c_id)
205                .or_default()
206                .push(syn::parse_quote! {
207                    let #self_id_ident = #self_id_expr;
208                });
209
210            for other_location in self.processes.keys().chain(self.clusters.keys()) {
211                let other_id_ident = syn::Ident::new(
212                    &format!("__hydro_lang_cluster_ids_{}", c_id),
213                    Span::call_site(),
214                );
215                let other_id_expr = D::cluster_ids(env, c_id).splice_untyped();
216                extra_stmts
217                    .entry(*other_location)
218                    .or_default()
219                    .push(syn::parse_quote! {
220                        let #other_id_ident = #other_id_expr;
221                    });
222            }
223        }
224        extra_stmts
225    }
226}
227
228impl<'a, D: Deploy<'a, CompileEnv = ()>> DeployFlow<'a, D> {
229    #[must_use]
230    pub fn deploy(mut self, env: &mut D::InstantiateEnv) -> DeployResult<'a, D> {
231        self.used = true;
232
233        let mut seen_tees_instantiate: HashMap<_, _> = HashMap::new();
234        let mut seen_tee_locations: HashMap<_, _> = HashMap::new();
235        self.ir.get_mut().iter_mut().for_each(|leaf| {
236            leaf.compile_network::<D>(
237                &(),
238                &mut seen_tees_instantiate,
239                &mut seen_tee_locations,
240                &self.processes,
241                &self.clusters,
242                &self.externals,
243            );
244        });
245
246        let mut compiled = build_inner(self.ir.get_mut());
247        let mut extra_stmts = self.extra_stmts(&());
248        let mut meta = D::Meta::default();
249
250        let (mut processes, mut clusters, mut externals) = (
251            std::mem::take(&mut self.processes)
252                .into_iter()
253                .filter_map(|(node_id, node)| {
254                    if let Some(ir) = compiled.remove(&node_id) {
255                        node.instantiate(
256                            env,
257                            &mut meta,
258                            ir,
259                            extra_stmts.remove(&node_id).unwrap_or_default(),
260                        );
261                        Some((node_id, node))
262                    } else {
263                        None
264                    }
265                })
266                .collect::<HashMap<_, _>>(),
267            std::mem::take(&mut self.clusters)
268                .into_iter()
269                .filter_map(|(cluster_id, cluster)| {
270                    if let Some(ir) = compiled.remove(&cluster_id) {
271                        cluster.instantiate(
272                            env,
273                            &mut meta,
274                            ir,
275                            extra_stmts.remove(&cluster_id).unwrap_or_default(),
276                        );
277                        Some((cluster_id, cluster))
278                    } else {
279                        None
280                    }
281                })
282                .collect::<HashMap<_, _>>(),
283            std::mem::take(&mut self.externals)
284                .into_iter()
285                .map(|(external_id, external)| {
286                    external.instantiate(
287                        env,
288                        &mut meta,
289                        compiled.remove(&external_id).unwrap(),
290                        extra_stmts.remove(&external_id).unwrap_or_default(),
291                    );
292                    (external_id, external)
293                })
294                .collect::<HashMap<_, _>>(),
295        );
296
297        for node in processes.values_mut() {
298            node.update_meta(&meta);
299        }
300
301        for cluster in clusters.values_mut() {
302            cluster.update_meta(&meta);
303        }
304
305        for external in externals.values_mut() {
306            external.update_meta(&meta);
307        }
308
309        let mut seen_tees_connect = HashMap::new();
310        self.ir.get_mut().iter_mut().for_each(|leaf| {
311            leaf.connect_network(&mut seen_tees_connect);
312        });
313
314        DeployResult {
315            processes,
316            clusters,
317            externals,
318            cluster_id_name: std::mem::take(&mut self.cluster_id_name)
319                .into_iter()
320                .collect(),
321        }
322    }
323}
324
325pub struct DeployResult<'a, D: Deploy<'a>> {
326    processes: HashMap<usize, D::Process>,
327    clusters: HashMap<usize, D::Cluster>,
328    externals: HashMap<usize, D::ExternalProcess>,
329    cluster_id_name: HashMap<usize, String>,
330}
331
332impl<'a, D: Deploy<'a>> DeployResult<'a, D> {
333    pub fn get_process<P>(&self, p: &Process<P>) -> &D::Process {
334        let id = match p.id() {
335            LocationId::Process(id) => id,
336            _ => panic!("Process ID expected"),
337        };
338
339        self.processes.get(&id).unwrap()
340    }
341
342    pub fn get_cluster<C>(&self, c: &Cluster<'a, C>) -> &D::Cluster {
343        let id = match c.id() {
344            LocationId::Cluster(id) => id,
345            _ => panic!("Cluster ID expected"),
346        };
347
348        self.clusters.get(&id).unwrap()
349    }
350
351    pub fn get_all_clusters(&self) -> impl Iterator<Item = (LocationId, String, &D::Cluster)> {
352        self.clusters.iter().map(|(&id, c)| {
353            (
354                LocationId::Cluster(id),
355                self.cluster_id_name.get(&id).unwrap().clone(),
356                c,
357            )
358        })
359    }
360
361    pub fn get_external<P>(&self, p: &ExternalProcess<P>) -> &D::ExternalProcess {
362        self.externals.get(&p.id).unwrap()
363    }
364
365    pub fn raw_port(&self, port: ExternalBytesPort) -> D::ExternalRawPort {
366        self.externals
367            .get(&port.process_id)
368            .unwrap()
369            .raw_port(port.port_id)
370    }
371
372    pub async fn connect_sink_bytes(
373        &self,
374        port: ExternalBytesPort,
375    ) -> Pin<Box<dyn Sink<Bytes, Error = Error>>> {
376        self.externals
377            .get(&port.process_id)
378            .unwrap()
379            .as_bytes_sink(port.port_id)
380            .await
381    }
382
383    pub async fn connect_sink_bincode<T: Serialize + DeserializeOwned + 'static>(
384        &self,
385        port: ExternalBincodeSink<T>,
386    ) -> Pin<Box<dyn Sink<T, Error = Error>>> {
387        self.externals
388            .get(&port.process_id)
389            .unwrap()
390            .as_bincode_sink(port.port_id)
391            .await
392    }
393
394    pub async fn connect_source_bytes(
395        &self,
396        port: ExternalBytesPort,
397    ) -> Pin<Box<dyn Stream<Item = Bytes>>> {
398        self.externals
399            .get(&port.process_id)
400            .unwrap()
401            .as_bytes_source(port.port_id)
402            .await
403    }
404
405    pub async fn connect_source_bincode<T: Serialize + DeserializeOwned + 'static>(
406        &self,
407        port: ExternalBincodeStream<T>,
408    ) -> Pin<Box<dyn Stream<Item = T>>> {
409        self.externals
410            .get(&port.process_id)
411            .unwrap()
412            .as_bincode_source(port.port_id)
413            .await
414    }
415}