1use std::cell::UnsafeCell;
2use std::collections::{BTreeMap, HashMap};
3use std::io::Error;
4use std::marker::PhantomData;
5use std::pin::Pin;
6
7use bytes::{Bytes, BytesMut};
8use futures::{Sink, Stream};
9use proc_macro2::Span;
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use stageleft::QuotedWithContext;
13
14use super::built::build_inner;
15use super::compiled::CompiledFlow;
16use super::deploy_provider::{
17 ClusterSpec, Deploy, ExternalSpec, IntoProcessSpec, Node, ProcessSpec, RegisterPort,
18};
19use super::ir::HydroRoot;
20use crate::location::dynamic::LocationId;
21use crate::location::external_process::{
22 ExternalBincodeBidi, ExternalBincodeSink, ExternalBincodeStream, ExternalBytesPort,
23};
24use crate::location::{Cluster, External, Location, Process};
25use crate::staging_util::Invariant;
26
27pub struct DeployFlow<'a, D>
28where
29 D: Deploy<'a>,
30{
31 pub(super) ir: UnsafeCell<Vec<HydroRoot>>,
35
36 pub(super) processes: HashMap<usize, D::Process>,
38
39 pub(super) process_id_name: Vec<(usize, String)>,
42
43 pub(super) externals: HashMap<usize, D::External>,
44 pub(super) external_id_name: Vec<(usize, String)>,
45
46 pub(super) clusters: HashMap<usize, D::Cluster>,
47 pub(super) cluster_id_name: Vec<(usize, String)>,
48
49 pub(super) _phantom: Invariant<'a, D>,
50}
51
52impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
53 pub fn ir(&self) -> &Vec<HydroRoot> {
54 unsafe {
55 &*self.ir.get()
57 }
58 }
59
60 pub fn with_process_id_name(
61 mut self,
62 process_id: usize,
63 process_name: String,
64 spec: impl IntoProcessSpec<'a, D>,
65 ) -> Self {
66 self.processes.insert(
67 process_id,
68 spec.into_process_spec().build(process_id, &process_name),
69 );
70 self
71 }
72
73 pub fn with_process<P>(self, process: &Process<P>, spec: impl IntoProcessSpec<'a, D>) -> Self {
74 self.with_process_id_name(process.id, std::any::type_name::<P>().to_string(), spec)
75 }
76
77 pub fn with_remaining_processes<S: IntoProcessSpec<'a, D> + 'a>(
78 mut self,
79 spec: impl Fn() -> S,
80 ) -> Self {
81 for (id, name) in &self.process_id_name {
82 self.processes
83 .insert(*id, spec().into_process_spec().build(*id, name));
84 }
85
86 self
87 }
88
89 pub fn with_external<P>(
90 mut self,
91 process: &External<P>,
92 spec: impl ExternalSpec<'a, D>,
93 ) -> Self {
94 let tag_name = std::any::type_name::<P>().to_string();
95 self.externals
96 .insert(process.id, spec.build(process.id, &tag_name));
97 self
98 }
99
100 pub fn with_remaining_externals<S: ExternalSpec<'a, D> + 'a>(
101 mut self,
102 spec: impl Fn() -> S,
103 ) -> Self {
104 for (id, name) in &self.external_id_name {
105 self.externals.insert(*id, spec().build(*id, name));
106 }
107
108 self
109 }
110
111 pub fn with_cluster_id_name(
112 mut self,
113 cluster_id: usize,
114 cluster_name: String,
115 spec: impl ClusterSpec<'a, D>,
116 ) -> Self {
117 self.clusters
118 .insert(cluster_id, spec.build(cluster_id, &cluster_name));
119 self
120 }
121
122 pub fn with_cluster<C>(self, cluster: &Cluster<C>, spec: impl ClusterSpec<'a, D>) -> Self {
123 self.with_cluster_id_name(cluster.id, std::any::type_name::<C>().to_string(), spec)
124 }
125
126 pub fn with_remaining_clusters<S: ClusterSpec<'a, D> + 'a>(
127 mut self,
128 spec: impl Fn() -> S,
129 ) -> Self {
130 for (id, name) in &self.cluster_id_name {
131 self.clusters.insert(*id, spec().build(*id, name));
132 }
133
134 self
135 }
136
137 pub fn preview_compile(&self) -> CompiledFlow<'a, ()> {
140 CompiledFlow {
141 dfir: build_inner(unsafe {
142 &mut *self.ir.get()
145 }),
146 _phantom: PhantomData,
147 }
148 }
149
150 pub fn compile_no_network(mut self) -> CompiledFlow<'a, D::GraphId> {
151 CompiledFlow {
152 dfir: build_inner(self.ir.get_mut()),
153 _phantom: PhantomData,
154 }
155 }
156}
157
158impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
159 pub fn compile(mut self, env: &D::CompileEnv) -> CompiledFlow<'a, D::GraphId> {
160 let mut seen_tees: HashMap<_, _> = HashMap::new();
161 let mut extra_stmts = BTreeMap::new();
162 self.ir.get_mut().iter_mut().for_each(|leaf| {
163 leaf.compile_network::<D>(
164 env,
165 &mut extra_stmts,
166 &mut seen_tees,
167 &self.processes,
168 &self.clusters,
169 &self.externals,
170 );
171 });
172
173 CompiledFlow {
174 dfir: build_inner(self.ir.get_mut()),
175 _phantom: PhantomData,
176 }
177 }
178
179 fn cluster_id_stmts(
180 &self,
181 extra_stmts: &mut BTreeMap<usize, Vec<syn::Stmt>>,
182 env: &<D as Deploy<'a>>::CompileEnv,
183 ) {
184 let mut all_clusters_sorted = self.clusters.keys().collect::<Vec<_>>();
185 all_clusters_sorted.sort();
186
187 for &c_id in all_clusters_sorted {
188 let self_id_ident = syn::Ident::new(
189 &format!("__hydro_lang_cluster_self_id_{}", c_id),
190 Span::call_site(),
191 );
192 let self_id_expr = D::cluster_self_id(env).splice_untyped();
193 extra_stmts
194 .entry(c_id)
195 .or_default()
196 .push(syn::parse_quote! {
197 let #self_id_ident = #self_id_expr;
198 });
199
200 for other_location in self.processes.keys().chain(self.clusters.keys()) {
201 let other_id_ident = syn::Ident::new(
202 &format!("__hydro_lang_cluster_ids_{}", c_id),
203 Span::call_site(),
204 );
205 let other_id_expr = D::cluster_ids(env, c_id).splice_untyped();
206 extra_stmts
207 .entry(*other_location)
208 .or_default()
209 .push(syn::parse_quote! {
210 let #other_id_ident = #other_id_expr;
211 });
212 }
213 }
214 }
215}
216
217impl<'a, D: Deploy<'a, CompileEnv = ()>> DeployFlow<'a, D> {
218 #[must_use]
219 pub fn deploy(mut self, env: &mut D::InstantiateEnv) -> DeployResult<'a, D> {
220 let mut seen_tees_instantiate: HashMap<_, _> = HashMap::new();
221 let mut extra_stmts = BTreeMap::new();
222 self.ir.get_mut().iter_mut().for_each(|leaf| {
223 leaf.compile_network::<D>(
224 &(),
225 &mut extra_stmts,
226 &mut seen_tees_instantiate,
227 &self.processes,
228 &self.clusters,
229 &self.externals,
230 );
231 });
232
233 let mut compiled = build_inner(self.ir.get_mut());
234 self.cluster_id_stmts(&mut extra_stmts, &());
235 let mut meta = D::Meta::default();
236
237 let (mut processes, mut clusters, mut externals) = (
238 std::mem::take(&mut self.processes)
239 .into_iter()
240 .filter_map(|(node_id, node)| {
241 if let Some(ir) = compiled.remove(&node_id) {
242 node.instantiate(
243 env,
244 &mut meta,
245 ir,
246 extra_stmts.remove(&node_id).unwrap_or_default(),
247 );
248 Some((node_id, node))
249 } else {
250 None
251 }
252 })
253 .collect::<HashMap<_, _>>(),
254 std::mem::take(&mut self.clusters)
255 .into_iter()
256 .filter_map(|(cluster_id, cluster)| {
257 if let Some(ir) = compiled.remove(&cluster_id) {
258 cluster.instantiate(
259 env,
260 &mut meta,
261 ir,
262 extra_stmts.remove(&cluster_id).unwrap_or_default(),
263 );
264 Some((cluster_id, cluster))
265 } else {
266 None
267 }
268 })
269 .collect::<HashMap<_, _>>(),
270 std::mem::take(&mut self.externals)
271 .into_iter()
272 .map(|(external_id, external)| {
273 external.instantiate(
274 env,
275 &mut meta,
276 Default::default(),
277 extra_stmts.remove(&external_id).unwrap_or_default(),
278 );
279 (external_id, external)
280 })
281 .collect::<HashMap<_, _>>(),
282 );
283
284 for node in processes.values_mut() {
285 node.update_meta(&meta);
286 }
287
288 for cluster in clusters.values_mut() {
289 cluster.update_meta(&meta);
290 }
291
292 for external in externals.values_mut() {
293 external.update_meta(&meta);
294 }
295
296 let mut seen_tees_connect = HashMap::new();
297 self.ir.get_mut().iter_mut().for_each(|leaf| {
298 leaf.connect_network(&mut seen_tees_connect);
299 });
300
301 DeployResult {
302 processes,
303 clusters,
304 externals,
305 cluster_id_name: std::mem::take(&mut self.cluster_id_name)
306 .into_iter()
307 .collect(),
308 process_id_name: std::mem::take(&mut self.process_id_name)
309 .into_iter()
310 .collect(),
311 }
312 }
313}
314
315pub struct DeployResult<'a, D: Deploy<'a>> {
316 processes: HashMap<usize, D::Process>,
317 clusters: HashMap<usize, D::Cluster>,
318 externals: HashMap<usize, D::External>,
319 cluster_id_name: HashMap<usize, String>,
320 process_id_name: HashMap<usize, String>,
321}
322
323impl<'a, D: Deploy<'a>> DeployResult<'a, D> {
324 pub fn get_process<P>(&self, p: &Process<P>) -> &D::Process {
325 let id = match p.id() {
326 LocationId::Process(id) => id,
327 _ => panic!("Process ID expected"),
328 };
329
330 self.processes.get(&id).unwrap()
331 }
332
333 pub fn get_cluster<C>(&self, c: &Cluster<'a, C>) -> &D::Cluster {
334 let id = match c.id() {
335 LocationId::Cluster(id) => id,
336 _ => panic!("Cluster ID expected"),
337 };
338
339 self.clusters.get(&id).unwrap()
340 }
341
342 pub fn get_all_clusters(&self) -> impl Iterator<Item = (LocationId, String, &D::Cluster)> {
343 self.clusters.iter().map(|(&id, c)| {
344 (
345 LocationId::Cluster(id),
346 self.cluster_id_name.get(&id).unwrap().clone(),
347 c,
348 )
349 })
350 }
351
352 pub fn get_all_processes(&self) -> impl Iterator<Item = (LocationId, String, &D::Process)> {
353 self.processes.iter().map(|(&id, p)| {
354 (
355 LocationId::Process(id),
356 self.process_id_name.get(&id).unwrap().clone(),
357 p,
358 )
359 })
360 }
361
362 pub fn get_external<P>(&self, p: &External<P>) -> &D::External {
363 self.externals.get(&p.id).unwrap()
364 }
365
366 pub fn raw_port<M>(&self, port: ExternalBytesPort<M>) -> D::ExternalRawPort {
367 self.externals
368 .get(&port.process_id)
369 .unwrap()
370 .raw_port(port.port_id)
371 }
372
373 pub async fn connect_bytes<M>(
374 &self,
375 port: ExternalBytesPort<M>,
376 ) -> (
377 Pin<Box<dyn Stream<Item = Result<BytesMut, Error>>>>,
378 Pin<Box<dyn Sink<Bytes, Error = Error>>>,
379 ) {
380 self.externals
381 .get(&port.process_id)
382 .unwrap()
383 .as_bytes_bidi(port.port_id)
384 .await
385 }
386
387 pub async fn connect_sink_bytes<M>(
388 &self,
389 port: ExternalBytesPort<M>,
390 ) -> Pin<Box<dyn Sink<Bytes, Error = Error>>> {
391 self.connect_bytes(port).await.1
392 }
393
394 pub async fn connect_bincode<
395 InT: Serialize + 'static,
396 OutT: DeserializeOwned + 'static,
397 Many,
398 >(
399 &self,
400 port: ExternalBincodeBidi<InT, OutT, Many>,
401 ) -> (
402 Pin<Box<dyn Stream<Item = OutT>>>,
403 Pin<Box<dyn Sink<InT, Error = Error>>>,
404 ) {
405 self.externals
406 .get(&port.process_id)
407 .unwrap()
408 .as_bincode_bidi(port.port_id)
409 .await
410 }
411
412 pub async fn connect_sink_bincode<T: Serialize + DeserializeOwned + 'static, Many>(
413 &self,
414 port: ExternalBincodeSink<T, Many>,
415 ) -> Pin<Box<dyn Sink<T, Error = Error>>> {
416 self.externals
417 .get(&port.process_id)
418 .unwrap()
419 .as_bincode_sink(port.port_id)
420 .await
421 }
422
423 pub async fn connect_source_bytes(
424 &self,
425 port: ExternalBytesPort,
426 ) -> Pin<Box<dyn Stream<Item = Result<BytesMut, Error>>>> {
427 self.connect_bytes(port).await.0
428 }
429
430 pub async fn connect_source_bincode<T: Serialize + DeserializeOwned + 'static>(
431 &self,
432 port: ExternalBincodeStream<T>,
433 ) -> Pin<Box<dyn Stream<Item = T>>> {
434 self.externals
435 .get(&port.process_id)
436 .unwrap()
437 .as_bincode_source(port.port_id)
438 .await
439 }
440}