hydro_deploy/rust_crate/
ports.rs

1use std::any::Any;
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::ops::Deref;
5use std::sync::{Arc, Weak};
6
7use anyhow::{Result, bail};
8use async_recursion::async_recursion;
9use async_trait::async_trait;
10use dyn_clone::DynClone;
11use hydro_deploy_integration::ServerPort;
12use tokio::sync::RwLock;
13
14use super::RustCrateService;
15use crate::{ClientStrategy, Host, LaunchedHost, PortNetworkHint, ServerStrategy};
16
17pub trait RustCrateSource: Send + Sync {
18    fn source_path(&self) -> SourcePath;
19    fn record_server_config(&self, config: ServerConfig);
20
21    fn host(&self) -> Arc<dyn Host>;
22    fn server(&self) -> Arc<dyn RustCrateServer>;
23    fn record_server_strategy(&self, config: ServerStrategy);
24
25    fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
26        config
27    }
28
29    fn send_to(&self, sink: &dyn RustCrateSink) {
30        let forward_res = sink.instantiate(&self.source_path());
31        if let Ok(instantiated) = forward_res {
32            self.record_server_config(instantiated());
33        } else {
34            drop(forward_res);
35            let instantiated = sink
36                .instantiate_reverse(&self.host(), self.server(), &|p| {
37                    self.wrap_reverse_server_config(p)
38                })
39                .unwrap();
40            self.record_server_strategy(instantiated(sink));
41        }
42    }
43}
44
45pub trait RustCrateServer: DynClone + Debug + Send + Sync {
46    fn get_port(&self) -> ServerPort;
47    fn launched_host(&self) -> Arc<dyn LaunchedHost>;
48}
49
50pub type ReverseSinkInstantiator = Box<dyn FnOnce(&dyn Any) -> ServerStrategy>;
51
52pub trait RustCrateSink: Any + Send + Sync {
53    /// Instantiate the sink as the source host connecting to the sink host.
54    /// Returns a thunk that can be called to perform mutations that instantiate the sink.
55    fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>>;
56
57    /// Instantiate the sink, but as the sink host connecting to the source host.
58    /// Returns a thunk that can be called to perform mutations that instantiate the sink, taking a mutable reference to this sink.
59    fn instantiate_reverse(
60        &self,
61        server_host: &Arc<dyn Host>,
62        server_sink: Arc<dyn RustCrateServer>,
63        wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
64    ) -> Result<ReverseSinkInstantiator>;
65}
66
67pub struct TaggedSource {
68    pub source: Arc<dyn RustCrateSource>,
69    pub tag: u32,
70}
71
72impl RustCrateSource for TaggedSource {
73    fn source_path(&self) -> SourcePath {
74        SourcePath::Tagged(Box::new(self.source.source_path()), self.tag)
75    }
76
77    fn record_server_config(&self, config: ServerConfig) {
78        self.source.record_server_config(config);
79    }
80
81    fn host(&self) -> Arc<dyn Host> {
82        self.source.host()
83    }
84
85    fn server(&self) -> Arc<dyn RustCrateServer> {
86        self.source.server()
87    }
88
89    fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
90        ServerConfig::Tagged(Box::new(config), self.tag)
91    }
92
93    fn record_server_strategy(&self, config: ServerStrategy) {
94        self.source.record_server_strategy(config);
95    }
96}
97
98pub struct NullSourceSink;
99
100impl RustCrateSource for NullSourceSink {
101    fn source_path(&self) -> SourcePath {
102        SourcePath::Null
103    }
104
105    fn host(&self) -> Arc<dyn Host> {
106        panic!("null source has no host")
107    }
108
109    fn server(&self) -> Arc<dyn RustCrateServer> {
110        panic!("null source has no server")
111    }
112
113    fn record_server_config(&self, _config: ServerConfig) {}
114    fn record_server_strategy(&self, _config: ServerStrategy) {}
115}
116
117impl RustCrateSink for NullSourceSink {
118    fn instantiate(&self, _client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
119        Ok(Box::new(|| ServerConfig::Null))
120    }
121
122    fn instantiate_reverse(
123        &self,
124        _server_host: &Arc<dyn Host>,
125        _server_sink: Arc<dyn RustCrateServer>,
126        _wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
127    ) -> Result<ReverseSinkInstantiator> {
128        Ok(Box::new(|_| ServerStrategy::Null))
129    }
130}
131
132pub struct DemuxSink {
133    pub demux: HashMap<u32, Arc<dyn RustCrateSink>>,
134}
135
136impl RustCrateSink for DemuxSink {
137    fn instantiate(&self, client_host: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
138        let mut thunk_map = HashMap::new();
139        for (key, target) in &self.demux {
140            thunk_map.insert(*key, target.instantiate(client_host)?);
141        }
142
143        Ok(Box::new(move || {
144            let instantiated_map = thunk_map
145                .into_iter()
146                .map(|(key, thunk)| (key, thunk()))
147                .collect();
148
149            ServerConfig::Demux(instantiated_map)
150        }))
151    }
152
153    fn instantiate_reverse(
154        &self,
155        server_host: &Arc<dyn Host>,
156        server_sink: Arc<dyn RustCrateServer>,
157        wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
158    ) -> Result<ReverseSinkInstantiator> {
159        let mut thunk_map = HashMap::new();
160        for (key, target) in &self.demux {
161            thunk_map.insert(
162                *key,
163                target.instantiate_reverse(
164                    server_host,
165                    server_sink.clone(),
166                    // the parent wrapper selects the demux port for the parent defn, so do that first
167                    &|p| ServerConfig::DemuxSelect(Box::new(wrap_client_port(p)), *key),
168                )?,
169            );
170        }
171
172        Ok(Box::new(move |me| {
173            let me = me.downcast_ref::<DemuxSink>().unwrap();
174            let instantiated_map = thunk_map
175                .into_iter()
176                .map(|(key, thunk)| (key, (thunk)(me.demux.get(&key).unwrap())))
177                .collect();
178
179            ServerStrategy::Demux(instantiated_map)
180        }))
181    }
182}
183
184#[derive(Clone, Debug)]
185pub struct RustCratePortConfig {
186    pub service: Weak<RwLock<RustCrateService>>,
187    pub service_host: Arc<dyn Host>,
188    pub service_server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
189    pub network_hint: PortNetworkHint,
190    pub port: String,
191    pub merge: bool,
192}
193
194impl RustCratePortConfig {
195    pub fn merge(&self) -> Self {
196        Self {
197            service: self.service.clone(),
198            service_host: self.service_host.clone(),
199            service_server_defns: self.service_server_defns.clone(),
200            network_hint: self.network_hint,
201            port: self.port.clone(),
202            merge: true,
203        }
204    }
205}
206
207impl RustCrateSource for RustCratePortConfig {
208    fn source_path(&self) -> SourcePath {
209        SourcePath::Direct(
210            self.service
211                .upgrade()
212                .unwrap()
213                .try_read()
214                .unwrap()
215                .on
216                .clone(),
217        )
218    }
219
220    fn host(&self) -> Arc<dyn Host> {
221        self.service_host.clone()
222    }
223
224    fn server(&self) -> Arc<dyn RustCrateServer> {
225        let from = self.service.upgrade().unwrap();
226        let from_read = from.try_read().unwrap();
227
228        Arc::new(RustCratePortConfig {
229            service: Arc::downgrade(&from),
230            service_host: from_read.on.clone(),
231            service_server_defns: from_read.server_defns.clone(),
232            network_hint: self.network_hint,
233            port: self.port.clone(),
234            merge: false,
235        })
236    }
237
238    fn record_server_config(&self, config: ServerConfig) {
239        let from = self.service.upgrade().unwrap();
240        let mut from_write = from.try_write().unwrap();
241
242        // TODO(shadaj): if already in this map, we want to broadcast
243        assert!(
244            !from_write.port_to_server.contains_key(&self.port),
245            "The port configuration is incorrect, for example, are you using a ConnectedDirect instead of a ConnectedDemux?"
246        );
247        from_write.port_to_server.insert(self.port.clone(), config);
248    }
249
250    fn record_server_strategy(&self, config: ServerStrategy) {
251        let from = self.service.upgrade().unwrap();
252        let mut from_write = from.try_write().unwrap();
253
254        assert!(!from_write.port_to_bind.contains_key(&self.port));
255        from_write.port_to_bind.insert(self.port.clone(), config);
256    }
257}
258
259#[async_trait]
260impl RustCrateServer for RustCratePortConfig {
261    fn get_port(&self) -> ServerPort {
262        // we are in `deployment.start()`, so no one should be writing
263        let server_defns = self.service_server_defns.try_read().unwrap();
264        server_defns.get(&self.port).unwrap().clone()
265    }
266
267    fn launched_host(&self) -> Arc<dyn LaunchedHost> {
268        self.service_host.launched().unwrap()
269    }
270}
271
272pub enum SourcePath {
273    Null,
274    Direct(Arc<dyn Host>),
275    Many(Arc<dyn Host>),
276    Tagged(Box<SourcePath>, u32),
277}
278
279impl SourcePath {
280    #[expect(
281        clippy::type_complexity,
282        reason = "internals (dyn Fn to defer instantiation)"
283    )]
284    fn plan<T: RustCrateServer + Clone + 'static>(
285        &self,
286        server: &T,
287        server_host: &dyn Host,
288        network_hint: PortNetworkHint,
289    ) -> Result<(Box<dyn FnOnce(&dyn Any) -> ServerStrategy>, ServerConfig)> {
290        match self {
291            SourcePath::Direct(client_host) => {
292                let (conn_type, bind_type) =
293                    server_host.strategy_as_server(client_host.deref(), network_hint)?;
294                let base_config = ServerConfig::from_strategy(&conn_type, Arc::new(server.clone()));
295                Ok((
296                    Box::new(|host| ServerStrategy::Direct(bind_type(host))),
297                    base_config,
298                ))
299            }
300
301            SourcePath::Many(client_host) => {
302                let (conn_type, bind_type) =
303                    server_host.strategy_as_server(client_host.deref(), network_hint)?;
304                let base_config = ServerConfig::from_strategy(&conn_type, Arc::new(server.clone()));
305                Ok((
306                    Box::new(|host| ServerStrategy::Many(bind_type(host))),
307                    base_config,
308                ))
309            }
310
311            SourcePath::Tagged(underlying, tag) => {
312                let (bind_type, base_config) =
313                    underlying.plan(server, server_host, network_hint)?;
314                let tag = *tag;
315                Ok((
316                    Box::new(move |host| ServerStrategy::Tagged(Box::new(bind_type(host)), tag)),
317                    ServerConfig::TaggedUnwrap(Box::new(base_config)),
318                ))
319            }
320
321            SourcePath::Null => Ok((Box::new(|_| ServerStrategy::Null), ServerConfig::Null)),
322        }
323    }
324}
325
326impl RustCrateSink for RustCratePortConfig {
327    fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
328        let server = self.service.upgrade().unwrap();
329        let server_read = server.try_read().unwrap();
330
331        let server_host = server_read.on.clone();
332
333        let (bind_type, base_config) =
334            client_path.plan(self, server_host.deref(), self.network_hint)?;
335
336        let server = server.clone();
337        let merge = self.merge;
338        let port = self.port.clone();
339        Ok(Box::new(move || {
340            let mut server_write = server.try_write().unwrap();
341            let bind_type = (bind_type)(&*server_write.on);
342
343            if merge {
344                let merge_config = server_write
345                    .port_to_bind
346                    .entry(port.clone())
347                    .or_insert(ServerStrategy::Merge(vec![]));
348                let merge_index = if let ServerStrategy::Merge(merge) = merge_config {
349                    merge.push(bind_type);
350                    merge.len() - 1
351                } else {
352                    panic!("Expected a merge connection definition")
353                };
354
355                ServerConfig::MergeSelect(Box::new(base_config), merge_index)
356            } else {
357                assert!(!server_write.port_to_bind.contains_key(&port));
358                server_write.port_to_bind.insert(port.clone(), bind_type);
359                base_config
360            }
361        }))
362    }
363
364    fn instantiate_reverse(
365        &self,
366        server_host: &Arc<dyn Host>,
367        server_sink: Arc<dyn RustCrateServer>,
368        wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
369    ) -> Result<ReverseSinkInstantiator> {
370        if !matches!(self.network_hint, PortNetworkHint::Auto) {
371            bail!("Trying to form collection where I am the client, but I have server hint")
372        }
373
374        let client = self.service.upgrade().unwrap();
375        let client_read = client.try_read().unwrap();
376
377        let server_host = server_host.clone();
378
379        let (conn_type, bind_type) =
380            server_host.strategy_as_server(client_read.on.deref(), PortNetworkHint::Auto)?;
381        let client_port = wrap_client_port(ServerConfig::from_strategy(&conn_type, server_sink));
382
383        let client = client.clone();
384        let merge = self.merge;
385        let port = self.port.clone();
386        Ok(Box::new(move |_| {
387            let mut client_write = client.try_write().unwrap();
388
389            if merge {
390                let merge_config = client_write
391                    .port_to_server
392                    .entry(port.clone())
393                    .or_insert(ServerConfig::Merge(vec![]));
394
395                if let ServerConfig::Merge(merge) = merge_config {
396                    merge.push(client_port);
397                } else {
398                    panic!()
399                };
400            } else {
401                assert!(!client_write.port_to_server.contains_key(&port));
402                client_write
403                    .port_to_server
404                    .insert(port.clone(), client_port);
405            };
406
407            ServerStrategy::Direct((bind_type)(&*client_write.on))
408        }))
409    }
410}
411
412#[derive(Clone, Debug)]
413pub enum ServerConfig {
414    Direct(Arc<dyn RustCrateServer>),
415    Forwarded(Arc<dyn RustCrateServer>),
416    /// A demux that will be used at runtime to listen to many connections.
417    Demux(HashMap<u32, ServerConfig>),
418    /// The other side of a demux, with a port to extract the appropriate connection.
419    DemuxSelect(Box<ServerConfig>, u32),
420    /// A merge that will be used at runtime to combine many connections.
421    Merge(Vec<ServerConfig>),
422    /// The other side of a merge, with a port to extract the appropriate connection.
423    MergeSelect(Box<ServerConfig>, usize),
424    Tagged(Box<ServerConfig>, u32),
425    TaggedUnwrap(Box<ServerConfig>),
426    Null,
427}
428
429impl ServerConfig {
430    pub fn from_strategy(
431        strategy: &ClientStrategy,
432        server: Arc<dyn RustCrateServer>,
433    ) -> ServerConfig {
434        match strategy {
435            ClientStrategy::UnixSocket(_) | ClientStrategy::InternalTcpPort(_) => {
436                ServerConfig::Direct(server)
437            }
438            ClientStrategy::ForwardedTcpPort(_) => ServerConfig::Forwarded(server),
439        }
440    }
441}
442
443#[async_recursion]
444async fn forward_connection(conn: &ServerPort, target: &dyn LaunchedHost) -> ServerPort {
445    match conn {
446        ServerPort::UnixSocket(_) => panic!("Expected a TCP port to be forwarded"),
447        ServerPort::TcpPort(addr) => ServerPort::TcpPort(target.forward_port(addr).await.unwrap()),
448        ServerPort::Demux(demux) => {
449            let mut forwarded_map = HashMap::new();
450            for (key, conn) in demux {
451                forwarded_map.insert(*key, forward_connection(conn, target).await);
452            }
453            ServerPort::Demux(forwarded_map)
454        }
455        ServerPort::Merge(merge) => {
456            let mut forwarded_vec = Vec::new();
457            for conn in merge {
458                forwarded_vec.push(forward_connection(conn, target).await);
459            }
460            ServerPort::Merge(forwarded_vec)
461        }
462        ServerPort::Tagged(underlying, id) => {
463            ServerPort::Tagged(Box::new(forward_connection(underlying, target).await), *id)
464        }
465        ServerPort::Null => ServerPort::Null,
466    }
467}
468
469impl ServerConfig {
470    #[async_recursion]
471    pub async fn load_instantiated(
472        &self,
473        select: &(dyn Fn(ServerPort) -> ServerPort + Send + Sync),
474    ) -> ServerPort {
475        match self {
476            ServerConfig::Direct(server) => select(server.get_port()),
477
478            ServerConfig::Forwarded(server) => {
479                let selected = select(server.get_port());
480                forward_connection(&selected, server.launched_host().as_ref()).await
481            }
482
483            ServerConfig::Demux(demux) => {
484                let mut demux_map = HashMap::new();
485                for (key, conn) in demux {
486                    demux_map.insert(*key, conn.load_instantiated(select).await);
487                }
488                ServerPort::Demux(demux_map)
489            }
490
491            ServerConfig::DemuxSelect(underlying, key) => {
492                let key = *key;
493                underlying
494                    .load_instantiated(
495                        &(move |p| {
496                            if let ServerPort::Demux(mut mapping) = p {
497                                select(mapping.remove(&key).unwrap())
498                            } else {
499                                panic!("Expected a demux connection definition")
500                            }
501                        }),
502                    )
503                    .await
504            }
505
506            ServerConfig::Merge(merge) => {
507                let mut merge_vec = Vec::new();
508                for conn in merge {
509                    merge_vec.push(conn.load_instantiated(select).await);
510                }
511                ServerPort::Merge(merge_vec)
512            }
513
514            ServerConfig::MergeSelect(underlying, key) => {
515                let key = *key;
516                underlying
517                    .load_instantiated(
518                        &(move |p| {
519                            if let ServerPort::Merge(mut mapping) = p {
520                                select(mapping.remove(key))
521                            } else {
522                                panic!("Expected a merge connection definition")
523                            }
524                        }),
525                    )
526                    .await
527            }
528
529            ServerConfig::Tagged(underlying, id) => {
530                ServerPort::Tagged(Box::new(underlying.load_instantiated(select).await), *id)
531            }
532
533            ServerConfig::TaggedUnwrap(underlying) => {
534                let loaded = underlying.load_instantiated(select).await;
535                if let ServerPort::Tagged(underlying, _) = loaded {
536                    *underlying
537                } else {
538                    panic!("Expected a tagged connection definition")
539                }
540            }
541
542            ServerConfig::Null => ServerPort::Null,
543        }
544    }
545}