1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3use std::time::Duration;
4
5use anyhow::{Context, Result, bail};
6use async_trait::async_trait;
7use futures::Future;
8use hydro_deploy_integration::{InitConfig, ServerPort};
9use memo_map::MemoMap;
10use serde::Serialize;
11use tokio::sync::{OnceCell, RwLock, mpsc};
12
13use super::build::{BuildError, BuildOutput, BuildParams, build_crate_memoized};
14use super::ports::{self, RustCratePortConfig};
15use super::tracing_options::TracingOptions;
16#[cfg(feature = "profile-folding")]
17use crate::TracingResults;
18use crate::progress::ProgressTracker;
19use crate::{
20 BaseServerStrategy, Host, LaunchedBinary, LaunchedHost, PortNetworkHint, ResourceBatch,
21 ResourceResult, ServerStrategy, Service,
22};
23
24pub struct RustCrateService {
25 id: usize,
26 pub(super) on: Arc<dyn Host>,
27 build_params: BuildParams,
28 tracing: Option<TracingOptions>,
29 args: Option<Vec<String>>,
30 display_id: Option<String>,
31 external_ports: Vec<u16>,
32
33 meta: OnceLock<String>,
34
35 pub(super) port_to_server: MemoMap<String, ports::ServerConfig>,
37 pub(super) port_to_bind: MemoMap<String, ServerStrategy>,
39
40 launched_host: OnceCell<Arc<dyn LaunchedHost>>,
41
42 pub(super) server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
46
47 launched_binary: OnceCell<Box<dyn LaunchedBinary>>,
48 started: OnceCell<()>,
49}
50
51impl RustCrateService {
52 pub fn new(
53 id: usize,
54 on: Arc<dyn Host>,
55 build_params: BuildParams,
56 tracing: Option<TracingOptions>,
57 args: Option<Vec<String>>,
58 display_id: Option<String>,
59 external_ports: Vec<u16>,
60 ) -> Self {
61 Self {
62 id,
63 on,
64 build_params,
65 tracing,
66 args,
67 display_id,
68 external_ports,
69 meta: OnceLock::new(),
70 port_to_server: MemoMap::new(),
71 port_to_bind: MemoMap::new(),
72 launched_host: OnceCell::new(),
73 server_defns: Arc::new(RwLock::new(HashMap::new())),
74 launched_binary: OnceCell::new(),
75 started: OnceCell::new(),
76 }
77 }
78
79 pub fn update_meta<T: Serialize>(&self, meta: T) {
80 if self.launched_binary.get().is_some() {
81 panic!("Cannot update meta after binary has been launched")
82 }
83 self.meta
84 .set(serde_json::to_string(&meta).unwrap())
85 .expect("Cannot set meta twice.");
86 }
87
88 pub fn get_port(self: &Arc<Self>, name: String) -> RustCratePortConfig {
89 RustCratePortConfig {
90 service: Arc::downgrade(self),
91 service_host: self.on.clone(),
92 service_server_defns: self.server_defns.clone(),
93 network_hint: PortNetworkHint::Auto,
94 port: name,
95 merge: false,
96 }
97 }
98
99 pub fn get_port_with_hint(
100 self: &Arc<Self>,
101 name: String,
102 network_hint: PortNetworkHint,
103 ) -> RustCratePortConfig {
104 RustCratePortConfig {
105 service: Arc::downgrade(self),
106 service_host: self.on.clone(),
107 service_server_defns: self.server_defns.clone(),
108 network_hint,
109 port: name,
110 merge: false,
111 }
112 }
113
114 pub fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
115 self.launched_binary.get().unwrap().stdout()
116 }
117
118 pub fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
119 self.launched_binary.get().unwrap().stderr()
120 }
121
122 pub fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
123 self.launched_binary.get().unwrap().stdout_filter(prefix)
124 }
125
126 pub fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
127 self.launched_binary.get().unwrap().stderr_filter(prefix)
128 }
129
130 #[cfg(feature = "profile-folding")]
131 pub fn tracing_results(&self) -> Option<&TracingResults> {
132 self.launched_binary.get().unwrap().tracing_results()
133 }
134
135 pub fn exit_code(&self) -> Option<i32> {
136 self.launched_binary.get().unwrap().exit_code()
137 }
138
139 fn build(
140 &self,
141 ) -> impl use<> + 'static + Future<Output = Result<&'static BuildOutput, BuildError>> {
142 build_crate_memoized(self.build_params.clone())
144 }
145}
146
147#[async_trait]
148impl Service for RustCrateService {
149 fn collect_resources(&self, _resource_batch: &mut ResourceBatch) {
150 if self.launched_host.get().is_some() {
151 return;
152 }
153
154 tokio::task::spawn(self.build());
155
156 let host = &self.on;
157
158 host.request_custom_binary();
159 for (_, bind_type) in self.port_to_bind.iter() {
160 host.request_port(bind_type);
161 }
162
163 for port in self.external_ports.iter() {
164 host.request_port_base(&BaseServerStrategy::ExternalTcpPort(*port));
165 }
166 }
167
168 async fn deploy(&self, resource_result: &Arc<ResourceResult>) -> Result<()> {
169 self.launched_host
170 .get_or_try_init::<anyhow::Error, _, _>(|| {
171 ProgressTracker::with_group(
172 self.display_id
173 .clone()
174 .unwrap_or_else(|| format!("service/{}", self.id)),
175 None,
176 || async {
177 let built = self.build().await?;
178
179 let host = &self.on;
180 let launched = host.provision(resource_result);
181
182 launched.copy_binary(built).await?;
183 Ok(launched)
184 },
185 )
186 })
187 .await?;
188 Ok(())
189 }
190
191 async fn ready(&self) -> Result<()> {
192 self.launched_binary
193 .get_or_try_init(|| {
194 ProgressTracker::with_group(
195 self.display_id
196 .clone()
197 .unwrap_or_else(|| format!("service/{}", self.id)),
198 None,
199 || async {
200 let launched_host = self.launched_host.get().unwrap();
201
202 let built = self.build().await?;
203 let args = self.args.as_ref().cloned().unwrap_or_default();
204
205 let binary = launched_host
206 .launch_binary(
207 self.display_id
208 .clone()
209 .unwrap_or_else(|| format!("service/{}", self.id)),
210 built,
211 &args,
212 self.tracing.clone(),
213 )
214 .await?;
215
216 let bind_config = self
217 .port_to_bind
218 .iter()
219 .map(|(port_name, bind_type)| {
220 (port_name.clone(), launched_host.server_config(bind_type))
221 })
222 .collect::<HashMap<_, _>>();
223
224 let formatted_bind_config = serde_json::to_string::<InitConfig>(&(
225 bind_config,
226 self.meta.get().map(|s| s.as_str().into()),
227 ))
228 .unwrap();
229
230 let stdout_receiver = binary.deploy_stdout();
232
233 binary.stdin().send(format!("{formatted_bind_config}\n"))?;
234
235 let ready_line = ProgressTracker::leaf(
236 "waiting for ready",
237 tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
238 )
239 .await
240 .context("Timed out waiting for ready")?
241 .context("Program unexpectedly quit")?;
242 if let Some(line_rest) = ready_line.strip_prefix("ready: ") {
243 *self.server_defns.try_write().unwrap() =
244 serde_json::from_str(line_rest).unwrap();
245 } else {
246 bail!("expected ready");
247 }
248 Ok(binary)
249 },
250 )
251 })
252 .await?;
253 Ok(())
254 }
255
256 async fn start(&self) -> Result<()> {
257 self.started
258 .get_or_try_init(|| async {
259 let sink_ports_futures =
260 self.port_to_server
261 .iter()
262 .map(|(port_name, outgoing)| async {
263 (&**port_name, outgoing.load_instantiated(&|p| p).await)
264 });
265 let sink_ports = futures::future::join_all(sink_ports_futures)
266 .await
267 .into_iter()
268 .collect::<HashMap<_, _>>();
269
270 let formatted_defns = serde_json::to_string(&sink_ports).unwrap();
271
272 let stdout_receiver = self.launched_binary.get().unwrap().deploy_stdout();
273
274 self.launched_binary
275 .get()
276 .unwrap()
277 .stdin()
278 .send(format!("start: {formatted_defns}\n"))
279 .unwrap();
280
281 let start_ack_line = ProgressTracker::leaf(
282 self.display_id
283 .clone()
284 .unwrap_or_else(|| format!("service/{}", self.id))
285 + " / waiting for ack start",
286 tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
287 )
288 .await??;
289 if !start_ack_line.starts_with("ack start") {
290 bail!("expected ack start");
291 }
292
293 Ok(())
294 })
295 .await?;
296
297 Ok(())
298 }
299
300 async fn stop(&self) -> Result<()> {
301 ProgressTracker::with_group(
302 self.display_id
303 .clone()
304 .unwrap_or_else(|| format!("service/{}", self.id)),
305 None,
306 || async {
307 let launched_binary = self.launched_binary.get().unwrap();
308 launched_binary.stdin().send("stop\n".to_string())?;
309
310 let timeout_result = ProgressTracker::leaf(
311 "waiting for exit",
312 tokio::time::timeout(Duration::from_secs(60), launched_binary.wait()),
313 )
314 .await;
315 match timeout_result {
316 Err(_timeout) => {} Ok(Err(unexpected_error)) => return Err(unexpected_error), Ok(Ok(_exit_status)) => {}
319 }
320 launched_binary.stop().await?;
321
322 Ok(())
323 },
324 )
325 .await
326 }
327}