1use std::cell::UnsafeCell;
2use std::collections::{BTreeMap, HashMap};
3use std::io::Error;
4use std::marker::PhantomData;
5use std::pin::Pin;
6
7use bytes::Bytes;
8use futures::{Sink, Stream};
9use proc_macro2::Span;
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use stageleft::QuotedWithContext;
13
14use super::built::build_inner;
15use super::compiled::CompiledFlow;
16use crate::deploy::{
17 ClusterSpec, Deploy, ExternalSpec, IntoProcessSpec, LocalDeploy, Node, ProcessSpec,
18 RegisterPort,
19};
20use crate::ir::HydroLeaf;
21use crate::location::external_process::{
22 ExternalBincodeSink, ExternalBincodeStream, ExternalBytesPort,
23};
24use crate::location::{Cluster, ExternalProcess, Location, LocationId, Process};
25use crate::staging_util::Invariant;
26
27pub struct DeployFlow<'a, D>
28where
29 D: LocalDeploy<'a>,
30{
31 pub(super) ir: UnsafeCell<Vec<HydroLeaf>>,
35
36 pub(super) processes: HashMap<usize, D::Process>,
38
39 pub(super) process_id_name: Vec<(usize, String)>,
42
43 pub(super) externals: HashMap<usize, D::ExternalProcess>,
44 pub(super) external_id_name: Vec<(usize, String)>,
45
46 pub(super) clusters: HashMap<usize, D::Cluster>,
47 pub(super) cluster_id_name: Vec<(usize, String)>,
48 pub(super) used: bool,
49
50 pub(super) _phantom: Invariant<'a, D>,
51}
52
53impl<'a, D: LocalDeploy<'a>> Drop for DeployFlow<'a, D> {
54 fn drop(&mut self) {
55 if !self.used {
56 panic!(
57 "Dropped DeployFlow without instantiating, you may have forgotten to call `compile` or `deploy`."
58 );
59 }
60 }
61}
62
63impl<'a, D: LocalDeploy<'a>> DeployFlow<'a, D> {
64 pub fn ir(&self) -> &Vec<HydroLeaf> {
65 unsafe {
66 &*self.ir.get()
68 }
69 }
70
71 pub fn with_process<P>(
72 mut self,
73 process: &Process<P>,
74 spec: impl IntoProcessSpec<'a, D>,
75 ) -> Self {
76 let tag_name = std::any::type_name::<P>().to_string();
77 self.processes.insert(
78 process.id,
79 spec.into_process_spec().build(process.id, &tag_name),
80 );
81 self
82 }
83
84 pub fn with_remaining_processes<S: IntoProcessSpec<'a, D> + 'a>(
85 mut self,
86 spec: impl Fn() -> S,
87 ) -> Self {
88 for (id, name) in &self.process_id_name {
89 self.processes
90 .insert(*id, spec().into_process_spec().build(*id, name));
91 }
92
93 self
94 }
95
96 pub fn with_external<P>(
97 mut self,
98 process: &ExternalProcess<P>,
99 spec: impl ExternalSpec<'a, D>,
100 ) -> Self {
101 let tag_name = std::any::type_name::<P>().to_string();
102 self.externals
103 .insert(process.id, spec.build(process.id, &tag_name));
104 self
105 }
106
107 pub fn with_remaining_externals<S: ExternalSpec<'a, D> + 'a>(
108 mut self,
109 spec: impl Fn() -> S,
110 ) -> Self {
111 for (id, name) in &self.external_id_name {
112 self.externals.insert(*id, spec().build(*id, name));
113 }
114
115 self
116 }
117
118 pub fn with_cluster<C>(mut self, cluster: &Cluster<C>, spec: impl ClusterSpec<'a, D>) -> Self {
119 let tag_name = std::any::type_name::<C>().to_string();
120 self.clusters
121 .insert(cluster.id, spec.build(cluster.id, &tag_name));
122 self
123 }
124
125 pub fn with_remaining_clusters<S: ClusterSpec<'a, D> + 'a>(
126 mut self,
127 spec: impl Fn() -> S,
128 ) -> Self {
129 for (id, name) in &self.cluster_id_name {
130 self.clusters.insert(*id, spec().build(*id, name));
131 }
132
133 self
134 }
135
136 pub fn preview_compile(&self) -> CompiledFlow<'a, ()> {
139 CompiledFlow {
140 dfir: build_inner(unsafe {
141 &mut *self.ir.get()
144 }),
145 #[cfg(feature = "staged_macro")]
146 extra_stmts: BTreeMap::new(),
147 _phantom: PhantomData,
148 }
149 }
150
151 pub fn compile_no_network(mut self) -> CompiledFlow<'a, D::GraphId> {
152 self.used = true;
153
154 CompiledFlow {
155 dfir: build_inner(self.ir.get_mut()),
156 #[cfg(feature = "staged_macro")]
157 extra_stmts: BTreeMap::new(),
158 _phantom: PhantomData,
159 }
160 }
161}
162
163impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
164 pub fn compile(mut self, env: &D::CompileEnv) -> CompiledFlow<'a, D::GraphId> {
165 self.used = true;
166
167 let mut seen_tees: HashMap<_, _> = HashMap::new();
168 let mut seen_tee_locations: HashMap<_, _> = HashMap::new();
169 self.ir.get_mut().iter_mut().for_each(|leaf| {
170 leaf.compile_network::<D>(
171 env,
172 &mut seen_tees,
173 &mut seen_tee_locations,
174 &self.processes,
175 &self.clusters,
176 &self.externals,
177 );
178 });
179
180 #[cfg(feature = "staged_macro")]
181 let extra_stmts = self.extra_stmts(env);
182
183 CompiledFlow {
184 dfir: build_inner(self.ir.get_mut()),
185 #[cfg(feature = "staged_macro")]
186 extra_stmts,
187 _phantom: PhantomData,
188 }
189 }
190
191 fn extra_stmts(&self, env: &<D as Deploy<'a>>::CompileEnv) -> BTreeMap<usize, Vec<syn::Stmt>> {
192 let mut extra_stmts: BTreeMap<usize, Vec<syn::Stmt>> = BTreeMap::new();
193
194 let mut all_clusters_sorted = self.clusters.keys().collect::<Vec<_>>();
195 all_clusters_sorted.sort();
196
197 for &c_id in all_clusters_sorted {
198 let self_id_ident = syn::Ident::new(
199 &format!("__hydro_lang_cluster_self_id_{}", c_id),
200 Span::call_site(),
201 );
202 let self_id_expr = D::cluster_self_id(env).splice_untyped();
203 extra_stmts
204 .entry(c_id)
205 .or_default()
206 .push(syn::parse_quote! {
207 let #self_id_ident = #self_id_expr;
208 });
209
210 for other_location in self.processes.keys().chain(self.clusters.keys()) {
211 let other_id_ident = syn::Ident::new(
212 &format!("__hydro_lang_cluster_ids_{}", c_id),
213 Span::call_site(),
214 );
215 let other_id_expr = D::cluster_ids(env, c_id).splice_untyped();
216 extra_stmts
217 .entry(*other_location)
218 .or_default()
219 .push(syn::parse_quote! {
220 let #other_id_ident = #other_id_expr;
221 });
222 }
223 }
224 extra_stmts
225 }
226}
227
228impl<'a, D: Deploy<'a, CompileEnv = ()>> DeployFlow<'a, D> {
229 #[must_use]
230 pub fn deploy(mut self, env: &mut D::InstantiateEnv) -> DeployResult<'a, D> {
231 self.used = true;
232
233 let mut seen_tees_instantiate: HashMap<_, _> = HashMap::new();
234 let mut seen_tee_locations: HashMap<_, _> = HashMap::new();
235 self.ir.get_mut().iter_mut().for_each(|leaf| {
236 leaf.compile_network::<D>(
237 &(),
238 &mut seen_tees_instantiate,
239 &mut seen_tee_locations,
240 &self.processes,
241 &self.clusters,
242 &self.externals,
243 );
244 });
245
246 let mut compiled = build_inner(self.ir.get_mut());
247 let mut extra_stmts = self.extra_stmts(&());
248 let mut meta = D::Meta::default();
249
250 let (mut processes, mut clusters, mut externals) = (
251 std::mem::take(&mut self.processes)
252 .into_iter()
253 .filter_map(|(node_id, node)| {
254 if let Some(ir) = compiled.remove(&node_id) {
255 node.instantiate(
256 env,
257 &mut meta,
258 ir,
259 extra_stmts.remove(&node_id).unwrap_or_default(),
260 );
261 Some((node_id, node))
262 } else {
263 None
264 }
265 })
266 .collect::<HashMap<_, _>>(),
267 std::mem::take(&mut self.clusters)
268 .into_iter()
269 .filter_map(|(cluster_id, cluster)| {
270 if let Some(ir) = compiled.remove(&cluster_id) {
271 cluster.instantiate(
272 env,
273 &mut meta,
274 ir,
275 extra_stmts.remove(&cluster_id).unwrap_or_default(),
276 );
277 Some((cluster_id, cluster))
278 } else {
279 None
280 }
281 })
282 .collect::<HashMap<_, _>>(),
283 std::mem::take(&mut self.externals)
284 .into_iter()
285 .map(|(external_id, external)| {
286 external.instantiate(
287 env,
288 &mut meta,
289 compiled.remove(&external_id).unwrap(),
290 extra_stmts.remove(&external_id).unwrap_or_default(),
291 );
292 (external_id, external)
293 })
294 .collect::<HashMap<_, _>>(),
295 );
296
297 for node in processes.values_mut() {
298 node.update_meta(&meta);
299 }
300
301 for cluster in clusters.values_mut() {
302 cluster.update_meta(&meta);
303 }
304
305 for external in externals.values_mut() {
306 external.update_meta(&meta);
307 }
308
309 let mut seen_tees_connect = HashMap::new();
310 self.ir.get_mut().iter_mut().for_each(|leaf| {
311 leaf.connect_network(&mut seen_tees_connect);
312 });
313
314 DeployResult {
315 processes,
316 clusters,
317 externals,
318 cluster_id_name: std::mem::take(&mut self.cluster_id_name)
319 .into_iter()
320 .collect(),
321 }
322 }
323}
324
325pub struct DeployResult<'a, D: Deploy<'a>> {
326 processes: HashMap<usize, D::Process>,
327 clusters: HashMap<usize, D::Cluster>,
328 externals: HashMap<usize, D::ExternalProcess>,
329 cluster_id_name: HashMap<usize, String>,
330}
331
332impl<'a, D: Deploy<'a>> DeployResult<'a, D> {
333 pub fn get_process<P>(&self, p: &Process<P>) -> &D::Process {
334 let id = match p.id() {
335 LocationId::Process(id) => id,
336 _ => panic!("Process ID expected"),
337 };
338
339 self.processes.get(&id).unwrap()
340 }
341
342 pub fn get_cluster<C>(&self, c: &Cluster<'a, C>) -> &D::Cluster {
343 let id = match c.id() {
344 LocationId::Cluster(id) => id,
345 _ => panic!("Cluster ID expected"),
346 };
347
348 self.clusters.get(&id).unwrap()
349 }
350
351 pub fn get_all_clusters(&self) -> impl Iterator<Item = (LocationId, String, &D::Cluster)> {
352 self.clusters.iter().map(|(&id, c)| {
353 (
354 LocationId::Cluster(id),
355 self.cluster_id_name.get(&id).unwrap().clone(),
356 c,
357 )
358 })
359 }
360
361 pub fn get_external<P>(&self, p: &ExternalProcess<P>) -> &D::ExternalProcess {
362 self.externals.get(&p.id).unwrap()
363 }
364
365 pub fn raw_port(&self, port: ExternalBytesPort) -> D::ExternalRawPort {
366 self.externals
367 .get(&port.process_id)
368 .unwrap()
369 .raw_port(port.port_id)
370 }
371
372 pub async fn connect_sink_bytes(
373 &self,
374 port: ExternalBytesPort,
375 ) -> Pin<Box<dyn Sink<Bytes, Error = Error>>> {
376 self.externals
377 .get(&port.process_id)
378 .unwrap()
379 .as_bytes_sink(port.port_id)
380 .await
381 }
382
383 pub async fn connect_sink_bincode<T: Serialize + DeserializeOwned + 'static>(
384 &self,
385 port: ExternalBincodeSink<T>,
386 ) -> Pin<Box<dyn Sink<T, Error = Error>>> {
387 self.externals
388 .get(&port.process_id)
389 .unwrap()
390 .as_bincode_sink(port.port_id)
391 .await
392 }
393
394 pub async fn connect_source_bytes(
395 &self,
396 port: ExternalBytesPort,
397 ) -> Pin<Box<dyn Stream<Item = Bytes>>> {
398 self.externals
399 .get(&port.process_id)
400 .unwrap()
401 .as_bytes_source(port.port_id)
402 .await
403 }
404
405 pub async fn connect_source_bincode<T: Serialize + DeserializeOwned + 'static>(
406 &self,
407 port: ExternalBincodeStream<T>,
408 ) -> Pin<Box<dyn Stream<Item = T>>> {
409 self.externals
410 .get(&port.process_id)
411 .unwrap()
412 .as_bincode_source(port.port_id)
413 .await
414 }
415}