1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::net::SocketAddr;
4use std::path::PathBuf;
5use std::sync::{Arc, Mutex};
6use std::time::Duration;
7
8use anyhow::{Context as _, Result};
9use async_ssh2_lite::ssh2::ErrorCode;
10use async_ssh2_lite::{AsyncChannel, AsyncSession, SessionConfiguration};
11use async_trait::async_trait;
12use futures::io::BufReader as FuturesBufReader;
13use futures::{AsyncBufReadExt, AsyncWriteExt};
14use hydro_deploy_integration::ServerBindConfig;
15use inferno::collapse::Collapse;
16use inferno::collapse::perf::Folder;
17use nanoid::nanoid;
18use tokio::fs::File;
19use tokio::io::{AsyncReadExt, BufReader as TokioBufReader};
20use tokio::net::{TcpListener, TcpStream};
21use tokio::runtime::Handle;
22use tokio::sync::{mpsc, oneshot};
23use tokio_stream::StreamExt;
24use tokio_util::compat::FuturesAsyncReadCompatExt;
25use tokio_util::io::SyncIoBridge;
26
27use crate::progress::ProgressTracker;
28use crate::rust_crate::build::BuildOutput;
29use crate::rust_crate::flamegraph::handle_fold_data;
30use crate::rust_crate::tracing_options::TracingOptions;
31use crate::util::{async_retry, prioritized_broadcast};
32use crate::{LaunchedBinary, LaunchedHost, ResourceResult, ServerStrategy, TracingResults};
33
34const PERF_OUTFILE: &str = "__profile.perf.data";
35
36pub type PrefixFilteredChannel = (Option<String>, mpsc::UnboundedSender<String>);
37
38struct LaunchedSshBinary {
39 _resource_result: Arc<ResourceResult>,
40 session: Option<AsyncSession<TcpStream>>,
41 channel: AsyncChannel<TcpStream>,
42 stdin_sender: mpsc::UnboundedSender<String>,
43 stdout_receivers: Arc<Mutex<Vec<PrefixFilteredChannel>>>,
44 stdout_deploy_receivers: Arc<Mutex<Option<oneshot::Sender<String>>>>,
45 stderr_receivers: Arc<Mutex<Vec<PrefixFilteredChannel>>>,
46 tracing: Option<TracingOptions>,
47 tracing_results: Option<TracingResults>,
48}
49
50#[async_trait]
51impl LaunchedBinary for LaunchedSshBinary {
52 fn stdin(&self) -> mpsc::UnboundedSender<String> {
53 self.stdin_sender.clone()
54 }
55
56 fn deploy_stdout(&self) -> oneshot::Receiver<String> {
57 let mut receivers = self.stdout_deploy_receivers.lock().unwrap();
58
59 if receivers.is_some() {
60 panic!("Only one deploy stdout receiver is allowed at a time");
61 }
62
63 let (sender, receiver) = oneshot::channel::<String>();
64 *receivers = Some(sender);
65 receiver
66 }
67
68 fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
69 let mut receivers = self.stdout_receivers.lock().unwrap();
70 let (sender, receiver) = mpsc::unbounded_channel::<String>();
71 receivers.push((None, sender));
72 receiver
73 }
74
75 fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
76 let mut receivers = self.stderr_receivers.lock().unwrap();
77 let (sender, receiver) = mpsc::unbounded_channel::<String>();
78 receivers.push((None, sender));
79 receiver
80 }
81
82 fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
83 let mut receivers = self.stdout_receivers.lock().unwrap();
84 let (sender, receiver) = mpsc::unbounded_channel::<String>();
85 receivers.push((Some(prefix), sender));
86 receiver
87 }
88
89 fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
90 let mut receivers = self.stderr_receivers.lock().unwrap();
91 let (sender, receiver) = mpsc::unbounded_channel::<String>();
92 receivers.push((Some(prefix), sender));
93 receiver
94 }
95
96 fn tracing_results(&self) -> Option<&TracingResults> {
97 self.tracing_results.as_ref()
98 }
99
100 fn exit_code(&self) -> Option<i32> {
101 if self.channel.eof() {
103 self.channel.exit_status().ok()
104 } else {
105 None
106 }
107 }
108
109 async fn wait(&mut self) -> Result<i32> {
110 self.channel.wait_eof().await.unwrap();
111 let exit_code = self.channel.exit_status()?;
112 self.channel.wait_close().await.unwrap();
113
114 Ok(exit_code)
115 }
116
117 async fn stop(&mut self) -> Result<()> {
118 if !self.channel.eof() {
119 ProgressTracker::leaf("force stopping", async {
120 self.channel.write_all(b"\x03").await?; self.channel.send_eof().await?;
122 self.channel.wait_eof().await?;
123 self.channel.wait_close().await?;
125 Result::<_>::Ok(())
126 })
127 .await?;
128 }
129
130 if let Some(tracing) = self.tracing.as_ref() {
132 let session = self.session.as_ref().unwrap();
133 if let Some(local_raw_perf) = tracing.perf_raw_outfile.as_ref() {
134 ProgressTracker::progress_leaf("downloading perf data", |progress, _| async move {
135 let sftp = async_retry(
136 &|| async { Ok(session.sftp().await?) },
137 10,
138 Duration::from_secs(1),
139 )
140 .await?;
141
142 let mut remote_raw_perf = sftp.open(&PathBuf::from(PERF_OUTFILE)).await?;
143 let mut local_raw_perf = File::create(local_raw_perf).await?;
144
145 let total_size = remote_raw_perf.stat().await?.size.unwrap();
146 let mut remote_tokio = remote_raw_perf.compat();
147
148 use tokio::io::AsyncWriteExt;
149 let mut index = 0;
150 loop {
151 let mut buffer = [0; 16 * 1024];
152 let n = remote_tokio.read(&mut buffer).await?;
153 if n == 0 {
154 break;
155 }
156 local_raw_perf.write_all(&buffer[..n]).await?;
157 index += n;
158 progress(((index as f64 / total_size as f64) * 100.0) as u64);
159 }
160
161 Ok::<(), anyhow::Error>(())
162 })
163 .await?;
164 }
165
166 let mut script_channel = session.channel_session().await?;
167 let mut fold_er = Folder::from(tracing.fold_perf_options.clone().unwrap_or_default());
168
169 let fold_data = ProgressTracker::leaf("perf script & folding", async move {
170 let mut stderr_lines = FuturesBufReader::new(script_channel.stderr()).lines();
171 let stdout = script_channel.stream(0);
172
173 let ((), fold_data, ()) = tokio::try_join!(
175 async move {
176 while let Some(Ok(s)) = stderr_lines.next().await {
178 ProgressTracker::eprintln(format!("[perf stderr] {s}"));
179 }
180 Result::<_>::Ok(())
181 },
182 async move {
183 tokio::task::spawn_blocking(move || {
185 let mut fold_data = Vec::new();
186 fold_er.collapse(
187 SyncIoBridge::new(TokioBufReader::new(stdout)),
188 &mut fold_data,
189 )?;
190 Ok(fold_data)
191 })
192 .await?
193 },
194 async move {
195 script_channel
197 .exec(&format!("perf script --symfs=/ -i {PERF_OUTFILE}"))
198 .await?;
199 Ok(())
200 },
201 )?;
202 Result::<_>::Ok(fold_data)
203 })
204 .await?;
205
206 self.tracing_results = Some(TracingResults {
207 folded_data: fold_data.clone(),
208 });
209
210 handle_fold_data(tracing, fold_data).await?;
211 };
212
213 Ok(())
214 }
215}
216
217impl Drop for LaunchedSshBinary {
218 fn drop(&mut self) {
219 if let Some(session) = self.session.take() {
220 tokio::task::block_in_place(|| {
221 Handle::current().block_on(session.disconnect(None, "", None))
222 })
223 .unwrap();
224 }
225 }
226}
227
228#[async_trait]
229pub trait LaunchedSshHost: Send + Sync {
230 fn get_internal_ip(&self) -> String;
231 fn get_external_ip(&self) -> Option<String>;
232 fn get_cloud_provider(&self) -> String;
233 fn resource_result(&self) -> &Arc<ResourceResult>;
234 fn ssh_user(&self) -> &str;
235
236 fn ssh_key_path(&self) -> PathBuf {
237 self.resource_result()
238 .terraform
239 .deployment_folder
240 .as_ref()
241 .unwrap()
242 .path()
243 .join(".ssh")
244 .join("vm_instance_ssh_key_pem")
245 }
246
247 fn server_config(&self, bind_type: &ServerStrategy) -> ServerBindConfig {
248 match bind_type {
249 ServerStrategy::UnixSocket => ServerBindConfig::UnixSocket,
250 ServerStrategy::InternalTcpPort => {
251 ServerBindConfig::TcpPort(self.get_internal_ip().clone())
252 }
253 ServerStrategy::ExternalTcpPort(_) => todo!(),
254 ServerStrategy::Demux(demux) => {
255 let mut config_map = HashMap::new();
256 for (key, underlying) in demux {
257 config_map.insert(*key, LaunchedSshHost::server_config(self, underlying));
258 }
259
260 ServerBindConfig::Demux(config_map)
261 }
262 ServerStrategy::Merge(merge) => {
263 let mut configs = vec![];
264 for underlying in merge {
265 configs.push(LaunchedSshHost::server_config(self, underlying));
266 }
267
268 ServerBindConfig::Merge(configs)
269 }
270 ServerStrategy::Tagged(underlying, id) => ServerBindConfig::Tagged(
271 Box::new(LaunchedSshHost::server_config(self, underlying)),
272 *id,
273 ),
274 ServerStrategy::Null => ServerBindConfig::Null,
275 }
276 }
277
278 async fn open_ssh_session(&self) -> Result<AsyncSession<TcpStream>> {
279 let target_addr = SocketAddr::new(
280 self.get_external_ip()
281 .as_ref()
282 .context(
283 self.get_cloud_provider()
284 + " host must be configured with an external IP to launch binaries",
285 )?
286 .parse()
287 .unwrap(),
288 22,
289 );
290
291 let res = ProgressTracker::leaf(
292 format!(
293 "connecting to host @ {}",
294 self.get_external_ip().as_ref().unwrap()
295 ),
296 async_retry(
297 &|| async {
298 let mut config = SessionConfiguration::new();
299 config.set_compress(true);
300
301 let mut session =
302 AsyncSession::<TcpStream>::connect(target_addr, Some(config)).await?;
303
304 session.handshake().await?;
305
306 session
307 .userauth_pubkey_file(
308 self.ssh_user(),
309 None,
310 self.ssh_key_path().as_path(),
311 None,
312 )
313 .await?;
314
315 Ok(session)
316 },
317 10,
318 Duration::from_secs(1),
319 ),
320 )
321 .await?;
322
323 Ok(res)
324 }
325}
326
327#[async_trait]
328impl<T: LaunchedSshHost> LaunchedHost for T {
329 fn server_config(&self, bind_type: &ServerStrategy) -> ServerBindConfig {
330 LaunchedSshHost::server_config(self, bind_type)
331 }
332
333 async fn copy_binary(&self, binary: &BuildOutput) -> Result<()> {
334 let session = self.open_ssh_session().await?;
335
336 let sftp = async_retry(
337 &|| async { Ok(session.sftp().await?) },
338 10,
339 Duration::from_secs(1),
340 )
341 .await?;
342
343 let unique_name = &binary.unique_id;
345
346 let user = self.ssh_user();
347 let binary_path = PathBuf::from(format!("/home/{user}/hydro-{unique_name}"));
348
349 if sftp.stat(&binary_path).await.is_err() {
350 let random = nanoid!(8);
351 let temp_path = PathBuf::from(format!("/home/{user}/hydro-{random}"));
352 let sftp = &sftp;
353
354 ProgressTracker::progress_leaf(
355 format!("uploading binary to {}", binary_path.display()),
356 |set_progress, _| {
357 let binary = &binary;
358 let binary_path = &binary_path;
359 async move {
360 let mut created_file = sftp.create(&temp_path).await?;
361
362 let mut index = 0;
363 while index < binary.bin_data.len() {
364 let written = created_file
365 .write(
366 &binary.bin_data[index
367 ..std::cmp::min(index + 128 * 1024, binary.bin_data.len())],
368 )
369 .await?;
370 index += written;
371 set_progress(
372 ((index as f64 / binary.bin_data.len() as f64) * 100.0) as u64,
373 );
374 }
375 let mut orig_file_stat = sftp.stat(&temp_path).await?;
376 orig_file_stat.perm = Some(0o755); created_file.setstat(orig_file_stat).await?;
378 created_file.close().await?;
379 drop(created_file);
380
381 match sftp.rename(&temp_path, binary_path, None).await {
382 Ok(_) => {}
383 Err(async_ssh2_lite::Error::Ssh2(e))
384 if e.code() == ErrorCode::SFTP(4) =>
385 {
386 sftp.unlink(&temp_path).await?;
388 }
389 Err(e) => return Err(e.into()),
390 }
391
392 anyhow::Ok(())
393 }
394 },
395 )
396 .await?;
397 }
398 drop(sftp);
399
400 Ok(())
401 }
402
403 async fn launch_binary(
404 &self,
405 id: String,
406 binary: &BuildOutput,
407 args: &[String],
408 tracing: Option<TracingOptions>,
409 ) -> Result<Box<dyn LaunchedBinary>> {
410 let session = self.open_ssh_session().await?;
411
412 let unique_name = &binary.unique_id;
413
414 let user = self.ssh_user();
415 let binary_path = PathBuf::from(format!("/home/{user}/hydro-{unique_name}"));
416
417 let channel = ProgressTracker::leaf(
418 format!("launching binary {}", binary_path.display()),
419 async {
420 let mut channel =
421 async_retry(
422 &|| async {
423 Ok(tokio::time::timeout(
424 Duration::from_secs(60),
425 session.channel_session(),
426 )
427 .await??)
428 },
429 10,
430 Duration::from_secs(1),
431 )
432 .await?;
433
434 let mut command = binary_path.to_str().unwrap().to_owned();
435 for arg in args{
436 command.push(' ');
437 command.push_str(&shell_escape::unix::escape(Cow::Borrowed(arg)))
438 }
439 if let Some(TracingOptions { frequency, .. }) = tracing.clone() {
441 command = format!(
443 "perf record -F {frequency} -e cycles:u --call-graph dwarf,65528 -o {PERF_OUTFILE} {command}",
444 );
445 }
446 channel.exec(&command).await?;
447 anyhow::Ok(channel)
448 },
449 )
450 .await?;
451
452 let (stdin_sender, mut stdin_receiver) = mpsc::unbounded_channel::<String>();
453 let mut stdin = channel.stream(0); tokio::spawn(async move {
455 while let Some(line) = stdin_receiver.recv().await {
456 if stdin.write_all(line.as_bytes()).await.is_err() {
457 break;
458 }
459
460 stdin.flush().await.unwrap();
461 }
462 });
463
464 let id_clone = id.clone();
465 let (stdout_deploy_receivers, stdout_receivers) =
466 prioritized_broadcast(FuturesBufReader::new(channel.stream(0)).lines(), move |s| {
467 ProgressTracker::println(format!("[{id_clone}] {s}"));
468 });
469 let (_, stderr_receivers) =
470 prioritized_broadcast(FuturesBufReader::new(channel.stderr()).lines(), move |s| {
471 ProgressTracker::println(format!("[{id} stderr] {s}"));
472 });
473
474 Ok(Box::new(LaunchedSshBinary {
475 _resource_result: self.resource_result().clone(),
476 session: Some(session),
477 channel,
478 stdin_sender,
479 stdout_deploy_receivers,
480 stdout_receivers,
481 stderr_receivers,
482 tracing,
483 tracing_results: None,
484 }))
485 }
486
487 async fn forward_port(&self, addr: &SocketAddr) -> Result<SocketAddr> {
488 let session = self.open_ssh_session().await?;
489
490 let local_port = TcpListener::bind("127.0.0.1:0").await?;
491 let local_addr = local_port.local_addr()?;
492
493 let internal_ip = addr.ip().to_string();
494 let port = addr.port();
495
496 tokio::spawn(async move {
497 #[expect(clippy::never_loop, reason = "tcp accept loop pattern")]
498 while let Ok((mut local_stream, _)) = local_port.accept().await {
499 let mut channel = session
500 .channel_direct_tcpip(&internal_ip, port, None)
501 .await
502 .unwrap();
503 let _ = tokio::io::copy_bidirectional(&mut local_stream, &mut channel).await;
504 break;
505 }
508
509 ProgressTracker::println("[hydro] closing forwarded port");
510 });
511
512 Ok(local_addr)
513 }
514}