dfir_rs/util/
deploy.rs
1#![allow(clippy::allow_attributes, missing_docs, reason = "// TODO(mingwei)")]
3
4use std::cell::RefCell;
5use std::collections::HashMap;
6
7pub use hydro_deploy_integration::*;
8use serde::de::DeserializeOwned;
9
10use crate::scheduled::graph::Dfir;
11
12#[macro_export]
13macro_rules! launch {
14 ($f:expr) => {
15 async {
16 let ports = $crate::util::deploy::init_no_ack_start().await;
17 let flow = $f(&ports);
18
19 println!("ack start");
20
21 $crate::util::deploy::launch_flow(flow).await
22 }
23 };
24}
25
26pub use crate::launch;
27
28pub async fn launch_flow(mut flow: Dfir<'_>) {
29 let stop = tokio::sync::oneshot::channel();
30 tokio::task::spawn_blocking(|| {
31 let mut line = String::new();
32 std::io::stdin().read_line(&mut line).unwrap();
33 if line.starts_with("stop") {
34 stop.0.send(()).unwrap();
35 } else {
36 eprintln!("Unexpected stdin input: {:?}", line);
37 }
38 });
39
40 let local_set = tokio::task::LocalSet::new();
41 let flow = local_set.run_until(flow.run_async());
42
43 tokio::select! {
44 _ = stop.1 => {},
45 _ = flow => {}
46 }
47}
48
49pub async fn init_no_ack_start<T: DeserializeOwned + Default>() -> DeployPorts<T> {
50 let mut input = String::new();
51 std::io::stdin().read_line(&mut input).unwrap();
52 let trimmed = input.trim();
53
54 let bind_config = serde_json::from_str::<InitConfig>(trimmed).unwrap();
55
56 let mut bind_results: HashMap<String, ServerPort> = HashMap::new();
58 let mut binds = HashMap::new();
59 for (name, config) in bind_config.0 {
60 let bound = config.bind().await;
61 bind_results.insert(name.clone(), bound.server_port());
62 binds.insert(name.clone(), bound);
63 }
64
65 let bind_serialized = serde_json::to_string(&bind_results).unwrap();
66 println!("ready: {bind_serialized}");
67
68 let mut start_buf = String::new();
69 std::io::stdin().read_line(&mut start_buf).unwrap();
70 let connection_defns = if start_buf.starts_with("start: ") {
71 serde_json::from_str::<HashMap<String, ServerPort>>(
72 start_buf.trim_start_matches("start: ").trim(),
73 )
74 .unwrap()
75 } else {
76 panic!("expected start");
77 };
78
79 let mut all_connected = HashMap::new();
80 for (name, defn) in connection_defns {
81 all_connected.insert(name, Connection::AsClient(defn.connect()));
82 }
83
84 for (name, defn) in binds {
85 all_connected.insert(name, Connection::AsServer(defn));
86 }
87
88 DeployPorts {
89 ports: RefCell::new(all_connected),
90 meta: bind_config
91 .1
92 .map(|b| serde_json::from_str(&b).unwrap())
93 .unwrap_or_default(),
94 }
95}
96
97pub async fn init<T: DeserializeOwned + Default>() -> DeployPorts<T> {
98 let ret = init_no_ack_start::<T>().await;
99
100 println!("ack start");
101
102 ret
103}