hydro_lang/compile/
deploy.rs

1use std::cell::UnsafeCell;
2use std::collections::{BTreeMap, HashMap};
3use std::io::Error;
4use std::marker::PhantomData;
5use std::pin::Pin;
6
7use bytes::{Bytes, BytesMut};
8use futures::{Sink, Stream};
9use proc_macro2::Span;
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use stageleft::QuotedWithContext;
13
14use super::built::build_inner;
15use super::compiled::CompiledFlow;
16use super::deploy_provider::{
17    ClusterSpec, Deploy, ExternalSpec, IntoProcessSpec, Node, ProcessSpec, RegisterPort,
18};
19use super::ir::HydroRoot;
20use crate::location::dynamic::LocationId;
21use crate::location::external_process::{
22    ExternalBincodeBidi, ExternalBincodeSink, ExternalBincodeStream, ExternalBytesPort,
23};
24use crate::location::{Cluster, External, Location, Process};
25use crate::staging_util::Invariant;
26
27pub struct DeployFlow<'a, D>
28where
29    D: Deploy<'a>,
30{
31    // We need to grab an `&mut` reference to the IR in `preview_compile` even though
32    // that function does not modify the IR. Using an `UnsafeCell` allows us to do this
33    // while still being able to lend out immutable references to the IR.
34    pub(super) ir: UnsafeCell<Vec<HydroRoot>>,
35
36    /// Deployed instances of each process in the flow
37    pub(super) processes: HashMap<usize, D::Process>,
38
39    /// Lists all the processes that were created in the flow, same ID as `processes`
40    /// but with the type name of the tag.
41    pub(super) process_id_name: Vec<(usize, String)>,
42
43    pub(super) externals: HashMap<usize, D::External>,
44    pub(super) external_id_name: Vec<(usize, String)>,
45
46    pub(super) clusters: HashMap<usize, D::Cluster>,
47    pub(super) cluster_id_name: Vec<(usize, String)>,
48
49    pub(super) _phantom: Invariant<'a, D>,
50}
51
52impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
53    pub fn ir(&self) -> &Vec<HydroRoot> {
54        unsafe {
55            // SAFETY: even when we grab this as mutable in `preview_compile`, we do not modify it
56            &*self.ir.get()
57        }
58    }
59
60    pub fn with_process_id_name(
61        mut self,
62        process_id: usize,
63        process_name: String,
64        spec: impl IntoProcessSpec<'a, D>,
65    ) -> Self {
66        self.processes.insert(
67            process_id,
68            spec.into_process_spec().build(process_id, &process_name),
69        );
70        self
71    }
72
73    pub fn with_process<P>(self, process: &Process<P>, spec: impl IntoProcessSpec<'a, D>) -> Self {
74        self.with_process_id_name(process.id, std::any::type_name::<P>().to_string(), spec)
75    }
76
77    pub fn with_remaining_processes<S: IntoProcessSpec<'a, D> + 'a>(
78        mut self,
79        spec: impl Fn() -> S,
80    ) -> Self {
81        for (id, name) in &self.process_id_name {
82            self.processes
83                .insert(*id, spec().into_process_spec().build(*id, name));
84        }
85
86        self
87    }
88
89    pub fn with_external<P>(
90        mut self,
91        process: &External<P>,
92        spec: impl ExternalSpec<'a, D>,
93    ) -> Self {
94        let tag_name = std::any::type_name::<P>().to_string();
95        self.externals
96            .insert(process.id, spec.build(process.id, &tag_name));
97        self
98    }
99
100    pub fn with_remaining_externals<S: ExternalSpec<'a, D> + 'a>(
101        mut self,
102        spec: impl Fn() -> S,
103    ) -> Self {
104        for (id, name) in &self.external_id_name {
105            self.externals.insert(*id, spec().build(*id, name));
106        }
107
108        self
109    }
110
111    pub fn with_cluster_id_name(
112        mut self,
113        cluster_id: usize,
114        cluster_name: String,
115        spec: impl ClusterSpec<'a, D>,
116    ) -> Self {
117        self.clusters
118            .insert(cluster_id, spec.build(cluster_id, &cluster_name));
119        self
120    }
121
122    pub fn with_cluster<C>(self, cluster: &Cluster<C>, spec: impl ClusterSpec<'a, D>) -> Self {
123        self.with_cluster_id_name(cluster.id, std::any::type_name::<C>().to_string(), spec)
124    }
125
126    pub fn with_remaining_clusters<S: ClusterSpec<'a, D> + 'a>(
127        mut self,
128        spec: impl Fn() -> S,
129    ) -> Self {
130        for (id, name) in &self.cluster_id_name {
131            self.clusters.insert(*id, spec().build(*id, name));
132        }
133
134        self
135    }
136
137    /// Compiles the flow into DFIR using placeholders for the network.
138    /// Useful for generating Mermaid diagrams of the DFIR.
139    pub fn preview_compile(&self) -> CompiledFlow<'a, ()> {
140        CompiledFlow {
141            dfir: build_inner(unsafe {
142                // SAFETY: `build_inner` does not mutate the IR, &mut is required
143                // only because the shared traversal logic requires it
144                &mut *self.ir.get()
145            }),
146            _phantom: PhantomData,
147        }
148    }
149
150    pub fn compile_no_network(mut self) -> CompiledFlow<'a, D::GraphId> {
151        CompiledFlow {
152            dfir: build_inner(self.ir.get_mut()),
153            _phantom: PhantomData,
154        }
155    }
156}
157
158impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
159    pub fn compile(mut self, env: &D::CompileEnv) -> CompiledFlow<'a, D::GraphId> {
160        let mut seen_tees: HashMap<_, _> = HashMap::new();
161        let mut extra_stmts = BTreeMap::new();
162        self.ir.get_mut().iter_mut().for_each(|leaf| {
163            leaf.compile_network::<D>(
164                env,
165                &mut extra_stmts,
166                &mut seen_tees,
167                &self.processes,
168                &self.clusters,
169                &self.externals,
170            );
171        });
172
173        CompiledFlow {
174            dfir: build_inner(self.ir.get_mut()),
175            _phantom: PhantomData,
176        }
177    }
178
179    fn cluster_id_stmts(
180        &self,
181        extra_stmts: &mut BTreeMap<usize, Vec<syn::Stmt>>,
182        env: &<D as Deploy<'a>>::CompileEnv,
183    ) {
184        let mut all_clusters_sorted = self.clusters.keys().collect::<Vec<_>>();
185        all_clusters_sorted.sort();
186
187        for &c_id in all_clusters_sorted {
188            let self_id_ident = syn::Ident::new(
189                &format!("__hydro_lang_cluster_self_id_{}", c_id),
190                Span::call_site(),
191            );
192            let self_id_expr = D::cluster_self_id(env).splice_untyped();
193            extra_stmts
194                .entry(c_id)
195                .or_default()
196                .push(syn::parse_quote! {
197                    let #self_id_ident = #self_id_expr;
198                });
199
200            for other_location in self.processes.keys().chain(self.clusters.keys()) {
201                let other_id_ident = syn::Ident::new(
202                    &format!("__hydro_lang_cluster_ids_{}", c_id),
203                    Span::call_site(),
204                );
205                let other_id_expr = D::cluster_ids(env, c_id).splice_untyped();
206                extra_stmts
207                    .entry(*other_location)
208                    .or_default()
209                    .push(syn::parse_quote! {
210                        let #other_id_ident = #other_id_expr;
211                    });
212            }
213        }
214    }
215}
216
217impl<'a, D: Deploy<'a, CompileEnv = ()>> DeployFlow<'a, D> {
218    #[must_use]
219    pub fn deploy(mut self, env: &mut D::InstantiateEnv) -> DeployResult<'a, D> {
220        let mut seen_tees_instantiate: HashMap<_, _> = HashMap::new();
221        let mut extra_stmts = BTreeMap::new();
222        self.ir.get_mut().iter_mut().for_each(|leaf| {
223            leaf.compile_network::<D>(
224                &(),
225                &mut extra_stmts,
226                &mut seen_tees_instantiate,
227                &self.processes,
228                &self.clusters,
229                &self.externals,
230            );
231        });
232
233        let mut compiled = build_inner(self.ir.get_mut());
234        self.cluster_id_stmts(&mut extra_stmts, &());
235        let mut meta = D::Meta::default();
236
237        let (mut processes, mut clusters, mut externals) = (
238            std::mem::take(&mut self.processes)
239                .into_iter()
240                .filter_map(|(node_id, node)| {
241                    if let Some(ir) = compiled.remove(&node_id) {
242                        node.instantiate(
243                            env,
244                            &mut meta,
245                            ir,
246                            extra_stmts.remove(&node_id).unwrap_or_default(),
247                        );
248                        Some((node_id, node))
249                    } else {
250                        None
251                    }
252                })
253                .collect::<HashMap<_, _>>(),
254            std::mem::take(&mut self.clusters)
255                .into_iter()
256                .filter_map(|(cluster_id, cluster)| {
257                    if let Some(ir) = compiled.remove(&cluster_id) {
258                        cluster.instantiate(
259                            env,
260                            &mut meta,
261                            ir,
262                            extra_stmts.remove(&cluster_id).unwrap_or_default(),
263                        );
264                        Some((cluster_id, cluster))
265                    } else {
266                        None
267                    }
268                })
269                .collect::<HashMap<_, _>>(),
270            std::mem::take(&mut self.externals)
271                .into_iter()
272                .map(|(external_id, external)| {
273                    external.instantiate(
274                        env,
275                        &mut meta,
276                        Default::default(),
277                        extra_stmts.remove(&external_id).unwrap_or_default(),
278                    );
279                    (external_id, external)
280                })
281                .collect::<HashMap<_, _>>(),
282        );
283
284        for node in processes.values_mut() {
285            node.update_meta(&meta);
286        }
287
288        for cluster in clusters.values_mut() {
289            cluster.update_meta(&meta);
290        }
291
292        for external in externals.values_mut() {
293            external.update_meta(&meta);
294        }
295
296        let mut seen_tees_connect = HashMap::new();
297        self.ir.get_mut().iter_mut().for_each(|leaf| {
298            leaf.connect_network(&mut seen_tees_connect);
299        });
300
301        DeployResult {
302            processes,
303            clusters,
304            externals,
305            cluster_id_name: std::mem::take(&mut self.cluster_id_name)
306                .into_iter()
307                .collect(),
308            process_id_name: std::mem::take(&mut self.process_id_name)
309                .into_iter()
310                .collect(),
311        }
312    }
313}
314
315pub struct DeployResult<'a, D: Deploy<'a>> {
316    processes: HashMap<usize, D::Process>,
317    clusters: HashMap<usize, D::Cluster>,
318    externals: HashMap<usize, D::External>,
319    cluster_id_name: HashMap<usize, String>,
320    process_id_name: HashMap<usize, String>,
321}
322
323impl<'a, D: Deploy<'a>> DeployResult<'a, D> {
324    pub fn get_process<P>(&self, p: &Process<P>) -> &D::Process {
325        let id = match p.id() {
326            LocationId::Process(id) => id,
327            _ => panic!("Process ID expected"),
328        };
329
330        self.processes.get(&id).unwrap()
331    }
332
333    pub fn get_cluster<C>(&self, c: &Cluster<'a, C>) -> &D::Cluster {
334        let id = match c.id() {
335            LocationId::Cluster(id) => id,
336            _ => panic!("Cluster ID expected"),
337        };
338
339        self.clusters.get(&id).unwrap()
340    }
341
342    pub fn get_all_clusters(&self) -> impl Iterator<Item = (LocationId, String, &D::Cluster)> {
343        self.clusters.iter().map(|(&id, c)| {
344            (
345                LocationId::Cluster(id),
346                self.cluster_id_name.get(&id).unwrap().clone(),
347                c,
348            )
349        })
350    }
351
352    pub fn get_all_processes(&self) -> impl Iterator<Item = (LocationId, String, &D::Process)> {
353        self.processes.iter().map(|(&id, p)| {
354            (
355                LocationId::Process(id),
356                self.process_id_name.get(&id).unwrap().clone(),
357                p,
358            )
359        })
360    }
361
362    pub fn get_external<P>(&self, p: &External<P>) -> &D::External {
363        self.externals.get(&p.id).unwrap()
364    }
365
366    pub fn raw_port<M>(&self, port: ExternalBytesPort<M>) -> D::ExternalRawPort {
367        self.externals
368            .get(&port.process_id)
369            .unwrap()
370            .raw_port(port.port_id)
371    }
372
373    pub async fn connect_bytes<M>(
374        &self,
375        port: ExternalBytesPort<M>,
376    ) -> (
377        Pin<Box<dyn Stream<Item = Result<BytesMut, Error>>>>,
378        Pin<Box<dyn Sink<Bytes, Error = Error>>>,
379    ) {
380        self.externals
381            .get(&port.process_id)
382            .unwrap()
383            .as_bytes_bidi(port.port_id)
384            .await
385    }
386
387    pub async fn connect_sink_bytes<M>(
388        &self,
389        port: ExternalBytesPort<M>,
390    ) -> Pin<Box<dyn Sink<Bytes, Error = Error>>> {
391        self.connect_bytes(port).await.1
392    }
393
394    pub async fn connect_bincode<
395        InT: Serialize + 'static,
396        OutT: DeserializeOwned + 'static,
397        Many,
398    >(
399        &self,
400        port: ExternalBincodeBidi<InT, OutT, Many>,
401    ) -> (
402        Pin<Box<dyn Stream<Item = OutT>>>,
403        Pin<Box<dyn Sink<InT, Error = Error>>>,
404    ) {
405        self.externals
406            .get(&port.process_id)
407            .unwrap()
408            .as_bincode_bidi(port.port_id)
409            .await
410    }
411
412    pub async fn connect_sink_bincode<T: Serialize + DeserializeOwned + 'static, Many>(
413        &self,
414        port: ExternalBincodeSink<T, Many>,
415    ) -> Pin<Box<dyn Sink<T, Error = Error>>> {
416        self.externals
417            .get(&port.process_id)
418            .unwrap()
419            .as_bincode_sink(port.port_id)
420            .await
421    }
422
423    pub async fn connect_source_bytes(
424        &self,
425        port: ExternalBytesPort,
426    ) -> Pin<Box<dyn Stream<Item = Result<BytesMut, Error>>>> {
427        self.connect_bytes(port).await.0
428    }
429
430    pub async fn connect_source_bincode<T: Serialize + DeserializeOwned + 'static>(
431        &self,
432        port: ExternalBincodeStream<T>,
433    ) -> Pin<Box<dyn Stream<Item = T>>> {
434        self.externals
435            .get(&port.process_id)
436            .unwrap()
437            .as_bincode_source(port.port_id)
438            .await
439    }
440}