dfir_rs/util/
deploy.rs

1//! Hydro Deploy integration for DFIR.
2#![allow(clippy::allow_attributes, missing_docs, reason = "// TODO(mingwei)")]
3
4use std::cell::RefCell;
5use std::collections::HashMap;
6
7use futures::StreamExt;
8use futures::stream::FuturesUnordered;
9pub use hydro_deploy_integration::*;
10use serde::de::DeserializeOwned;
11
12use crate::scheduled::graph::Dfir;
13
14#[macro_export]
15macro_rules! launch {
16    ($f:expr) => {
17        async {
18            let ports = $crate::util::deploy::init_no_ack_start().await;
19            let flow = $f(&ports);
20
21            println!("ack start");
22
23            $crate::util::deploy::launch_flow(flow).await
24        }
25    };
26}
27
28pub use crate::launch;
29
30pub async fn launch_flow(mut flow: Dfir<'_>) {
31    let stop = tokio::sync::oneshot::channel();
32    tokio::task::spawn_blocking(|| {
33        let mut line = String::new();
34        std::io::stdin().read_line(&mut line).unwrap();
35        if line.starts_with("stop") {
36            stop.0.send(()).unwrap();
37        } else {
38            eprintln!("Unexpected stdin input: {:?}", line);
39        }
40    });
41
42    let local_set = tokio::task::LocalSet::new();
43    let flow = local_set.run_until(flow.run_async());
44
45    tokio::select! {
46        _ = stop.1 => {},
47        _ = flow => {}
48    }
49}
50
51pub async fn init_no_ack_start<T: DeserializeOwned + Default>() -> DeployPorts<T> {
52    let mut input = String::new();
53    std::io::stdin().read_line(&mut input).unwrap();
54    let trimmed = input.trim();
55
56    let bind_config = serde_json::from_str::<InitConfig>(trimmed).unwrap();
57
58    // config telling other services how to connect to me
59    let mut bind_results: HashMap<String, ServerPort> = HashMap::new();
60    let mut binds = HashMap::new();
61    for (name, config) in bind_config.0 {
62        let bound = config.bind().await;
63        bind_results.insert(name.clone(), bound.server_port());
64        binds.insert(name.clone(), bound);
65    }
66
67    let bind_serialized = serde_json::to_string(&bind_results).unwrap();
68    println!("ready: {bind_serialized}");
69
70    let mut start_buf = String::new();
71    std::io::stdin().read_line(&mut start_buf).unwrap();
72    let connection_defns = if start_buf.starts_with("start: ") {
73        serde_json::from_str::<HashMap<String, ServerPort>>(
74            start_buf.trim_start_matches("start: ").trim(),
75        )
76        .unwrap()
77    } else {
78        panic!("expected start");
79    };
80
81    let (client_conns, server_conns) = futures::join!(
82        connection_defns
83            .into_iter()
84            .map(|(name, defn)| async move { (name, Connection::AsClient(defn.connect().await)) })
85            .collect::<FuturesUnordered<_>>()
86            .collect::<Vec<_>>(),
87        binds
88            .into_iter()
89            .map(
90                |(name, defn)| async move { (name, Connection::AsServer(accept_bound(defn).await)) }
91            )
92            .collect::<FuturesUnordered<_>>()
93            .collect::<Vec<_>>()
94    );
95
96    let all_connected = client_conns
97        .into_iter()
98        .chain(server_conns.into_iter())
99        .collect();
100
101    DeployPorts {
102        ports: RefCell::new(all_connected),
103        meta: bind_config
104            .1
105            .map(|b| serde_json::from_str(&b).unwrap())
106            .unwrap_or_default(),
107    }
108}
109
110pub async fn init<T: DeserializeOwned + Default>() -> DeployPorts<T> {
111    let ret = init_no_ack_start::<T>().await;
112
113    println!("ack start");
114
115    ret
116}