hydro_deploy/
ssh.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::sync::Arc;
4#[cfg(feature = "profile-folding")]
5use std::sync::OnceLock;
6use std::time::Duration;
7
8use anyhow::{Context as _, Result};
9use async_ssh2_russh::russh::client::{Config, Handler};
10use async_ssh2_russh::russh::{Disconnect, compression};
11use async_ssh2_russh::russh_sftp::protocol::{Status, StatusCode};
12use async_ssh2_russh::sftp::SftpError;
13use async_ssh2_russh::{AsyncChannel, AsyncSession, NoCheckHandler};
14use async_trait::async_trait;
15use hydro_deploy_integration::ServerBindConfig;
16#[cfg(feature = "profile-folding")]
17use inferno::collapse::Collapse;
18#[cfg(feature = "profile-folding")]
19use inferno::collapse::perf::Folder;
20use nanoid::nanoid;
21use tokio::fs::File;
22#[cfg(feature = "profile-folding")]
23use tokio::io::BufReader;
24use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
25use tokio::net::TcpListener;
26use tokio::sync::{mpsc, oneshot};
27use tokio_stream::StreamExt;
28use tokio_stream::wrappers::LinesStream;
29#[cfg(feature = "profile-folding")]
30use tokio_util::io::SyncIoBridge;
31
32#[cfg(feature = "profile-folding")]
33use crate::TracingResults;
34use crate::progress::ProgressTracker;
35use crate::rust_crate::build::BuildOutput;
36#[cfg(feature = "profile-folding")]
37use crate::rust_crate::flamegraph::handle_fold_data;
38use crate::rust_crate::tracing_options::TracingOptions;
39use crate::util::{PriorityBroadcast, async_retry, prioritized_broadcast};
40use crate::{BaseServerStrategy, LaunchedBinary, LaunchedHost, ResourceResult};
41
42const PERF_OUTFILE: &str = "__profile.perf.data";
43
44struct LaunchedSshBinary {
45    _resource_result: Arc<ResourceResult>,
46    // TODO(mingwei): instead of using `NoCheckHandler`, we should check the server's public key
47    // fingerprint (get it somehow via terraform), but ssh `publickey` authentication already
48    // generally prevents MITM attacks.
49    session: Option<AsyncSession<NoCheckHandler>>,
50    channel: AsyncChannel,
51    stdin_sender: mpsc::UnboundedSender<String>,
52    stdout_broadcast: PriorityBroadcast,
53    stderr_broadcast: PriorityBroadcast,
54    tracing: Option<TracingOptions>,
55    #[cfg(feature = "profile-folding")]
56    tracing_results: OnceLock<TracingResults>,
57}
58
59#[async_trait]
60impl LaunchedBinary for LaunchedSshBinary {
61    fn stdin(&self) -> mpsc::UnboundedSender<String> {
62        self.stdin_sender.clone()
63    }
64
65    fn deploy_stdout(&self) -> oneshot::Receiver<String> {
66        self.stdout_broadcast.receive_priority()
67    }
68
69    fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
70        self.stdout_broadcast.receive(None)
71    }
72
73    fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
74        self.stderr_broadcast.receive(None)
75    }
76
77    fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
78        self.stdout_broadcast.receive(Some(prefix))
79    }
80
81    fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
82        self.stderr_broadcast.receive(Some(prefix))
83    }
84
85    #[cfg(feature = "profile-folding")]
86    fn tracing_results(&self) -> Option<&TracingResults> {
87        self.tracing_results.get()
88    }
89
90    fn exit_code(&self) -> Option<i32> {
91        // until the program exits, the exit status is meaningless
92        self.channel
93            .recv_exit_status()
94            .try_get()
95            .map(|&ec| ec as _)
96            .ok()
97    }
98
99    async fn wait(&self) -> Result<i32> {
100        let _ = self.channel.closed().wait().await;
101        Ok(*self.channel.recv_exit_status().try_get()? as _)
102    }
103
104    async fn stop(&self) -> Result<()> {
105        if !self.channel.closed().is_done() {
106            ProgressTracker::leaf("force stopping", async {
107                // self.channel.signal(russh::Sig::INT).await?; // `^C`
108                self.channel.eof().await?; // Send EOF.
109                self.channel.close().await?; // Close the channel.
110                self.channel.closed().wait().await;
111                Result::<_>::Ok(())
112            })
113            .await?;
114        }
115
116        // Run perf post-processing and download perf output.
117        if let Some(tracing) = self.tracing.as_ref() {
118            #[cfg(feature = "profile-folding")]
119            assert!(
120                self.tracing_results.get().is_none(),
121                "`tracing_results` already set! Was `stop()` called twice? This is a bug."
122            );
123
124            let session = self.session.as_ref().unwrap();
125            if let Some(local_raw_perf) = tracing.perf_raw_outfile.as_ref() {
126                ProgressTracker::progress_leaf("downloading perf data", |progress, _| async move {
127                    let sftp =
128                        async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
129
130                    let mut remote_raw_perf = sftp.open(PERF_OUTFILE).await?;
131                    let mut local_raw_perf = File::create(local_raw_perf).await?;
132
133                    let total_size = remote_raw_perf.metadata().await?.size.unwrap();
134
135                    use tokio::io::AsyncWriteExt;
136                    let mut index = 0;
137                    loop {
138                        let mut buffer = [0; 16 * 1024];
139                        let n = remote_raw_perf.read(&mut buffer).await?;
140                        if n == 0 {
141                            break;
142                        }
143                        local_raw_perf.write_all(&buffer[..n]).await?;
144                        index += n;
145                        progress(((index as f64 / total_size as f64) * 100.0) as u64);
146                    }
147
148                    Ok::<(), anyhow::Error>(())
149                })
150                .await?;
151            }
152
153            #[cfg(feature = "profile-folding")]
154            let script_channel = session.open_channel().await?;
155            #[cfg(feature = "profile-folding")]
156            let mut fold_er = Folder::from(tracing.fold_perf_options.clone().unwrap_or_default());
157
158            #[cfg(feature = "profile-folding")]
159            let fold_data = ProgressTracker::leaf("perf script & folding", async move {
160                let mut stderr_lines = script_channel.stderr().lines();
161                let stdout = script_channel.stdout();
162
163                // Pattern on `()` to make sure no `Result`s are ignored.
164                let ((), fold_data, ()) = tokio::try_join!(
165                    async move {
166                        // Log stderr.
167                        while let Ok(Some(s)) = stderr_lines.next_line().await {
168                            ProgressTracker::eprintln(format!("[perf stderr] {s}"));
169                        }
170                        Result::<_>::Ok(())
171                    },
172                    async move {
173                        // Download perf output and fold.
174                        tokio::task::spawn_blocking(move || {
175                            let mut fold_data = Vec::new();
176                            fold_er.collapse(
177                                SyncIoBridge::new(BufReader::new(stdout)),
178                                &mut fold_data,
179                            )?;
180                            Ok(fold_data)
181                        })
182                        .await?
183                    },
184                    async move {
185                        // Run command (last!).
186                        script_channel
187                            .exec(false, format!("perf script --symfs=/ -i {PERF_OUTFILE}"))
188                            .await?;
189                        Ok(())
190                    },
191                )?;
192                Result::<_>::Ok(fold_data)
193            })
194            .await?;
195
196            #[cfg(feature = "profile-folding")]
197            self.tracing_results
198                .set(TracingResults {
199                    folded_data: fold_data.clone(),
200                })
201                .expect("`tracing_results` already set! This is a bug.");
202
203            #[cfg(feature = "profile-folding")]
204            handle_fold_data(tracing, fold_data).await?;
205        };
206
207        Ok(())
208    }
209}
210
211impl Drop for LaunchedSshBinary {
212    fn drop(&mut self) {
213        if let Some(session) = self.session.take() {
214            tokio::task::block_in_place(|| {
215                tokio::runtime::Handle::current().block_on(session.disconnect(
216                    Disconnect::ByApplication,
217                    "",
218                    "",
219                ))
220            })
221            .unwrap();
222        }
223    }
224}
225
226#[async_trait]
227pub trait LaunchedSshHost: Send + Sync {
228    fn get_internal_ip(&self) -> String;
229    fn get_external_ip(&self) -> Option<String>;
230    fn get_cloud_provider(&self) -> String;
231    fn resource_result(&self) -> &Arc<ResourceResult>;
232    fn ssh_user(&self) -> &str;
233
234    fn ssh_key_path(&self) -> PathBuf {
235        self.resource_result()
236            .terraform
237            .deployment_folder
238            .as_ref()
239            .unwrap()
240            .path()
241            .join(".ssh")
242            .join("vm_instance_ssh_key_pem")
243    }
244
245    async fn open_ssh_session(&self) -> Result<AsyncSession<NoCheckHandler>> {
246        let target_addr = SocketAddr::new(
247            self.get_external_ip()
248                .as_ref()
249                .context(
250                    self.get_cloud_provider()
251                        + " host must be configured with an external IP to launch binaries",
252                )?
253                .parse()
254                .unwrap(),
255            22,
256        );
257
258        let res = ProgressTracker::leaf(
259            format!(
260                "connecting to host @ {}",
261                self.get_external_ip().as_ref().unwrap()
262            ),
263            async_retry(
264                &|| async {
265                    let mut config = Config::default();
266                    config.preferred.compression = (&[
267                        compression::ZLIB,
268                        compression::ZLIB_LEGACY,
269                        compression::NONE,
270                    ])
271                        .into();
272                    AsyncSession::connect_publickey(
273                        config,
274                        target_addr,
275                        self.ssh_user(),
276                        self.ssh_key_path(),
277                    )
278                    .await
279                },
280                10,
281                Duration::from_secs(1),
282            ),
283        )
284        .await?;
285
286        Ok(res)
287    }
288}
289
290async fn create_channel<H>(session: &AsyncSession<H>) -> Result<AsyncChannel>
291where
292    H: 'static + Handler,
293{
294    async_retry(
295        &|| async {
296            Ok(tokio::time::timeout(Duration::from_secs(60), session.open_channel()).await??)
297        },
298        10,
299        Duration::from_secs(1),
300    )
301    .await
302}
303
304#[async_trait]
305impl<T: LaunchedSshHost> LaunchedHost for T {
306    fn base_server_config(&self, bind_type: &BaseServerStrategy) -> ServerBindConfig {
307        match bind_type {
308            BaseServerStrategy::UnixSocket => ServerBindConfig::UnixSocket,
309            BaseServerStrategy::InternalTcpPort(hint) => {
310                ServerBindConfig::TcpPort(self.get_internal_ip().clone(), *hint)
311            }
312            BaseServerStrategy::ExternalTcpPort(_) => todo!(),
313        }
314    }
315
316    async fn copy_binary(&self, binary: &BuildOutput) -> Result<()> {
317        let session = self.open_ssh_session().await?;
318
319        let sftp = async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
320
321        let user = self.ssh_user();
322        // we may be deploying multiple binaries, so give each a unique name
323        let binary_path = format!("/home/{user}/hydro-{}", binary.unique_id());
324
325        if sftp.metadata(&binary_path).await.is_err() {
326            let random = nanoid!(8);
327            let temp_path = format!("/home/{user}/hydro-{random}");
328            let sftp = &sftp;
329
330            ProgressTracker::progress_leaf(
331                format!("uploading binary to {}", binary_path),
332                |set_progress, _| {
333                    async move {
334                        let mut created_file = sftp.create(&temp_path).await?;
335
336                        let mut index = 0;
337                        while index < binary.bin_data.len() {
338                            let written = created_file
339                                .write(
340                                    &binary.bin_data[index
341                                        ..std::cmp::min(index + 128 * 1024, binary.bin_data.len())],
342                                )
343                                .await?;
344                            index += written;
345                            set_progress(
346                                ((index as f64 / binary.bin_data.len() as f64) * 100.0) as u64,
347                            );
348                        }
349                        let mut orig_file_stat = sftp.metadata(&temp_path).await?;
350                        orig_file_stat.permissions = Some(0o755); // allow the copied binary to be executed by anyone
351                        created_file.set_metadata(orig_file_stat).await?;
352                        created_file.sync_all().await?;
353                        drop(created_file);
354
355                        match sftp.rename(&temp_path, binary_path).await {
356                            Ok(_) => {}
357                            Err(SftpError::Status(Status {
358                                status_code: StatusCode::Failure, // SSH_FXP_STATUS = 4
359                                ..
360                            })) => {
361                                // file already exists
362                                sftp.remove_file(temp_path).await?;
363                            }
364                            Err(e) => return Err(e.into()),
365                        }
366
367                        anyhow::Ok(())
368                    }
369                },
370            )
371            .await?;
372        }
373        sftp.close().await?;
374
375        Ok(())
376    }
377
378    async fn launch_binary(
379        &self,
380        id: String,
381        binary: &BuildOutput,
382        args: &[String],
383        tracing: Option<TracingOptions>,
384    ) -> Result<Box<dyn LaunchedBinary>> {
385        let session = self.open_ssh_session().await?;
386
387        let user = self.ssh_user();
388        let binary_path = PathBuf::from(format!("/home/{user}/hydro-{}", binary.unique_id()));
389
390        let mut command = binary_path.to_str().unwrap().to_owned();
391        for arg in args {
392            command.push(' ');
393            command.push_str(&shell_escape::unix::escape(arg.into()))
394        }
395
396        // Launch with tracing if specified.
397        if let Some(TracingOptions {
398            frequency,
399            setup_command,
400            ..
401        }) = tracing.clone()
402        {
403            let id_clone = id.clone();
404            ProgressTracker::leaf("install perf", async {
405                // Run setup command
406                if let Some(setup_command) = setup_command {
407                    let setup_channel = create_channel(&session).await?;
408                    let (setup_stdout, setup_stderr) =
409                        (setup_channel.stdout(), setup_channel.stderr());
410                    setup_channel.exec(false, &*setup_command).await?;
411
412                    // log outputs
413                    let mut output_lines = LinesStream::new(setup_stdout.lines())
414                        .merge(LinesStream::new(setup_stderr.lines()));
415                    while let Some(line) = output_lines.next().await {
416                        ProgressTracker::eprintln(format!(
417                            "[{} install perf] {}",
418                            id_clone,
419                            line.unwrap()
420                        ));
421                    }
422
423                    setup_channel.closed().wait().await;
424                    let exit_code = setup_channel.recv_exit_status().try_get();
425                    if Ok(&0) != exit_code {
426                        anyhow::bail!("Failed to install perf on remote host");
427                    }
428                }
429                Ok(())
430            })
431            .await?;
432
433            // Attach perf to the command
434            // Note: `LaunchedSshHost` assumes `perf` on linux.
435            command = format!(
436                "perf record -F {frequency} -e cycles:u --call-graph dwarf,65528 -o {PERF_OUTFILE} {command}",
437            );
438        }
439
440        let (channel, stdout, stderr) = ProgressTracker::leaf(
441            format!("launching binary {}", binary_path.display()),
442            async {
443                let channel = create_channel(&session).await?;
444                // Make sure to begin reading stdout/stderr before running the command.
445                let (stdout, stderr) = (channel.stdout(), channel.stderr());
446                channel.exec(false, command).await?;
447                anyhow::Ok((channel, stdout, stderr))
448            },
449        )
450        .await?;
451
452        let (stdin_sender, mut stdin_receiver) = mpsc::unbounded_channel::<String>();
453        let mut stdin = channel.stdin();
454
455        tokio::spawn(async move {
456            while let Some(line) = stdin_receiver.recv().await {
457                if stdin.write_all(line.as_bytes()).await.is_err() {
458                    break;
459                }
460                stdin.flush().await.unwrap();
461            }
462        });
463
464        let id_clone = id.clone();
465        let stdout_broadcast = prioritized_broadcast(LinesStream::new(stdout.lines()), move |s| {
466            ProgressTracker::println(format!("[{id_clone}] {s}"));
467        });
468        let stderr_broadcast = prioritized_broadcast(LinesStream::new(stderr.lines()), move |s| {
469            ProgressTracker::println(format!("[{id} stderr] {s}"));
470        });
471
472        Ok(Box::new(LaunchedSshBinary {
473            _resource_result: self.resource_result().clone(),
474            session: Some(session),
475            channel,
476            stdin_sender,
477            stdout_broadcast,
478            stderr_broadcast,
479            tracing,
480            #[cfg(feature = "profile-folding")]
481            tracing_results: OnceLock::new(),
482        }))
483    }
484
485    async fn forward_port(&self, addr: &SocketAddr) -> Result<SocketAddr> {
486        let session = self.open_ssh_session().await?;
487
488        let local_port = TcpListener::bind("127.0.0.1:0").await?;
489        let local_addr = local_port.local_addr()?;
490
491        let internal_ip = addr.ip().to_string();
492        let port = addr.port();
493
494        tokio::spawn(async move {
495            #[expect(clippy::never_loop, reason = "tcp accept loop pattern")]
496            while let Ok((mut local_stream, _)) = local_port.accept().await {
497                let mut channel = session
498                    .channel_open_direct_tcpip(internal_ip, port.into(), "127.0.0.1", 22)
499                    .await
500                    .unwrap()
501                    .into_stream();
502                let _ = tokio::io::copy_bidirectional(&mut local_stream, &mut channel).await;
503                break;
504                // TODO(shadaj): we should be returning an Arc so that we know
505                // if anyone wants to connect to this forwarded port
506            }
507
508            ProgressTracker::println("[hydro] closing forwarded port");
509        });
510
511        Ok(local_addr)
512    }
513}