1use std::cell::UnsafeCell;
2use std::collections::{BTreeMap, HashMap};
3use std::io::Error;
4use std::marker::PhantomData;
5use std::pin::Pin;
6
7use dfir_rs::bytes::Bytes;
8use dfir_rs::futures::{Sink, Stream};
9use proc_macro2::Span;
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use stageleft::QuotedWithContext;
13
14use super::built::build_inner;
15use super::compiled::CompiledFlow;
16use crate::deploy::{
17 ClusterSpec, Deploy, ExternalSpec, IntoProcessSpec, LocalDeploy, Node, ProcessSpec,
18 RegisterPort,
19};
20use crate::ir::HydroLeaf;
21use crate::location::external_process::{
22 ExternalBincodeSink, ExternalBincodeStream, ExternalBytesPort,
23};
24use crate::location::{Cluster, ExternalProcess, Location, LocationId, Process};
25use crate::staging_util::Invariant;
26
27pub struct DeployFlow<'a, D: LocalDeploy<'a>> {
28 pub(super) ir: UnsafeCell<Vec<HydroLeaf>>,
32
33 pub(super) processes: HashMap<usize, D::Process>,
35
36 pub(super) process_id_name: Vec<(usize, String)>,
39
40 pub(super) externals: HashMap<usize, D::ExternalProcess>,
41 pub(super) external_id_name: Vec<(usize, String)>,
42
43 pub(super) clusters: HashMap<usize, D::Cluster>,
44 pub(super) cluster_id_name: Vec<(usize, String)>,
45 pub(super) used: bool,
46
47 pub(super) _phantom: Invariant<'a, D>,
48}
49
50impl<'a, D: LocalDeploy<'a>> Drop for DeployFlow<'a, D> {
51 fn drop(&mut self) {
52 if !self.used {
53 panic!(
54 "Dropped DeployFlow without instantiating, you may have forgotten to call `compile` or `deploy`."
55 );
56 }
57 }
58}
59
60impl<'a, D: LocalDeploy<'a>> DeployFlow<'a, D> {
61 pub fn ir(&self) -> &Vec<HydroLeaf> {
62 unsafe {
63 &*self.ir.get()
65 }
66 }
67
68 pub fn with_process<P>(
69 mut self,
70 process: &Process<P>,
71 spec: impl IntoProcessSpec<'a, D>,
72 ) -> Self {
73 let tag_name = std::any::type_name::<P>().to_string();
74 self.processes.insert(
75 process.id,
76 spec.into_process_spec().build(process.id, &tag_name),
77 );
78 self
79 }
80
81 pub fn with_remaining_processes<S: IntoProcessSpec<'a, D> + 'a>(
82 mut self,
83 spec: impl Fn() -> S,
84 ) -> Self {
85 for (id, name) in &self.process_id_name {
86 self.processes
87 .insert(*id, spec().into_process_spec().build(*id, name));
88 }
89
90 self
91 }
92
93 pub fn with_external<P>(
94 mut self,
95 process: &ExternalProcess<P>,
96 spec: impl ExternalSpec<'a, D>,
97 ) -> Self {
98 let tag_name = std::any::type_name::<P>().to_string();
99 self.externals
100 .insert(process.id, spec.build(process.id, &tag_name));
101 self
102 }
103
104 pub fn with_remaining_externals<S: ExternalSpec<'a, D> + 'a>(
105 mut self,
106 spec: impl Fn() -> S,
107 ) -> Self {
108 for (id, name) in &self.external_id_name {
109 self.externals.insert(*id, spec().build(*id, name));
110 }
111
112 self
113 }
114
115 pub fn with_cluster<C>(mut self, cluster: &Cluster<C>, spec: impl ClusterSpec<'a, D>) -> Self {
116 let tag_name = std::any::type_name::<C>().to_string();
117 self.clusters
118 .insert(cluster.id, spec.build(cluster.id, &tag_name));
119 self
120 }
121
122 pub fn with_remaining_clusters<S: ClusterSpec<'a, D> + 'a>(
123 mut self,
124 spec: impl Fn() -> S,
125 ) -> Self {
126 for (id, name) in &self.cluster_id_name {
127 self.clusters.insert(*id, spec().build(*id, name));
128 }
129
130 self
131 }
132
133 pub fn preview_compile(&self) -> CompiledFlow<'a, ()> {
136 CompiledFlow {
137 dfir: build_inner(unsafe {
138 &mut *self.ir.get()
141 }),
142 extra_stmts: BTreeMap::new(),
143 _phantom: PhantomData,
144 }
145 }
146
147 pub fn compile_no_network(mut self) -> CompiledFlow<'a, D::GraphId> {
148 self.used = true;
149
150 CompiledFlow {
151 dfir: build_inner(self.ir.get_mut()),
152 extra_stmts: BTreeMap::new(),
153 _phantom: PhantomData,
154 }
155 }
156}
157
158impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
159 pub fn compile(mut self, env: &D::CompileEnv) -> CompiledFlow<'a, D::GraphId> {
160 self.used = true;
161
162 let mut seen_tees: HashMap<_, _> = HashMap::new();
163 let mut seen_tee_locations: HashMap<_, _> = HashMap::new();
164 self.ir.get_mut().iter_mut().for_each(|leaf| {
165 leaf.compile_network::<D>(
166 env,
167 &mut seen_tees,
168 &mut seen_tee_locations,
169 &self.processes,
170 &self.clusters,
171 &self.externals,
172 );
173 });
174
175 let extra_stmts = self.extra_stmts(env);
176
177 CompiledFlow {
178 dfir: build_inner(self.ir.get_mut()),
179 extra_stmts,
180 _phantom: PhantomData,
181 }
182 }
183
184 fn extra_stmts(&self, env: &<D as Deploy<'a>>::CompileEnv) -> BTreeMap<usize, Vec<syn::Stmt>> {
185 let mut extra_stmts: BTreeMap<usize, Vec<syn::Stmt>> = BTreeMap::new();
186
187 let mut all_clusters_sorted = self.clusters.keys().collect::<Vec<_>>();
188 all_clusters_sorted.sort();
189
190 for &c_id in all_clusters_sorted {
191 let self_id_ident = syn::Ident::new(
192 &format!("__hydro_lang_cluster_self_id_{}", c_id),
193 Span::call_site(),
194 );
195 let self_id_expr = D::cluster_self_id(env).splice_untyped();
196 extra_stmts
197 .entry(c_id)
198 .or_default()
199 .push(syn::parse_quote! {
200 let #self_id_ident = #self_id_expr;
201 });
202
203 for other_location in self.processes.keys().chain(self.clusters.keys()) {
204 let other_id_ident = syn::Ident::new(
205 &format!("__hydro_lang_cluster_ids_{}", c_id),
206 Span::call_site(),
207 );
208 let other_id_expr = D::cluster_ids(env, c_id).splice_untyped();
209 extra_stmts
210 .entry(*other_location)
211 .or_default()
212 .push(syn::parse_quote! {
213 let #other_id_ident = #other_id_expr;
214 });
215 }
216 }
217 extra_stmts
218 }
219}
220
221impl<'a, D: Deploy<'a, CompileEnv = ()>> DeployFlow<'a, D> {
222 #[must_use]
223 pub fn deploy(mut self, env: &mut D::InstantiateEnv) -> DeployResult<'a, D> {
224 self.used = true;
225
226 let mut seen_tees_instantiate: HashMap<_, _> = HashMap::new();
227 let mut seen_tee_locations: HashMap<_, _> = HashMap::new();
228 self.ir.get_mut().iter_mut().for_each(|leaf| {
229 leaf.compile_network::<D>(
230 &(),
231 &mut seen_tees_instantiate,
232 &mut seen_tee_locations,
233 &self.processes,
234 &self.clusters,
235 &self.externals,
236 );
237 });
238
239 let mut compiled = build_inner(self.ir.get_mut());
240 let mut extra_stmts = self.extra_stmts(&());
241 let mut meta = D::Meta::default();
242
243 let (mut processes, mut clusters, mut externals) = (
244 std::mem::take(&mut self.processes)
245 .into_iter()
246 .filter_map(|(node_id, node)| {
247 if let Some(ir) = compiled.remove(&node_id) {
248 node.instantiate(
249 env,
250 &mut meta,
251 ir,
252 extra_stmts.remove(&node_id).unwrap_or_default(),
253 );
254 Some((node_id, node))
255 } else {
256 None
257 }
258 })
259 .collect::<HashMap<_, _>>(),
260 std::mem::take(&mut self.clusters)
261 .into_iter()
262 .filter_map(|(cluster_id, cluster)| {
263 if let Some(ir) = compiled.remove(&cluster_id) {
264 cluster.instantiate(
265 env,
266 &mut meta,
267 ir,
268 extra_stmts.remove(&cluster_id).unwrap_or_default(),
269 );
270 Some((cluster_id, cluster))
271 } else {
272 None
273 }
274 })
275 .collect::<HashMap<_, _>>(),
276 std::mem::take(&mut self.externals)
277 .into_iter()
278 .map(|(external_id, external)| {
279 external.instantiate(
280 env,
281 &mut meta,
282 compiled.remove(&external_id).unwrap(),
283 extra_stmts.remove(&external_id).unwrap_or_default(),
284 );
285 (external_id, external)
286 })
287 .collect::<HashMap<_, _>>(),
288 );
289
290 for node in processes.values_mut() {
291 node.update_meta(&meta);
292 }
293
294 for cluster in clusters.values_mut() {
295 cluster.update_meta(&meta);
296 }
297
298 for external in externals.values_mut() {
299 external.update_meta(&meta);
300 }
301
302 let mut seen_tees_connect = HashMap::new();
303 self.ir.get_mut().iter_mut().for_each(|leaf| {
304 leaf.connect_network(&mut seen_tees_connect);
305 });
306
307 DeployResult {
308 processes,
309 clusters,
310 externals,
311 cluster_id_name: std::mem::take(&mut self.cluster_id_name)
312 .into_iter()
313 .collect(),
314 }
315 }
316}
317
318pub struct DeployResult<'a, D: Deploy<'a>> {
319 processes: HashMap<usize, D::Process>,
320 clusters: HashMap<usize, D::Cluster>,
321 externals: HashMap<usize, D::ExternalProcess>,
322 cluster_id_name: HashMap<usize, String>,
323}
324
325impl<'a, D: Deploy<'a>> DeployResult<'a, D> {
326 pub fn get_process<P>(&self, p: &Process<P>) -> &D::Process {
327 let id = match p.id() {
328 LocationId::Process(id) => id,
329 _ => panic!("Process ID expected"),
330 };
331
332 self.processes.get(&id).unwrap()
333 }
334
335 pub fn get_cluster<C>(&self, c: &Cluster<'a, C>) -> &D::Cluster {
336 let id = match c.id() {
337 LocationId::Cluster(id) => id,
338 _ => panic!("Cluster ID expected"),
339 };
340
341 self.clusters.get(&id).unwrap()
342 }
343
344 pub fn get_all_clusters(&self) -> impl Iterator<Item = (LocationId, String, &D::Cluster)> {
345 self.clusters.iter().map(|(&id, c)| {
346 (
347 LocationId::Cluster(id),
348 self.cluster_id_name.get(&id).unwrap().clone(),
349 c,
350 )
351 })
352 }
353
354 pub fn get_external<P>(&self, p: &ExternalProcess<P>) -> &D::ExternalProcess {
355 self.externals.get(&p.id).unwrap()
356 }
357
358 pub fn raw_port(&self, port: ExternalBytesPort) -> D::ExternalRawPort {
359 self.externals
360 .get(&port.process_id)
361 .unwrap()
362 .raw_port(port.port_id)
363 }
364
365 pub async fn connect_sink_bytes(
366 &self,
367 port: ExternalBytesPort,
368 ) -> Pin<Box<dyn Sink<Bytes, Error = Error>>> {
369 self.externals
370 .get(&port.process_id)
371 .unwrap()
372 .as_bytes_sink(port.port_id)
373 .await
374 }
375
376 pub async fn connect_sink_bincode<T: Serialize + DeserializeOwned + 'static>(
377 &self,
378 port: ExternalBincodeSink<T>,
379 ) -> Pin<Box<dyn Sink<T, Error = Error>>> {
380 self.externals
381 .get(&port.process_id)
382 .unwrap()
383 .as_bincode_sink(port.port_id)
384 .await
385 }
386
387 pub async fn connect_source_bytes(
388 &self,
389 port: ExternalBytesPort,
390 ) -> Pin<Box<dyn Stream<Item = Bytes>>> {
391 self.externals
392 .get(&port.process_id)
393 .unwrap()
394 .as_bytes_source(port.port_id)
395 .await
396 }
397
398 pub async fn connect_source_bincode<T: Serialize + DeserializeOwned + 'static>(
399 &self,
400 port: ExternalBincodeStream<T>,
401 ) -> Pin<Box<dyn Stream<Item = T>>> {
402 self.externals
403 .get(&port.process_id)
404 .unwrap()
405 .as_bincode_source(port.port_id)
406 .await
407 }
408}