hydro_deploy/
ssh.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::net::SocketAddr;
4use std::path::PathBuf;
5use std::sync::{Arc, Mutex};
6use std::time::Duration;
7
8use anyhow::{Context as _, Result};
9use async_ssh2_lite::ssh2::ErrorCode;
10use async_ssh2_lite::{AsyncChannel, AsyncSession, SessionConfiguration};
11use async_trait::async_trait;
12use futures::io::BufReader as FuturesBufReader;
13use futures::{AsyncBufReadExt, AsyncWriteExt};
14use hydro_deploy_integration::ServerBindConfig;
15use inferno::collapse::Collapse;
16use inferno::collapse::perf::Folder;
17use nanoid::nanoid;
18use tokio::fs::File;
19use tokio::io::{AsyncReadExt, BufReader as TokioBufReader};
20use tokio::net::{TcpListener, TcpStream};
21use tokio::runtime::Handle;
22use tokio::sync::{mpsc, oneshot};
23use tokio_stream::StreamExt;
24use tokio_util::compat::FuturesAsyncReadCompatExt;
25use tokio_util::io::SyncIoBridge;
26
27use crate::progress::ProgressTracker;
28use crate::rust_crate::build::BuildOutput;
29use crate::rust_crate::flamegraph::handle_fold_data;
30use crate::rust_crate::tracing_options::TracingOptions;
31use crate::util::{async_retry, prioritized_broadcast};
32use crate::{LaunchedBinary, LaunchedHost, ResourceResult, ServerStrategy, TracingResults};
33
34const PERF_OUTFILE: &str = "__profile.perf.data";
35
36pub type PrefixFilteredChannel = (Option<String>, mpsc::UnboundedSender<String>);
37
38struct LaunchedSshBinary {
39    _resource_result: Arc<ResourceResult>,
40    session: Option<AsyncSession<TcpStream>>,
41    channel: AsyncChannel<TcpStream>,
42    stdin_sender: mpsc::UnboundedSender<String>,
43    stdout_receivers: Arc<Mutex<Vec<PrefixFilteredChannel>>>,
44    stdout_deploy_receivers: Arc<Mutex<Option<oneshot::Sender<String>>>>,
45    stderr_receivers: Arc<Mutex<Vec<PrefixFilteredChannel>>>,
46    tracing: Option<TracingOptions>,
47    tracing_results: Option<TracingResults>,
48}
49
50#[async_trait]
51impl LaunchedBinary for LaunchedSshBinary {
52    fn stdin(&self) -> mpsc::UnboundedSender<String> {
53        self.stdin_sender.clone()
54    }
55
56    fn deploy_stdout(&self) -> oneshot::Receiver<String> {
57        let mut receivers = self.stdout_deploy_receivers.lock().unwrap();
58
59        if receivers.is_some() {
60            panic!("Only one deploy stdout receiver is allowed at a time");
61        }
62
63        let (sender, receiver) = oneshot::channel::<String>();
64        *receivers = Some(sender);
65        receiver
66    }
67
68    fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
69        let mut receivers = self.stdout_receivers.lock().unwrap();
70        let (sender, receiver) = mpsc::unbounded_channel::<String>();
71        receivers.push((None, sender));
72        receiver
73    }
74
75    fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
76        let mut receivers = self.stderr_receivers.lock().unwrap();
77        let (sender, receiver) = mpsc::unbounded_channel::<String>();
78        receivers.push((None, sender));
79        receiver
80    }
81
82    fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
83        let mut receivers = self.stdout_receivers.lock().unwrap();
84        let (sender, receiver) = mpsc::unbounded_channel::<String>();
85        receivers.push((Some(prefix), sender));
86        receiver
87    }
88
89    fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
90        let mut receivers = self.stderr_receivers.lock().unwrap();
91        let (sender, receiver) = mpsc::unbounded_channel::<String>();
92        receivers.push((Some(prefix), sender));
93        receiver
94    }
95
96    fn tracing_results(&self) -> Option<&TracingResults> {
97        self.tracing_results.as_ref()
98    }
99
100    fn exit_code(&self) -> Option<i32> {
101        // until the program exits, the exit status is meaningless
102        if self.channel.eof() {
103            self.channel.exit_status().ok()
104        } else {
105            None
106        }
107    }
108
109    async fn wait(&mut self) -> Result<i32> {
110        self.channel.wait_eof().await.unwrap();
111        let exit_code = self.channel.exit_status()?;
112        self.channel.wait_close().await.unwrap();
113
114        Ok(exit_code)
115    }
116
117    async fn stop(&mut self) -> Result<()> {
118        if !self.channel.eof() {
119            ProgressTracker::leaf("force stopping", async {
120                self.channel.write_all(b"\x03").await?; // `^C`
121                self.channel.send_eof().await?;
122                self.channel.wait_eof().await?;
123                // `exit_status()`
124                self.channel.wait_close().await?;
125                Result::<_>::Ok(())
126            })
127            .await?;
128        }
129
130        // Run perf post-processing and download perf output.
131        if let Some(tracing) = self.tracing.as_ref() {
132            let session = self.session.as_ref().unwrap();
133            if let Some(local_raw_perf) = tracing.perf_raw_outfile.as_ref() {
134                ProgressTracker::progress_leaf("downloading perf data", |progress, _| async move {
135                    let sftp = async_retry(
136                        &|| async { Ok(session.sftp().await?) },
137                        10,
138                        Duration::from_secs(1),
139                    )
140                    .await?;
141
142                    let mut remote_raw_perf = sftp.open(&PathBuf::from(PERF_OUTFILE)).await?;
143                    let mut local_raw_perf = File::create(local_raw_perf).await?;
144
145                    let total_size = remote_raw_perf.stat().await?.size.unwrap();
146                    let mut remote_tokio = remote_raw_perf.compat();
147
148                    use tokio::io::AsyncWriteExt;
149                    let mut index = 0;
150                    loop {
151                        let mut buffer = [0; 16 * 1024];
152                        let n = remote_tokio.read(&mut buffer).await?;
153                        if n == 0 {
154                            break;
155                        }
156                        local_raw_perf.write_all(&buffer[..n]).await?;
157                        index += n;
158                        progress(((index as f64 / total_size as f64) * 100.0) as u64);
159                    }
160
161                    Ok::<(), anyhow::Error>(())
162                })
163                .await?;
164            }
165
166            let mut script_channel = session.channel_session().await?;
167            let mut fold_er = Folder::from(tracing.fold_perf_options.clone().unwrap_or_default());
168
169            let fold_data = ProgressTracker::leaf("perf script & folding", async move {
170                let mut stderr_lines = FuturesBufReader::new(script_channel.stderr()).lines();
171                let stdout = script_channel.stream(0);
172
173                // Pattern on `()` to make sure no `Result`s are ignored.
174                let ((), fold_data, ()) = tokio::try_join!(
175                    async move {
176                        // Log stderr.
177                        while let Some(Ok(s)) = stderr_lines.next().await {
178                            ProgressTracker::eprintln(format!("[perf stderr] {s}"));
179                        }
180                        Result::<_>::Ok(())
181                    },
182                    async move {
183                        // Download perf output and fold.
184                        tokio::task::spawn_blocking(move || {
185                            let mut fold_data = Vec::new();
186                            fold_er.collapse(
187                                SyncIoBridge::new(TokioBufReader::new(stdout)),
188                                &mut fold_data,
189                            )?;
190                            Ok(fold_data)
191                        })
192                        .await?
193                    },
194                    async move {
195                        // Run command (last!).
196                        script_channel
197                            .exec(&format!("perf script --symfs=/ -i {PERF_OUTFILE}"))
198                            .await?;
199                        Ok(())
200                    },
201                )?;
202                Result::<_>::Ok(fold_data)
203            })
204            .await?;
205
206            self.tracing_results = Some(TracingResults {
207                folded_data: fold_data.clone(),
208            });
209
210            handle_fold_data(tracing, fold_data).await?;
211        };
212
213        Ok(())
214    }
215}
216
217impl Drop for LaunchedSshBinary {
218    fn drop(&mut self) {
219        if let Some(session) = self.session.take() {
220            tokio::task::block_in_place(|| {
221                Handle::current().block_on(session.disconnect(None, "", None))
222            })
223            .unwrap();
224        }
225    }
226}
227
228#[async_trait]
229pub trait LaunchedSshHost: Send + Sync {
230    fn get_internal_ip(&self) -> String;
231    fn get_external_ip(&self) -> Option<String>;
232    fn get_cloud_provider(&self) -> String;
233    fn resource_result(&self) -> &Arc<ResourceResult>;
234    fn ssh_user(&self) -> &str;
235
236    fn ssh_key_path(&self) -> PathBuf {
237        self.resource_result()
238            .terraform
239            .deployment_folder
240            .as_ref()
241            .unwrap()
242            .path()
243            .join(".ssh")
244            .join("vm_instance_ssh_key_pem")
245    }
246
247    fn server_config(&self, bind_type: &ServerStrategy) -> ServerBindConfig {
248        match bind_type {
249            ServerStrategy::UnixSocket => ServerBindConfig::UnixSocket,
250            ServerStrategy::InternalTcpPort => {
251                ServerBindConfig::TcpPort(self.get_internal_ip().clone())
252            }
253            ServerStrategy::ExternalTcpPort(_) => todo!(),
254            ServerStrategy::Demux(demux) => {
255                let mut config_map = HashMap::new();
256                for (key, underlying) in demux {
257                    config_map.insert(*key, LaunchedSshHost::server_config(self, underlying));
258                }
259
260                ServerBindConfig::Demux(config_map)
261            }
262            ServerStrategy::Merge(merge) => {
263                let mut configs = vec![];
264                for underlying in merge {
265                    configs.push(LaunchedSshHost::server_config(self, underlying));
266                }
267
268                ServerBindConfig::Merge(configs)
269            }
270            ServerStrategy::Tagged(underlying, id) => ServerBindConfig::Tagged(
271                Box::new(LaunchedSshHost::server_config(self, underlying)),
272                *id,
273            ),
274            ServerStrategy::Null => ServerBindConfig::Null,
275        }
276    }
277
278    async fn open_ssh_session(&self) -> Result<AsyncSession<TcpStream>> {
279        let target_addr = SocketAddr::new(
280            self.get_external_ip()
281                .as_ref()
282                .context(
283                    self.get_cloud_provider()
284                        + " host must be configured with an external IP to launch binaries",
285                )?
286                .parse()
287                .unwrap(),
288            22,
289        );
290
291        let res = ProgressTracker::leaf(
292            format!(
293                "connecting to host @ {}",
294                self.get_external_ip().as_ref().unwrap()
295            ),
296            async_retry(
297                &|| async {
298                    let mut config = SessionConfiguration::new();
299                    config.set_compress(true);
300
301                    let mut session =
302                        AsyncSession::<TcpStream>::connect(target_addr, Some(config)).await?;
303
304                    session.handshake().await?;
305
306                    session
307                        .userauth_pubkey_file(
308                            self.ssh_user(),
309                            None,
310                            self.ssh_key_path().as_path(),
311                            None,
312                        )
313                        .await?;
314
315                    Ok(session)
316                },
317                10,
318                Duration::from_secs(1),
319            ),
320        )
321        .await?;
322
323        Ok(res)
324    }
325}
326
327#[async_trait]
328impl<T: LaunchedSshHost> LaunchedHost for T {
329    fn server_config(&self, bind_type: &ServerStrategy) -> ServerBindConfig {
330        LaunchedSshHost::server_config(self, bind_type)
331    }
332
333    async fn copy_binary(&self, binary: &BuildOutput) -> Result<()> {
334        let session = self.open_ssh_session().await?;
335
336        let sftp = async_retry(
337            &|| async { Ok(session.sftp().await?) },
338            10,
339            Duration::from_secs(1),
340        )
341        .await?;
342
343        // we may be deploying multiple binaries, so give each a unique name
344        let unique_name = &binary.unique_id;
345
346        let user = self.ssh_user();
347        let binary_path = PathBuf::from(format!("/home/{user}/hydro-{unique_name}"));
348
349        if sftp.stat(&binary_path).await.is_err() {
350            let random = nanoid!(8);
351            let temp_path = PathBuf::from(format!("/home/{user}/hydro-{random}"));
352            let sftp = &sftp;
353
354            ProgressTracker::progress_leaf(
355                format!("uploading binary to {}", binary_path.display()),
356                |set_progress, _| {
357                    let binary = &binary;
358                    let binary_path = &binary_path;
359                    async move {
360                        let mut created_file = sftp.create(&temp_path).await?;
361
362                        let mut index = 0;
363                        while index < binary.bin_data.len() {
364                            let written = created_file
365                                .write(
366                                    &binary.bin_data[index
367                                        ..std::cmp::min(index + 128 * 1024, binary.bin_data.len())],
368                                )
369                                .await?;
370                            index += written;
371                            set_progress(
372                                ((index as f64 / binary.bin_data.len() as f64) * 100.0) as u64,
373                            );
374                        }
375                        let mut orig_file_stat = sftp.stat(&temp_path).await?;
376                        orig_file_stat.perm = Some(0o755); // allow the copied binary to be executed by anyone
377                        created_file.setstat(orig_file_stat).await?;
378                        created_file.close().await?;
379                        drop(created_file);
380
381                        match sftp.rename(&temp_path, binary_path, None).await {
382                            Ok(_) => {}
383                            Err(async_ssh2_lite::Error::Ssh2(e))
384                                if e.code() == ErrorCode::SFTP(4) =>
385                            {
386                                // file already exists
387                                sftp.unlink(&temp_path).await?;
388                            }
389                            Err(e) => return Err(e.into()),
390                        }
391
392                        anyhow::Ok(())
393                    }
394                },
395            )
396            .await?;
397        }
398        drop(sftp);
399
400        Ok(())
401    }
402
403    async fn launch_binary(
404        &self,
405        id: String,
406        binary: &BuildOutput,
407        args: &[String],
408        tracing: Option<TracingOptions>,
409    ) -> Result<Box<dyn LaunchedBinary>> {
410        let session = self.open_ssh_session().await?;
411
412        let unique_name = &binary.unique_id;
413
414        let user = self.ssh_user();
415        let binary_path = PathBuf::from(format!("/home/{user}/hydro-{unique_name}"));
416
417        let channel = ProgressTracker::leaf(
418            format!("launching binary {}", binary_path.display()),
419            async {
420                let mut channel =
421                    async_retry(
422                        &|| async {
423                            Ok(tokio::time::timeout(
424                                Duration::from_secs(60),
425                                session.channel_session(),
426                            )
427                            .await??)
428                        },
429                        10,
430                        Duration::from_secs(1),
431                    )
432                    .await?;
433
434                let mut command = binary_path.to_str().unwrap().to_owned();
435                for arg in args{
436                    command.push(' ');
437                    command.push_str(&shell_escape::unix::escape(Cow::Borrowed(arg)))
438                }
439                // Launch with perf if specified, also copy local binary to expected place for perf report to work
440                if let Some(TracingOptions { frequency, .. }) = tracing.clone() {
441                    // Attach perf to the command
442                    command = format!(
443                        "perf record -F {frequency} -e cycles:u --call-graph dwarf,65528 -o {PERF_OUTFILE} {command}",
444                    );
445                }
446                channel.exec(&command).await?;
447                anyhow::Ok(channel)
448            },
449        )
450        .await?;
451
452        let (stdin_sender, mut stdin_receiver) = mpsc::unbounded_channel::<String>();
453        let mut stdin = channel.stream(0); // stream 0 is stdout/stdin, we use it for stdin
454        tokio::spawn(async move {
455            while let Some(line) = stdin_receiver.recv().await {
456                if stdin.write_all(line.as_bytes()).await.is_err() {
457                    break;
458                }
459
460                stdin.flush().await.unwrap();
461            }
462        });
463
464        let id_clone = id.clone();
465        let (stdout_deploy_receivers, stdout_receivers) =
466            prioritized_broadcast(FuturesBufReader::new(channel.stream(0)).lines(), move |s| {
467                ProgressTracker::println(format!("[{id_clone}] {s}"));
468            });
469        let (_, stderr_receivers) =
470            prioritized_broadcast(FuturesBufReader::new(channel.stderr()).lines(), move |s| {
471                ProgressTracker::println(format!("[{id} stderr] {s}"));
472            });
473
474        Ok(Box::new(LaunchedSshBinary {
475            _resource_result: self.resource_result().clone(),
476            session: Some(session),
477            channel,
478            stdin_sender,
479            stdout_deploy_receivers,
480            stdout_receivers,
481            stderr_receivers,
482            tracing,
483            tracing_results: None,
484        }))
485    }
486
487    async fn forward_port(&self, addr: &SocketAddr) -> Result<SocketAddr> {
488        let session = self.open_ssh_session().await?;
489
490        let local_port = TcpListener::bind("127.0.0.1:0").await?;
491        let local_addr = local_port.local_addr()?;
492
493        let internal_ip = addr.ip().to_string();
494        let port = addr.port();
495
496        tokio::spawn(async move {
497            #[expect(clippy::never_loop, reason = "tcp accept loop pattern")]
498            while let Ok((mut local_stream, _)) = local_port.accept().await {
499                let mut channel = session
500                    .channel_direct_tcpip(&internal_ip, port, None)
501                    .await
502                    .unwrap();
503                let _ = tokio::io::copy_bidirectional(&mut local_stream, &mut channel).await;
504                break;
505                // TODO(shadaj): we should be returning an Arc so that we know
506                // if anyone wants to connect to this forwarded port
507            }
508
509            ProgressTracker::println("[hydro] closing forwarded port");
510        });
511
512        Ok(local_addr)
513    }
514}