hydro_deploy/rust_crate/
ports.rs

1use std::any::Any;
2use std::collections::HashMap;
3use std::ops::Deref;
4use std::sync::{Arc, Weak};
5
6use anyhow::Result;
7use async_recursion::async_recursion;
8use async_trait::async_trait;
9use dyn_clone::DynClone;
10use hydro_deploy_integration::ServerPort;
11use tokio::sync::RwLock;
12
13use super::RustCrateService;
14use crate::{ClientStrategy, Host, HostStrategyGetter, LaunchedHost, ServerStrategy};
15
16pub trait RustCrateSource: Send + Sync {
17    fn source_path(&self) -> SourcePath;
18    fn record_server_config(&self, config: ServerConfig);
19
20    fn host(&self) -> Arc<dyn Host>;
21    fn server(&self) -> Arc<dyn RustCrateServer>;
22    fn record_server_strategy(&self, config: ServerStrategy);
23
24    fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
25        config
26    }
27
28    fn send_to(&self, sink: &dyn RustCrateSink) {
29        let forward_res = sink.instantiate(&self.source_path());
30        if let Ok(instantiated) = forward_res {
31            self.record_server_config(instantiated());
32        } else {
33            drop(forward_res);
34            let instantiated = sink
35                .instantiate_reverse(&self.host(), self.server(), &|p| {
36                    self.wrap_reverse_server_config(p)
37                })
38                .unwrap();
39            self.record_server_strategy(instantiated(sink.as_any()));
40        }
41    }
42}
43
44pub trait RustCrateServer: DynClone + Send + Sync {
45    fn get_port(&self) -> ServerPort;
46    fn launched_host(&self) -> Arc<dyn LaunchedHost>;
47}
48
49pub type ReverseSinkInstantiator = Box<dyn FnOnce(&dyn Any) -> ServerStrategy>;
50
51pub trait RustCrateSink: Send + Sync {
52    fn as_any(&self) -> &dyn Any;
53
54    /// Instantiate the sink as the source host connecting to the sink host.
55    /// Returns a thunk that can be called to perform mutations that instantiate the sink.
56    fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>>;
57
58    /// Instantiate the sink, but as the sink host connecting to the source host.
59    /// Returns a thunk that can be called to perform mutations that instantiate the sink, taking a mutable reference to this sink.
60    fn instantiate_reverse(
61        &self,
62        server_host: &Arc<dyn Host>,
63        server_sink: Arc<dyn RustCrateServer>,
64        wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
65    ) -> Result<ReverseSinkInstantiator>;
66}
67
68pub struct TaggedSource {
69    pub source: Arc<dyn RustCrateSource>,
70    pub tag: u32,
71}
72
73impl RustCrateSource for TaggedSource {
74    fn source_path(&self) -> SourcePath {
75        SourcePath::Tagged(Box::new(self.source.source_path()), self.tag)
76    }
77
78    fn record_server_config(&self, config: ServerConfig) {
79        self.source.record_server_config(config);
80    }
81
82    fn host(&self) -> Arc<dyn Host> {
83        self.source.host()
84    }
85
86    fn server(&self) -> Arc<dyn RustCrateServer> {
87        self.source.server()
88    }
89
90    fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
91        ServerConfig::Tagged(Box::new(config), self.tag)
92    }
93
94    fn record_server_strategy(&self, config: ServerStrategy) {
95        self.source.record_server_strategy(config);
96    }
97}
98
99pub struct NullSourceSink;
100
101impl RustCrateSource for NullSourceSink {
102    fn source_path(&self) -> SourcePath {
103        SourcePath::Null
104    }
105
106    fn host(&self) -> Arc<dyn Host> {
107        panic!("null source has no host")
108    }
109
110    fn server(&self) -> Arc<dyn RustCrateServer> {
111        panic!("null source has no server")
112    }
113
114    fn record_server_config(&self, _config: ServerConfig) {}
115    fn record_server_strategy(&self, _config: ServerStrategy) {}
116}
117
118impl RustCrateSink for NullSourceSink {
119    fn as_any(&self) -> &dyn Any {
120        self
121    }
122
123    fn instantiate(&self, _client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
124        Ok(Box::new(|| ServerConfig::Null))
125    }
126
127    fn instantiate_reverse(
128        &self,
129        _server_host: &Arc<dyn Host>,
130        _server_sink: Arc<dyn RustCrateServer>,
131        _wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
132    ) -> Result<ReverseSinkInstantiator> {
133        Ok(Box::new(|_| ServerStrategy::Null))
134    }
135}
136
137pub struct DemuxSink {
138    pub demux: HashMap<u32, Arc<dyn RustCrateSink>>,
139}
140
141impl RustCrateSink for DemuxSink {
142    fn as_any(&self) -> &dyn Any {
143        self
144    }
145
146    fn instantiate(&self, client_host: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
147        let mut thunk_map = HashMap::new();
148        for (key, target) in &self.demux {
149            thunk_map.insert(*key, target.instantiate(client_host)?);
150        }
151
152        Ok(Box::new(move || {
153            let instantiated_map = thunk_map
154                .into_iter()
155                .map(|(key, thunk)| (key, thunk()))
156                .collect();
157
158            ServerConfig::Demux(instantiated_map)
159        }))
160    }
161
162    fn instantiate_reverse(
163        &self,
164        server_host: &Arc<dyn Host>,
165        server_sink: Arc<dyn RustCrateServer>,
166        wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
167    ) -> Result<ReverseSinkInstantiator> {
168        let mut thunk_map = HashMap::new();
169        for (key, target) in &self.demux {
170            thunk_map.insert(
171                *key,
172                target.instantiate_reverse(
173                    server_host,
174                    server_sink.clone(),
175                    // the parent wrapper selects the demux port for the parent defn, so do that first
176                    &|p| ServerConfig::DemuxSelect(Box::new(wrap_client_port(p)), *key),
177                )?,
178            );
179        }
180
181        Ok(Box::new(move |me| {
182            let me = me.downcast_ref::<DemuxSink>().unwrap();
183            let instantiated_map = thunk_map
184                .into_iter()
185                .map(|(key, thunk)| (key, thunk(me.demux.get(&key).unwrap().as_any())))
186                .collect();
187
188            ServerStrategy::Demux(instantiated_map)
189        }))
190    }
191}
192
193#[derive(Clone)]
194pub struct RustCratePortConfig {
195    pub service: Weak<RwLock<RustCrateService>>,
196    pub service_host: Arc<dyn Host>,
197    pub service_server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
198    pub port: String,
199    pub merge: bool,
200}
201
202impl RustCratePortConfig {
203    pub fn merge(&self) -> Self {
204        Self {
205            service: self.service.clone(),
206            service_host: self.service_host.clone(),
207            service_server_defns: self.service_server_defns.clone(),
208            port: self.port.clone(),
209            merge: true,
210        }
211    }
212}
213
214impl RustCrateSource for RustCratePortConfig {
215    fn source_path(&self) -> SourcePath {
216        SourcePath::Direct(
217            self.service
218                .upgrade()
219                .unwrap()
220                .try_read()
221                .unwrap()
222                .on
223                .clone(),
224        )
225    }
226
227    fn host(&self) -> Arc<dyn Host> {
228        self.service_host.clone()
229    }
230
231    fn server(&self) -> Arc<dyn RustCrateServer> {
232        let from = self.service.upgrade().unwrap();
233        let from_read = from.try_read().unwrap();
234
235        Arc::new(RustCratePortConfig {
236            service: Arc::downgrade(&from),
237            service_host: from_read.on.clone(),
238            service_server_defns: from_read.server_defns.clone(),
239            port: self.port.clone(),
240            merge: false,
241        })
242    }
243
244    fn record_server_config(&self, config: ServerConfig) {
245        let from = self.service.upgrade().unwrap();
246        let mut from_write = from.try_write().unwrap();
247
248        // TODO(shadaj): if already in this map, we want to broadcast
249        assert!(
250            !from_write.port_to_server.contains_key(&self.port),
251            "The port configuration is incorrect, for example, are you using a ConnectedDirect instead of a ConnectedDemux?"
252        );
253        from_write.port_to_server.insert(self.port.clone(), config);
254    }
255
256    fn record_server_strategy(&self, config: ServerStrategy) {
257        let from = self.service.upgrade().unwrap();
258        let mut from_write = from.try_write().unwrap();
259
260        assert!(!from_write.port_to_bind.contains_key(&self.port));
261        from_write.port_to_bind.insert(self.port.clone(), config);
262    }
263}
264
265#[async_trait]
266impl RustCrateServer for RustCratePortConfig {
267    fn get_port(&self) -> ServerPort {
268        // we are in `deployment.start()`, so no one should be writing
269        let server_defns = self.service_server_defns.try_read().unwrap();
270        server_defns.get(&self.port).unwrap().clone()
271    }
272
273    fn launched_host(&self) -> Arc<dyn LaunchedHost> {
274        self.service_host.launched().unwrap()
275    }
276}
277
278pub enum SourcePath {
279    Null,
280    Direct(Arc<dyn Host>),
281    Tagged(Box<SourcePath>, u32),
282}
283
284impl SourcePath {
285    fn plan<T: RustCrateServer + Clone + 'static>(
286        &self,
287        server: &T,
288        server_host: &dyn Host,
289    ) -> Result<(HostStrategyGetter, ServerConfig)> {
290        match self {
291            SourcePath::Direct(client_host) => {
292                let (conn_type, bind_type) = server_host.strategy_as_server(client_host.deref())?;
293                let base_config = ServerConfig::from_strategy(&conn_type, Arc::new(server.clone()));
294                Ok((bind_type, base_config))
295            }
296
297            SourcePath::Tagged(underlying, tag) => {
298                let (bind_type, base_config) = underlying.plan(server, server_host)?;
299                let tag = *tag;
300                let strategy_getter: HostStrategyGetter =
301                    Box::new(move |host| ServerStrategy::Tagged(Box::new(bind_type(host)), tag));
302                Ok((
303                    strategy_getter,
304                    ServerConfig::TaggedUnwrap(Box::new(base_config)),
305                ))
306            }
307
308            SourcePath::Null => {
309                let strategy_getter: HostStrategyGetter = Box::new(|_| ServerStrategy::Null);
310                Ok((strategy_getter, ServerConfig::Null))
311            }
312        }
313    }
314}
315
316impl RustCrateSink for RustCratePortConfig {
317    fn as_any(&self) -> &dyn Any {
318        self
319    }
320
321    fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
322        let server = self.service.upgrade().unwrap();
323        let server_read = server.try_read().unwrap();
324
325        let server_host = server_read.on.clone();
326
327        let (bind_type, base_config) = client_path.plan(self, server_host.deref())?;
328
329        let server = server.clone();
330        let merge = self.merge;
331        let port = self.port.clone();
332        Ok(Box::new(move || {
333            let mut server_write = server.try_write().unwrap();
334            let bind_type = (bind_type)(server_write.on.as_any());
335
336            if merge {
337                let merge_config = server_write
338                    .port_to_bind
339                    .entry(port.clone())
340                    .or_insert(ServerStrategy::Merge(vec![]));
341                let merge_index = if let ServerStrategy::Merge(merge) = merge_config {
342                    merge.push(bind_type);
343                    merge.len() - 1
344                } else {
345                    panic!("Expected a merge connection definition")
346                };
347
348                ServerConfig::MergeSelect(Box::new(base_config), merge_index)
349            } else {
350                assert!(!server_write.port_to_bind.contains_key(&port));
351                server_write.port_to_bind.insert(port.clone(), bind_type);
352                base_config
353            }
354        }))
355    }
356
357    fn instantiate_reverse(
358        &self,
359        server_host: &Arc<dyn Host>,
360        server_sink: Arc<dyn RustCrateServer>,
361        wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
362    ) -> Result<ReverseSinkInstantiator> {
363        let client = self.service.upgrade().unwrap();
364        let client_read = client.try_read().unwrap();
365
366        let server_host = server_host.clone();
367
368        let (conn_type, bind_type) = server_host.strategy_as_server(client_read.on.deref())?;
369        let client_port = wrap_client_port(ServerConfig::from_strategy(&conn_type, server_sink));
370
371        let client = client.clone();
372        let merge = self.merge;
373        let port = self.port.clone();
374        Ok(Box::new(move |_| {
375            let mut client_write = client.try_write().unwrap();
376
377            if merge {
378                let merge_config = client_write
379                    .port_to_server
380                    .entry(port.clone())
381                    .or_insert(ServerConfig::Merge(vec![]));
382
383                if let ServerConfig::Merge(merge) = merge_config {
384                    merge.push(client_port);
385                } else {
386                    panic!()
387                };
388            } else {
389                assert!(!client_write.port_to_server.contains_key(&port));
390                client_write
391                    .port_to_server
392                    .insert(port.clone(), client_port);
393            };
394
395            (bind_type)(client_write.on.as_any())
396        }))
397    }
398}
399
400#[derive(Clone)]
401pub enum ServerConfig {
402    Direct(Arc<dyn RustCrateServer>),
403    Forwarded(Arc<dyn RustCrateServer>),
404    /// A demux that will be used at runtime to listen to many connections.
405    Demux(HashMap<u32, ServerConfig>),
406    /// The other side of a demux, with a port to extract the appropriate connection.
407    DemuxSelect(Box<ServerConfig>, u32),
408    /// A merge that will be used at runtime to combine many connections.
409    Merge(Vec<ServerConfig>),
410    /// The other side of a merge, with a port to extract the appropriate connection.
411    MergeSelect(Box<ServerConfig>, usize),
412    Tagged(Box<ServerConfig>, u32),
413    TaggedUnwrap(Box<ServerConfig>),
414    Null,
415}
416
417impl ServerConfig {
418    pub fn from_strategy(
419        strategy: &ClientStrategy,
420        server: Arc<dyn RustCrateServer>,
421    ) -> ServerConfig {
422        match strategy {
423            ClientStrategy::UnixSocket(_) | ClientStrategy::InternalTcpPort(_) => {
424                ServerConfig::Direct(server)
425            }
426            ClientStrategy::ForwardedTcpPort(_) => ServerConfig::Forwarded(server),
427        }
428    }
429}
430
431#[async_recursion]
432async fn forward_connection(conn: &ServerPort, target: &dyn LaunchedHost) -> ServerPort {
433    match conn {
434        ServerPort::UnixSocket(_) => panic!("Expected a TCP port to be forwarded"),
435        ServerPort::TcpPort(addr) => ServerPort::TcpPort(target.forward_port(addr).await.unwrap()),
436        ServerPort::Demux(demux) => {
437            let mut forwarded_map = HashMap::new();
438            for (key, conn) in demux {
439                forwarded_map.insert(*key, forward_connection(conn, target).await);
440            }
441            ServerPort::Demux(forwarded_map)
442        }
443        ServerPort::Merge(merge) => {
444            let mut forwarded_vec = Vec::new();
445            for conn in merge {
446                forwarded_vec.push(forward_connection(conn, target).await);
447            }
448            ServerPort::Merge(forwarded_vec)
449        }
450        ServerPort::Tagged(underlying, id) => {
451            ServerPort::Tagged(Box::new(forward_connection(underlying, target).await), *id)
452        }
453        ServerPort::Null => ServerPort::Null,
454    }
455}
456
457impl ServerConfig {
458    #[async_recursion]
459    pub async fn load_instantiated(
460        &self,
461        select: &(dyn Fn(ServerPort) -> ServerPort + Send + Sync),
462    ) -> ServerPort {
463        match self {
464            ServerConfig::Direct(server) => select(server.get_port()),
465
466            ServerConfig::Forwarded(server) => {
467                let selected = select(server.get_port());
468                forward_connection(&selected, server.launched_host().as_ref()).await
469            }
470
471            ServerConfig::Demux(demux) => {
472                let mut demux_map = HashMap::new();
473                for (key, conn) in demux {
474                    demux_map.insert(*key, conn.load_instantiated(select).await);
475                }
476                ServerPort::Demux(demux_map)
477            }
478
479            ServerConfig::DemuxSelect(underlying, key) => {
480                let key = *key;
481                underlying
482                    .load_instantiated(
483                        &(move |p| {
484                            if let ServerPort::Demux(mut mapping) = p {
485                                select(mapping.remove(&key).unwrap())
486                            } else {
487                                panic!("Expected a demux connection definition")
488                            }
489                        }),
490                    )
491                    .await
492            }
493
494            ServerConfig::Merge(merge) => {
495                let mut merge_vec = Vec::new();
496                for conn in merge {
497                    merge_vec.push(conn.load_instantiated(select).await);
498                }
499                ServerPort::Merge(merge_vec)
500            }
501
502            ServerConfig::MergeSelect(underlying, key) => {
503                let key = *key;
504                underlying
505                    .load_instantiated(
506                        &(move |p| {
507                            if let ServerPort::Merge(mut mapping) = p {
508                                select(mapping.remove(key))
509                            } else {
510                                panic!("Expected a merge connection definition")
511                            }
512                        }),
513                    )
514                    .await
515            }
516
517            ServerConfig::Tagged(underlying, id) => {
518                ServerPort::Tagged(Box::new(underlying.load_instantiated(select).await), *id)
519            }
520
521            ServerConfig::TaggedUnwrap(underlying) => {
522                let loaded = underlying.load_instantiated(select).await;
523                if let ServerPort::Tagged(underlying, _) = loaded {
524                    *underlying
525                } else {
526                    panic!("Expected a tagged connection definition")
527                }
528            }
529
530            ServerConfig::Null => ServerPort::Null,
531        }
532    }
533}