1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3use std::time::Duration;
4
5use anyhow::{Context, Result, bail};
6use async_trait::async_trait;
7use futures::Future;
8use hydro_deploy_integration::{InitConfig, ServerPort};
9use memo_map::MemoMap;
10use serde::Serialize;
11use tokio::sync::{OnceCell, RwLock, mpsc};
12
13use super::build::{BuildError, BuildOutput, BuildParams, build_crate_memoized};
14use super::ports::{self, RustCratePortConfig};
15use super::tracing_options::TracingOptions;
16#[cfg(feature = "profile-folding")]
17use crate::TracingResults;
18use crate::progress::ProgressTracker;
19use crate::{
20 BaseServerStrategy, Host, LaunchedBinary, LaunchedHost, PortNetworkHint, ResourceBatch,
21 ResourceResult, ServerStrategy, Service,
22};
23
24pub struct RustCrateService {
25 id: usize,
26 pub(super) on: Arc<dyn Host>,
27 build_params: BuildParams,
28 tracing: Option<TracingOptions>,
29 args: Option<Vec<String>>,
30 display_id: Option<String>,
31 external_ports: Vec<u16>,
32 env: HashMap<String, String>,
33
34 meta: OnceLock<String>,
35
36 pub(super) port_to_server: MemoMap<String, ports::ServerConfig>,
38 pub(super) port_to_bind: MemoMap<String, ServerStrategy>,
40
41 launched_host: OnceCell<Arc<dyn LaunchedHost>>,
42
43 pub(super) server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
47
48 launched_binary: OnceCell<Box<dyn LaunchedBinary>>,
49 started: OnceCell<()>,
50}
51
52impl RustCrateService {
53 #[expect(clippy::too_many_arguments, reason = "internal use")]
54 pub fn new(
55 id: usize,
56 on: Arc<dyn Host>,
57 build_params: BuildParams,
58 tracing: Option<TracingOptions>,
59 args: Option<Vec<String>>,
60 display_id: Option<String>,
61 external_ports: Vec<u16>,
62 env: HashMap<String, String>,
63 ) -> Self {
64 Self {
65 id,
66 on,
67 build_params,
68 tracing,
69 args,
70 display_id,
71 external_ports,
72 env,
73 meta: OnceLock::new(),
74 port_to_server: MemoMap::new(),
75 port_to_bind: MemoMap::new(),
76 launched_host: OnceCell::new(),
77 server_defns: Arc::new(RwLock::new(HashMap::new())),
78 launched_binary: OnceCell::new(),
79 started: OnceCell::new(),
80 }
81 }
82
83 pub fn update_meta<T: Serialize>(&self, meta: T) {
84 if self.launched_binary.get().is_some() {
85 panic!("Cannot update meta after binary has been launched")
86 }
87 self.meta
88 .set(serde_json::to_string(&meta).unwrap())
89 .expect("Cannot set meta twice.");
90 }
91
92 pub fn get_port(self: &Arc<Self>, name: String) -> RustCratePortConfig {
93 RustCratePortConfig {
94 service: Arc::downgrade(self),
95 service_host: self.on.clone(),
96 service_server_defns: self.server_defns.clone(),
97 network_hint: PortNetworkHint::Auto,
98 port: name,
99 merge: false,
100 }
101 }
102
103 pub fn get_port_with_hint(
104 self: &Arc<Self>,
105 name: String,
106 network_hint: PortNetworkHint,
107 ) -> RustCratePortConfig {
108 RustCratePortConfig {
109 service: Arc::downgrade(self),
110 service_host: self.on.clone(),
111 service_server_defns: self.server_defns.clone(),
112 network_hint,
113 port: name,
114 merge: false,
115 }
116 }
117
118 pub fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
119 self.launched_binary.get().unwrap().stdout()
120 }
121
122 pub fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
123 self.launched_binary.get().unwrap().stderr()
124 }
125
126 pub fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
127 self.launched_binary.get().unwrap().stdout_filter(prefix)
128 }
129
130 pub fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
131 self.launched_binary.get().unwrap().stderr_filter(prefix)
132 }
133
134 #[cfg(feature = "profile-folding")]
135 pub fn tracing_results(&self) -> Option<&TracingResults> {
136 self.launched_binary.get().unwrap().tracing_results()
137 }
138
139 pub fn exit_code(&self) -> Option<i32> {
140 self.launched_binary.get().unwrap().exit_code()
141 }
142
143 fn build(
144 &self,
145 ) -> impl use<> + 'static + Future<Output = Result<&'static BuildOutput, BuildError>> {
146 build_crate_memoized(self.build_params.clone())
148 }
149}
150
151#[async_trait]
152impl Service for RustCrateService {
153 fn collect_resources(&self, _resource_batch: &mut ResourceBatch) {
154 if self.launched_host.get().is_some() {
155 return;
156 }
157
158 tokio::task::spawn(self.build());
159
160 let host = &self.on;
161
162 host.request_custom_binary();
163 for (_, bind_type) in self.port_to_bind.iter() {
164 host.request_port(bind_type);
165 }
166
167 for port in self.external_ports.iter() {
168 host.request_port_base(&BaseServerStrategy::ExternalTcpPort(*port));
169 }
170 }
171
172 async fn deploy(&self, resource_result: &Arc<ResourceResult>) -> Result<()> {
173 self.launched_host
174 .get_or_try_init::<anyhow::Error, _, _>(|| {
175 ProgressTracker::with_group(
176 self.display_id
177 .clone()
178 .unwrap_or_else(|| format!("service/{}", self.id)),
179 None,
180 || async {
181 let built = self.build().await?;
182
183 let host = &self.on;
184 let launched = host.provision(resource_result);
185
186 launched.copy_binary(built).await?;
187 Ok(launched)
188 },
189 )
190 })
191 .await?;
192 Ok(())
193 }
194
195 async fn ready(&self) -> Result<()> {
196 self.launched_binary
197 .get_or_try_init(|| {
198 ProgressTracker::with_group(
199 self.display_id
200 .clone()
201 .unwrap_or_else(|| format!("service/{}", self.id)),
202 None,
203 || async {
204 let launched_host = self.launched_host.get().unwrap();
205
206 let built = self.build().await?;
207 let args = self.args.as_ref().cloned().unwrap_or_default();
208
209 let binary = launched_host
210 .launch_binary(
211 self.display_id
212 .clone()
213 .unwrap_or_else(|| format!("service/{}", self.id)),
214 built,
215 &args,
216 self.tracing.clone(),
217 &self.env,
218 )
219 .await?;
220
221 let bind_config = self
222 .port_to_bind
223 .iter()
224 .map(|(port_name, bind_type)| {
225 (port_name.clone(), launched_host.server_config(bind_type))
226 })
227 .collect::<HashMap<_, _>>();
228
229 let formatted_bind_config = serde_json::to_string::<InitConfig>(&(
230 bind_config,
231 self.meta.get().map(|s| s.as_str().into()),
232 ))
233 .unwrap();
234
235 let stdout_receiver = binary.deploy_stdout();
237
238 binary.stdin().send(format!("{formatted_bind_config}\n"))?;
239
240 let ready_line = ProgressTracker::leaf(
241 "waiting for ready",
242 tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
243 )
244 .await
245 .context("Timed out waiting for ready")?
246 .context("Program unexpectedly quit")?;
247 if let Some(line_rest) = ready_line.strip_prefix("ready: ") {
248 *self.server_defns.try_write().unwrap() =
249 serde_json::from_str(line_rest).unwrap();
250 } else {
251 bail!("expected ready");
252 }
253 Ok(binary)
254 },
255 )
256 })
257 .await?;
258 Ok(())
259 }
260
261 async fn start(&self) -> Result<()> {
262 self.started
263 .get_or_try_init(|| async {
264 let sink_ports_futures =
265 self.port_to_server
266 .iter()
267 .map(|(port_name, outgoing)| async {
268 (&**port_name, outgoing.load_instantiated(&|p| p).await)
269 });
270 let sink_ports = futures::future::join_all(sink_ports_futures)
271 .await
272 .into_iter()
273 .collect::<HashMap<_, _>>();
274
275 let formatted_defns = serde_json::to_string(&sink_ports).unwrap();
276
277 let stdout_receiver = self.launched_binary.get().unwrap().deploy_stdout();
278
279 self.launched_binary
280 .get()
281 .unwrap()
282 .stdin()
283 .send(format!("start: {formatted_defns}\n"))
284 .unwrap();
285
286 let start_ack_line = ProgressTracker::leaf(
287 self.display_id
288 .clone()
289 .unwrap_or_else(|| format!("service/{}", self.id))
290 + " / waiting for ack start",
291 tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
292 )
293 .await??;
294 if !start_ack_line.starts_with("ack start") {
295 bail!("expected ack start");
296 }
297
298 Ok(())
299 })
300 .await?;
301
302 Ok(())
303 }
304
305 async fn stop(&self) -> Result<()> {
306 ProgressTracker::with_group(
307 self.display_id
308 .clone()
309 .unwrap_or_else(|| format!("service/{}", self.id)),
310 None,
311 || async {
312 let launched_binary = self.launched_binary.get().unwrap();
313 launched_binary.stdin().send("stop\n".to_owned())?;
314
315 let timeout_result = ProgressTracker::leaf(
316 "waiting for exit",
317 tokio::time::timeout(Duration::from_secs(60), launched_binary.wait()),
318 )
319 .await;
320 match timeout_result {
321 Err(_timeout) => {} Ok(Err(unexpected_error)) => return Err(unexpected_error), Ok(Ok(_exit_status)) => {}
324 }
325 launched_binary.stop().await?;
326
327 Ok(())
328 },
329 )
330 .await
331 }
332}