Skip to main content

hydro_deploy/rust_crate/
service.rs

1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3use std::time::Duration;
4
5use anyhow::{Context, Result, bail};
6use async_trait::async_trait;
7use futures::Future;
8use hydro_deploy_integration::{InitConfig, ServerPort};
9use memo_map::MemoMap;
10use serde::Serialize;
11use tokio::sync::{OnceCell, RwLock, mpsc};
12
13use super::build::{BuildError, BuildOutput, BuildParams, build_crate_memoized};
14use super::ports::{self, RustCratePortConfig};
15use super::tracing_options::TracingOptions;
16#[cfg(feature = "profile-folding")]
17use crate::TracingResults;
18use crate::progress::ProgressTracker;
19use crate::{
20    BaseServerStrategy, Host, LaunchedBinary, LaunchedHost, PortNetworkHint, ResourceBatch,
21    ResourceResult, ServerStrategy, Service,
22};
23
24pub struct RustCrateService {
25    id: usize,
26    pub(super) on: Arc<dyn Host>,
27    build_params: BuildParams,
28    tracing: Option<TracingOptions>,
29    args: Option<Vec<String>>,
30    display_id: Option<String>,
31    external_ports: Vec<u16>,
32    env: HashMap<String, String>,
33
34    meta: OnceLock<String>,
35
36    /// Configuration for the ports this service will connect to as a client.
37    pub(super) port_to_server: MemoMap<String, ports::ServerConfig>,
38    /// Configuration for the ports that this service will listen on a port for.
39    pub(super) port_to_bind: MemoMap<String, ServerStrategy>,
40
41    launched_host: OnceCell<Arc<dyn LaunchedHost>>,
42
43    /// A map of port names to config for how other services can connect to this one.
44    /// Only valid after `ready` has been called, only contains ports that are configured
45    /// in `server_ports`.
46    pub(super) server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
47
48    launched_binary: OnceCell<Box<dyn LaunchedBinary>>,
49    started: OnceCell<()>,
50}
51
52impl RustCrateService {
53    #[expect(clippy::too_many_arguments, reason = "internal use")]
54    pub fn new(
55        id: usize,
56        on: Arc<dyn Host>,
57        build_params: BuildParams,
58        tracing: Option<TracingOptions>,
59        args: Option<Vec<String>>,
60        display_id: Option<String>,
61        external_ports: Vec<u16>,
62        env: HashMap<String, String>,
63    ) -> Self {
64        Self {
65            id,
66            on,
67            build_params,
68            tracing,
69            args,
70            display_id,
71            external_ports,
72            env,
73            meta: OnceLock::new(),
74            port_to_server: MemoMap::new(),
75            port_to_bind: MemoMap::new(),
76            launched_host: OnceCell::new(),
77            server_defns: Arc::new(RwLock::new(HashMap::new())),
78            launched_binary: OnceCell::new(),
79            started: OnceCell::new(),
80        }
81    }
82
83    pub fn update_meta<T: Serialize>(&self, meta: T) {
84        if self.launched_binary.get().is_some() {
85            panic!("Cannot update meta after binary has been launched")
86        }
87        self.meta
88            .set(serde_json::to_string(&meta).unwrap())
89            .expect("Cannot set meta twice.");
90    }
91
92    pub fn get_port(self: &Arc<Self>, name: String) -> RustCratePortConfig {
93        RustCratePortConfig {
94            service: Arc::downgrade(self),
95            service_host: self.on.clone(),
96            service_server_defns: self.server_defns.clone(),
97            network_hint: PortNetworkHint::Auto,
98            port: name,
99            merge: false,
100        }
101    }
102
103    pub fn get_port_with_hint(
104        self: &Arc<Self>,
105        name: String,
106        network_hint: PortNetworkHint,
107    ) -> RustCratePortConfig {
108        RustCratePortConfig {
109            service: Arc::downgrade(self),
110            service_host: self.on.clone(),
111            service_server_defns: self.server_defns.clone(),
112            network_hint,
113            port: name,
114            merge: false,
115        }
116    }
117
118    pub fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
119        self.launched_binary.get().unwrap().stdout()
120    }
121
122    pub fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
123        self.launched_binary.get().unwrap().stderr()
124    }
125
126    pub fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
127        self.launched_binary.get().unwrap().stdout_filter(prefix)
128    }
129
130    pub fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
131        self.launched_binary.get().unwrap().stderr_filter(prefix)
132    }
133
134    #[cfg(feature = "profile-folding")]
135    pub fn tracing_results(&self) -> Option<&TracingResults> {
136        self.launched_binary.get().unwrap().tracing_results()
137    }
138
139    pub fn exit_code(&self) -> Option<i32> {
140        self.launched_binary.get().unwrap().exit_code()
141    }
142
143    fn build(
144        &self,
145    ) -> impl use<> + 'static + Future<Output = Result<&'static BuildOutput, BuildError>> {
146        // Memoized, so no caching in `self` is needed.
147        build_crate_memoized(self.build_params.clone())
148    }
149}
150
151#[async_trait]
152impl Service for RustCrateService {
153    fn collect_resources(&self, _resource_batch: &mut ResourceBatch) {
154        if self.launched_host.get().is_some() {
155            return;
156        }
157
158        tokio::task::spawn(self.build());
159
160        let host = &self.on;
161
162        host.request_custom_binary();
163        for (_, bind_type) in self.port_to_bind.iter() {
164            host.request_port(bind_type);
165        }
166
167        for port in self.external_ports.iter() {
168            host.request_port_base(&BaseServerStrategy::ExternalTcpPort(*port));
169        }
170    }
171
172    async fn deploy(&self, resource_result: &Arc<ResourceResult>) -> Result<()> {
173        self.launched_host
174            .get_or_try_init::<anyhow::Error, _, _>(|| {
175                ProgressTracker::with_group(
176                    self.display_id
177                        .clone()
178                        .unwrap_or_else(|| format!("service/{}", self.id)),
179                    None,
180                    || async {
181                        let built = self.build().await?;
182
183                        let host = &self.on;
184                        let launched = host.provision(resource_result);
185
186                        launched.copy_binary(built).await?;
187                        Ok(launched)
188                    },
189                )
190            })
191            .await?;
192        Ok(())
193    }
194
195    async fn ready(&self) -> Result<()> {
196        self.launched_binary
197            .get_or_try_init(|| {
198                ProgressTracker::with_group(
199                    self.display_id
200                        .clone()
201                        .unwrap_or_else(|| format!("service/{}", self.id)),
202                    None,
203                    || async {
204                        let launched_host = self.launched_host.get().unwrap();
205
206                        let built = self.build().await?;
207                        let args = self.args.as_ref().cloned().unwrap_or_default();
208
209                        let binary = launched_host
210                            .launch_binary(
211                                self.display_id
212                                    .clone()
213                                    .unwrap_or_else(|| format!("service/{}", self.id)),
214                                built,
215                                &args,
216                                self.tracing.clone(),
217                                &self.env,
218                            )
219                            .await?;
220
221                        let bind_config = self
222                            .port_to_bind
223                            .iter()
224                            .map(|(port_name, bind_type)| {
225                                (port_name.clone(), launched_host.server_config(bind_type))
226                            })
227                            .collect::<HashMap<_, _>>();
228
229                        let formatted_bind_config = serde_json::to_string::<InitConfig>(&(
230                            bind_config,
231                            self.meta.get().map(|s| s.as_str().into()),
232                        ))
233                        .unwrap();
234
235                        // request stdout before sending config so we don't miss the "ready" response
236                        let stdout_receiver = binary.deploy_stdout();
237
238                        binary.stdin().send(format!("{formatted_bind_config}\n"))?;
239
240                        let ready_line = ProgressTracker::leaf(
241                            "waiting for ready",
242                            tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
243                        )
244                        .await
245                        .context("Timed out waiting for ready")?
246                        .context("Program unexpectedly quit")?;
247                        if let Some(line_rest) = ready_line.strip_prefix("ready: ") {
248                            *self.server_defns.try_write().unwrap() =
249                                serde_json::from_str(line_rest).unwrap();
250                        } else {
251                            bail!("expected ready");
252                        }
253                        Ok(binary)
254                    },
255                )
256            })
257            .await?;
258        Ok(())
259    }
260
261    async fn start(&self) -> Result<()> {
262        self.started
263            .get_or_try_init(|| async {
264                let sink_ports_futures =
265                    self.port_to_server
266                        .iter()
267                        .map(|(port_name, outgoing)| async {
268                            (&**port_name, outgoing.load_instantiated(&|p| p).await)
269                        });
270                let sink_ports = futures::future::join_all(sink_ports_futures)
271                    .await
272                    .into_iter()
273                    .collect::<HashMap<_, _>>();
274
275                let formatted_defns = serde_json::to_string(&sink_ports).unwrap();
276
277                let stdout_receiver = self.launched_binary.get().unwrap().deploy_stdout();
278
279                self.launched_binary
280                    .get()
281                    .unwrap()
282                    .stdin()
283                    .send(format!("start: {formatted_defns}\n"))
284                    .unwrap();
285
286                let start_ack_line = ProgressTracker::leaf(
287                    self.display_id
288                        .clone()
289                        .unwrap_or_else(|| format!("service/{}", self.id))
290                        + " / waiting for ack start",
291                    tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
292                )
293                .await??;
294                if !start_ack_line.starts_with("ack start") {
295                    bail!("expected ack start");
296                }
297
298                Ok(())
299            })
300            .await?;
301
302        Ok(())
303    }
304
305    async fn stop(&self) -> Result<()> {
306        ProgressTracker::with_group(
307            self.display_id
308                .clone()
309                .unwrap_or_else(|| format!("service/{}", self.id)),
310            None,
311            || async {
312                let launched_binary = self.launched_binary.get().unwrap();
313                launched_binary.stdin().send("stop\n".to_owned())?;
314
315                let timeout_result = ProgressTracker::leaf(
316                    "waiting for exit",
317                    tokio::time::timeout(Duration::from_secs(60), launched_binary.wait()),
318                )
319                .await;
320                match timeout_result {
321                    Err(_timeout) => {} // `wait()` timed out, but stop will force quit.
322                    Ok(Err(unexpected_error)) => return Err(unexpected_error), // `wait()` errored.
323                    Ok(Ok(_exit_status)) => {}
324                }
325                launched_binary.stop().await?;
326
327                Ok(())
328            },
329        )
330        .await
331    }
332}