1#![allow(clippy::allow_attributes, missing_docs, reason = "// TODO(mingwei)")]
3
4use std::cell::RefCell;
5use std::collections::HashMap;
6
7use futures::StreamExt;
8use futures::stream::FuturesUnordered;
9pub use hydro_deploy_integration::*;
10use serde::de::DeserializeOwned;
11
12use crate::scheduled::graph::Dfir;
13
14#[macro_export]
15macro_rules! launch {
16 ($f:expr) => {
17 async {
18 let ports = $crate::util::deploy::init_no_ack_start().await;
19 let flow = $f(&ports);
20
21 println!("ack start");
22
23 $crate::util::deploy::launch_flow(flow).await
24 }
25 };
26}
27
28pub use crate::launch;
29
30pub async fn launch_flow(mut flow: Dfir<'_>) {
31 let stop = tokio::sync::oneshot::channel();
32 tokio::task::spawn_blocking(|| {
33 let mut line = String::new();
34 std::io::stdin().read_line(&mut line).unwrap();
35 if line.starts_with("stop") {
36 stop.0.send(()).unwrap();
37 } else {
38 eprintln!("Unexpected stdin input: {:?}", line);
39 }
40 });
41
42 let local_set = tokio::task::LocalSet::new();
43 let flow = local_set.run_until(flow.run_async());
44
45 tokio::select! {
46 _ = stop.1 => {},
47 _ = flow => {}
48 }
49}
50
51pub async fn init_no_ack_start<T: DeserializeOwned + Default>() -> DeployPorts<T> {
52 let mut input = String::new();
53 std::io::stdin().read_line(&mut input).unwrap();
54 let trimmed = input.trim();
55
56 let bind_config = serde_json::from_str::<InitConfig>(trimmed).unwrap();
57
58 let mut bind_results: HashMap<String, ServerPort> = HashMap::new();
60 let mut binds = HashMap::new();
61 for (name, config) in bind_config.0 {
62 let bound = config.bind().await;
63 bind_results.insert(name.clone(), bound.server_port());
64 binds.insert(name.clone(), bound);
65 }
66
67 let bind_serialized = serde_json::to_string(&bind_results).unwrap();
68 println!("ready: {bind_serialized}");
69
70 let mut start_buf = String::new();
71 std::io::stdin().read_line(&mut start_buf).unwrap();
72 let connection_defns = if start_buf.starts_with("start: ") {
73 serde_json::from_str::<HashMap<String, ServerPort>>(
74 start_buf.trim_start_matches("start: ").trim(),
75 )
76 .unwrap()
77 } else {
78 panic!("expected start");
79 };
80
81 let (client_conns, server_conns) = futures::join!(
82 connection_defns
83 .into_iter()
84 .map(|(name, defn)| async move { (name, Connection::AsClient(defn.connect().await)) })
85 .collect::<FuturesUnordered<_>>()
86 .collect::<Vec<_>>(),
87 binds
88 .into_iter()
89 .map(
90 |(name, defn)| async move { (name, Connection::AsServer(accept_bound(defn).await)) }
91 )
92 .collect::<FuturesUnordered<_>>()
93 .collect::<Vec<_>>()
94 );
95
96 let all_connected = client_conns
97 .into_iter()
98 .chain(server_conns.into_iter())
99 .collect();
100
101 DeployPorts {
102 ports: RefCell::new(all_connected),
103 meta: bind_config
104 .1
105 .map(|b| serde_json::from_str(&b).unwrap())
106 .unwrap_or_default(),
107 }
108}
109
110pub async fn init<T: DeserializeOwned + Default>() -> DeployPorts<T> {
111 let ret = init_no_ack_start::<T>().await;
112
113 println!("ack start");
114
115 ret
116}