hydro_deploy/rust_crate/
service.rs

1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3use std::time::Duration;
4
5use anyhow::{Context, Result, bail};
6use async_trait::async_trait;
7use futures::Future;
8use hydro_deploy_integration::{InitConfig, ServerPort};
9use memo_map::MemoMap;
10use serde::Serialize;
11use tokio::sync::{OnceCell, RwLock, mpsc};
12
13use super::build::{BuildError, BuildOutput, BuildParams, build_crate_memoized};
14use super::ports::{self, RustCratePortConfig};
15use super::tracing_options::TracingOptions;
16#[cfg(feature = "profile-folding")]
17use crate::TracingResults;
18use crate::progress::ProgressTracker;
19use crate::{
20    BaseServerStrategy, Host, LaunchedBinary, LaunchedHost, PortNetworkHint, ResourceBatch,
21    ResourceResult, ServerStrategy, Service,
22};
23
24pub struct RustCrateService {
25    id: usize,
26    pub(super) on: Arc<dyn Host>,
27    build_params: BuildParams,
28    tracing: Option<TracingOptions>,
29    args: Option<Vec<String>>,
30    display_id: Option<String>,
31    external_ports: Vec<u16>,
32
33    meta: OnceLock<String>,
34
35    /// Configuration for the ports this service will connect to as a client.
36    pub(super) port_to_server: MemoMap<String, ports::ServerConfig>,
37    /// Configuration for the ports that this service will listen on a port for.
38    pub(super) port_to_bind: MemoMap<String, ServerStrategy>,
39
40    launched_host: OnceCell<Arc<dyn LaunchedHost>>,
41
42    /// A map of port names to config for how other services can connect to this one.
43    /// Only valid after `ready` has been called, only contains ports that are configured
44    /// in `server_ports`.
45    pub(super) server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
46
47    launched_binary: OnceCell<Box<dyn LaunchedBinary>>,
48    started: OnceCell<()>,
49}
50
51impl RustCrateService {
52    pub fn new(
53        id: usize,
54        on: Arc<dyn Host>,
55        build_params: BuildParams,
56        tracing: Option<TracingOptions>,
57        args: Option<Vec<String>>,
58        display_id: Option<String>,
59        external_ports: Vec<u16>,
60    ) -> Self {
61        Self {
62            id,
63            on,
64            build_params,
65            tracing,
66            args,
67            display_id,
68            external_ports,
69            meta: OnceLock::new(),
70            port_to_server: MemoMap::new(),
71            port_to_bind: MemoMap::new(),
72            launched_host: OnceCell::new(),
73            server_defns: Arc::new(RwLock::new(HashMap::new())),
74            launched_binary: OnceCell::new(),
75            started: OnceCell::new(),
76        }
77    }
78
79    pub fn update_meta<T: Serialize>(&self, meta: T) {
80        if self.launched_binary.get().is_some() {
81            panic!("Cannot update meta after binary has been launched")
82        }
83        self.meta
84            .set(serde_json::to_string(&meta).unwrap())
85            .expect("Cannot set meta twice.");
86    }
87
88    pub fn get_port(self: &Arc<Self>, name: String) -> RustCratePortConfig {
89        RustCratePortConfig {
90            service: Arc::downgrade(self),
91            service_host: self.on.clone(),
92            service_server_defns: self.server_defns.clone(),
93            network_hint: PortNetworkHint::Auto,
94            port: name,
95            merge: false,
96        }
97    }
98
99    pub fn get_port_with_hint(
100        self: &Arc<Self>,
101        name: String,
102        network_hint: PortNetworkHint,
103    ) -> RustCratePortConfig {
104        RustCratePortConfig {
105            service: Arc::downgrade(self),
106            service_host: self.on.clone(),
107            service_server_defns: self.server_defns.clone(),
108            network_hint,
109            port: name,
110            merge: false,
111        }
112    }
113
114    pub fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
115        self.launched_binary.get().unwrap().stdout()
116    }
117
118    pub fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
119        self.launched_binary.get().unwrap().stderr()
120    }
121
122    pub fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
123        self.launched_binary.get().unwrap().stdout_filter(prefix)
124    }
125
126    pub fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
127        self.launched_binary.get().unwrap().stderr_filter(prefix)
128    }
129
130    #[cfg(feature = "profile-folding")]
131    pub fn tracing_results(&self) -> Option<&TracingResults> {
132        self.launched_binary.get().unwrap().tracing_results()
133    }
134
135    pub fn exit_code(&self) -> Option<i32> {
136        self.launched_binary.get().unwrap().exit_code()
137    }
138
139    fn build(
140        &self,
141    ) -> impl use<> + 'static + Future<Output = Result<&'static BuildOutput, BuildError>> {
142        // Memoized, so no caching in `self` is needed.
143        build_crate_memoized(self.build_params.clone())
144    }
145}
146
147#[async_trait]
148impl Service for RustCrateService {
149    fn collect_resources(&self, _resource_batch: &mut ResourceBatch) {
150        if self.launched_host.get().is_some() {
151            return;
152        }
153
154        tokio::task::spawn(self.build());
155
156        let host = &self.on;
157
158        host.request_custom_binary();
159        for (_, bind_type) in self.port_to_bind.iter() {
160            host.request_port(bind_type);
161        }
162
163        for port in self.external_ports.iter() {
164            host.request_port_base(&BaseServerStrategy::ExternalTcpPort(*port));
165        }
166    }
167
168    async fn deploy(&self, resource_result: &Arc<ResourceResult>) -> Result<()> {
169        self.launched_host
170            .get_or_try_init::<anyhow::Error, _, _>(|| {
171                ProgressTracker::with_group(
172                    self.display_id
173                        .clone()
174                        .unwrap_or_else(|| format!("service/{}", self.id)),
175                    None,
176                    || async {
177                        let built = self.build().await?;
178
179                        let host = &self.on;
180                        let launched = host.provision(resource_result);
181
182                        launched.copy_binary(built).await?;
183                        Ok(launched)
184                    },
185                )
186            })
187            .await?;
188        Ok(())
189    }
190
191    async fn ready(&self) -> Result<()> {
192        self.launched_binary
193            .get_or_try_init(|| {
194                ProgressTracker::with_group(
195                    self.display_id
196                        .clone()
197                        .unwrap_or_else(|| format!("service/{}", self.id)),
198                    None,
199                    || async {
200                        let launched_host = self.launched_host.get().unwrap();
201
202                        let built = self.build().await?;
203                        let args = self.args.as_ref().cloned().unwrap_or_default();
204
205                        let binary = launched_host
206                            .launch_binary(
207                                self.display_id
208                                    .clone()
209                                    .unwrap_or_else(|| format!("service/{}", self.id)),
210                                built,
211                                &args,
212                                self.tracing.clone(),
213                            )
214                            .await?;
215
216                        let bind_config = self
217                            .port_to_bind
218                            .iter()
219                            .map(|(port_name, bind_type)| {
220                                (port_name.clone(), launched_host.server_config(bind_type))
221                            })
222                            .collect::<HashMap<_, _>>();
223
224                        let formatted_bind_config = serde_json::to_string::<InitConfig>(&(
225                            bind_config,
226                            self.meta.get().map(|s| s.as_str().into()),
227                        ))
228                        .unwrap();
229
230                        // request stdout before sending config so we don't miss the "ready" response
231                        let stdout_receiver = binary.deploy_stdout();
232
233                        binary.stdin().send(format!("{formatted_bind_config}\n"))?;
234
235                        let ready_line = ProgressTracker::leaf(
236                            "waiting for ready",
237                            tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
238                        )
239                        .await
240                        .context("Timed out waiting for ready")?
241                        .context("Program unexpectedly quit")?;
242                        if let Some(line_rest) = ready_line.strip_prefix("ready: ") {
243                            *self.server_defns.try_write().unwrap() =
244                                serde_json::from_str(line_rest).unwrap();
245                        } else {
246                            bail!("expected ready");
247                        }
248                        Ok(binary)
249                    },
250                )
251            })
252            .await?;
253        Ok(())
254    }
255
256    async fn start(&self) -> Result<()> {
257        self.started
258            .get_or_try_init(|| async {
259                let sink_ports_futures =
260                    self.port_to_server
261                        .iter()
262                        .map(|(port_name, outgoing)| async {
263                            (&**port_name, outgoing.load_instantiated(&|p| p).await)
264                        });
265                let sink_ports = futures::future::join_all(sink_ports_futures)
266                    .await
267                    .into_iter()
268                    .collect::<HashMap<_, _>>();
269
270                let formatted_defns = serde_json::to_string(&sink_ports).unwrap();
271
272                let stdout_receiver = self.launched_binary.get().unwrap().deploy_stdout();
273
274                self.launched_binary
275                    .get()
276                    .unwrap()
277                    .stdin()
278                    .send(format!("start: {formatted_defns}\n"))
279                    .unwrap();
280
281                let start_ack_line = ProgressTracker::leaf(
282                    self.display_id
283                        .clone()
284                        .unwrap_or_else(|| format!("service/{}", self.id))
285                        + " / waiting for ack start",
286                    tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
287                )
288                .await??;
289                if !start_ack_line.starts_with("ack start") {
290                    bail!("expected ack start");
291                }
292
293                Ok(())
294            })
295            .await?;
296
297        Ok(())
298    }
299
300    async fn stop(&self) -> Result<()> {
301        ProgressTracker::with_group(
302            self.display_id
303                .clone()
304                .unwrap_or_else(|| format!("service/{}", self.id)),
305            None,
306            || async {
307                let launched_binary = self.launched_binary.get().unwrap();
308                launched_binary.stdin().send("stop\n".to_string())?;
309
310                let timeout_result = ProgressTracker::leaf(
311                    "waiting for exit",
312                    tokio::time::timeout(Duration::from_secs(60), launched_binary.wait()),
313                )
314                .await;
315                match timeout_result {
316                    Err(_timeout) => {} // `wait()` timed out, but stop will force quit.
317                    Ok(Err(unexpected_error)) => return Err(unexpected_error), // `wait()` errored.
318                    Ok(Ok(_exit_status)) => {}
319                }
320                launched_binary.stop().await?;
321
322                Ok(())
323            },
324        )
325        .await
326    }
327}