1use std::cell::UnsafeCell;
2use std::collections::{BTreeMap, HashMap};
3use std::io::Error;
4use std::marker::PhantomData;
5use std::pin::Pin;
6
7use bytes::{Bytes, BytesMut};
8use futures::{Sink, Stream};
9use proc_macro2::Span;
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use stageleft::QuotedWithContext;
13
14use super::built::build_inner;
15use super::compiled::CompiledFlow;
16use super::deploy_provider::{
17 ClusterSpec, Deploy, ExternalSpec, IntoProcessSpec, Node, ProcessSpec, RegisterPort,
18};
19use super::ir::HydroRoot;
20use crate::live_collections::stream::{Ordering, Retries};
21use crate::location::dynamic::LocationId;
22use crate::location::external_process::{
23 ExternalBincodeBidi, ExternalBincodeSink, ExternalBincodeStream, ExternalBytesPort,
24};
25use crate::location::{Cluster, External, Location, Process};
26use crate::staging_util::Invariant;
27
28pub struct DeployFlow<'a, D>
29where
30 D: Deploy<'a>,
31{
32 pub(super) ir: UnsafeCell<Vec<HydroRoot>>,
36
37 pub(super) processes: HashMap<usize, D::Process>,
39
40 pub(super) process_id_name: Vec<(usize, String)>,
43
44 pub(super) externals: HashMap<usize, D::External>,
45 pub(super) external_id_name: Vec<(usize, String)>,
46
47 pub(super) clusters: HashMap<usize, D::Cluster>,
48 pub(super) cluster_id_name: Vec<(usize, String)>,
49
50 pub(super) _phantom: Invariant<'a, D>,
51}
52
53impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
54 pub fn ir(&self) -> &Vec<HydroRoot> {
55 unsafe {
56 &*self.ir.get()
58 }
59 }
60
61 pub fn with_process_id_name(
62 mut self,
63 process_id: usize,
64 process_name: String,
65 spec: impl IntoProcessSpec<'a, D>,
66 ) -> Self {
67 self.processes.insert(
68 process_id,
69 spec.into_process_spec().build(process_id, &process_name),
70 );
71 self
72 }
73
74 pub fn with_process<P>(self, process: &Process<P>, spec: impl IntoProcessSpec<'a, D>) -> Self {
75 self.with_process_id_name(process.id, std::any::type_name::<P>().to_string(), spec)
76 }
77
78 pub fn with_remaining_processes<S: IntoProcessSpec<'a, D> + 'a>(
79 mut self,
80 spec: impl Fn() -> S,
81 ) -> Self {
82 for (id, name) in &self.process_id_name {
83 self.processes
84 .insert(*id, spec().into_process_spec().build(*id, name));
85 }
86
87 self
88 }
89
90 pub fn with_external<P>(
91 mut self,
92 process: &External<P>,
93 spec: impl ExternalSpec<'a, D>,
94 ) -> Self {
95 let tag_name = std::any::type_name::<P>().to_string();
96 self.externals
97 .insert(process.id, spec.build(process.id, &tag_name));
98 self
99 }
100
101 pub fn with_remaining_externals<S: ExternalSpec<'a, D> + 'a>(
102 mut self,
103 spec: impl Fn() -> S,
104 ) -> Self {
105 for (id, name) in &self.external_id_name {
106 self.externals.insert(*id, spec().build(*id, name));
107 }
108
109 self
110 }
111
112 pub fn with_cluster_id_name(
113 mut self,
114 cluster_id: usize,
115 cluster_name: String,
116 spec: impl ClusterSpec<'a, D>,
117 ) -> Self {
118 self.clusters
119 .insert(cluster_id, spec.build(cluster_id, &cluster_name));
120 self
121 }
122
123 pub fn with_cluster<C>(self, cluster: &Cluster<C>, spec: impl ClusterSpec<'a, D>) -> Self {
124 self.with_cluster_id_name(cluster.id, std::any::type_name::<C>().to_string(), spec)
125 }
126
127 pub fn with_remaining_clusters<S: ClusterSpec<'a, D> + 'a>(
128 mut self,
129 spec: impl Fn() -> S,
130 ) -> Self {
131 for (id, name) in &self.cluster_id_name {
132 self.clusters.insert(*id, spec().build(*id, name));
133 }
134
135 self
136 }
137
138 pub fn preview_compile(&self) -> CompiledFlow<'a, ()> {
141 CompiledFlow {
142 dfir: build_inner::<D>(unsafe {
143 &mut *self.ir.get()
146 }),
147 _phantom: PhantomData,
148 }
149 }
150
151 pub fn compile_no_network(mut self) -> CompiledFlow<'a, D::GraphId> {
152 CompiledFlow {
153 dfir: build_inner::<D>(self.ir.get_mut()),
154 _phantom: PhantomData,
155 }
156 }
157}
158
159impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
160 pub fn compile(mut self) -> CompiledFlow<'a, D::GraphId> {
161 let mut seen_tees: HashMap<_, _> = HashMap::new();
162 let mut extra_stmts = BTreeMap::new();
163 self.ir.get_mut().iter_mut().for_each(|leaf| {
164 leaf.compile_network::<D>(
165 &mut extra_stmts,
166 &mut seen_tees,
167 &self.processes,
168 &self.clusters,
169 &self.externals,
170 );
171 });
172
173 CompiledFlow {
174 dfir: build_inner::<D>(self.ir.get_mut()),
175 _phantom: PhantomData,
176 }
177 }
178
179 fn cluster_id_stmts(&self, extra_stmts: &mut BTreeMap<usize, Vec<syn::Stmt>>) {
180 let mut all_clusters_sorted = self.clusters.keys().collect::<Vec<_>>();
181 all_clusters_sorted.sort();
182
183 for &c_id in all_clusters_sorted {
184 let self_id_ident = syn::Ident::new(
185 &format!("__hydro_lang_cluster_self_id_{}", c_id),
186 Span::call_site(),
187 );
188 let self_id_expr = D::cluster_self_id().splice_untyped();
189 extra_stmts
190 .entry(c_id)
191 .or_default()
192 .push(syn::parse_quote! {
193 let #self_id_ident = &*Box::leak(Box::new(#self_id_expr));
194 });
195
196 for other_location in self.processes.keys().chain(self.clusters.keys()) {
197 let other_id_ident = syn::Ident::new(
198 &format!("__hydro_lang_cluster_ids_{}", c_id),
199 Span::call_site(),
200 );
201 let other_id_expr = D::cluster_ids(c_id).splice_untyped();
202 extra_stmts
203 .entry(*other_location)
204 .or_default()
205 .push(syn::parse_quote! {
206 let #other_id_ident = #other_id_expr;
207 });
208 }
209 }
210 }
211}
212
213impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
214 #[must_use]
215 pub fn deploy(mut self, env: &mut D::InstantiateEnv) -> DeployResult<'a, D> {
216 let mut seen_tees_instantiate: HashMap<_, _> = HashMap::new();
217 let mut extra_stmts = BTreeMap::new();
218 self.ir.get_mut().iter_mut().for_each(|leaf| {
219 leaf.compile_network::<D>(
220 &mut extra_stmts,
221 &mut seen_tees_instantiate,
222 &self.processes,
223 &self.clusters,
224 &self.externals,
225 );
226 });
227
228 let mut compiled = build_inner::<D>(self.ir.get_mut());
229 self.cluster_id_stmts(&mut extra_stmts);
230 let mut meta = D::Meta::default();
231
232 let (mut processes, mut clusters, mut externals) = (
233 std::mem::take(&mut self.processes)
234 .into_iter()
235 .filter_map(|(node_id, node)| {
236 if let Some(ir) = compiled.remove(&node_id) {
237 node.instantiate(
238 env,
239 &mut meta,
240 ir,
241 extra_stmts.remove(&node_id).unwrap_or_default(),
242 );
243 Some((node_id, node))
244 } else {
245 None
246 }
247 })
248 .collect::<HashMap<_, _>>(),
249 std::mem::take(&mut self.clusters)
250 .into_iter()
251 .filter_map(|(cluster_id, cluster)| {
252 if let Some(ir) = compiled.remove(&cluster_id) {
253 cluster.instantiate(
254 env,
255 &mut meta,
256 ir,
257 extra_stmts.remove(&cluster_id).unwrap_or_default(),
258 );
259 Some((cluster_id, cluster))
260 } else {
261 None
262 }
263 })
264 .collect::<HashMap<_, _>>(),
265 std::mem::take(&mut self.externals)
266 .into_iter()
267 .map(|(external_id, external)| {
268 external.instantiate(
269 env,
270 &mut meta,
271 Default::default(),
272 extra_stmts.remove(&external_id).unwrap_or_default(),
273 );
274 (external_id, external)
275 })
276 .collect::<HashMap<_, _>>(),
277 );
278
279 for node in processes.values_mut() {
280 node.update_meta(&meta);
281 }
282
283 for cluster in clusters.values_mut() {
284 cluster.update_meta(&meta);
285 }
286
287 for external in externals.values_mut() {
288 external.update_meta(&meta);
289 }
290
291 let mut seen_tees_connect = HashMap::new();
292 self.ir.get_mut().iter_mut().for_each(|leaf| {
293 leaf.connect_network(&mut seen_tees_connect);
294 });
295
296 DeployResult {
297 processes,
298 clusters,
299 externals,
300 cluster_id_name: std::mem::take(&mut self.cluster_id_name)
301 .into_iter()
302 .collect(),
303 process_id_name: std::mem::take(&mut self.process_id_name)
304 .into_iter()
305 .collect(),
306 }
307 }
308}
309
310pub struct DeployResult<'a, D: Deploy<'a>> {
311 processes: HashMap<usize, D::Process>,
312 clusters: HashMap<usize, D::Cluster>,
313 externals: HashMap<usize, D::External>,
314 cluster_id_name: HashMap<usize, String>,
315 process_id_name: HashMap<usize, String>,
316}
317
318impl<'a, D: Deploy<'a>> DeployResult<'a, D> {
319 pub fn get_process<P>(&self, p: &Process<P>) -> &D::Process {
320 let id = match p.id() {
321 LocationId::Process(id) => id,
322 _ => panic!("Process ID expected"),
323 };
324
325 self.processes.get(&id).unwrap()
326 }
327
328 pub fn get_cluster<C>(&self, c: &Cluster<'a, C>) -> &D::Cluster {
329 let id = match c.id() {
330 LocationId::Cluster(id) => id,
331 _ => panic!("Cluster ID expected"),
332 };
333
334 self.clusters.get(&id).unwrap()
335 }
336
337 pub fn get_all_clusters(&self) -> impl Iterator<Item = (LocationId, String, &D::Cluster)> {
338 self.clusters.iter().map(|(&id, c)| {
339 (
340 LocationId::Cluster(id),
341 self.cluster_id_name.get(&id).unwrap().clone(),
342 c,
343 )
344 })
345 }
346
347 pub fn get_all_processes(&self) -> impl Iterator<Item = (LocationId, String, &D::Process)> {
348 self.processes.iter().map(|(&id, p)| {
349 (
350 LocationId::Process(id),
351 self.process_id_name.get(&id).unwrap().clone(),
352 p,
353 )
354 })
355 }
356
357 pub fn get_external<P>(&self, p: &External<P>) -> &D::External {
358 self.externals.get(&p.id).unwrap()
359 }
360
361 pub fn raw_port<M>(&self, port: ExternalBytesPort<M>) -> D::ExternalRawPort {
362 self.externals
363 .get(&port.process_id)
364 .unwrap()
365 .raw_port(port.port_id)
366 }
367
368 #[deprecated(note = "use `connect` instead")]
369 pub async fn connect_bytes<M>(
370 &self,
371 port: ExternalBytesPort<M>,
372 ) -> (
373 Pin<Box<dyn Stream<Item = Result<BytesMut, Error>>>>,
374 Pin<Box<dyn Sink<Bytes, Error = Error>>>,
375 ) {
376 self.connect(port).await
377 }
378
379 #[deprecated(note = "use `connect` instead")]
380 pub async fn connect_sink_bytes<M>(
381 &self,
382 port: ExternalBytesPort<M>,
383 ) -> Pin<Box<dyn Sink<Bytes, Error = Error>>> {
384 self.connect(port).await.1
385 }
386
387 pub async fn connect_bincode<
388 InT: Serialize + 'static,
389 OutT: DeserializeOwned + 'static,
390 Many,
391 >(
392 &self,
393 port: ExternalBincodeBidi<InT, OutT, Many>,
394 ) -> (
395 Pin<Box<dyn Stream<Item = OutT>>>,
396 Pin<Box<dyn Sink<InT, Error = Error>>>,
397 ) {
398 self.externals
399 .get(&port.process_id)
400 .unwrap()
401 .as_bincode_bidi(port.port_id)
402 .await
403 }
404
405 #[deprecated(note = "use `connect` instead")]
406 pub async fn connect_sink_bincode<T: Serialize + DeserializeOwned + 'static, Many>(
407 &self,
408 port: ExternalBincodeSink<T, Many>,
409 ) -> Pin<Box<dyn Sink<T, Error = Error>>> {
410 self.connect(port).await
411 }
412
413 #[deprecated(note = "use `connect` instead")]
414 pub async fn connect_source_bytes(
415 &self,
416 port: ExternalBytesPort,
417 ) -> Pin<Box<dyn Stream<Item = Result<BytesMut, Error>>>> {
418 self.connect(port).await.0
419 }
420
421 #[deprecated(note = "use `connect` instead")]
422 pub async fn connect_source_bincode<
423 T: Serialize + DeserializeOwned + 'static,
424 O: Ordering,
425 R: Retries,
426 >(
427 &self,
428 port: ExternalBincodeStream<T, O, R>,
429 ) -> Pin<Box<dyn Stream<Item = T>>> {
430 self.connect(port).await
431 }
432
433 pub async fn connect<'b, P: ConnectableAsync<&'b Self>>(
434 &'b self,
435 port: P,
436 ) -> <P as ConnectableAsync<&'b Self>>::Output {
437 port.connect(self).await
438 }
439}
440
441pub trait ConnectableAsync<Ctx> {
442 type Output;
443
444 fn connect(self, ctx: Ctx) -> impl Future<Output = Self::Output>;
445}
446
447impl<'a, D: Deploy<'a>, M> ConnectableAsync<&DeployResult<'a, D>> for ExternalBytesPort<M> {
448 type Output = (
449 Pin<Box<dyn Stream<Item = Result<BytesMut, Error>>>>,
450 Pin<Box<dyn Sink<Bytes, Error = Error>>>,
451 );
452
453 async fn connect(self, ctx: &DeployResult<'a, D>) -> Self::Output {
454 ctx.externals
455 .get(&self.process_id)
456 .unwrap()
457 .as_bytes_bidi(self.port_id)
458 .await
459 }
460}
461
462impl<'a, D: Deploy<'a>, T: DeserializeOwned + 'static, O: Ordering, R: Retries>
463 ConnectableAsync<&DeployResult<'a, D>> for ExternalBincodeStream<T, O, R>
464{
465 type Output = Pin<Box<dyn Stream<Item = T>>>;
466
467 async fn connect(self, ctx: &DeployResult<'a, D>) -> Self::Output {
468 ctx.externals
469 .get(&self.process_id)
470 .unwrap()
471 .as_bincode_source(self.port_id)
472 .await
473 }
474}
475
476impl<'a, D: Deploy<'a>, T: Serialize + 'static, Many> ConnectableAsync<&DeployResult<'a, D>>
477 for ExternalBincodeSink<T, Many>
478{
479 type Output = Pin<Box<dyn Sink<T, Error = Error>>>;
480
481 async fn connect(self, ctx: &DeployResult<'a, D>) -> Self::Output {
482 ctx.externals
483 .get(&self.process_id)
484 .unwrap()
485 .as_bincode_sink(self.port_id)
486 .await
487 }
488}