dfir_rs/util/
deploy.rs
1#![allow(clippy::allow_attributes, missing_docs, reason = "// TODO(mingwei)")]
3
4use std::cell::RefCell;
5use std::collections::HashMap;
6
7pub use hydro_deploy_integration::*;
8use serde::de::DeserializeOwned;
9
10use crate::scheduled::graph::Dfir;
11
12#[macro_export]
13macro_rules! launch {
14 ($f:expr) => {
15 async {
16 let ports = $crate::util::deploy::init_no_ack_start().await;
17 let flow = $f(&ports);
18
19 println!("ack start");
20
21 $crate::util::deploy::launch_flow(flow).await
22 }
23 };
24}
25
26pub use crate::launch;
27
28pub async fn launch_flow(mut flow: Dfir<'_>) {
29 let stop = tokio::sync::oneshot::channel();
30 tokio::task::spawn_blocking(|| {
31 let mut line = String::new();
32 std::io::stdin().read_line(&mut line).unwrap();
33 if line.starts_with("stop") {
34 stop.0.send(()).unwrap();
35 } else {
36 eprintln!("Unexpected stdin input: {:?}", line);
37 }
38 });
39
40 let local_set = tokio::task::LocalSet::new();
41 let flow = local_set.run_until(flow.run_async());
42
43 tokio::select! {
44 _ = stop.1 => {},
45 _ = flow => {}
46 }
47}
48
49pub struct DeployPorts<T = Option<()>> {
52 ports: RefCell<HashMap<String, Connection>>,
53 pub meta: T,
54}
55
56impl<T> DeployPorts<T> {
57 pub fn port(&self, name: &str) -> Connection {
58 self.ports
59 .try_borrow_mut()
60 .unwrap()
61 .remove(name)
62 .unwrap_or_else(|| panic!("port {} not found", name))
63 }
64}
65
66pub async fn init_no_ack_start<T: DeserializeOwned + Default>() -> DeployPorts<T> {
67 let mut input = String::new();
68 std::io::stdin().read_line(&mut input).unwrap();
69 let trimmed = input.trim();
70
71 let bind_config = serde_json::from_str::<InitConfig>(trimmed).unwrap();
72
73 let mut bind_results: HashMap<String, ServerPort> = HashMap::new();
75 let mut binds = HashMap::new();
76 for (name, config) in bind_config.0 {
77 let bound = config.bind().await;
78 bind_results.insert(name.clone(), bound.server_port());
79 binds.insert(name.clone(), bound);
80 }
81
82 let bind_serialized = serde_json::to_string(&bind_results).unwrap();
83 println!("ready: {bind_serialized}");
84
85 let mut start_buf = String::new();
86 std::io::stdin().read_line(&mut start_buf).unwrap();
87 let connection_defns = if start_buf.starts_with("start: ") {
88 serde_json::from_str::<HashMap<String, ServerPort>>(
89 start_buf.trim_start_matches("start: ").trim(),
90 )
91 .unwrap()
92 } else {
93 panic!("expected start");
94 };
95
96 let mut all_connected = HashMap::new();
97 for (name, defn) in connection_defns {
98 all_connected.insert(name, Connection::AsClient(defn.connect()));
99 }
100
101 for (name, defn) in binds {
102 all_connected.insert(name, Connection::AsServer(defn));
103 }
104
105 DeployPorts {
106 ports: RefCell::new(all_connected),
107 meta: bind_config
108 .1
109 .map(|b| serde_json::from_str(&b).unwrap())
110 .unwrap_or_default(),
111 }
112}
113
114pub async fn init<T: DeserializeOwned + Default>() -> DeployPorts<T> {
115 let ret = init_no_ack_start::<T>().await;
116
117 println!("ack start");
118
119 ret
120}