hydro_lang/builder/
deploy.rs

1use std::cell::UnsafeCell;
2use std::collections::{BTreeMap, HashMap};
3use std::io::Error;
4use std::marker::PhantomData;
5use std::pin::Pin;
6
7use bytes::{Bytes, BytesMut};
8use futures::{Sink, Stream};
9use proc_macro2::Span;
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use stageleft::QuotedWithContext;
13
14use super::built::build_inner;
15use super::compiled::CompiledFlow;
16use crate::deploy::{
17    ClusterSpec, Deploy, ExternalSpec, IntoProcessSpec, Node, ProcessSpec, RegisterPort,
18};
19use crate::ir::HydroRoot;
20use crate::location::external_process::{
21    ExternalBincodeBidi, ExternalBincodeSink, ExternalBincodeStream, ExternalBytesPort,
22};
23use crate::location::{Cluster, External, Location, LocationId, Process};
24use crate::staging_util::Invariant;
25
26pub struct DeployFlow<'a, D>
27where
28    D: Deploy<'a>,
29{
30    // We need to grab an `&mut` reference to the IR in `preview_compile` even though
31    // that function does not modify the IR. Using an `UnsafeCell` allows us to do this
32    // while still being able to lend out immutable references to the IR.
33    pub(super) ir: UnsafeCell<Vec<HydroRoot>>,
34
35    /// Deployed instances of each process in the flow
36    pub(super) processes: HashMap<usize, D::Process>,
37
38    /// Lists all the processes that were created in the flow, same ID as `processes`
39    /// but with the type name of the tag.
40    pub(super) process_id_name: Vec<(usize, String)>,
41
42    pub(super) externals: HashMap<usize, D::External>,
43    pub(super) external_id_name: Vec<(usize, String)>,
44
45    pub(super) clusters: HashMap<usize, D::Cluster>,
46    pub(super) cluster_id_name: Vec<(usize, String)>,
47
48    pub(super) _phantom: Invariant<'a, D>,
49}
50
51impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
52    pub fn ir(&self) -> &Vec<HydroRoot> {
53        unsafe {
54            // SAFETY: even when we grab this as mutable in `preview_compile`, we do not modify it
55            &*self.ir.get()
56        }
57    }
58
59    pub fn with_process_id_name(
60        mut self,
61        process_id: usize,
62        process_name: String,
63        spec: impl IntoProcessSpec<'a, D>,
64    ) -> Self {
65        self.processes.insert(
66            process_id,
67            spec.into_process_spec().build(process_id, &process_name),
68        );
69        self
70    }
71
72    pub fn with_process<P>(self, process: &Process<P>, spec: impl IntoProcessSpec<'a, D>) -> Self {
73        self.with_process_id_name(process.id, std::any::type_name::<P>().to_string(), spec)
74    }
75
76    pub fn with_remaining_processes<S: IntoProcessSpec<'a, D> + 'a>(
77        mut self,
78        spec: impl Fn() -> S,
79    ) -> Self {
80        for (id, name) in &self.process_id_name {
81            self.processes
82                .insert(*id, spec().into_process_spec().build(*id, name));
83        }
84
85        self
86    }
87
88    pub fn with_external<P>(
89        mut self,
90        process: &External<P>,
91        spec: impl ExternalSpec<'a, D>,
92    ) -> Self {
93        let tag_name = std::any::type_name::<P>().to_string();
94        self.externals
95            .insert(process.id, spec.build(process.id, &tag_name));
96        self
97    }
98
99    pub fn with_remaining_externals<S: ExternalSpec<'a, D> + 'a>(
100        mut self,
101        spec: impl Fn() -> S,
102    ) -> Self {
103        for (id, name) in &self.external_id_name {
104            self.externals.insert(*id, spec().build(*id, name));
105        }
106
107        self
108    }
109
110    pub fn with_cluster_id_name(
111        mut self,
112        cluster_id: usize,
113        cluster_name: String,
114        spec: impl ClusterSpec<'a, D>,
115    ) -> Self {
116        self.clusters
117            .insert(cluster_id, spec.build(cluster_id, &cluster_name));
118        self
119    }
120
121    pub fn with_cluster<C>(self, cluster: &Cluster<C>, spec: impl ClusterSpec<'a, D>) -> Self {
122        self.with_cluster_id_name(cluster.id, std::any::type_name::<C>().to_string(), spec)
123    }
124
125    pub fn with_remaining_clusters<S: ClusterSpec<'a, D> + 'a>(
126        mut self,
127        spec: impl Fn() -> S,
128    ) -> Self {
129        for (id, name) in &self.cluster_id_name {
130            self.clusters.insert(*id, spec().build(*id, name));
131        }
132
133        self
134    }
135
136    /// Compiles the flow into DFIR using placeholders for the network.
137    /// Useful for generating Mermaid diagrams of the DFIR.
138    pub fn preview_compile(&self) -> CompiledFlow<'a, ()> {
139        CompiledFlow {
140            dfir: build_inner(unsafe {
141                // SAFETY: `build_inner` does not mutate the IR, &mut is required
142                // only because the shared traversal logic requires it
143                &mut *self.ir.get()
144            }),
145            _phantom: PhantomData,
146        }
147    }
148
149    pub fn compile_no_network(mut self) -> CompiledFlow<'a, D::GraphId> {
150        CompiledFlow {
151            dfir: build_inner(self.ir.get_mut()),
152            _phantom: PhantomData,
153        }
154    }
155}
156
157impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
158    pub fn compile(mut self, env: &D::CompileEnv) -> CompiledFlow<'a, D::GraphId> {
159        let mut seen_tees: HashMap<_, _> = HashMap::new();
160        let mut extra_stmts = BTreeMap::new();
161        self.ir.get_mut().iter_mut().for_each(|leaf| {
162            leaf.compile_network::<D>(
163                env,
164                &mut extra_stmts,
165                &mut seen_tees,
166                &self.processes,
167                &self.clusters,
168                &self.externals,
169            );
170        });
171
172        CompiledFlow {
173            dfir: build_inner(self.ir.get_mut()),
174            _phantom: PhantomData,
175        }
176    }
177
178    fn cluster_id_stmts(
179        &self,
180        extra_stmts: &mut BTreeMap<usize, Vec<syn::Stmt>>,
181        env: &<D as Deploy<'a>>::CompileEnv,
182    ) {
183        let mut all_clusters_sorted = self.clusters.keys().collect::<Vec<_>>();
184        all_clusters_sorted.sort();
185
186        for &c_id in all_clusters_sorted {
187            let self_id_ident = syn::Ident::new(
188                &format!("__hydro_lang_cluster_self_id_{}", c_id),
189                Span::call_site(),
190            );
191            let self_id_expr = D::cluster_self_id(env).splice_untyped();
192            extra_stmts
193                .entry(c_id)
194                .or_default()
195                .push(syn::parse_quote! {
196                    let #self_id_ident = #self_id_expr;
197                });
198
199            for other_location in self.processes.keys().chain(self.clusters.keys()) {
200                let other_id_ident = syn::Ident::new(
201                    &format!("__hydro_lang_cluster_ids_{}", c_id),
202                    Span::call_site(),
203                );
204                let other_id_expr = D::cluster_ids(env, c_id).splice_untyped();
205                extra_stmts
206                    .entry(*other_location)
207                    .or_default()
208                    .push(syn::parse_quote! {
209                        let #other_id_ident = #other_id_expr;
210                    });
211            }
212        }
213    }
214}
215
216impl<'a, D: Deploy<'a, CompileEnv = ()>> DeployFlow<'a, D> {
217    #[must_use]
218    pub fn deploy(mut self, env: &mut D::InstantiateEnv) -> DeployResult<'a, D> {
219        let mut seen_tees_instantiate: HashMap<_, _> = HashMap::new();
220        let mut extra_stmts = BTreeMap::new();
221        self.ir.get_mut().iter_mut().for_each(|leaf| {
222            leaf.compile_network::<D>(
223                &(),
224                &mut extra_stmts,
225                &mut seen_tees_instantiate,
226                &self.processes,
227                &self.clusters,
228                &self.externals,
229            );
230        });
231
232        let mut compiled = build_inner(self.ir.get_mut());
233        self.cluster_id_stmts(&mut extra_stmts, &());
234        let mut meta = D::Meta::default();
235
236        let (mut processes, mut clusters, mut externals) = (
237            std::mem::take(&mut self.processes)
238                .into_iter()
239                .filter_map(|(node_id, node)| {
240                    if let Some(ir) = compiled.remove(&node_id) {
241                        node.instantiate(
242                            env,
243                            &mut meta,
244                            ir,
245                            extra_stmts.remove(&node_id).unwrap_or_default(),
246                        );
247                        Some((node_id, node))
248                    } else {
249                        None
250                    }
251                })
252                .collect::<HashMap<_, _>>(),
253            std::mem::take(&mut self.clusters)
254                .into_iter()
255                .filter_map(|(cluster_id, cluster)| {
256                    if let Some(ir) = compiled.remove(&cluster_id) {
257                        cluster.instantiate(
258                            env,
259                            &mut meta,
260                            ir,
261                            extra_stmts.remove(&cluster_id).unwrap_or_default(),
262                        );
263                        Some((cluster_id, cluster))
264                    } else {
265                        None
266                    }
267                })
268                .collect::<HashMap<_, _>>(),
269            std::mem::take(&mut self.externals)
270                .into_iter()
271                .map(|(external_id, external)| {
272                    external.instantiate(
273                        env,
274                        &mut meta,
275                        Default::default(),
276                        extra_stmts.remove(&external_id).unwrap_or_default(),
277                    );
278                    (external_id, external)
279                })
280                .collect::<HashMap<_, _>>(),
281        );
282
283        for node in processes.values_mut() {
284            node.update_meta(&meta);
285        }
286
287        for cluster in clusters.values_mut() {
288            cluster.update_meta(&meta);
289        }
290
291        for external in externals.values_mut() {
292            external.update_meta(&meta);
293        }
294
295        let mut seen_tees_connect = HashMap::new();
296        self.ir.get_mut().iter_mut().for_each(|leaf| {
297            leaf.connect_network(&mut seen_tees_connect);
298        });
299
300        DeployResult {
301            processes,
302            clusters,
303            externals,
304            cluster_id_name: std::mem::take(&mut self.cluster_id_name)
305                .into_iter()
306                .collect(),
307            process_id_name: std::mem::take(&mut self.process_id_name)
308                .into_iter()
309                .collect(),
310        }
311    }
312}
313
314pub struct DeployResult<'a, D: Deploy<'a>> {
315    processes: HashMap<usize, D::Process>,
316    clusters: HashMap<usize, D::Cluster>,
317    externals: HashMap<usize, D::External>,
318    cluster_id_name: HashMap<usize, String>,
319    process_id_name: HashMap<usize, String>,
320}
321
322impl<'a, D: Deploy<'a>> DeployResult<'a, D> {
323    pub fn get_process<P>(&self, p: &Process<P>) -> &D::Process {
324        let id = match p.id() {
325            LocationId::Process(id) => id,
326            _ => panic!("Process ID expected"),
327        };
328
329        self.processes.get(&id).unwrap()
330    }
331
332    pub fn get_cluster<C>(&self, c: &Cluster<'a, C>) -> &D::Cluster {
333        let id = match c.id() {
334            LocationId::Cluster(id) => id,
335            _ => panic!("Cluster ID expected"),
336        };
337
338        self.clusters.get(&id).unwrap()
339    }
340
341    pub fn get_all_clusters(&self) -> impl Iterator<Item = (LocationId, String, &D::Cluster)> {
342        self.clusters.iter().map(|(&id, c)| {
343            (
344                LocationId::Cluster(id),
345                self.cluster_id_name.get(&id).unwrap().clone(),
346                c,
347            )
348        })
349    }
350
351    pub fn get_all_processes(&self) -> impl Iterator<Item = (LocationId, String, &D::Process)> {
352        self.processes.iter().map(|(&id, p)| {
353            (
354                LocationId::Process(id),
355                self.process_id_name.get(&id).unwrap().clone(),
356                p,
357            )
358        })
359    }
360
361    pub fn get_external<P>(&self, p: &External<P>) -> &D::External {
362        self.externals.get(&p.id).unwrap()
363    }
364
365    pub fn raw_port<M>(&self, port: ExternalBytesPort<M>) -> D::ExternalRawPort {
366        self.externals
367            .get(&port.process_id)
368            .unwrap()
369            .raw_port(port.port_id)
370    }
371
372    pub async fn connect_bytes<M>(
373        &self,
374        port: ExternalBytesPort<M>,
375    ) -> (
376        Pin<Box<dyn Stream<Item = Result<BytesMut, Error>>>>,
377        Pin<Box<dyn Sink<Bytes, Error = Error>>>,
378    ) {
379        self.externals
380            .get(&port.process_id)
381            .unwrap()
382            .as_bytes_bidi(port.port_id)
383            .await
384    }
385
386    pub async fn connect_sink_bytes<M>(
387        &self,
388        port: ExternalBytesPort<M>,
389    ) -> Pin<Box<dyn Sink<Bytes, Error = Error>>> {
390        self.connect_bytes(port).await.1
391    }
392
393    pub async fn connect_bincode<
394        InT: Serialize + 'static,
395        OutT: DeserializeOwned + 'static,
396        Many,
397    >(
398        &self,
399        port: ExternalBincodeBidi<InT, OutT, Many>,
400    ) -> (
401        Pin<Box<dyn Stream<Item = OutT>>>,
402        Pin<Box<dyn Sink<InT, Error = Error>>>,
403    ) {
404        self.externals
405            .get(&port.process_id)
406            .unwrap()
407            .as_bincode_bidi(port.port_id)
408            .await
409    }
410
411    pub async fn connect_sink_bincode<T: Serialize + DeserializeOwned + 'static, Many>(
412        &self,
413        port: ExternalBincodeSink<T, Many>,
414    ) -> Pin<Box<dyn Sink<T, Error = Error>>> {
415        self.externals
416            .get(&port.process_id)
417            .unwrap()
418            .as_bincode_sink(port.port_id)
419            .await
420    }
421
422    pub async fn connect_source_bytes(
423        &self,
424        port: ExternalBytesPort,
425    ) -> Pin<Box<dyn Stream<Item = Result<BytesMut, Error>>>> {
426        self.connect_bytes(port).await.0
427    }
428
429    pub async fn connect_source_bincode<T: Serialize + DeserializeOwned + 'static>(
430        &self,
431        port: ExternalBincodeStream<T>,
432    ) -> Pin<Box<dyn Stream<Item = T>>> {
433        self.externals
434            .get(&port.process_id)
435            .unwrap()
436            .as_bincode_source(port.port_id)
437            .await
438    }
439}