Skip to main content

hydro_deploy/
ssh.rs

1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::path::PathBuf;
4use std::sync::Arc;
5#[cfg(feature = "profile-folding")]
6use std::sync::OnceLock;
7use std::time::Duration;
8
9use anyhow::{Context as _, Result};
10use async_ssh2_russh::russh::client::{Config, Handler};
11use async_ssh2_russh::russh::{Disconnect, compression};
12use async_ssh2_russh::russh_sftp::protocol::{Status, StatusCode};
13use async_ssh2_russh::sftp::SftpError;
14use async_ssh2_russh::{AsyncChannel, AsyncSession, NoCheckHandler};
15use async_trait::async_trait;
16use hydro_deploy_integration::ServerBindConfig;
17#[cfg(feature = "profile-folding")]
18use inferno::collapse::Collapse;
19#[cfg(feature = "profile-folding")]
20use inferno::collapse::perf::Folder;
21use nanoid::nanoid;
22use tokio::fs::File;
23#[cfg(feature = "profile-folding")]
24use tokio::io::BufReader;
25use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpListener;
27use tokio::sync::{mpsc, oneshot};
28use tokio_stream::StreamExt;
29use tokio_stream::wrappers::LinesStream;
30#[cfg(feature = "profile-folding")]
31use tokio_util::io::SyncIoBridge;
32
33#[cfg(feature = "profile-folding")]
34use crate::TracingResults;
35use crate::progress::ProgressTracker;
36use crate::rust_crate::build::BuildOutput;
37#[cfg(feature = "profile-folding")]
38use crate::rust_crate::flamegraph::handle_fold_data;
39use crate::rust_crate::tracing_options::TracingOptions;
40use crate::util::{PriorityBroadcast, async_retry, prioritized_broadcast};
41use crate::{BaseServerStrategy, LaunchedBinary, LaunchedHost, ResourceResult};
42
43const PERF_OUTFILE: &str = "__profile.perf.data";
44
45struct LaunchedSshBinary {
46    _resource_result: Arc<ResourceResult>,
47    // TODO(mingwei): instead of using `NoCheckHandler`, we should check the server's public key
48    // fingerprint (get it somehow via terraform), but ssh `publickey` authentication already
49    // generally prevents MITM attacks.
50    session: Option<AsyncSession<NoCheckHandler>>,
51    channel: AsyncChannel,
52    stdin_sender: mpsc::UnboundedSender<String>,
53    stdout_broadcast: PriorityBroadcast,
54    stderr_broadcast: PriorityBroadcast,
55    tracing: Option<TracingOptions>,
56    #[cfg(feature = "profile-folding")]
57    tracing_results: OnceLock<TracingResults>,
58}
59
60#[async_trait]
61impl LaunchedBinary for LaunchedSshBinary {
62    fn stdin(&self) -> mpsc::UnboundedSender<String> {
63        self.stdin_sender.clone()
64    }
65
66    fn deploy_stdout(&self) -> oneshot::Receiver<String> {
67        self.stdout_broadcast.receive_priority()
68    }
69
70    fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
71        self.stdout_broadcast.receive(None)
72    }
73
74    fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
75        self.stderr_broadcast.receive(None)
76    }
77
78    fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
79        self.stdout_broadcast.receive(Some(prefix))
80    }
81
82    fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
83        self.stderr_broadcast.receive(Some(prefix))
84    }
85
86    #[cfg(feature = "profile-folding")]
87    fn tracing_results(&self) -> Option<&TracingResults> {
88        self.tracing_results.get()
89    }
90
91    fn exit_code(&self) -> Option<i32> {
92        // until the program exits, the exit status is meaningless
93        self.channel
94            .recv_exit_status()
95            .try_get()
96            .map(|&ec| ec as _)
97            .ok()
98    }
99
100    async fn wait(&self) -> Result<i32> {
101        let _ = self.channel.closed().wait().await;
102        Ok(*self.channel.recv_exit_status().try_get()? as _)
103    }
104
105    async fn stop(&self) -> Result<()> {
106        if !self.channel.closed().is_done() {
107            ProgressTracker::leaf("force stopping", async {
108                // self.channel.signal(russh::Sig::INT).await?; // `^C`
109                self.channel.eof().await?; // Send EOF.
110                self.channel.close().await?; // Close the channel.
111                self.channel.closed().wait().await;
112                Result::<_>::Ok(())
113            })
114            .await?;
115        }
116
117        // Run perf post-processing and download perf output.
118        if let Some(tracing) = self.tracing.as_ref() {
119            #[cfg(feature = "profile-folding")]
120            assert!(
121                self.tracing_results.get().is_none(),
122                "`tracing_results` already set! Was `stop()` called twice? This is a bug."
123            );
124
125            let session = self.session.as_ref().unwrap();
126            if let Some(local_raw_perf) = tracing.perf_raw_outfile.as_ref() {
127                ProgressTracker::progress_leaf("downloading perf data", |progress, _| async move {
128                    let sftp =
129                        async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
130
131                    let mut remote_raw_perf = sftp.open(PERF_OUTFILE).await?;
132                    let mut local_raw_perf = File::create(local_raw_perf).await?;
133
134                    let total_size = remote_raw_perf.metadata().await?.size.unwrap();
135
136                    use tokio::io::AsyncWriteExt;
137                    let mut index = 0;
138                    loop {
139                        let mut buffer = [0; 16 * 1024];
140                        let n = remote_raw_perf.read(&mut buffer).await?;
141                        if n == 0 {
142                            break;
143                        }
144                        local_raw_perf.write_all(&buffer[..n]).await?;
145                        index += n;
146                        progress(((index as f64 / total_size as f64) * 100.0) as u64);
147                    }
148
149                    Ok::<(), anyhow::Error>(())
150                })
151                .await?;
152            }
153
154            #[cfg(feature = "profile-folding")]
155            let script_channel = session.open_channel().await?;
156            #[cfg(feature = "profile-folding")]
157            let mut fold_er = Folder::from(tracing.fold_perf_options.clone().unwrap_or_default());
158
159            #[cfg(feature = "profile-folding")]
160            let fold_data = ProgressTracker::leaf("perf script & folding", async move {
161                let mut stderr_lines = script_channel.stderr().lines();
162                let stdout = script_channel.stdout();
163
164                // Pattern on `()` to make sure no `Result`s are ignored.
165                let ((), fold_data, ()) = tokio::try_join!(
166                    async move {
167                        // Log stderr.
168                        while let Ok(Some(s)) = stderr_lines.next_line().await {
169                            ProgressTracker::eprintln(format!("[perf stderr] {s}"));
170                        }
171                        Result::<_>::Ok(())
172                    },
173                    async move {
174                        // Download perf output and fold.
175                        tokio::task::spawn_blocking(move || {
176                            let mut fold_data = Vec::new();
177                            fold_er.collapse(
178                                SyncIoBridge::new(BufReader::new(stdout)),
179                                &mut fold_data,
180                            )?;
181                            Ok(fold_data)
182                        })
183                        .await?
184                    },
185                    async move {
186                        // Run command (last!).
187                        script_channel
188                            .exec(false, format!("perf script --symfs=/ -i {PERF_OUTFILE}"))
189                            .await?;
190                        Ok(())
191                    },
192                )?;
193                Result::<_>::Ok(fold_data)
194            })
195            .await?;
196
197            #[cfg(feature = "profile-folding")]
198            self.tracing_results
199                .set(TracingResults {
200                    folded_data: fold_data.clone(),
201                })
202                .expect("`tracing_results` already set! This is a bug.");
203
204            #[cfg(feature = "profile-folding")]
205            handle_fold_data(tracing, fold_data).await?;
206        };
207
208        Ok(())
209    }
210}
211
212impl Drop for LaunchedSshBinary {
213    fn drop(&mut self) {
214        if let Some(session) = self.session.take() {
215            tokio::task::block_in_place(|| {
216                tokio::runtime::Handle::current().block_on(session.disconnect(
217                    Disconnect::ByApplication,
218                    "",
219                    "",
220                ))
221            })
222            .unwrap();
223        }
224    }
225}
226
227#[async_trait]
228pub trait LaunchedSshHost: Send + Sync {
229    fn get_internal_ip(&self) -> &str;
230    fn get_external_ip(&self) -> Option<&str>;
231    fn get_cloud_provider(&self) -> &'static str;
232    fn resource_result(&self) -> &Arc<ResourceResult>;
233    fn ssh_user(&self) -> &str;
234
235    fn ssh_key_path(&self) -> PathBuf {
236        self.resource_result()
237            .terraform
238            .deployment_folder
239            .as_ref()
240            .unwrap()
241            .path()
242            .join(".ssh")
243            .join("vm_instance_ssh_key_pem")
244    }
245
246    async fn open_ssh_session(&self) -> Result<AsyncSession<NoCheckHandler>> {
247        let target_addr = SocketAddr::new(
248            self.get_external_ip()
249                .context(format!(
250                    "{} host must be configured with an external IP to launch binaries",
251                    self.get_cloud_provider()
252                ))?
253                .parse()
254                .unwrap(),
255            22,
256        );
257
258        let res = ProgressTracker::leaf(
259            format!("connecting to host @ {}", self.get_external_ip().unwrap()),
260            async_retry(
261                &|| async {
262                    let mut config = Config::default();
263                    config.preferred.compression = (&[
264                        compression::ZLIB,
265                        compression::ZLIB_LEGACY,
266                        compression::NONE,
267                    ])
268                        .into();
269                    AsyncSession::connect_publickey(
270                        config,
271                        target_addr,
272                        self.ssh_user(),
273                        self.ssh_key_path(),
274                    )
275                    .await
276                },
277                10,
278                Duration::from_secs(1),
279            ),
280        )
281        .await?;
282
283        Ok(res)
284    }
285}
286
287async fn create_channel<H>(session: &AsyncSession<H>) -> Result<AsyncChannel>
288where
289    H: 'static + Handler,
290{
291    async_retry(
292        &|| async {
293            Ok(tokio::time::timeout(Duration::from_secs(60), session.open_channel()).await??)
294        },
295        10,
296        Duration::from_secs(1),
297    )
298    .await
299}
300
301#[async_trait]
302impl<T: LaunchedSshHost> LaunchedHost for T {
303    fn base_server_config(&self, bind_type: &BaseServerStrategy) -> ServerBindConfig {
304        match bind_type {
305            BaseServerStrategy::UnixSocket => ServerBindConfig::UnixSocket,
306            BaseServerStrategy::InternalTcpPort(hint) => {
307                ServerBindConfig::TcpPort(self.get_internal_ip().to_owned(), *hint)
308            }
309            BaseServerStrategy::ExternalTcpPort(_) => todo!(),
310        }
311    }
312
313    async fn copy_binary(&self, binary: &BuildOutput) -> Result<()> {
314        let session = self.open_ssh_session().await?;
315
316        let sftp = async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
317
318        let user = self.ssh_user();
319        // we may be deploying multiple binaries, so give each a unique name
320        let binary_path = format!("/home/{user}/hydro-{}", binary.unique_id());
321
322        if sftp.metadata(&binary_path).await.is_err() {
323            let random = nanoid!(8);
324            let temp_path = format!("/home/{user}/hydro-{random}");
325            let sftp = &sftp;
326
327            ProgressTracker::progress_leaf(
328                format!("uploading binary to {}", binary_path),
329                |set_progress, _| {
330                    async move {
331                        let mut created_file = sftp.create(&temp_path).await?;
332
333                        let mut index = 0;
334                        while index < binary.bin_data.len() {
335                            let written = created_file
336                                .write(
337                                    &binary.bin_data[index
338                                        ..std::cmp::min(index + 128 * 1024, binary.bin_data.len())],
339                                )
340                                .await?;
341                            index += written;
342                            set_progress(
343                                ((index as f64 / binary.bin_data.len() as f64) * 100.0) as u64,
344                            );
345                        }
346                        let mut orig_file_stat = sftp.metadata(&temp_path).await?;
347                        orig_file_stat.permissions = Some(0o755); // allow the copied binary to be executed by anyone
348                        created_file.set_metadata(orig_file_stat).await?;
349                        created_file.sync_all().await?;
350                        drop(created_file);
351
352                        match sftp.rename(&temp_path, binary_path).await {
353                            Ok(_) => {}
354                            Err(SftpError::Status(Status {
355                                status_code: StatusCode::Failure, // SSH_FXP_STATUS = 4
356                                ..
357                            })) => {
358                                // file already exists
359                                sftp.remove_file(temp_path).await?;
360                            }
361                            Err(e) => return Err(e.into()),
362                        }
363
364                        anyhow::Ok(())
365                    }
366                },
367            )
368            .await?;
369        }
370        sftp.close().await?;
371
372        Ok(())
373    }
374
375    async fn launch_binary(
376        &self,
377        id: String,
378        binary: &BuildOutput,
379        args: &[String],
380        tracing: Option<TracingOptions>,
381        env: &HashMap<String, String>,
382    ) -> Result<Box<dyn LaunchedBinary>> {
383        let session = self.open_ssh_session().await?;
384
385        let user = self.ssh_user();
386        let binary_path = PathBuf::from(format!("/home/{user}/hydro-{}", binary.unique_id()));
387
388        let mut command = String::new();
389        // Prepend env variables
390        for (k, v) in env {
391            command.push_str(&format!("{}={} ", k, shell_escape::unix::escape(v.into())));
392        }
393
394        command.push_str(binary_path.to_str().unwrap());
395        for arg in args {
396            command.push(' ');
397            command.push_str(&shell_escape::unix::escape(arg.into()))
398        }
399
400        // Launch with tracing if specified.
401        if let Some(TracingOptions {
402            frequency,
403            setup_command,
404            ..
405        }) = tracing.clone()
406        {
407            let id_clone = id.clone();
408            ProgressTracker::leaf("install perf", async {
409                // Run setup command
410                if let Some(setup_command) = setup_command {
411                    let setup_channel = create_channel(&session).await?;
412                    let (setup_stdout, setup_stderr) =
413                        (setup_channel.stdout(), setup_channel.stderr());
414                    setup_channel.exec(false, &*setup_command).await?;
415
416                    // log outputs
417                    let mut output_lines = LinesStream::new(setup_stdout.lines())
418                        .merge(LinesStream::new(setup_stderr.lines()));
419                    while let Some(line) = output_lines.next().await {
420                        ProgressTracker::eprintln(format!(
421                            "[{} install perf] {}",
422                            id_clone,
423                            line.unwrap()
424                        ));
425                    }
426
427                    setup_channel.closed().wait().await;
428                    let exit_code = setup_channel.recv_exit_status().try_get();
429                    if Ok(&0) != exit_code {
430                        anyhow::bail!("Failed to install perf on remote host");
431                    }
432                }
433                Ok(())
434            })
435            .await?;
436
437            // Attach perf to the command
438            // Note: `LaunchedSshHost` assumes `perf` on linux.
439            command = format!(
440                "perf record -F {frequency} -e cycles:u --call-graph dwarf,65528 -o {PERF_OUTFILE} {command}",
441            );
442        }
443
444        let (channel, stdout, stderr) = ProgressTracker::leaf(
445            format!("launching binary {}", binary_path.display()),
446            async {
447                let channel = create_channel(&session).await?;
448                // Make sure to begin reading stdout/stderr before running the command.
449                let (stdout, stderr) = (channel.stdout(), channel.stderr());
450                channel.exec(false, command).await?;
451                anyhow::Ok((channel, stdout, stderr))
452            },
453        )
454        .await?;
455
456        let (stdin_sender, mut stdin_receiver) = mpsc::unbounded_channel::<String>();
457        let mut stdin = channel.stdin();
458
459        tokio::spawn(async move {
460            while let Some(line) = stdin_receiver.recv().await {
461                if stdin.write_all(line.as_bytes()).await.is_err() {
462                    break;
463                }
464                stdin.flush().await.unwrap();
465            }
466        });
467
468        let id_clone = id.clone();
469        let stdout_broadcast = prioritized_broadcast(LinesStream::new(stdout.lines()), move |s| {
470            ProgressTracker::println(format!("[{id_clone}] {s}"));
471        });
472        let stderr_broadcast = prioritized_broadcast(LinesStream::new(stderr.lines()), move |s| {
473            ProgressTracker::println(format!("[{id} stderr] {s}"));
474        });
475
476        Ok(Box::new(LaunchedSshBinary {
477            _resource_result: self.resource_result().clone(),
478            session: Some(session),
479            channel,
480            stdin_sender,
481            stdout_broadcast,
482            stderr_broadcast,
483            tracing,
484            #[cfg(feature = "profile-folding")]
485            tracing_results: OnceLock::new(),
486        }))
487    }
488
489    async fn forward_port(&self, addr: &SocketAddr) -> Result<SocketAddr> {
490        let session = self.open_ssh_session().await?;
491
492        let local_port = TcpListener::bind("127.0.0.1:0").await?;
493        let local_addr = local_port.local_addr()?;
494
495        let internal_ip = addr.ip().to_string();
496        let port = addr.port();
497
498        tokio::spawn(async move {
499            #[expect(clippy::never_loop, reason = "tcp accept loop pattern")]
500            while let Ok((mut local_stream, _)) = local_port.accept().await {
501                let mut channel = session
502                    .channel_open_direct_tcpip(internal_ip, port.into(), "127.0.0.1", 22)
503                    .await
504                    .unwrap()
505                    .into_stream();
506                let _ = tokio::io::copy_bidirectional(&mut local_stream, &mut channel).await;
507                break;
508                // TODO(shadaj): we should be returning an Arc so that we know
509                // if anyone wants to connect to this forwarded port
510            }
511
512            ProgressTracker::println("[hydro] closing forwarded port");
513        });
514
515        Ok(local_addr)
516    }
517}