1use std::cell::UnsafeCell;
2use std::collections::{BTreeMap, HashMap};
3use std::io::Error;
4use std::marker::PhantomData;
5use std::pin::Pin;
6
7use bytes::{Bytes, BytesMut};
8use futures::{Sink, Stream};
9use proc_macro2::Span;
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use stageleft::QuotedWithContext;
13
14use super::built::build_inner;
15use super::compiled::CompiledFlow;
16use crate::deploy::{
17 ClusterSpec, Deploy, ExternalSpec, IntoProcessSpec, Node, ProcessSpec, RegisterPort,
18};
19use crate::ir::HydroRoot;
20use crate::location::external_process::{
21 ExternalBincodeBidi, ExternalBincodeSink, ExternalBincodeStream, ExternalBytesPort,
22};
23use crate::location::{Cluster, External, Location, LocationId, Process};
24use crate::staging_util::Invariant;
25
26pub struct DeployFlow<'a, D>
27where
28 D: Deploy<'a>,
29{
30 pub(super) ir: UnsafeCell<Vec<HydroRoot>>,
34
35 pub(super) processes: HashMap<usize, D::Process>,
37
38 pub(super) process_id_name: Vec<(usize, String)>,
41
42 pub(super) externals: HashMap<usize, D::External>,
43 pub(super) external_id_name: Vec<(usize, String)>,
44
45 pub(super) clusters: HashMap<usize, D::Cluster>,
46 pub(super) cluster_id_name: Vec<(usize, String)>,
47
48 pub(super) _phantom: Invariant<'a, D>,
49}
50
51impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
52 pub fn ir(&self) -> &Vec<HydroRoot> {
53 unsafe {
54 &*self.ir.get()
56 }
57 }
58
59 pub fn with_process_id_name(
60 mut self,
61 process_id: usize,
62 process_name: String,
63 spec: impl IntoProcessSpec<'a, D>,
64 ) -> Self {
65 self.processes.insert(
66 process_id,
67 spec.into_process_spec().build(process_id, &process_name),
68 );
69 self
70 }
71
72 pub fn with_process<P>(self, process: &Process<P>, spec: impl IntoProcessSpec<'a, D>) -> Self {
73 self.with_process_id_name(process.id, std::any::type_name::<P>().to_string(), spec)
74 }
75
76 pub fn with_remaining_processes<S: IntoProcessSpec<'a, D> + 'a>(
77 mut self,
78 spec: impl Fn() -> S,
79 ) -> Self {
80 for (id, name) in &self.process_id_name {
81 self.processes
82 .insert(*id, spec().into_process_spec().build(*id, name));
83 }
84
85 self
86 }
87
88 pub fn with_external<P>(
89 mut self,
90 process: &External<P>,
91 spec: impl ExternalSpec<'a, D>,
92 ) -> Self {
93 let tag_name = std::any::type_name::<P>().to_string();
94 self.externals
95 .insert(process.id, spec.build(process.id, &tag_name));
96 self
97 }
98
99 pub fn with_remaining_externals<S: ExternalSpec<'a, D> + 'a>(
100 mut self,
101 spec: impl Fn() -> S,
102 ) -> Self {
103 for (id, name) in &self.external_id_name {
104 self.externals.insert(*id, spec().build(*id, name));
105 }
106
107 self
108 }
109
110 pub fn with_cluster_id_name(
111 mut self,
112 cluster_id: usize,
113 cluster_name: String,
114 spec: impl ClusterSpec<'a, D>,
115 ) -> Self {
116 self.clusters
117 .insert(cluster_id, spec.build(cluster_id, &cluster_name));
118 self
119 }
120
121 pub fn with_cluster<C>(self, cluster: &Cluster<C>, spec: impl ClusterSpec<'a, D>) -> Self {
122 self.with_cluster_id_name(cluster.id, std::any::type_name::<C>().to_string(), spec)
123 }
124
125 pub fn with_remaining_clusters<S: ClusterSpec<'a, D> + 'a>(
126 mut self,
127 spec: impl Fn() -> S,
128 ) -> Self {
129 for (id, name) in &self.cluster_id_name {
130 self.clusters.insert(*id, spec().build(*id, name));
131 }
132
133 self
134 }
135
136 pub fn preview_compile(&self) -> CompiledFlow<'a, ()> {
139 CompiledFlow {
140 dfir: build_inner(unsafe {
141 &mut *self.ir.get()
144 }),
145 _phantom: PhantomData,
146 }
147 }
148
149 pub fn compile_no_network(mut self) -> CompiledFlow<'a, D::GraphId> {
150 CompiledFlow {
151 dfir: build_inner(self.ir.get_mut()),
152 _phantom: PhantomData,
153 }
154 }
155}
156
157impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
158 pub fn compile(mut self, env: &D::CompileEnv) -> CompiledFlow<'a, D::GraphId> {
159 let mut seen_tees: HashMap<_, _> = HashMap::new();
160 let mut extra_stmts = BTreeMap::new();
161 self.ir.get_mut().iter_mut().for_each(|leaf| {
162 leaf.compile_network::<D>(
163 env,
164 &mut extra_stmts,
165 &mut seen_tees,
166 &self.processes,
167 &self.clusters,
168 &self.externals,
169 );
170 });
171
172 CompiledFlow {
173 dfir: build_inner(self.ir.get_mut()),
174 _phantom: PhantomData,
175 }
176 }
177
178 fn cluster_id_stmts(
179 &self,
180 extra_stmts: &mut BTreeMap<usize, Vec<syn::Stmt>>,
181 env: &<D as Deploy<'a>>::CompileEnv,
182 ) {
183 let mut all_clusters_sorted = self.clusters.keys().collect::<Vec<_>>();
184 all_clusters_sorted.sort();
185
186 for &c_id in all_clusters_sorted {
187 let self_id_ident = syn::Ident::new(
188 &format!("__hydro_lang_cluster_self_id_{}", c_id),
189 Span::call_site(),
190 );
191 let self_id_expr = D::cluster_self_id(env).splice_untyped();
192 extra_stmts
193 .entry(c_id)
194 .or_default()
195 .push(syn::parse_quote! {
196 let #self_id_ident = #self_id_expr;
197 });
198
199 for other_location in self.processes.keys().chain(self.clusters.keys()) {
200 let other_id_ident = syn::Ident::new(
201 &format!("__hydro_lang_cluster_ids_{}", c_id),
202 Span::call_site(),
203 );
204 let other_id_expr = D::cluster_ids(env, c_id).splice_untyped();
205 extra_stmts
206 .entry(*other_location)
207 .or_default()
208 .push(syn::parse_quote! {
209 let #other_id_ident = #other_id_expr;
210 });
211 }
212 }
213 }
214}
215
216impl<'a, D: Deploy<'a, CompileEnv = ()>> DeployFlow<'a, D> {
217 #[must_use]
218 pub fn deploy(mut self, env: &mut D::InstantiateEnv) -> DeployResult<'a, D> {
219 let mut seen_tees_instantiate: HashMap<_, _> = HashMap::new();
220 let mut extra_stmts = BTreeMap::new();
221 self.ir.get_mut().iter_mut().for_each(|leaf| {
222 leaf.compile_network::<D>(
223 &(),
224 &mut extra_stmts,
225 &mut seen_tees_instantiate,
226 &self.processes,
227 &self.clusters,
228 &self.externals,
229 );
230 });
231
232 let mut compiled = build_inner(self.ir.get_mut());
233 self.cluster_id_stmts(&mut extra_stmts, &());
234 let mut meta = D::Meta::default();
235
236 let (mut processes, mut clusters, mut externals) = (
237 std::mem::take(&mut self.processes)
238 .into_iter()
239 .filter_map(|(node_id, node)| {
240 if let Some(ir) = compiled.remove(&node_id) {
241 node.instantiate(
242 env,
243 &mut meta,
244 ir,
245 extra_stmts.remove(&node_id).unwrap_or_default(),
246 );
247 Some((node_id, node))
248 } else {
249 None
250 }
251 })
252 .collect::<HashMap<_, _>>(),
253 std::mem::take(&mut self.clusters)
254 .into_iter()
255 .filter_map(|(cluster_id, cluster)| {
256 if let Some(ir) = compiled.remove(&cluster_id) {
257 cluster.instantiate(
258 env,
259 &mut meta,
260 ir,
261 extra_stmts.remove(&cluster_id).unwrap_or_default(),
262 );
263 Some((cluster_id, cluster))
264 } else {
265 None
266 }
267 })
268 .collect::<HashMap<_, _>>(),
269 std::mem::take(&mut self.externals)
270 .into_iter()
271 .map(|(external_id, external)| {
272 external.instantiate(
273 env,
274 &mut meta,
275 Default::default(),
276 extra_stmts.remove(&external_id).unwrap_or_default(),
277 );
278 (external_id, external)
279 })
280 .collect::<HashMap<_, _>>(),
281 );
282
283 for node in processes.values_mut() {
284 node.update_meta(&meta);
285 }
286
287 for cluster in clusters.values_mut() {
288 cluster.update_meta(&meta);
289 }
290
291 for external in externals.values_mut() {
292 external.update_meta(&meta);
293 }
294
295 let mut seen_tees_connect = HashMap::new();
296 self.ir.get_mut().iter_mut().for_each(|leaf| {
297 leaf.connect_network(&mut seen_tees_connect);
298 });
299
300 DeployResult {
301 processes,
302 clusters,
303 externals,
304 cluster_id_name: std::mem::take(&mut self.cluster_id_name)
305 .into_iter()
306 .collect(),
307 process_id_name: std::mem::take(&mut self.process_id_name)
308 .into_iter()
309 .collect(),
310 }
311 }
312}
313
314pub struct DeployResult<'a, D: Deploy<'a>> {
315 processes: HashMap<usize, D::Process>,
316 clusters: HashMap<usize, D::Cluster>,
317 externals: HashMap<usize, D::External>,
318 cluster_id_name: HashMap<usize, String>,
319 process_id_name: HashMap<usize, String>,
320}
321
322impl<'a, D: Deploy<'a>> DeployResult<'a, D> {
323 pub fn get_process<P>(&self, p: &Process<P>) -> &D::Process {
324 let id = match p.id() {
325 LocationId::Process(id) => id,
326 _ => panic!("Process ID expected"),
327 };
328
329 self.processes.get(&id).unwrap()
330 }
331
332 pub fn get_cluster<C>(&self, c: &Cluster<'a, C>) -> &D::Cluster {
333 let id = match c.id() {
334 LocationId::Cluster(id) => id,
335 _ => panic!("Cluster ID expected"),
336 };
337
338 self.clusters.get(&id).unwrap()
339 }
340
341 pub fn get_all_clusters(&self) -> impl Iterator<Item = (LocationId, String, &D::Cluster)> {
342 self.clusters.iter().map(|(&id, c)| {
343 (
344 LocationId::Cluster(id),
345 self.cluster_id_name.get(&id).unwrap().clone(),
346 c,
347 )
348 })
349 }
350
351 pub fn get_all_processes(&self) -> impl Iterator<Item = (LocationId, String, &D::Process)> {
352 self.processes.iter().map(|(&id, p)| {
353 (
354 LocationId::Process(id),
355 self.process_id_name.get(&id).unwrap().clone(),
356 p,
357 )
358 })
359 }
360
361 pub fn get_external<P>(&self, p: &External<P>) -> &D::External {
362 self.externals.get(&p.id).unwrap()
363 }
364
365 pub fn raw_port<M>(&self, port: ExternalBytesPort<M>) -> D::ExternalRawPort {
366 self.externals
367 .get(&port.process_id)
368 .unwrap()
369 .raw_port(port.port_id)
370 }
371
372 pub async fn connect_bytes<M>(
373 &self,
374 port: ExternalBytesPort<M>,
375 ) -> (
376 Pin<Box<dyn Stream<Item = Result<BytesMut, Error>>>>,
377 Pin<Box<dyn Sink<Bytes, Error = Error>>>,
378 ) {
379 self.externals
380 .get(&port.process_id)
381 .unwrap()
382 .as_bytes_bidi(port.port_id)
383 .await
384 }
385
386 pub async fn connect_sink_bytes<M>(
387 &self,
388 port: ExternalBytesPort<M>,
389 ) -> Pin<Box<dyn Sink<Bytes, Error = Error>>> {
390 self.connect_bytes(port).await.1
391 }
392
393 pub async fn connect_bincode<
394 InT: Serialize + 'static,
395 OutT: DeserializeOwned + 'static,
396 Many,
397 >(
398 &self,
399 port: ExternalBincodeBidi<InT, OutT, Many>,
400 ) -> (
401 Pin<Box<dyn Stream<Item = OutT>>>,
402 Pin<Box<dyn Sink<InT, Error = Error>>>,
403 ) {
404 self.externals
405 .get(&port.process_id)
406 .unwrap()
407 .as_bincode_bidi(port.port_id)
408 .await
409 }
410
411 pub async fn connect_sink_bincode<T: Serialize + DeserializeOwned + 'static, Many>(
412 &self,
413 port: ExternalBincodeSink<T, Many>,
414 ) -> Pin<Box<dyn Sink<T, Error = Error>>> {
415 self.externals
416 .get(&port.process_id)
417 .unwrap()
418 .as_bincode_sink(port.port_id)
419 .await
420 }
421
422 pub async fn connect_source_bytes(
423 &self,
424 port: ExternalBytesPort,
425 ) -> Pin<Box<dyn Stream<Item = Result<BytesMut, Error>>>> {
426 self.connect_bytes(port).await.0
427 }
428
429 pub async fn connect_source_bincode<T: Serialize + DeserializeOwned + 'static>(
430 &self,
431 port: ExternalBincodeStream<T>,
432 ) -> Pin<Box<dyn Stream<Item = T>>> {
433 self.externals
434 .get(&port.process_id)
435 .unwrap()
436 .as_bincode_source(port.port_id)
437 .await
438 }
439}