1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::sync::{Arc, OnceLock};
4
5use anyhow::{Result, bail};
6use async_process::{Command, Stdio};
7use async_trait::async_trait;
8use hydro_deploy_integration::ServerBindConfig;
9
10use crate::progress::ProgressTracker;
11use crate::rust_crate::build::BuildOutput;
12use crate::rust_crate::tracing_options::TracingOptions;
13use crate::{
14 BaseServerStrategy, ClientStrategy, Host, HostStrategyGetter, HostTargetType, LaunchedBinary,
15 LaunchedHost, PortNetworkHint, ResourceBatch, ResourceResult,
16};
17
18pub mod launched_binary;
19pub use launched_binary::*;
20
21#[cfg(feature = "profile-folding")]
22#[cfg(any(target_os = "macos", target_family = "windows"))]
23mod samply;
24
25static LOCAL_LIBDIR: OnceLock<String> = OnceLock::new();
26
27#[derive(Debug)]
28pub struct LocalhostHost {
29 pub id: usize,
30 client_only: bool,
31}
32
33impl LocalhostHost {
34 pub fn new(id: usize) -> LocalhostHost {
35 LocalhostHost {
36 id,
37 client_only: false,
38 }
39 }
40
41 pub fn client_only(&self) -> LocalhostHost {
42 LocalhostHost {
43 id: self.id,
44 client_only: true,
45 }
46 }
47}
48
49impl Host for LocalhostHost {
50 fn target_type(&self) -> HostTargetType {
51 HostTargetType::Local
52 }
53
54 fn request_port_base(&self, _bind_type: &BaseServerStrategy) {}
55 fn collect_resources(&self, _resource_batch: &mut ResourceBatch) {}
56 fn request_custom_binary(&self) {}
57
58 fn id(&self) -> usize {
59 self.id
60 }
61
62 fn launched(&self) -> Option<Arc<dyn LaunchedHost>> {
63 Some(Arc::new(LaunchedLocalhost))
64 }
65
66 fn provision(&self, _resource_result: &Arc<ResourceResult>) -> Arc<dyn LaunchedHost> {
67 Arc::new(LaunchedLocalhost)
68 }
69
70 fn strategy_as_server<'a>(
71 &'a self,
72 connection_from: &dyn Host,
73 network_hint: PortNetworkHint,
74 ) -> Result<(ClientStrategy<'a>, HostStrategyGetter)> {
75 if self.client_only {
76 anyhow::bail!("Localhost cannot be a server if it is client only")
77 }
78
79 if matches!(network_hint, PortNetworkHint::Auto)
80 && connection_from.can_connect_to(ClientStrategy::UnixSocket(self.id))
81 {
82 Ok((
83 ClientStrategy::UnixSocket(self.id),
84 Box::new(|_| BaseServerStrategy::UnixSocket),
85 ))
86 } else if matches!(
87 network_hint,
88 PortNetworkHint::Auto | PortNetworkHint::TcpPort(_)
89 ) && connection_from.can_connect_to(ClientStrategy::InternalTcpPort(self))
90 {
91 Ok((
92 ClientStrategy::InternalTcpPort(self),
93 Box::new(move |_| {
94 BaseServerStrategy::InternalTcpPort(match network_hint {
95 PortNetworkHint::Auto => None,
96 PortNetworkHint::TcpPort(port) => port,
97 })
98 }),
99 ))
100 } else {
101 anyhow::bail!("Could not find a strategy to connect to localhost")
102 }
103 }
104
105 fn can_connect_to(&self, typ: ClientStrategy) -> bool {
106 match typ {
107 ClientStrategy::UnixSocket(id) => {
108 #[cfg(unix)]
109 {
110 self.id == id
111 }
112
113 #[cfg(not(unix))]
114 {
115 let _ = id;
116 false
117 }
118 }
119 ClientStrategy::InternalTcpPort(target_host) => self.id == target_host.id(),
120 ClientStrategy::ForwardedTcpPort(_) => true,
121 }
122 }
123}
124
125struct LaunchedLocalhost;
126
127#[async_trait]
128impl LaunchedHost for LaunchedLocalhost {
129 fn base_server_config(&self, bind_type: &BaseServerStrategy) -> ServerBindConfig {
130 match bind_type {
131 BaseServerStrategy::UnixSocket => ServerBindConfig::UnixSocket,
132 BaseServerStrategy::InternalTcpPort(port) => {
133 ServerBindConfig::TcpPort("127.0.0.1".to_owned(), *port)
134 }
135 BaseServerStrategy::ExternalTcpPort(_) => panic!("Cannot bind to external port"),
136 }
137 }
138
139 async fn copy_binary(&self, _binary: &BuildOutput) -> Result<()> {
140 Ok(())
141 }
142
143 async fn launch_binary(
144 &self,
145 id: String,
146 binary: &BuildOutput,
147 args: &[String],
148 tracing: Option<TracingOptions>,
149 env: &HashMap<String, String>,
150 ) -> Result<Box<dyn LaunchedBinary>> {
151 let (maybe_perf_outfile, mut command) = if let Some(tracing) = tracing.as_ref() {
152 if cfg!(any(target_os = "macos", target_family = "windows")) {
153 ProgressTracker::println(
155 format!("[{id} tracing] Profiling binary with `samply`.",),
156 );
157 let samply_outfile = tempfile::NamedTempFile::new()?;
158
159 let mut command = Command::new("samply");
160 command
161 .arg("record")
162 .arg("--save-only")
163 .arg("--output")
164 .arg(samply_outfile.as_ref())
165 .arg(&binary.bin_path)
166 .args(args);
167 (Some(samply_outfile), command)
168 } else if cfg!(target_family = "unix") {
169 ProgressTracker::println(format!("[{} tracing] Tracing binary with `perf`.", id));
171 let perf_outfile = tempfile::NamedTempFile::new()?;
172
173 let mut command = Command::new("perf");
174 command
175 .args([
176 "record",
177 "-F",
178 &tracing.frequency.to_string(),
179 "-e",
180 "cycles:u",
181 "--call-graph",
182 "dwarf,65528",
183 "-o",
184 ])
185 .arg(perf_outfile.as_ref())
186 .arg(&binary.bin_path)
187 .args(args);
188
189 (Some(perf_outfile), command)
190 } else {
191 bail!(
192 "Unknown OS for samply/perf tracing: {}",
193 std::env::consts::OS
194 );
195 }
196 } else {
197 let mut command = Command::new(&binary.bin_path);
198 command.args(args);
199 (None, command)
200 };
201
202 let dylib_path_var = if cfg!(windows) {
204 "PATH"
205 } else if cfg!(target_os = "macos") {
206 "DYLD_FALLBACK_LIBRARY_PATH"
207 } else if cfg!(target_os = "aix") {
208 "LIBPATH"
209 } else {
210 "LD_LIBRARY_PATH"
211 };
212
213 let local_libdir = LOCAL_LIBDIR.get_or_init(|| {
214 std::process::Command::new("rustc")
215 .arg("--print")
216 .arg("target-libdir")
217 .output()
218 .map(|output| str::from_utf8(&output.stdout).unwrap().trim().to_owned())
219 .unwrap()
220 });
221
222 command.env(
223 dylib_path_var,
224 std::env::var_os(dylib_path_var).map_or_else(
225 || {
226 std::env::join_paths(
227 [
228 binary.shared_library_path.as_ref(),
229 Some(&std::path::PathBuf::from(local_libdir)),
230 ]
231 .into_iter()
232 .flatten(),
233 )
234 .unwrap()
235 },
236 |paths| {
237 let mut paths = std::env::split_paths(&paths).collect::<Vec<_>>();
238 paths.insert(0, std::path::PathBuf::from(local_libdir));
239 if let Some(shared_path) = &binary.shared_library_path {
240 paths.insert(0, shared_path.to_path_buf());
241 }
242 std::env::join_paths(paths).unwrap()
243 },
244 ),
245 );
246
247 command.envs(env);
248
249 command
250 .stdin(Stdio::piped())
251 .stdout(Stdio::piped())
252 .stderr(Stdio::piped());
253
254 #[cfg(not(target_family = "unix"))]
255 command.kill_on_drop(true);
256
257 ProgressTracker::println(format!("[{}] running command: `{:?}`", id, command));
258
259 let child = command.spawn().map_err(|e| {
260 let msg = if maybe_perf_outfile.is_some() && std::io::ErrorKind::NotFound == e.kind() {
261 "Tracing executable not found, ensure it is installed"
262 } else {
263 "Failed to execute command"
264 };
265 anyhow::Error::new(e).context(format!("{}: {:?}", msg, command))
266 })?;
267
268 Ok(Box::new(LaunchedLocalhostBinary::new(
269 child,
270 id,
271 tracing,
272 maybe_perf_outfile.map(|f| TracingDataLocal { outfile: f }),
273 )))
274 }
275
276 async fn forward_port(&self, addr: &SocketAddr) -> Result<SocketAddr> {
277 Ok(*addr)
278 }
279}