hydro_deploy/rust_crate/
service.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4use std::time::Duration;
5
6use anyhow::{Context, Result, bail};
7use async_trait::async_trait;
8use futures::Future;
9use hydro_deploy_integration::{InitConfig, ServerPort};
10use serde::Serialize;
11use tokio::sync::{RwLock, mpsc};
12
13use super::build::{BuildError, BuildOutput, BuildParams, build_crate_memoized};
14use super::ports::{self, RustCratePortConfig, RustCrateSink, SourcePath};
15use super::tracing_options::TracingOptions;
16use crate::progress::ProgressTracker;
17use crate::{
18    Host, LaunchedBinary, LaunchedHost, ResourceBatch, ResourceResult, ServerStrategy, Service,
19    TracingResults,
20};
21
22pub struct RustCrateService {
23    id: usize,
24    pub(super) on: Arc<dyn Host>,
25    build_params: BuildParams,
26    tracing: Option<TracingOptions>,
27    args: Option<Vec<String>>,
28    display_id: Option<String>,
29    external_ports: Vec<u16>,
30
31    meta: Option<String>,
32
33    /// Configuration for the ports this service will connect to as a client.
34    pub(super) port_to_server: HashMap<String, ports::ServerConfig>,
35
36    /// Configuration for the ports that this service will listen on a port for.
37    pub(super) port_to_bind: HashMap<String, ServerStrategy>,
38
39    launched_host: Option<Arc<dyn LaunchedHost>>,
40
41    /// A map of port names to config for how other services can connect to this one.
42    /// Only valid after `ready` has been called, only contains ports that are configured
43    /// in `server_ports`.
44    pub(super) server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
45
46    launched_binary: Option<Box<dyn LaunchedBinary>>,
47    started: bool,
48}
49
50impl RustCrateService {
51    #[expect(clippy::too_many_arguments, reason = "internal code")]
52    pub fn new(
53        id: usize,
54        src: PathBuf,
55        on: Arc<dyn Host>,
56        bin: Option<String>,
57        example: Option<String>,
58        profile: Option<String>,
59        rustflags: Option<String>,
60        target_dir: Option<PathBuf>,
61        no_default_features: bool,
62        tracing: Option<TracingOptions>,
63        features: Option<Vec<String>>,
64        args: Option<Vec<String>>,
65        display_id: Option<String>,
66        external_ports: Vec<u16>,
67    ) -> Self {
68        let target_type = on.target_type();
69
70        let build_params = BuildParams::new(
71            src,
72            bin,
73            example,
74            profile,
75            rustflags,
76            target_dir,
77            no_default_features,
78            target_type,
79            features,
80        );
81
82        Self {
83            id,
84            on,
85            build_params,
86            tracing,
87            args,
88            display_id,
89            external_ports,
90            meta: None,
91            port_to_server: HashMap::new(),
92            port_to_bind: HashMap::new(),
93            launched_host: None,
94            server_defns: Arc::new(RwLock::new(HashMap::new())),
95            launched_binary: None,
96            started: false,
97        }
98    }
99
100    pub fn update_meta<T: Serialize>(&mut self, meta: T) {
101        if self.launched_binary.is_some() {
102            panic!("Cannot update meta after binary has been launched")
103        }
104
105        self.meta = Some(serde_json::to_string(&meta).unwrap());
106    }
107
108    pub fn get_port(
109        &self,
110        name: String,
111        self_arc: &Arc<RwLock<RustCrateService>>,
112    ) -> RustCratePortConfig {
113        RustCratePortConfig {
114            service: Arc::downgrade(self_arc),
115            service_host: self.on.clone(),
116            service_server_defns: self.server_defns.clone(),
117            port: name,
118            merge: false,
119        }
120    }
121
122    pub fn add_connection(
123        &mut self,
124        self_arc: &Arc<RwLock<RustCrateService>>,
125        my_port: String,
126        sink: &dyn RustCrateSink,
127    ) -> Result<()> {
128        let forward_res = sink.instantiate(&SourcePath::Direct(self.on.clone()));
129        if let Ok(instantiated) = forward_res {
130            // TODO(shadaj): if already in this map, we want to broadcast
131            assert!(!self.port_to_server.contains_key(&my_port));
132            self.port_to_server.insert(my_port, instantiated());
133            Ok(())
134        } else {
135            drop(forward_res);
136            let instantiated = sink.instantiate_reverse(
137                &self.on,
138                Arc::new(RustCratePortConfig {
139                    service: Arc::downgrade(self_arc),
140                    service_host: self.on.clone(),
141                    service_server_defns: self.server_defns.clone(),
142                    port: my_port.clone(),
143                    merge: false,
144                }),
145                &|p| p,
146            )?;
147
148            assert!(!self.port_to_bind.contains_key(&my_port));
149            self.port_to_bind
150                .insert(my_port, instantiated(sink.as_any()));
151
152            Ok(())
153        }
154    }
155
156    pub fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
157        self.launched_binary.as_ref().unwrap().stdout()
158    }
159
160    pub fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
161        self.launched_binary.as_ref().unwrap().stderr()
162    }
163
164    pub fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
165        self.launched_binary.as_ref().unwrap().stdout_filter(prefix)
166    }
167
168    pub fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
169        self.launched_binary.as_ref().unwrap().stderr_filter(prefix)
170    }
171
172    pub fn tracing_results(&self) -> Option<&TracingResults> {
173        self.launched_binary.as_ref().unwrap().tracing_results()
174    }
175
176    pub fn exit_code(&self) -> Option<i32> {
177        self.launched_binary.as_ref().unwrap().exit_code()
178    }
179
180    fn build(
181        &self,
182    ) -> impl use<> + 'static + Future<Output = Result<&'static BuildOutput, BuildError>> {
183        // Memoized, so no caching in `self` is needed.
184        build_crate_memoized(self.build_params.clone())
185    }
186}
187
188#[async_trait]
189impl Service for RustCrateService {
190    fn collect_resources(&self, _resource_batch: &mut ResourceBatch) {
191        if self.launched_host.is_some() {
192            return;
193        }
194
195        tokio::task::spawn(self.build());
196
197        let host = &self.on;
198
199        host.request_custom_binary();
200        for (_, bind_type) in self.port_to_bind.iter() {
201            host.request_port(bind_type);
202        }
203
204        for port in self.external_ports.iter() {
205            host.request_port(&ServerStrategy::ExternalTcpPort(*port));
206        }
207    }
208
209    async fn deploy(&mut self, resource_result: &Arc<ResourceResult>) -> Result<()> {
210        if self.launched_host.is_some() {
211            return Ok(());
212        }
213
214        ProgressTracker::with_group(
215            self.display_id
216                .clone()
217                .unwrap_or_else(|| format!("service/{}", self.id)),
218            None,
219            || async {
220                let built = ProgressTracker::leaf("build", self.build()).await?;
221
222                let host = &self.on;
223                let launched = host.provision(resource_result);
224
225                launched.copy_binary(built).await?;
226
227                self.launched_host = Some(launched);
228                Ok(())
229            },
230        )
231        .await
232    }
233
234    async fn ready(&mut self) -> Result<()> {
235        if self.launched_binary.is_some() {
236            return Ok(());
237        }
238
239        ProgressTracker::with_group(
240            self.display_id
241                .clone()
242                .unwrap_or_else(|| format!("service/{}", self.id)),
243            None,
244            || async {
245                let launched_host = self.launched_host.as_ref().unwrap();
246
247                let built = self.build().await?;
248                let args = self.args.as_ref().cloned().unwrap_or_default();
249
250                let binary = launched_host
251                    .launch_binary(
252                        self.display_id
253                            .clone()
254                            .unwrap_or_else(|| format!("service/{}", self.id)),
255                        built,
256                        &args,
257                        self.tracing.clone(),
258                    )
259                    .await?;
260
261                let mut bind_config = HashMap::new();
262                for (port_name, bind_type) in self.port_to_bind.iter() {
263                    bind_config.insert(port_name.clone(), launched_host.server_config(bind_type));
264                }
265
266                let formatted_bind_config =
267                    serde_json::to_string::<InitConfig>(&(bind_config, self.meta.clone())).unwrap();
268
269                // request stdout before sending config so we don't miss the "ready" response
270                let stdout_receiver = binary.deploy_stdout();
271
272                binary.stdin().send(format!("{formatted_bind_config}\n"))?;
273
274                let ready_line = ProgressTracker::leaf(
275                    "waiting for ready",
276                    tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
277                )
278                .await
279                .context("Timed out waiting for ready")?
280                .context("Program unexpectedly quit")?;
281                if ready_line.starts_with("ready: ") {
282                    *self.server_defns.try_write().unwrap() =
283                        serde_json::from_str(ready_line.trim_start_matches("ready: ")).unwrap();
284                } else {
285                    bail!("expected ready");
286                }
287
288                self.launched_binary = Some(binary);
289
290                Ok(())
291            },
292        )
293        .await
294    }
295
296    async fn start(&mut self) -> Result<()> {
297        if self.started {
298            return Ok(());
299        }
300
301        let mut sink_ports = HashMap::new();
302        for (port_name, outgoing) in self.port_to_server.drain() {
303            sink_ports.insert(port_name.clone(), outgoing.load_instantiated(&|p| p).await);
304        }
305
306        let formatted_defns = serde_json::to_string(&sink_ports).unwrap();
307
308        let stdout_receiver = self.launched_binary.as_ref().unwrap().deploy_stdout();
309
310        self.launched_binary
311            .as_ref()
312            .unwrap()
313            .stdin()
314            .send(format!("start: {formatted_defns}\n"))
315            .unwrap();
316
317        let start_ack_line = ProgressTracker::leaf(
318            self.display_id
319                .clone()
320                .unwrap_or_else(|| format!("service/{}", self.id))
321                + " / waiting for ack start",
322            tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
323        )
324        .await??;
325        if !start_ack_line.starts_with("ack start") {
326            bail!("expected ack start");
327        }
328
329        self.started = true;
330        Ok(())
331    }
332
333    async fn stop(&mut self) -> Result<()> {
334        ProgressTracker::with_group(
335            self.display_id
336                .clone()
337                .unwrap_or_else(|| format!("service/{}", self.id)),
338            None,
339            || async {
340                let launched_binary = self.launched_binary.as_mut().unwrap();
341                launched_binary.stdin().send("stop\n".to_string())?;
342
343                let timeout_result = ProgressTracker::leaf(
344                    "waiting for exit",
345                    tokio::time::timeout(Duration::from_secs(60), launched_binary.wait()),
346                )
347                .await;
348                match timeout_result {
349                    Err(_timeout) => {} // `wait()` timed out, but stop will force quit.
350                    Ok(Err(unexpected_error)) => return Err(unexpected_error), // `wait()` errored.
351                    Ok(Ok(_exit_status)) => {}
352                }
353                launched_binary.stop().await?;
354
355                Ok(())
356            },
357        )
358        .await
359    }
360}