hydro_deploy/rust_crate/
ports.rs

1use std::any::Any;
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::ops::Deref;
5use std::sync::{Arc, Weak};
6
7use anyhow::{Result, bail};
8use append_only_vec::AppendOnlyVec;
9use async_recursion::async_recursion;
10use hydro_deploy_integration::ServerPort;
11use tokio::sync::RwLock;
12
13use super::RustCrateService;
14use crate::{ClientStrategy, Host, LaunchedHost, PortNetworkHint, ServerStrategy};
15
16pub trait RustCrateSource: Send + Sync {
17    fn source_path(&self) -> SourcePath;
18    fn record_server_config(&self, config: ServerConfig);
19
20    fn host(&self) -> Arc<dyn Host>;
21    fn server(&self) -> Arc<dyn RustCrateServer>;
22    fn record_server_strategy(&self, config: ServerStrategy);
23
24    fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
25        config
26    }
27
28    fn send_to(&self, sink: &dyn RustCrateSink) {
29        let forward_res = sink.instantiate(&self.source_path());
30        if let Ok(instantiated) = forward_res {
31            self.record_server_config(instantiated());
32        } else {
33            drop(forward_res);
34            let instantiated = sink
35                .instantiate_reverse(&self.host(), self.server(), &|p| {
36                    self.wrap_reverse_server_config(p)
37                })
38                .unwrap();
39            self.record_server_strategy(instantiated(sink));
40        }
41    }
42}
43
44pub trait RustCrateServer: Debug + Send + Sync {
45    fn get_port(&self) -> ServerPort;
46    fn launched_host(&self) -> Arc<dyn LaunchedHost>;
47}
48
49pub type ReverseSinkInstantiator = Box<dyn FnOnce(&dyn Any) -> ServerStrategy>;
50
51pub trait RustCrateSink: Any + Send + Sync {
52    /// Instantiate the sink as the source host connecting to the sink host.
53    /// Returns a thunk that can be called to perform mutations that instantiate the sink.
54    fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>>;
55
56    /// Instantiate the sink, but as the sink host connecting to the source host.
57    /// Returns a thunk that can be called to perform mutations that instantiate the sink, taking a mutable reference to this sink.
58    fn instantiate_reverse(
59        &self,
60        server_host: &Arc<dyn Host>,
61        server_sink: Arc<dyn RustCrateServer>,
62        wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
63    ) -> Result<ReverseSinkInstantiator>;
64}
65
66pub struct TaggedSource {
67    pub source: Arc<dyn RustCrateSource>,
68    pub tag: u32,
69}
70
71impl RustCrateSource for TaggedSource {
72    fn source_path(&self) -> SourcePath {
73        SourcePath::Tagged(Box::new(self.source.source_path()), self.tag)
74    }
75
76    fn record_server_config(&self, config: ServerConfig) {
77        self.source.record_server_config(config);
78    }
79
80    fn host(&self) -> Arc<dyn Host> {
81        self.source.host()
82    }
83
84    fn server(&self) -> Arc<dyn RustCrateServer> {
85        self.source.server()
86    }
87
88    fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
89        ServerConfig::Tagged(Box::new(config), self.tag)
90    }
91
92    fn record_server_strategy(&self, config: ServerStrategy) {
93        self.source.record_server_strategy(config);
94    }
95}
96
97pub struct NullSourceSink;
98
99impl RustCrateSource for NullSourceSink {
100    fn source_path(&self) -> SourcePath {
101        SourcePath::Null
102    }
103
104    fn host(&self) -> Arc<dyn Host> {
105        panic!("null source has no host")
106    }
107
108    fn server(&self) -> Arc<dyn RustCrateServer> {
109        panic!("null source has no server")
110    }
111
112    fn record_server_config(&self, _config: ServerConfig) {}
113    fn record_server_strategy(&self, _config: ServerStrategy) {}
114}
115
116impl RustCrateSink for NullSourceSink {
117    fn instantiate(&self, _client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
118        Ok(Box::new(|| ServerConfig::Null))
119    }
120
121    fn instantiate_reverse(
122        &self,
123        _server_host: &Arc<dyn Host>,
124        _server_sink: Arc<dyn RustCrateServer>,
125        _wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
126    ) -> Result<ReverseSinkInstantiator> {
127        Ok(Box::new(|_| ServerStrategy::Null))
128    }
129}
130
131pub struct DemuxSink {
132    pub demux: HashMap<u32, Arc<dyn RustCrateSink>>,
133}
134
135impl RustCrateSink for DemuxSink {
136    fn instantiate(&self, client_host: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
137        let mut thunk_map = HashMap::new();
138        for (key, target) in &self.demux {
139            thunk_map.insert(*key, target.instantiate(client_host)?);
140        }
141
142        Ok(Box::new(move || {
143            let instantiated_map = thunk_map
144                .into_iter()
145                .map(|(key, thunk)| (key, (thunk)()))
146                .collect();
147
148            ServerConfig::Demux(instantiated_map)
149        }))
150    }
151
152    fn instantiate_reverse(
153        &self,
154        server_host: &Arc<dyn Host>,
155        server_sink: Arc<dyn RustCrateServer>,
156        wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
157    ) -> Result<ReverseSinkInstantiator> {
158        let mut thunk_map = HashMap::new();
159        for (key, target) in &self.demux {
160            thunk_map.insert(
161                *key,
162                target.instantiate_reverse(
163                    server_host,
164                    server_sink.clone(),
165                    // the parent wrapper selects the demux port for the parent defn, so do that first
166                    &|p| ServerConfig::DemuxSelect(Box::new(wrap_client_port(p)), *key),
167                )?,
168            );
169        }
170
171        Ok(Box::new(move |me| {
172            let me = me.downcast_ref::<DemuxSink>().unwrap();
173            let instantiated_map = thunk_map
174                .into_iter()
175                .map(|(key, thunk)| (key, (thunk)(me.demux.get(&key).unwrap())))
176                .collect();
177
178            ServerStrategy::Demux(instantiated_map)
179        }))
180    }
181}
182
183#[derive(Clone, Debug)]
184pub struct RustCratePortConfig {
185    pub service: Weak<RustCrateService>,
186    pub service_host: Arc<dyn Host>,
187    pub service_server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
188    pub network_hint: PortNetworkHint,
189    pub port: String,
190    pub merge: bool,
191}
192
193impl RustCratePortConfig {
194    pub fn merge(mut self) -> Self {
195        self.merge = true;
196        self
197    }
198}
199
200impl RustCrateSource for RustCratePortConfig {
201    fn source_path(&self) -> SourcePath {
202        SourcePath::Direct(self.service.upgrade().unwrap().on.clone())
203    }
204
205    fn host(&self) -> Arc<dyn Host> {
206        self.service_host.clone()
207    }
208
209    fn server(&self) -> Arc<dyn RustCrateServer> {
210        let from = self.service.upgrade().unwrap();
211
212        Arc::new(RustCratePortConfig {
213            service: Arc::downgrade(&from),
214            service_host: from.on.clone(),
215            service_server_defns: from.server_defns.clone(),
216            network_hint: self.network_hint,
217            port: self.port.clone(),
218            merge: false,
219        })
220    }
221
222    fn record_server_config(&self, config: ServerConfig) {
223        let from = self.service.upgrade().unwrap();
224        // TODO(shadaj): if already in this map, we want to broadcast
225        assert!(
226            from.port_to_server.insert(self.port.clone(), config),
227            "The port configuration is incorrect, for example, are you using a ConnectedDirect instead of a ConnectedDemux?"
228        );
229    }
230
231    fn record_server_strategy(&self, config: ServerStrategy) {
232        let from = self.service.upgrade().unwrap();
233        assert!(
234            from.port_to_bind.insert(self.port.clone(), config),
235            "port already set!"
236        );
237    }
238}
239
240impl RustCrateServer for RustCratePortConfig {
241    fn get_port(&self) -> ServerPort {
242        // we are in `deployment.start()`, so no one should be writing
243        let server_defns = self.service_server_defns.try_read().unwrap();
244        server_defns.get(&self.port).unwrap().clone()
245    }
246
247    fn launched_host(&self) -> Arc<dyn LaunchedHost> {
248        self.service_host.launched().unwrap()
249    }
250}
251
252pub enum SourcePath {
253    Null,
254    Direct(Arc<dyn Host>),
255    Many(Arc<dyn Host>),
256    Tagged(Box<SourcePath>, u32),
257}
258
259impl SourcePath {
260    #[expect(
261        clippy::type_complexity,
262        reason = "internals (dyn Fn to defer instantiation)"
263    )]
264    fn plan<T: RustCrateServer + Clone + 'static>(
265        &self,
266        server: &T,
267        server_host: &dyn Host,
268        network_hint: PortNetworkHint,
269    ) -> Result<(Box<dyn FnOnce(&dyn Any) -> ServerStrategy>, ServerConfig)> {
270        match self {
271            SourcePath::Direct(client_host) => {
272                let (conn_type, bind_type) =
273                    server_host.strategy_as_server(client_host.deref(), network_hint)?;
274                let base_config = ServerConfig::from_strategy(&conn_type, Arc::new(server.clone()));
275                Ok((
276                    Box::new(|host| ServerStrategy::Direct(bind_type(host))),
277                    base_config,
278                ))
279            }
280
281            SourcePath::Many(client_host) => {
282                let (conn_type, bind_type) =
283                    server_host.strategy_as_server(client_host.deref(), network_hint)?;
284                let base_config = ServerConfig::from_strategy(&conn_type, Arc::new(server.clone()));
285                Ok((
286                    Box::new(|host| ServerStrategy::Many(bind_type(host))),
287                    base_config,
288                ))
289            }
290
291            SourcePath::Tagged(underlying, tag) => {
292                let (bind_type, base_config) =
293                    underlying.plan(server, server_host, network_hint)?;
294                let tag = *tag;
295                Ok((
296                    Box::new(move |host| ServerStrategy::Tagged(Box::new(bind_type(host)), tag)),
297                    ServerConfig::TaggedUnwrap(Box::new(base_config)),
298                ))
299            }
300
301            SourcePath::Null => Ok((Box::new(|_| ServerStrategy::Null), ServerConfig::Null)),
302        }
303    }
304}
305
306impl RustCrateSink for RustCratePortConfig {
307    fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
308        let server = self.service.upgrade().unwrap();
309
310        let server_host = server.on.clone();
311
312        let (bind_type, base_config) =
313            client_path.plan(self, server_host.deref(), self.network_hint)?;
314
315        let server = server.clone();
316        let merge = self.merge;
317        let port = self.port.clone();
318        Ok(Box::new(move || {
319            let bind_type = (bind_type)(&*server.on);
320
321            if merge {
322                let merge_config = server
323                    .port_to_bind
324                    .get_or_insert_owned(port, || ServerStrategy::Merge(Default::default()));
325                let ServerStrategy::Merge(merge) = merge_config else {
326                    panic!("Expected a merge connection definition")
327                };
328                merge.push(bind_type);
329                ServerConfig::MergeSelect(Box::new(base_config), merge.len() - 1)
330            } else {
331                assert!(
332                    server.port_to_bind.insert(port.clone(), bind_type),
333                    "port already set!"
334                );
335                base_config
336            }
337        }))
338    }
339
340    fn instantiate_reverse(
341        &self,
342        server_host: &Arc<dyn Host>,
343        server_sink: Arc<dyn RustCrateServer>,
344        wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
345    ) -> Result<ReverseSinkInstantiator> {
346        if !matches!(self.network_hint, PortNetworkHint::Auto) {
347            bail!("Trying to form collection where I am the client, but I have server hint")
348        }
349
350        let client = self.service.upgrade().unwrap();
351
352        let server_host = server_host.clone();
353
354        let (conn_type, bind_type) =
355            server_host.strategy_as_server(&*client.on, PortNetworkHint::Auto)?;
356        let client_port = wrap_client_port(ServerConfig::from_strategy(&conn_type, server_sink));
357
358        let client = client.clone();
359        let merge = self.merge;
360        let port = self.port.clone();
361        Ok(Box::new(move |_| {
362            if merge {
363                let merge_config = client
364                    .port_to_server
365                    .get_or_insert_owned(port, || ServerConfig::Merge(Default::default()));
366                let ServerConfig::Merge(merge) = merge_config else {
367                    panic!()
368                };
369                merge.push(client_port);
370            } else {
371                assert!(
372                    client.port_to_server.insert(port.clone(), client_port),
373                    "port already set!"
374                );
375            };
376
377            ServerStrategy::Direct((bind_type)(&*client.on))
378        }))
379    }
380}
381
382#[derive(Clone, Debug)]
383pub enum ServerConfig {
384    Direct(Arc<dyn RustCrateServer>),
385    Forwarded(Arc<dyn RustCrateServer>),
386    /// A demux that will be used at runtime to listen to many connections.
387    Demux(HashMap<u32, ServerConfig>),
388    /// The other side of a demux, with a port to extract the appropriate connection.
389    DemuxSelect(Box<ServerConfig>, u32),
390    /// A merge that will be used at runtime to combine many connections.
391    /// AppendOnlyVec has a quite large inline array, so we box it.
392    Merge(Box<AppendOnlyVec<ServerConfig>>),
393    /// The other side of a merge, with a port to extract the appropriate connection.
394    MergeSelect(Box<ServerConfig>, usize),
395    Tagged(Box<ServerConfig>, u32),
396    TaggedUnwrap(Box<ServerConfig>),
397    Null,
398}
399
400impl ServerConfig {
401    pub fn from_strategy(
402        strategy: &ClientStrategy,
403        server: Arc<dyn RustCrateServer>,
404    ) -> ServerConfig {
405        match strategy {
406            ClientStrategy::UnixSocket(_) | ClientStrategy::InternalTcpPort(_) => {
407                ServerConfig::Direct(server)
408            }
409            ClientStrategy::ForwardedTcpPort(_) => ServerConfig::Forwarded(server),
410        }
411    }
412}
413
414#[async_recursion]
415async fn forward_connection(conn: &ServerPort, target: &dyn LaunchedHost) -> ServerPort {
416    match conn {
417        ServerPort::UnixSocket(_) => panic!("Expected a TCP port to be forwarded"),
418        ServerPort::TcpPort(addr) => ServerPort::TcpPort(target.forward_port(addr).await.unwrap()),
419        ServerPort::Demux(demux) => {
420            let mut forwarded_map = HashMap::new();
421            for (key, conn) in demux {
422                forwarded_map.insert(*key, forward_connection(conn, target).await);
423            }
424            ServerPort::Demux(forwarded_map)
425        }
426        ServerPort::Merge(merge) => {
427            let mut forwarded_vec = Vec::new();
428            for conn in merge {
429                forwarded_vec.push(forward_connection(conn, target).await);
430            }
431            ServerPort::Merge(forwarded_vec)
432        }
433        ServerPort::Tagged(underlying, id) => {
434            ServerPort::Tagged(Box::new(forward_connection(underlying, target).await), *id)
435        }
436        ServerPort::Null => ServerPort::Null,
437    }
438}
439
440impl ServerConfig {
441    #[async_recursion]
442    pub async fn load_instantiated(
443        &self,
444        select: &(dyn Fn(ServerPort) -> ServerPort + Send + Sync),
445    ) -> ServerPort {
446        match self {
447            ServerConfig::Direct(server) => select(server.get_port()),
448
449            ServerConfig::Forwarded(server) => {
450                let selected = select(server.get_port());
451                forward_connection(&selected, server.launched_host().as_ref()).await
452            }
453
454            ServerConfig::Demux(demux) => {
455                let mut demux_map = HashMap::new();
456                for (key, conn) in demux {
457                    demux_map.insert(*key, conn.load_instantiated(select).await);
458                }
459                ServerPort::Demux(demux_map)
460            }
461
462            ServerConfig::DemuxSelect(underlying, key) => {
463                let key = *key;
464                underlying
465                    .load_instantiated(
466                        &(move |p| {
467                            if let ServerPort::Demux(mut mapping) = p {
468                                select(mapping.remove(&key).unwrap())
469                            } else {
470                                panic!("Expected a demux connection definition")
471                            }
472                        }),
473                    )
474                    .await
475            }
476
477            ServerConfig::Merge(merge) => {
478                let mut merge_vec = Vec::new();
479                for conn in merge.iter() {
480                    merge_vec.push(conn.load_instantiated(select).await);
481                }
482                ServerPort::Merge(merge_vec)
483            }
484
485            ServerConfig::MergeSelect(underlying, key) => {
486                let key = *key;
487                underlying
488                    .load_instantiated(
489                        &(move |p| {
490                            if let ServerPort::Merge(mut mapping) = p {
491                                select(mapping.remove(key))
492                            } else {
493                                panic!("Expected a merge connection definition")
494                            }
495                        }),
496                    )
497                    .await
498            }
499
500            ServerConfig::Tagged(underlying, id) => {
501                ServerPort::Tagged(Box::new(underlying.load_instantiated(select).await), *id)
502            }
503
504            ServerConfig::TaggedUnwrap(underlying) => {
505                let loaded = underlying.load_instantiated(select).await;
506                if let ServerPort::Tagged(underlying, _) = loaded {
507                    *underlying
508                } else {
509                    panic!("Expected a tagged connection definition")
510                }
511            }
512
513            ServerConfig::Null => ServerPort::Null,
514        }
515    }
516}