dfir_rs/util/
deploy.rs

1//! Hydro Deploy integration for DFIR.
2#![allow(clippy::allow_attributes, missing_docs, reason = "// TODO(mingwei)")]
3
4use std::cell::RefCell;
5use std::collections::HashMap;
6
7pub use hydro_deploy_integration::*;
8use serde::de::DeserializeOwned;
9
10use crate::scheduled::graph::Dfir;
11
12#[macro_export]
13macro_rules! launch {
14    ($f:expr) => {
15        async {
16            let ports = $crate::util::deploy::init_no_ack_start().await;
17            let flow = $f(&ports);
18
19            println!("ack start");
20
21            $crate::util::deploy::launch_flow(flow).await
22        }
23    };
24}
25
26pub use crate::launch;
27
28pub async fn launch_flow(mut flow: Dfir<'_>) {
29    let stop = tokio::sync::oneshot::channel();
30    tokio::task::spawn_blocking(|| {
31        let mut line = String::new();
32        std::io::stdin().read_line(&mut line).unwrap();
33        if line.starts_with("stop") {
34            stop.0.send(()).unwrap();
35        } else {
36            eprintln!("Unexpected stdin input: {:?}", line);
37        }
38    });
39
40    let local_set = tokio::task::LocalSet::new();
41    let flow = local_set.run_until(flow.run_async());
42
43    tokio::select! {
44        _ = stop.1 => {},
45        _ = flow => {}
46    }
47}
48
49pub async fn init_no_ack_start<T: DeserializeOwned + Default>() -> DeployPorts<T> {
50    let mut input = String::new();
51    std::io::stdin().read_line(&mut input).unwrap();
52    let trimmed = input.trim();
53
54    let bind_config = serde_json::from_str::<InitConfig>(trimmed).unwrap();
55
56    // config telling other services how to connect to me
57    let mut bind_results: HashMap<String, ServerPort> = HashMap::new();
58    let mut binds = HashMap::new();
59    for (name, config) in bind_config.0 {
60        let bound = config.bind().await;
61        bind_results.insert(name.clone(), bound.server_port());
62        binds.insert(name.clone(), bound);
63    }
64
65    let bind_serialized = serde_json::to_string(&bind_results).unwrap();
66    println!("ready: {bind_serialized}");
67
68    let mut start_buf = String::new();
69    std::io::stdin().read_line(&mut start_buf).unwrap();
70    let connection_defns = if start_buf.starts_with("start: ") {
71        serde_json::from_str::<HashMap<String, ServerPort>>(
72            start_buf.trim_start_matches("start: ").trim(),
73        )
74        .unwrap()
75    } else {
76        panic!("expected start");
77    };
78
79    let mut all_connected = HashMap::new();
80    for (name, defn) in connection_defns {
81        all_connected.insert(name, Connection::AsClient(defn.connect()));
82    }
83
84    for (name, defn) in binds {
85        all_connected.insert(name, Connection::AsServer(defn));
86    }
87
88    DeployPorts {
89        ports: RefCell::new(all_connected),
90        meta: bind_config
91            .1
92            .map(|b| serde_json::from_str(&b).unwrap())
93            .unwrap_or_default(),
94    }
95}
96
97pub async fn init<T: DeserializeOwned + Default>() -> DeployPorts<T> {
98    let ret = init_no_ack_start::<T>().await;
99
100    println!("ack start");
101
102    ret
103}