hydro_deploy/
ssh.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::net::SocketAddr;
4use std::path::PathBuf;
5use std::sync::{Arc, Mutex};
6use std::time::Duration;
7
8use anyhow::{Context as _, Result};
9use async_ssh2_lite::ssh2::ErrorCode;
10use async_ssh2_lite::{AsyncChannel, AsyncSession, SessionConfiguration};
11use async_trait::async_trait;
12use futures::io::BufReader as FuturesBufReader;
13use futures::{AsyncBufReadExt, AsyncWriteExt};
14use hydro_deploy_integration::ServerBindConfig;
15use inferno::collapse::Collapse;
16use inferno::collapse::perf::Folder;
17use nanoid::nanoid;
18use tokio::fs::File;
19use tokio::io::{AsyncReadExt, BufReader as TokioBufReader};
20use tokio::net::{TcpListener, TcpStream};
21use tokio::runtime::Handle;
22use tokio::sync::{mpsc, oneshot};
23use tokio_stream::StreamExt;
24use tokio_util::compat::FuturesAsyncReadCompatExt;
25use tokio_util::io::SyncIoBridge;
26
27use crate::progress::ProgressTracker;
28use crate::rust_crate::build::BuildOutput;
29use crate::rust_crate::flamegraph::handle_fold_data;
30use crate::rust_crate::tracing_options::TracingOptions;
31use crate::util::{async_retry, prioritized_broadcast};
32use crate::{LaunchedBinary, LaunchedHost, ResourceResult, ServerStrategy, TracingResults};
33
34const PERF_OUTFILE: &str = "__profile.perf.data";
35
36pub type PrefixFilteredChannel = (Option<String>, mpsc::UnboundedSender<String>);
37
38struct LaunchedSshBinary {
39    _resource_result: Arc<ResourceResult>,
40    session: Option<AsyncSession<TcpStream>>,
41    channel: AsyncChannel<TcpStream>,
42    stdin_sender: mpsc::UnboundedSender<String>,
43    stdout_receivers: Arc<Mutex<Vec<PrefixFilteredChannel>>>,
44    stdout_deploy_receivers: Arc<Mutex<Option<oneshot::Sender<String>>>>,
45    stderr_receivers: Arc<Mutex<Vec<PrefixFilteredChannel>>>,
46    tracing: Option<TracingOptions>,
47    tracing_results: Option<TracingResults>,
48}
49
50#[async_trait]
51impl LaunchedBinary for LaunchedSshBinary {
52    fn stdin(&self) -> mpsc::UnboundedSender<String> {
53        self.stdin_sender.clone()
54    }
55
56    fn deploy_stdout(&self) -> oneshot::Receiver<String> {
57        let mut receivers = self.stdout_deploy_receivers.lock().unwrap();
58
59        if receivers.is_some() {
60            panic!("Only one deploy stdout receiver is allowed at a time");
61        }
62
63        let (sender, receiver) = oneshot::channel::<String>();
64        *receivers = Some(sender);
65        receiver
66    }
67
68    fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
69        let mut receivers = self.stdout_receivers.lock().unwrap();
70        let (sender, receiver) = mpsc::unbounded_channel::<String>();
71        receivers.push((None, sender));
72        receiver
73    }
74
75    fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
76        let mut receivers = self.stderr_receivers.lock().unwrap();
77        let (sender, receiver) = mpsc::unbounded_channel::<String>();
78        receivers.push((None, sender));
79        receiver
80    }
81
82    fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
83        let mut receivers = self.stdout_receivers.lock().unwrap();
84        let (sender, receiver) = mpsc::unbounded_channel::<String>();
85        receivers.push((Some(prefix), sender));
86        receiver
87    }
88
89    fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
90        let mut receivers = self.stderr_receivers.lock().unwrap();
91        let (sender, receiver) = mpsc::unbounded_channel::<String>();
92        receivers.push((Some(prefix), sender));
93        receiver
94    }
95
96    fn tracing_results(&self) -> Option<&TracingResults> {
97        self.tracing_results.as_ref()
98    }
99
100    fn exit_code(&self) -> Option<i32> {
101        // until the program exits, the exit status is meaningless
102        if self.channel.eof() {
103            self.channel.exit_status().ok()
104        } else {
105            None
106        }
107    }
108
109    async fn wait(&mut self) -> Result<i32> {
110        self.channel.wait_eof().await.unwrap();
111        let exit_code = self.channel.exit_status()?;
112        self.channel.wait_close().await.unwrap();
113
114        Ok(exit_code)
115    }
116
117    async fn stop(&mut self) -> Result<()> {
118        if !self.channel.eof() {
119            ProgressTracker::leaf("force stopping", async {
120                self.channel.write_all(b"\x03").await?; // `^C`
121                self.channel.send_eof().await?;
122                self.channel.wait_eof().await?;
123                // `exit_status()`
124                self.channel.wait_close().await?;
125                Result::<_>::Ok(())
126            })
127            .await?;
128        }
129
130        // Run perf post-processing and download perf output.
131        if let Some(tracing) = self.tracing.as_ref() {
132            let session = self.session.as_ref().unwrap();
133            if let Some(local_raw_perf) = tracing.perf_raw_outfile.as_ref() {
134                ProgressTracker::progress_leaf("downloading perf data", |progress, _| async move {
135                    let sftp = async_retry(
136                        &|| async { Ok(session.sftp().await?) },
137                        10,
138                        Duration::from_secs(1),
139                    )
140                    .await?;
141
142                    let mut remote_raw_perf = sftp.open(&PathBuf::from(PERF_OUTFILE)).await?;
143                    let mut local_raw_perf = File::create(local_raw_perf).await?;
144
145                    let total_size = remote_raw_perf.stat().await?.size.unwrap();
146                    let mut remote_tokio = remote_raw_perf.compat();
147
148                    use tokio::io::AsyncWriteExt;
149                    let mut index = 0;
150                    loop {
151                        let mut buffer = [0; 16 * 1024];
152                        let n = remote_tokio.read(&mut buffer).await?;
153                        if n == 0 {
154                            break;
155                        }
156                        local_raw_perf.write_all(&buffer[..n]).await?;
157                        index += n;
158                        progress(((index as f64 / total_size as f64) * 100.0) as u64);
159                    }
160
161                    Ok::<(), anyhow::Error>(())
162                })
163                .await?;
164            }
165
166            let mut script_channel = session.channel_session().await?;
167            let mut fold_er = Folder::from(tracing.fold_perf_options.clone().unwrap_or_default());
168
169            let fold_data = ProgressTracker::leaf("perf script & folding", async move {
170                let mut stderr_lines = FuturesBufReader::new(script_channel.stderr()).lines();
171                let stdout = script_channel.stream(0);
172
173                // Pattern on `()` to make sure no `Result`s are ignored.
174                let ((), fold_data, ()) = tokio::try_join!(
175                    async move {
176                        // Log stderr.
177                        while let Some(Ok(s)) = stderr_lines.next().await {
178                            ProgressTracker::eprintln(format!("[perf stderr] {s}"));
179                        }
180                        Result::<_>::Ok(())
181                    },
182                    async move {
183                        // Download perf output and fold.
184                        tokio::task::spawn_blocking(move || {
185                            let mut fold_data = Vec::new();
186                            fold_er.collapse(
187                                SyncIoBridge::new(TokioBufReader::new(stdout)),
188                                &mut fold_data,
189                            )?;
190                            Ok(fold_data)
191                        })
192                        .await?
193                    },
194                    async move {
195                        // Run command (last!).
196                        script_channel
197                            .exec(&format!("perf script --symfs=/ -i {PERF_OUTFILE}"))
198                            .await?;
199                        Ok(())
200                    },
201                )?;
202                Result::<_>::Ok(fold_data)
203            })
204            .await?;
205
206            self.tracing_results = Some(TracingResults {
207                folded_data: fold_data.clone(),
208            });
209
210            handle_fold_data(tracing, fold_data).await?;
211        };
212
213        Ok(())
214    }
215}
216
217impl Drop for LaunchedSshBinary {
218    fn drop(&mut self) {
219        if let Some(session) = self.session.take() {
220            tokio::task::block_in_place(|| {
221                Handle::current().block_on(session.disconnect(None, "", None))
222            })
223            .unwrap();
224        }
225    }
226}
227
228#[async_trait]
229pub trait LaunchedSshHost: Send + Sync {
230    fn get_internal_ip(&self) -> String;
231    fn get_external_ip(&self) -> Option<String>;
232    fn get_cloud_provider(&self) -> String;
233    fn resource_result(&self) -> &Arc<ResourceResult>;
234    fn ssh_user(&self) -> &str;
235
236    fn ssh_key_path(&self) -> PathBuf {
237        self.resource_result()
238            .terraform
239            .deployment_folder
240            .as_ref()
241            .unwrap()
242            .path()
243            .join(".ssh")
244            .join("vm_instance_ssh_key_pem")
245    }
246
247    fn server_config(&self, bind_type: &ServerStrategy) -> ServerBindConfig {
248        match bind_type {
249            ServerStrategy::UnixSocket => ServerBindConfig::UnixSocket,
250            ServerStrategy::InternalTcpPort => {
251                ServerBindConfig::TcpPort(self.get_internal_ip().clone())
252            }
253            ServerStrategy::ExternalTcpPort(_) => todo!(),
254            ServerStrategy::Demux(demux) => {
255                let mut config_map = HashMap::new();
256                for (key, underlying) in demux {
257                    config_map.insert(*key, LaunchedSshHost::server_config(self, underlying));
258                }
259
260                ServerBindConfig::Demux(config_map)
261            }
262            ServerStrategy::Merge(merge) => {
263                let mut configs = vec![];
264                for underlying in merge {
265                    configs.push(LaunchedSshHost::server_config(self, underlying));
266                }
267
268                ServerBindConfig::Merge(configs)
269            }
270            ServerStrategy::Tagged(underlying, id) => ServerBindConfig::Tagged(
271                Box::new(LaunchedSshHost::server_config(self, underlying)),
272                *id,
273            ),
274            ServerStrategy::Null => ServerBindConfig::Null,
275        }
276    }
277
278    async fn open_ssh_session(&self) -> Result<AsyncSession<TcpStream>> {
279        let target_addr = SocketAddr::new(
280            self.get_external_ip()
281                .as_ref()
282                .context(
283                    self.get_cloud_provider()
284                        + " host must be configured with an external IP to launch binaries",
285                )?
286                .parse()
287                .unwrap(),
288            22,
289        );
290
291        let res = ProgressTracker::leaf(
292            format!(
293                "connecting to host @ {}",
294                self.get_external_ip().as_ref().unwrap()
295            ),
296            async_retry(
297                &|| async {
298                    let mut config = SessionConfiguration::new();
299                    config.set_compress(true);
300
301                    let mut session =
302                        AsyncSession::<TcpStream>::connect(target_addr, Some(config)).await?;
303
304                    session.handshake().await?;
305
306                    session
307                        .userauth_pubkey_file(
308                            self.ssh_user(),
309                            None,
310                            self.ssh_key_path().as_path(),
311                            None,
312                        )
313                        .await?;
314
315                    Ok(session)
316                },
317                10,
318                Duration::from_secs(1),
319            ),
320        )
321        .await?;
322
323        Ok(res)
324    }
325}
326
327async fn create_channel(session: &AsyncSession<TcpStream>) -> Result<AsyncChannel<TcpStream>> {
328    async_retry(
329        &|| async {
330            Ok(tokio::time::timeout(Duration::from_secs(60), session.channel_session()).await??)
331        },
332        10,
333        Duration::from_secs(1),
334    )
335    .await
336}
337
338#[async_trait]
339impl<T: LaunchedSshHost> LaunchedHost for T {
340    fn server_config(&self, bind_type: &ServerStrategy) -> ServerBindConfig {
341        LaunchedSshHost::server_config(self, bind_type)
342    }
343
344    async fn copy_binary(&self, binary: &BuildOutput) -> Result<()> {
345        let session = self.open_ssh_session().await?;
346
347        let sftp = async_retry(
348            &|| async { Ok(session.sftp().await?) },
349            10,
350            Duration::from_secs(1),
351        )
352        .await?;
353
354        // we may be deploying multiple binaries, so give each a unique name
355        let unique_name = &binary.unique_id;
356
357        let user = self.ssh_user();
358        let binary_path = PathBuf::from(format!("/home/{user}/hydro-{unique_name}"));
359
360        if sftp.stat(&binary_path).await.is_err() {
361            let random = nanoid!(8);
362            let temp_path = PathBuf::from(format!("/home/{user}/hydro-{random}"));
363            let sftp = &sftp;
364
365            ProgressTracker::progress_leaf(
366                format!("uploading binary to {}", binary_path.display()),
367                |set_progress, _| {
368                    let binary = &binary;
369                    let binary_path = &binary_path;
370                    async move {
371                        let mut created_file = sftp.create(&temp_path).await?;
372
373                        let mut index = 0;
374                        while index < binary.bin_data.len() {
375                            let written = created_file
376                                .write(
377                                    &binary.bin_data[index
378                                        ..std::cmp::min(index + 128 * 1024, binary.bin_data.len())],
379                                )
380                                .await?;
381                            index += written;
382                            set_progress(
383                                ((index as f64 / binary.bin_data.len() as f64) * 100.0) as u64,
384                            );
385                        }
386                        let mut orig_file_stat = sftp.stat(&temp_path).await?;
387                        orig_file_stat.perm = Some(0o755); // allow the copied binary to be executed by anyone
388                        created_file.setstat(orig_file_stat).await?;
389                        created_file.close().await?;
390                        drop(created_file);
391
392                        match sftp.rename(&temp_path, binary_path, None).await {
393                            Ok(_) => {}
394                            Err(async_ssh2_lite::Error::Ssh2(e))
395                                if e.code() == ErrorCode::SFTP(4) =>
396                            {
397                                // file already exists
398                                sftp.unlink(&temp_path).await?;
399                            }
400                            Err(e) => return Err(e.into()),
401                        }
402
403                        anyhow::Ok(())
404                    }
405                },
406            )
407            .await?;
408        }
409        drop(sftp);
410
411        Ok(())
412    }
413
414    async fn launch_binary(
415        &self,
416        id: String,
417        binary: &BuildOutput,
418        args: &[String],
419        tracing: Option<TracingOptions>,
420    ) -> Result<Box<dyn LaunchedBinary>> {
421        let session = self.open_ssh_session().await?;
422
423        let unique_name = &binary.unique_id;
424
425        let user = self.ssh_user();
426        let binary_path = PathBuf::from(format!("/home/{user}/hydro-{unique_name}"));
427
428        let channel = ProgressTracker::leaf(format!("launching binary {}", binary_path.display()),
429            async {
430                let mut channel = create_channel(&session).await?;
431
432                let mut command = binary_path.to_str().unwrap().to_owned();
433                for arg in args{
434                    command.push(' ');
435                    command.push_str(&shell_escape::unix::escape(Cow::Borrowed(arg)))
436                }
437                // Launch with tracing if specified, also copy local binary to expected place for perf report to work
438                if let Some(TracingOptions { frequency, setup_command, .. }) = tracing.clone() {
439
440                    // Run setup command
441                    if let Some(setup_command) = setup_command {
442                        let mut setup_channel = create_channel(&session).await?;
443                        setup_channel
444                            .exec(&setup_command)
445                            .await?;
446
447                        // log outputs
448                        let mut setup_stdout = FuturesBufReader::new(setup_channel.stream(0)).lines();
449                        while let Some(line) = setup_stdout.next().await {
450                            ProgressTracker::eprintln(format!("[install perf] {}", line.unwrap()));
451                        }
452
453                        setup_channel.wait_eof().await?;
454                        let exit_code = setup_channel.exit_status()?;
455                        setup_channel.wait_close().await?;
456                        if exit_code != 0 {
457                            anyhow::bail!("Failed to install perf on remote host");
458                        }
459                    }
460
461                    // Attach perf to the command
462                    // Note: `LaunchedSshHost` assumes `perf` on linux.
463                    command = format!(
464                        "perf record -F {frequency} -e cycles:u --call-graph dwarf,65528 -o {PERF_OUTFILE} {command}",
465                    );
466                }
467                channel.exec(&command).await?;
468                anyhow::Ok(channel)
469            }
470        )
471        .await?;
472
473        let (stdin_sender, mut stdin_receiver) = mpsc::unbounded_channel::<String>();
474        let mut stdin = channel.stream(0); // stream 0 is stdout/stdin, we use it for stdin
475        tokio::spawn(async move {
476            while let Some(line) = stdin_receiver.recv().await {
477                if stdin.write_all(line.as_bytes()).await.is_err() {
478                    break;
479                }
480
481                stdin.flush().await.unwrap();
482            }
483        });
484
485        let id_clone = id.clone();
486        let (stdout_deploy_receivers, stdout_receivers) =
487            prioritized_broadcast(FuturesBufReader::new(channel.stream(0)).lines(), move |s| {
488                ProgressTracker::println(format!("[{id_clone}] {s}"));
489            });
490        let (_, stderr_receivers) =
491            prioritized_broadcast(FuturesBufReader::new(channel.stderr()).lines(), move |s| {
492                ProgressTracker::println(format!("[{id} stderr] {s}"));
493            });
494
495        Ok(Box::new(LaunchedSshBinary {
496            _resource_result: self.resource_result().clone(),
497            session: Some(session),
498            channel,
499            stdin_sender,
500            stdout_deploy_receivers,
501            stdout_receivers,
502            stderr_receivers,
503            tracing,
504            tracing_results: None,
505        }))
506    }
507
508    async fn forward_port(&self, addr: &SocketAddr) -> Result<SocketAddr> {
509        let session = self.open_ssh_session().await?;
510
511        let local_port = TcpListener::bind("127.0.0.1:0").await?;
512        let local_addr = local_port.local_addr()?;
513
514        let internal_ip = addr.ip().to_string();
515        let port = addr.port();
516
517        tokio::spawn(async move {
518            #[expect(clippy::never_loop, reason = "tcp accept loop pattern")]
519            while let Ok((mut local_stream, _)) = local_port.accept().await {
520                let mut channel = session
521                    .channel_direct_tcpip(&internal_ip, port, None)
522                    .await
523                    .unwrap();
524                let _ = tokio::io::copy_bidirectional(&mut local_stream, &mut channel).await;
525                break;
526                // TODO(shadaj): we should be returning an Arc so that we know
527                // if anyone wants to connect to this forwarded port
528            }
529
530            ProgressTracker::println("[hydro] closing forwarded port");
531        });
532
533        Ok(local_addr)
534    }
535}