1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::sync::Arc;
4#[cfg(feature = "profile-folding")]
5use std::sync::OnceLock;
6use std::time::Duration;
7
8use anyhow::{Context as _, Result};
9use async_ssh2_russh::russh::client::{Config, Handler};
10use async_ssh2_russh::russh::{Disconnect, compression};
11use async_ssh2_russh::russh_sftp::protocol::{Status, StatusCode};
12use async_ssh2_russh::sftp::SftpError;
13use async_ssh2_russh::{AsyncChannel, AsyncSession, NoCheckHandler};
14use async_trait::async_trait;
15use hydro_deploy_integration::ServerBindConfig;
16#[cfg(feature = "profile-folding")]
17use inferno::collapse::Collapse;
18#[cfg(feature = "profile-folding")]
19use inferno::collapse::perf::Folder;
20use nanoid::nanoid;
21use tokio::fs::File;
22#[cfg(feature = "profile-folding")]
23use tokio::io::BufReader;
24use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
25use tokio::net::TcpListener;
26use tokio::sync::{mpsc, oneshot};
27use tokio_stream::StreamExt;
28use tokio_stream::wrappers::LinesStream;
29#[cfg(feature = "profile-folding")]
30use tokio_util::io::SyncIoBridge;
31
32#[cfg(feature = "profile-folding")]
33use crate::TracingResults;
34use crate::progress::ProgressTracker;
35use crate::rust_crate::build::BuildOutput;
36#[cfg(feature = "profile-folding")]
37use crate::rust_crate::flamegraph::handle_fold_data;
38use crate::rust_crate::tracing_options::TracingOptions;
39use crate::util::{PriorityBroadcast, async_retry, prioritized_broadcast};
40use crate::{BaseServerStrategy, LaunchedBinary, LaunchedHost, ResourceResult};
41
42const PERF_OUTFILE: &str = "__profile.perf.data";
43
44struct LaunchedSshBinary {
45 _resource_result: Arc<ResourceResult>,
46 session: Option<AsyncSession<NoCheckHandler>>,
50 channel: AsyncChannel,
51 stdin_sender: mpsc::UnboundedSender<String>,
52 stdout_broadcast: PriorityBroadcast,
53 stderr_broadcast: PriorityBroadcast,
54 tracing: Option<TracingOptions>,
55 #[cfg(feature = "profile-folding")]
56 tracing_results: OnceLock<TracingResults>,
57}
58
59#[async_trait]
60impl LaunchedBinary for LaunchedSshBinary {
61 fn stdin(&self) -> mpsc::UnboundedSender<String> {
62 self.stdin_sender.clone()
63 }
64
65 fn deploy_stdout(&self) -> oneshot::Receiver<String> {
66 self.stdout_broadcast.receive_priority()
67 }
68
69 fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
70 self.stdout_broadcast.receive(None)
71 }
72
73 fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
74 self.stderr_broadcast.receive(None)
75 }
76
77 fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
78 self.stdout_broadcast.receive(Some(prefix))
79 }
80
81 fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
82 self.stderr_broadcast.receive(Some(prefix))
83 }
84
85 #[cfg(feature = "profile-folding")]
86 fn tracing_results(&self) -> Option<&TracingResults> {
87 self.tracing_results.get()
88 }
89
90 fn exit_code(&self) -> Option<i32> {
91 self.channel
93 .recv_exit_status()
94 .try_get()
95 .map(|&ec| ec as _)
96 .ok()
97 }
98
99 async fn wait(&self) -> Result<i32> {
100 let _ = self.channel.closed().wait().await;
101 Ok(*self.channel.recv_exit_status().try_get()? as _)
102 }
103
104 async fn stop(&self) -> Result<()> {
105 if !self.channel.closed().is_done() {
106 ProgressTracker::leaf("force stopping", async {
107 self.channel.eof().await?; self.channel.close().await?; self.channel.closed().wait().await;
111 Result::<_>::Ok(())
112 })
113 .await?;
114 }
115
116 if let Some(tracing) = self.tracing.as_ref() {
118 #[cfg(feature = "profile-folding")]
119 assert!(
120 self.tracing_results.get().is_none(),
121 "`tracing_results` already set! Was `stop()` called twice? This is a bug."
122 );
123
124 let session = self.session.as_ref().unwrap();
125 if let Some(local_raw_perf) = tracing.perf_raw_outfile.as_ref() {
126 ProgressTracker::progress_leaf("downloading perf data", |progress, _| async move {
127 let sftp =
128 async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
129
130 let mut remote_raw_perf = sftp.open(PERF_OUTFILE).await?;
131 let mut local_raw_perf = File::create(local_raw_perf).await?;
132
133 let total_size = remote_raw_perf.metadata().await?.size.unwrap();
134
135 use tokio::io::AsyncWriteExt;
136 let mut index = 0;
137 loop {
138 let mut buffer = [0; 16 * 1024];
139 let n = remote_raw_perf.read(&mut buffer).await?;
140 if n == 0 {
141 break;
142 }
143 local_raw_perf.write_all(&buffer[..n]).await?;
144 index += n;
145 progress(((index as f64 / total_size as f64) * 100.0) as u64);
146 }
147
148 Ok::<(), anyhow::Error>(())
149 })
150 .await?;
151 }
152
153 #[cfg(feature = "profile-folding")]
154 let script_channel = session.open_channel().await?;
155 #[cfg(feature = "profile-folding")]
156 let mut fold_er = Folder::from(tracing.fold_perf_options.clone().unwrap_or_default());
157
158 #[cfg(feature = "profile-folding")]
159 let fold_data = ProgressTracker::leaf("perf script & folding", async move {
160 let mut stderr_lines = script_channel.stderr().lines();
161 let stdout = script_channel.stdout();
162
163 let ((), fold_data, ()) = tokio::try_join!(
165 async move {
166 while let Ok(Some(s)) = stderr_lines.next_line().await {
168 ProgressTracker::eprintln(format!("[perf stderr] {s}"));
169 }
170 Result::<_>::Ok(())
171 },
172 async move {
173 tokio::task::spawn_blocking(move || {
175 let mut fold_data = Vec::new();
176 fold_er.collapse(
177 SyncIoBridge::new(BufReader::new(stdout)),
178 &mut fold_data,
179 )?;
180 Ok(fold_data)
181 })
182 .await?
183 },
184 async move {
185 script_channel
187 .exec(false, format!("perf script --symfs=/ -i {PERF_OUTFILE}"))
188 .await?;
189 Ok(())
190 },
191 )?;
192 Result::<_>::Ok(fold_data)
193 })
194 .await?;
195
196 #[cfg(feature = "profile-folding")]
197 self.tracing_results
198 .set(TracingResults {
199 folded_data: fold_data.clone(),
200 })
201 .expect("`tracing_results` already set! This is a bug.");
202
203 #[cfg(feature = "profile-folding")]
204 handle_fold_data(tracing, fold_data).await?;
205 };
206
207 Ok(())
208 }
209}
210
211impl Drop for LaunchedSshBinary {
212 fn drop(&mut self) {
213 if let Some(session) = self.session.take() {
214 tokio::task::block_in_place(|| {
215 tokio::runtime::Handle::current().block_on(session.disconnect(
216 Disconnect::ByApplication,
217 "",
218 "",
219 ))
220 })
221 .unwrap();
222 }
223 }
224}
225
226#[async_trait]
227pub trait LaunchedSshHost: Send + Sync {
228 fn get_internal_ip(&self) -> String;
229 fn get_external_ip(&self) -> Option<String>;
230 fn get_cloud_provider(&self) -> String;
231 fn resource_result(&self) -> &Arc<ResourceResult>;
232 fn ssh_user(&self) -> &str;
233
234 fn ssh_key_path(&self) -> PathBuf {
235 self.resource_result()
236 .terraform
237 .deployment_folder
238 .as_ref()
239 .unwrap()
240 .path()
241 .join(".ssh")
242 .join("vm_instance_ssh_key_pem")
243 }
244
245 async fn open_ssh_session(&self) -> Result<AsyncSession<NoCheckHandler>> {
246 let target_addr = SocketAddr::new(
247 self.get_external_ip()
248 .as_ref()
249 .context(
250 self.get_cloud_provider()
251 + " host must be configured with an external IP to launch binaries",
252 )?
253 .parse()
254 .unwrap(),
255 22,
256 );
257
258 let res = ProgressTracker::leaf(
259 format!(
260 "connecting to host @ {}",
261 self.get_external_ip().as_ref().unwrap()
262 ),
263 async_retry(
264 &|| async {
265 let mut config = Config::default();
266 config.preferred.compression = (&[
267 compression::ZLIB,
268 compression::ZLIB_LEGACY,
269 compression::NONE,
270 ])
271 .into();
272 AsyncSession::connect_publickey(
273 config,
274 target_addr,
275 self.ssh_user(),
276 self.ssh_key_path(),
277 )
278 .await
279 },
280 10,
281 Duration::from_secs(1),
282 ),
283 )
284 .await?;
285
286 Ok(res)
287 }
288}
289
290async fn create_channel<H>(session: &AsyncSession<H>) -> Result<AsyncChannel>
291where
292 H: 'static + Handler,
293{
294 async_retry(
295 &|| async {
296 Ok(tokio::time::timeout(Duration::from_secs(60), session.open_channel()).await??)
297 },
298 10,
299 Duration::from_secs(1),
300 )
301 .await
302}
303
304#[async_trait]
305impl<T: LaunchedSshHost> LaunchedHost for T {
306 fn base_server_config(&self, bind_type: &BaseServerStrategy) -> ServerBindConfig {
307 match bind_type {
308 BaseServerStrategy::UnixSocket => ServerBindConfig::UnixSocket,
309 BaseServerStrategy::InternalTcpPort(hint) => {
310 ServerBindConfig::TcpPort(self.get_internal_ip().clone(), *hint)
311 }
312 BaseServerStrategy::ExternalTcpPort(_) => todo!(),
313 }
314 }
315
316 async fn copy_binary(&self, binary: &BuildOutput) -> Result<()> {
317 let session = self.open_ssh_session().await?;
318
319 let sftp = async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
320
321 let user = self.ssh_user();
322 let binary_path = format!("/home/{user}/hydro-{}", binary.unique_id());
324
325 if sftp.metadata(&binary_path).await.is_err() {
326 let random = nanoid!(8);
327 let temp_path = format!("/home/{user}/hydro-{random}");
328 let sftp = &sftp;
329
330 ProgressTracker::progress_leaf(
331 format!("uploading binary to {}", binary_path),
332 |set_progress, _| {
333 async move {
334 let mut created_file = sftp.create(&temp_path).await?;
335
336 let mut index = 0;
337 while index < binary.bin_data.len() {
338 let written = created_file
339 .write(
340 &binary.bin_data[index
341 ..std::cmp::min(index + 128 * 1024, binary.bin_data.len())],
342 )
343 .await?;
344 index += written;
345 set_progress(
346 ((index as f64 / binary.bin_data.len() as f64) * 100.0) as u64,
347 );
348 }
349 let mut orig_file_stat = sftp.metadata(&temp_path).await?;
350 orig_file_stat.permissions = Some(0o755); created_file.set_metadata(orig_file_stat).await?;
352 created_file.sync_all().await?;
353 drop(created_file);
354
355 match sftp.rename(&temp_path, binary_path).await {
356 Ok(_) => {}
357 Err(SftpError::Status(Status {
358 status_code: StatusCode::Failure, ..
360 })) => {
361 sftp.remove_file(temp_path).await?;
363 }
364 Err(e) => return Err(e.into()),
365 }
366
367 anyhow::Ok(())
368 }
369 },
370 )
371 .await?;
372 }
373 sftp.close().await?;
374
375 Ok(())
376 }
377
378 async fn launch_binary(
379 &self,
380 id: String,
381 binary: &BuildOutput,
382 args: &[String],
383 tracing: Option<TracingOptions>,
384 ) -> Result<Box<dyn LaunchedBinary>> {
385 let session = self.open_ssh_session().await?;
386
387 let user = self.ssh_user();
388 let binary_path = PathBuf::from(format!("/home/{user}/hydro-{}", binary.unique_id()));
389
390 let mut command = binary_path.to_str().unwrap().to_owned();
391 for arg in args {
392 command.push(' ');
393 command.push_str(&shell_escape::unix::escape(arg.into()))
394 }
395
396 if let Some(TracingOptions {
398 frequency,
399 setup_command,
400 ..
401 }) = tracing.clone()
402 {
403 let id_clone = id.clone();
404 ProgressTracker::leaf("install perf", async {
405 if let Some(setup_command) = setup_command {
407 let setup_channel = create_channel(&session).await?;
408 let (setup_stdout, setup_stderr) =
409 (setup_channel.stdout(), setup_channel.stderr());
410 setup_channel.exec(false, &*setup_command).await?;
411
412 let mut output_lines = LinesStream::new(setup_stdout.lines())
414 .merge(LinesStream::new(setup_stderr.lines()));
415 while let Some(line) = output_lines.next().await {
416 ProgressTracker::eprintln(format!(
417 "[{} install perf] {}",
418 id_clone,
419 line.unwrap()
420 ));
421 }
422
423 setup_channel.closed().wait().await;
424 let exit_code = setup_channel.recv_exit_status().try_get();
425 if Ok(&0) != exit_code {
426 anyhow::bail!("Failed to install perf on remote host");
427 }
428 }
429 Ok(())
430 })
431 .await?;
432
433 command = format!(
436 "perf record -F {frequency} -e cycles:u --call-graph dwarf,65528 -o {PERF_OUTFILE} {command}",
437 );
438 }
439
440 let (channel, stdout, stderr) = ProgressTracker::leaf(
441 format!("launching binary {}", binary_path.display()),
442 async {
443 let channel = create_channel(&session).await?;
444 let (stdout, stderr) = (channel.stdout(), channel.stderr());
446 channel.exec(false, command).await?;
447 anyhow::Ok((channel, stdout, stderr))
448 },
449 )
450 .await?;
451
452 let (stdin_sender, mut stdin_receiver) = mpsc::unbounded_channel::<String>();
453 let mut stdin = channel.stdin();
454
455 tokio::spawn(async move {
456 while let Some(line) = stdin_receiver.recv().await {
457 if stdin.write_all(line.as_bytes()).await.is_err() {
458 break;
459 }
460 stdin.flush().await.unwrap();
461 }
462 });
463
464 let id_clone = id.clone();
465 let stdout_broadcast = prioritized_broadcast(LinesStream::new(stdout.lines()), move |s| {
466 ProgressTracker::println(format!("[{id_clone}] {s}"));
467 });
468 let stderr_broadcast = prioritized_broadcast(LinesStream::new(stderr.lines()), move |s| {
469 ProgressTracker::println(format!("[{id} stderr] {s}"));
470 });
471
472 Ok(Box::new(LaunchedSshBinary {
473 _resource_result: self.resource_result().clone(),
474 session: Some(session),
475 channel,
476 stdin_sender,
477 stdout_broadcast,
478 stderr_broadcast,
479 tracing,
480 #[cfg(feature = "profile-folding")]
481 tracing_results: OnceLock::new(),
482 }))
483 }
484
485 async fn forward_port(&self, addr: &SocketAddr) -> Result<SocketAddr> {
486 let session = self.open_ssh_session().await?;
487
488 let local_port = TcpListener::bind("127.0.0.1:0").await?;
489 let local_addr = local_port.local_addr()?;
490
491 let internal_ip = addr.ip().to_string();
492 let port = addr.port();
493
494 tokio::spawn(async move {
495 #[expect(clippy::never_loop, reason = "tcp accept loop pattern")]
496 while let Ok((mut local_stream, _)) = local_port.accept().await {
497 let mut channel = session
498 .channel_open_direct_tcpip(internal_ip, port.into(), "127.0.0.1", 22)
499 .await
500 .unwrap()
501 .into_stream();
502 let _ = tokio::io::copy_bidirectional(&mut local_stream, &mut channel).await;
503 break;
504 }
507
508 ProgressTracker::println("[hydro] closing forwarded port");
509 });
510
511 Ok(local_addr)
512 }
513}