1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::path::PathBuf;
4use std::sync::Arc;
5#[cfg(feature = "profile-folding")]
6use std::sync::OnceLock;
7use std::time::Duration;
8
9use anyhow::{Context as _, Result};
10use async_ssh2_russh::russh::client::{Config, Handler};
11use async_ssh2_russh::russh::{Disconnect, compression};
12use async_ssh2_russh::russh_sftp::protocol::{Status, StatusCode};
13use async_ssh2_russh::sftp::SftpError;
14use async_ssh2_russh::{AsyncChannel, AsyncSession, NoCheckHandler};
15use async_trait::async_trait;
16use hydro_deploy_integration::ServerBindConfig;
17#[cfg(feature = "profile-folding")]
18use inferno::collapse::Collapse;
19#[cfg(feature = "profile-folding")]
20use inferno::collapse::perf::Folder;
21use nanoid::nanoid;
22use tokio::fs::File;
23#[cfg(feature = "profile-folding")]
24use tokio::io::BufReader;
25use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpListener;
27use tokio::sync::{mpsc, oneshot};
28use tokio_stream::StreamExt;
29use tokio_stream::wrappers::LinesStream;
30#[cfg(feature = "profile-folding")]
31use tokio_util::io::SyncIoBridge;
32
33#[cfg(feature = "profile-folding")]
34use crate::TracingResults;
35use crate::progress::ProgressTracker;
36use crate::rust_crate::build::BuildOutput;
37#[cfg(feature = "profile-folding")]
38use crate::rust_crate::flamegraph::handle_fold_data;
39use crate::rust_crate::tracing_options::TracingOptions;
40use crate::util::{PriorityBroadcast, async_retry, prioritized_broadcast};
41use crate::{BaseServerStrategy, LaunchedBinary, LaunchedHost, ResourceResult};
42
43const PERF_OUTFILE: &str = "__profile.perf.data";
44
45struct LaunchedSshBinary {
46 _resource_result: Arc<ResourceResult>,
47 session: Option<AsyncSession<NoCheckHandler>>,
51 channel: AsyncChannel,
52 stdin_sender: mpsc::UnboundedSender<String>,
53 stdout_broadcast: PriorityBroadcast,
54 stderr_broadcast: PriorityBroadcast,
55 tracing: Option<TracingOptions>,
56 #[cfg(feature = "profile-folding")]
57 tracing_results: OnceLock<TracingResults>,
58}
59
60#[async_trait]
61impl LaunchedBinary for LaunchedSshBinary {
62 fn stdin(&self) -> mpsc::UnboundedSender<String> {
63 self.stdin_sender.clone()
64 }
65
66 fn deploy_stdout(&self) -> oneshot::Receiver<String> {
67 self.stdout_broadcast.receive_priority()
68 }
69
70 fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
71 self.stdout_broadcast.receive(None)
72 }
73
74 fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
75 self.stderr_broadcast.receive(None)
76 }
77
78 fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
79 self.stdout_broadcast.receive(Some(prefix))
80 }
81
82 fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
83 self.stderr_broadcast.receive(Some(prefix))
84 }
85
86 #[cfg(feature = "profile-folding")]
87 fn tracing_results(&self) -> Option<&TracingResults> {
88 self.tracing_results.get()
89 }
90
91 fn exit_code(&self) -> Option<i32> {
92 self.channel
94 .recv_exit_status()
95 .try_get()
96 .map(|&ec| ec as _)
97 .ok()
98 }
99
100 async fn wait(&self) -> Result<i32> {
101 let _ = self.channel.closed().wait().await;
102 Ok(*self.channel.recv_exit_status().try_get()? as _)
103 }
104
105 async fn stop(&self) -> Result<()> {
106 if !self.channel.closed().is_done() {
107 ProgressTracker::leaf("force stopping", async {
108 self.channel.eof().await?; self.channel.close().await?; self.channel.closed().wait().await;
112 Result::<_>::Ok(())
113 })
114 .await?;
115 }
116
117 if let Some(tracing) = self.tracing.as_ref() {
119 #[cfg(feature = "profile-folding")]
120 assert!(
121 self.tracing_results.get().is_none(),
122 "`tracing_results` already set! Was `stop()` called twice? This is a bug."
123 );
124
125 let session = self.session.as_ref().unwrap();
126 if let Some(local_raw_perf) = tracing.perf_raw_outfile.as_ref() {
127 ProgressTracker::progress_leaf("downloading perf data", |progress, _| async move {
128 let sftp =
129 async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
130
131 let mut remote_raw_perf = sftp.open(PERF_OUTFILE).await?;
132 let mut local_raw_perf = File::create(local_raw_perf).await?;
133
134 let total_size = remote_raw_perf.metadata().await?.size.unwrap();
135
136 use tokio::io::AsyncWriteExt;
137 let mut index = 0;
138 loop {
139 let mut buffer = [0; 16 * 1024];
140 let n = remote_raw_perf.read(&mut buffer).await?;
141 if n == 0 {
142 break;
143 }
144 local_raw_perf.write_all(&buffer[..n]).await?;
145 index += n;
146 progress(((index as f64 / total_size as f64) * 100.0) as u64);
147 }
148
149 Ok::<(), anyhow::Error>(())
150 })
151 .await?;
152 }
153
154 #[cfg(feature = "profile-folding")]
155 let script_channel = session.open_channel().await?;
156 #[cfg(feature = "profile-folding")]
157 let mut fold_er = Folder::from(tracing.fold_perf_options.clone().unwrap_or_default());
158
159 #[cfg(feature = "profile-folding")]
160 let fold_data = ProgressTracker::leaf("perf script & folding", async move {
161 let mut stderr_lines = script_channel.stderr().lines();
162 let stdout = script_channel.stdout();
163
164 let ((), fold_data, ()) = tokio::try_join!(
166 async move {
167 while let Ok(Some(s)) = stderr_lines.next_line().await {
169 ProgressTracker::eprintln(format!("[perf stderr] {s}"));
170 }
171 Result::<_>::Ok(())
172 },
173 async move {
174 tokio::task::spawn_blocking(move || {
176 let mut fold_data = Vec::new();
177 fold_er.collapse(
178 SyncIoBridge::new(BufReader::new(stdout)),
179 &mut fold_data,
180 )?;
181 Ok(fold_data)
182 })
183 .await?
184 },
185 async move {
186 script_channel
188 .exec(false, format!("perf script --symfs=/ -i {PERF_OUTFILE}"))
189 .await?;
190 Ok(())
191 },
192 )?;
193 Result::<_>::Ok(fold_data)
194 })
195 .await?;
196
197 #[cfg(feature = "profile-folding")]
198 self.tracing_results
199 .set(TracingResults {
200 folded_data: fold_data.clone(),
201 })
202 .expect("`tracing_results` already set! This is a bug.");
203
204 #[cfg(feature = "profile-folding")]
205 handle_fold_data(tracing, fold_data).await?;
206 };
207
208 Ok(())
209 }
210}
211
212impl Drop for LaunchedSshBinary {
213 fn drop(&mut self) {
214 if let Some(session) = self.session.take() {
215 tokio::task::block_in_place(|| {
216 tokio::runtime::Handle::current().block_on(session.disconnect(
217 Disconnect::ByApplication,
218 "",
219 "",
220 ))
221 })
222 .unwrap();
223 }
224 }
225}
226
227#[async_trait]
228pub trait LaunchedSshHost: Send + Sync {
229 fn get_internal_ip(&self) -> &str;
230 fn get_external_ip(&self) -> Option<&str>;
231 fn get_cloud_provider(&self) -> &'static str;
232 fn resource_result(&self) -> &Arc<ResourceResult>;
233 fn ssh_user(&self) -> &str;
234
235 fn ssh_key_path(&self) -> PathBuf {
236 self.resource_result()
237 .terraform
238 .deployment_folder
239 .as_ref()
240 .unwrap()
241 .path()
242 .join(".ssh")
243 .join("vm_instance_ssh_key_pem")
244 }
245
246 async fn open_ssh_session(&self) -> Result<AsyncSession<NoCheckHandler>> {
247 let target_addr = SocketAddr::new(
248 self.get_external_ip()
249 .context(format!(
250 "{} host must be configured with an external IP to launch binaries",
251 self.get_cloud_provider()
252 ))?
253 .parse()
254 .unwrap(),
255 22,
256 );
257
258 let res = ProgressTracker::leaf(
259 format!("connecting to host @ {}", self.get_external_ip().unwrap()),
260 async_retry(
261 &|| async {
262 let mut config = Config::default();
263 config.preferred.compression = (&[
264 compression::ZLIB,
265 compression::ZLIB_LEGACY,
266 compression::NONE,
267 ])
268 .into();
269 AsyncSession::connect_publickey(
270 config,
271 target_addr,
272 self.ssh_user(),
273 self.ssh_key_path(),
274 )
275 .await
276 },
277 10,
278 Duration::from_secs(1),
279 ),
280 )
281 .await?;
282
283 Ok(res)
284 }
285}
286
287async fn create_channel<H>(session: &AsyncSession<H>) -> Result<AsyncChannel>
288where
289 H: 'static + Handler,
290{
291 async_retry(
292 &|| async {
293 Ok(tokio::time::timeout(Duration::from_secs(60), session.open_channel()).await??)
294 },
295 10,
296 Duration::from_secs(1),
297 )
298 .await
299}
300
301#[async_trait]
302impl<T: LaunchedSshHost> LaunchedHost for T {
303 fn base_server_config(&self, bind_type: &BaseServerStrategy) -> ServerBindConfig {
304 match bind_type {
305 BaseServerStrategy::UnixSocket => ServerBindConfig::UnixSocket,
306 BaseServerStrategy::InternalTcpPort(hint) => {
307 ServerBindConfig::TcpPort(self.get_internal_ip().to_owned(), *hint)
308 }
309 BaseServerStrategy::ExternalTcpPort(_) => todo!(),
310 }
311 }
312
313 async fn copy_binary(&self, binary: &BuildOutput) -> Result<()> {
314 let session = self.open_ssh_session().await?;
315
316 let sftp = async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
317
318 let user = self.ssh_user();
319 let binary_path = format!("/home/{user}/hydro-{}", binary.unique_id());
321
322 if sftp.metadata(&binary_path).await.is_err() {
323 let random = nanoid!(8);
324 let temp_path = format!("/home/{user}/hydro-{random}");
325 let sftp = &sftp;
326
327 ProgressTracker::progress_leaf(
328 format!("uploading binary to {}", binary_path),
329 |set_progress, _| {
330 async move {
331 let mut created_file = sftp.create(&temp_path).await?;
332
333 let mut index = 0;
334 while index < binary.bin_data.len() {
335 let written = created_file
336 .write(
337 &binary.bin_data[index
338 ..std::cmp::min(index + 128 * 1024, binary.bin_data.len())],
339 )
340 .await?;
341 index += written;
342 set_progress(
343 ((index as f64 / binary.bin_data.len() as f64) * 100.0) as u64,
344 );
345 }
346 let mut orig_file_stat = sftp.metadata(&temp_path).await?;
347 orig_file_stat.permissions = Some(0o755); created_file.set_metadata(orig_file_stat).await?;
349 created_file.sync_all().await?;
350 drop(created_file);
351
352 match sftp.rename(&temp_path, binary_path).await {
353 Ok(_) => {}
354 Err(SftpError::Status(Status {
355 status_code: StatusCode::Failure, ..
357 })) => {
358 sftp.remove_file(temp_path).await?;
360 }
361 Err(e) => return Err(e.into()),
362 }
363
364 anyhow::Ok(())
365 }
366 },
367 )
368 .await?;
369 }
370 sftp.close().await?;
371
372 Ok(())
373 }
374
375 async fn launch_binary(
376 &self,
377 id: String,
378 binary: &BuildOutput,
379 args: &[String],
380 tracing: Option<TracingOptions>,
381 env: &HashMap<String, String>,
382 ) -> Result<Box<dyn LaunchedBinary>> {
383 let session = self.open_ssh_session().await?;
384
385 let user = self.ssh_user();
386 let binary_path = PathBuf::from(format!("/home/{user}/hydro-{}", binary.unique_id()));
387
388 let mut command = String::new();
389 for (k, v) in env {
391 command.push_str(&format!("{}={} ", k, shell_escape::unix::escape(v.into())));
392 }
393
394 command.push_str(binary_path.to_str().unwrap());
395 for arg in args {
396 command.push(' ');
397 command.push_str(&shell_escape::unix::escape(arg.into()))
398 }
399
400 if let Some(TracingOptions {
402 frequency,
403 setup_command,
404 ..
405 }) = tracing.clone()
406 {
407 let id_clone = id.clone();
408 ProgressTracker::leaf("install perf", async {
409 if let Some(setup_command) = setup_command {
411 let setup_channel = create_channel(&session).await?;
412 let (setup_stdout, setup_stderr) =
413 (setup_channel.stdout(), setup_channel.stderr());
414 setup_channel.exec(false, &*setup_command).await?;
415
416 let mut output_lines = LinesStream::new(setup_stdout.lines())
418 .merge(LinesStream::new(setup_stderr.lines()));
419 while let Some(line) = output_lines.next().await {
420 ProgressTracker::eprintln(format!(
421 "[{} install perf] {}",
422 id_clone,
423 line.unwrap()
424 ));
425 }
426
427 setup_channel.closed().wait().await;
428 let exit_code = setup_channel.recv_exit_status().try_get();
429 if Ok(&0) != exit_code {
430 anyhow::bail!("Failed to install perf on remote host");
431 }
432 }
433 Ok(())
434 })
435 .await?;
436
437 command = format!(
440 "perf record -F {frequency} -e cycles:u --call-graph dwarf,65528 -o {PERF_OUTFILE} {command}",
441 );
442 }
443
444 let (channel, stdout, stderr) = ProgressTracker::leaf(
445 format!("launching binary {}", binary_path.display()),
446 async {
447 let channel = create_channel(&session).await?;
448 let (stdout, stderr) = (channel.stdout(), channel.stderr());
450 channel.exec(false, command).await?;
451 anyhow::Ok((channel, stdout, stderr))
452 },
453 )
454 .await?;
455
456 let (stdin_sender, mut stdin_receiver) = mpsc::unbounded_channel::<String>();
457 let mut stdin = channel.stdin();
458
459 tokio::spawn(async move {
460 while let Some(line) = stdin_receiver.recv().await {
461 if stdin.write_all(line.as_bytes()).await.is_err() {
462 break;
463 }
464 stdin.flush().await.unwrap();
465 }
466 });
467
468 let id_clone = id.clone();
469 let stdout_broadcast = prioritized_broadcast(LinesStream::new(stdout.lines()), move |s| {
470 ProgressTracker::println(format!("[{id_clone}] {s}"));
471 });
472 let stderr_broadcast = prioritized_broadcast(LinesStream::new(stderr.lines()), move |s| {
473 ProgressTracker::println(format!("[{id} stderr] {s}"));
474 });
475
476 Ok(Box::new(LaunchedSshBinary {
477 _resource_result: self.resource_result().clone(),
478 session: Some(session),
479 channel,
480 stdin_sender,
481 stdout_broadcast,
482 stderr_broadcast,
483 tracing,
484 #[cfg(feature = "profile-folding")]
485 tracing_results: OnceLock::new(),
486 }))
487 }
488
489 async fn forward_port(&self, addr: &SocketAddr) -> Result<SocketAddr> {
490 let session = self.open_ssh_session().await?;
491
492 let local_port = TcpListener::bind("127.0.0.1:0").await?;
493 let local_addr = local_port.local_addr()?;
494
495 let internal_ip = addr.ip().to_string();
496 let port = addr.port();
497
498 tokio::spawn(async move {
499 #[expect(clippy::never_loop, reason = "tcp accept loop pattern")]
500 while let Ok((mut local_stream, _)) = local_port.accept().await {
501 let mut channel = session
502 .channel_open_direct_tcpip(internal_ip, port.into(), "127.0.0.1", 22)
503 .await
504 .unwrap()
505 .into_stream();
506 let _ = tokio::io::copy_bidirectional(&mut local_stream, &mut channel).await;
507 break;
508 }
511
512 ProgressTracker::println("[hydro] closing forwarded port");
513 });
514
515 Ok(local_addr)
516 }
517}