dfir_rs/util/
deploy.rs

1//! Hydro Deploy integration for DFIR.
2#![allow(clippy::allow_attributes, missing_docs, reason = "// TODO(mingwei)")]
3
4use std::cell::RefCell;
5use std::collections::HashMap;
6
7pub use hydro_deploy_integration::*;
8use serde::de::DeserializeOwned;
9
10use crate::scheduled::graph::Dfir;
11
12#[macro_export]
13macro_rules! launch {
14    ($f:expr) => {
15        async {
16            let ports = $crate::util::deploy::init_no_ack_start().await;
17            let flow = $f(&ports);
18
19            println!("ack start");
20
21            $crate::util::deploy::launch_flow(flow).await
22        }
23    };
24}
25
26pub use crate::launch;
27
28pub async fn launch_flow(mut flow: Dfir<'_>) {
29    let stop = tokio::sync::oneshot::channel();
30    tokio::task::spawn_blocking(|| {
31        let mut line = String::new();
32        std::io::stdin().read_line(&mut line).unwrap();
33        if line.starts_with("stop") {
34            stop.0.send(()).unwrap();
35        } else {
36            eprintln!("Unexpected stdin input: {:?}", line);
37        }
38    });
39
40    let local_set = tokio::task::LocalSet::new();
41    let flow = local_set.run_until(flow.run_async());
42
43    tokio::select! {
44        _ = stop.1 => {},
45        _ = flow => {}
46    }
47}
48
49/// Contains runtime information passed by Hydro Deploy to a program,
50/// describing how to connect to other services and metadata about them.
51pub struct DeployPorts<T = Option<()>> {
52    ports: RefCell<HashMap<String, Connection>>,
53    pub meta: T,
54}
55
56impl<T> DeployPorts<T> {
57    pub fn port(&self, name: &str) -> Connection {
58        self.ports
59            .try_borrow_mut()
60            .unwrap()
61            .remove(name)
62            .unwrap_or_else(|| panic!("port {} not found", name))
63    }
64}
65
66pub async fn init_no_ack_start<T: DeserializeOwned + Default>() -> DeployPorts<T> {
67    let mut input = String::new();
68    std::io::stdin().read_line(&mut input).unwrap();
69    let trimmed = input.trim();
70
71    let bind_config = serde_json::from_str::<InitConfig>(trimmed).unwrap();
72
73    // config telling other services how to connect to me
74    let mut bind_results: HashMap<String, ServerPort> = HashMap::new();
75    let mut binds = HashMap::new();
76    for (name, config) in bind_config.0 {
77        let bound = config.bind().await;
78        bind_results.insert(name.clone(), bound.server_port());
79        binds.insert(name.clone(), bound);
80    }
81
82    let bind_serialized = serde_json::to_string(&bind_results).unwrap();
83    println!("ready: {bind_serialized}");
84
85    let mut start_buf = String::new();
86    std::io::stdin().read_line(&mut start_buf).unwrap();
87    let connection_defns = if start_buf.starts_with("start: ") {
88        serde_json::from_str::<HashMap<String, ServerPort>>(
89            start_buf.trim_start_matches("start: ").trim(),
90        )
91        .unwrap()
92    } else {
93        panic!("expected start");
94    };
95
96    let mut all_connected = HashMap::new();
97    for (name, defn) in connection_defns {
98        all_connected.insert(name, Connection::AsClient(defn.connect()));
99    }
100
101    for (name, defn) in binds {
102        all_connected.insert(name, Connection::AsServer(defn));
103    }
104
105    DeployPorts {
106        ports: RefCell::new(all_connected),
107        meta: bind_config
108            .1
109            .map(|b| serde_json::from_str(&b).unwrap())
110            .unwrap_or_default(),
111    }
112}
113
114pub async fn init<T: DeserializeOwned + Default>() -> DeployPorts<T> {
115    let ret = init_no_ack_start::<T>().await;
116
117    println!("ack start");
118
119    ret
120}