1use std::any::Any;
2use std::collections::HashMap;
3use std::ops::Deref;
4use std::sync::{Arc, Weak};
5
6use anyhow::Result;
7use async_recursion::async_recursion;
8use async_trait::async_trait;
9use dyn_clone::DynClone;
10use hydro_deploy_integration::ServerPort;
11use tokio::sync::RwLock;
12
13use super::RustCrateService;
14use crate::{ClientStrategy, Host, HostStrategyGetter, LaunchedHost, ServerStrategy};
15
16pub trait RustCrateSource: Send + Sync {
17 fn source_path(&self) -> SourcePath;
18 fn record_server_config(&self, config: ServerConfig);
19
20 fn host(&self) -> Arc<dyn Host>;
21 fn server(&self) -> Arc<dyn RustCrateServer>;
22 fn record_server_strategy(&self, config: ServerStrategy);
23
24 fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
25 config
26 }
27
28 fn send_to(&self, sink: &dyn RustCrateSink) {
29 let forward_res = sink.instantiate(&self.source_path());
30 if let Ok(instantiated) = forward_res {
31 self.record_server_config(instantiated());
32 } else {
33 drop(forward_res);
34 let instantiated = sink
35 .instantiate_reverse(&self.host(), self.server(), &|p| {
36 self.wrap_reverse_server_config(p)
37 })
38 .unwrap();
39 self.record_server_strategy(instantiated(sink.as_any()));
40 }
41 }
42}
43
44pub trait RustCrateServer: DynClone + Send + Sync {
45 fn get_port(&self) -> ServerPort;
46 fn launched_host(&self) -> Arc<dyn LaunchedHost>;
47}
48
49pub type ReverseSinkInstantiator = Box<dyn FnOnce(&dyn Any) -> ServerStrategy>;
50
51pub trait RustCrateSink: Send + Sync {
52 fn as_any(&self) -> &dyn Any;
53
54 fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>>;
57
58 fn instantiate_reverse(
61 &self,
62 server_host: &Arc<dyn Host>,
63 server_sink: Arc<dyn RustCrateServer>,
64 wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
65 ) -> Result<ReverseSinkInstantiator>;
66}
67
68pub struct TaggedSource {
69 pub source: Arc<dyn RustCrateSource>,
70 pub tag: u32,
71}
72
73impl RustCrateSource for TaggedSource {
74 fn source_path(&self) -> SourcePath {
75 SourcePath::Tagged(Box::new(self.source.source_path()), self.tag)
76 }
77
78 fn record_server_config(&self, config: ServerConfig) {
79 self.source.record_server_config(config);
80 }
81
82 fn host(&self) -> Arc<dyn Host> {
83 self.source.host()
84 }
85
86 fn server(&self) -> Arc<dyn RustCrateServer> {
87 self.source.server()
88 }
89
90 fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
91 ServerConfig::Tagged(Box::new(config), self.tag)
92 }
93
94 fn record_server_strategy(&self, config: ServerStrategy) {
95 self.source.record_server_strategy(config);
96 }
97}
98
99pub struct NullSourceSink;
100
101impl RustCrateSource for NullSourceSink {
102 fn source_path(&self) -> SourcePath {
103 SourcePath::Null
104 }
105
106 fn host(&self) -> Arc<dyn Host> {
107 panic!("null source has no host")
108 }
109
110 fn server(&self) -> Arc<dyn RustCrateServer> {
111 panic!("null source has no server")
112 }
113
114 fn record_server_config(&self, _config: ServerConfig) {}
115 fn record_server_strategy(&self, _config: ServerStrategy) {}
116}
117
118impl RustCrateSink for NullSourceSink {
119 fn as_any(&self) -> &dyn Any {
120 self
121 }
122
123 fn instantiate(&self, _client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
124 Ok(Box::new(|| ServerConfig::Null))
125 }
126
127 fn instantiate_reverse(
128 &self,
129 _server_host: &Arc<dyn Host>,
130 _server_sink: Arc<dyn RustCrateServer>,
131 _wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
132 ) -> Result<ReverseSinkInstantiator> {
133 Ok(Box::new(|_| ServerStrategy::Null))
134 }
135}
136
137pub struct DemuxSink {
138 pub demux: HashMap<u32, Arc<dyn RustCrateSink>>,
139}
140
141impl RustCrateSink for DemuxSink {
142 fn as_any(&self) -> &dyn Any {
143 self
144 }
145
146 fn instantiate(&self, client_host: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
147 let mut thunk_map = HashMap::new();
148 for (key, target) in &self.demux {
149 thunk_map.insert(*key, target.instantiate(client_host)?);
150 }
151
152 Ok(Box::new(move || {
153 let instantiated_map = thunk_map
154 .into_iter()
155 .map(|(key, thunk)| (key, thunk()))
156 .collect();
157
158 ServerConfig::Demux(instantiated_map)
159 }))
160 }
161
162 fn instantiate_reverse(
163 &self,
164 server_host: &Arc<dyn Host>,
165 server_sink: Arc<dyn RustCrateServer>,
166 wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
167 ) -> Result<ReverseSinkInstantiator> {
168 let mut thunk_map = HashMap::new();
169 for (key, target) in &self.demux {
170 thunk_map.insert(
171 *key,
172 target.instantiate_reverse(
173 server_host,
174 server_sink.clone(),
175 &|p| ServerConfig::DemuxSelect(Box::new(wrap_client_port(p)), *key),
177 )?,
178 );
179 }
180
181 Ok(Box::new(move |me| {
182 let me = me.downcast_ref::<DemuxSink>().unwrap();
183 let instantiated_map = thunk_map
184 .into_iter()
185 .map(|(key, thunk)| (key, thunk(me.demux.get(&key).unwrap().as_any())))
186 .collect();
187
188 ServerStrategy::Demux(instantiated_map)
189 }))
190 }
191}
192
193#[derive(Clone)]
194pub struct RustCratePortConfig {
195 pub service: Weak<RwLock<RustCrateService>>,
196 pub service_host: Arc<dyn Host>,
197 pub service_server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
198 pub port: String,
199 pub merge: bool,
200}
201
202impl RustCratePortConfig {
203 pub fn merge(&self) -> Self {
204 Self {
205 service: self.service.clone(),
206 service_host: self.service_host.clone(),
207 service_server_defns: self.service_server_defns.clone(),
208 port: self.port.clone(),
209 merge: true,
210 }
211 }
212}
213
214impl RustCrateSource for RustCratePortConfig {
215 fn source_path(&self) -> SourcePath {
216 SourcePath::Direct(
217 self.service
218 .upgrade()
219 .unwrap()
220 .try_read()
221 .unwrap()
222 .on
223 .clone(),
224 )
225 }
226
227 fn host(&self) -> Arc<dyn Host> {
228 self.service_host.clone()
229 }
230
231 fn server(&self) -> Arc<dyn RustCrateServer> {
232 let from = self.service.upgrade().unwrap();
233 let from_read = from.try_read().unwrap();
234
235 Arc::new(RustCratePortConfig {
236 service: Arc::downgrade(&from),
237 service_host: from_read.on.clone(),
238 service_server_defns: from_read.server_defns.clone(),
239 port: self.port.clone(),
240 merge: false,
241 })
242 }
243
244 fn record_server_config(&self, config: ServerConfig) {
245 let from = self.service.upgrade().unwrap();
246 let mut from_write = from.try_write().unwrap();
247
248 assert!(
250 !from_write.port_to_server.contains_key(&self.port),
251 "The port configuration is incorrect, for example, are you using a ConnectedDirect instead of a ConnectedDemux?"
252 );
253 from_write.port_to_server.insert(self.port.clone(), config);
254 }
255
256 fn record_server_strategy(&self, config: ServerStrategy) {
257 let from = self.service.upgrade().unwrap();
258 let mut from_write = from.try_write().unwrap();
259
260 assert!(!from_write.port_to_bind.contains_key(&self.port));
261 from_write.port_to_bind.insert(self.port.clone(), config);
262 }
263}
264
265#[async_trait]
266impl RustCrateServer for RustCratePortConfig {
267 fn get_port(&self) -> ServerPort {
268 let server_defns = self.service_server_defns.try_read().unwrap();
270 server_defns.get(&self.port).unwrap().clone()
271 }
272
273 fn launched_host(&self) -> Arc<dyn LaunchedHost> {
274 self.service_host.launched().unwrap()
275 }
276}
277
278pub enum SourcePath {
279 Null,
280 Direct(Arc<dyn Host>),
281 Tagged(Box<SourcePath>, u32),
282}
283
284impl SourcePath {
285 fn plan<T: RustCrateServer + Clone + 'static>(
286 &self,
287 server: &T,
288 server_host: &dyn Host,
289 ) -> Result<(HostStrategyGetter, ServerConfig)> {
290 match self {
291 SourcePath::Direct(client_host) => {
292 let (conn_type, bind_type) = server_host.strategy_as_server(client_host.deref())?;
293 let base_config = ServerConfig::from_strategy(&conn_type, Arc::new(server.clone()));
294 Ok((bind_type, base_config))
295 }
296
297 SourcePath::Tagged(underlying, tag) => {
298 let (bind_type, base_config) = underlying.plan(server, server_host)?;
299 let tag = *tag;
300 let strategy_getter: HostStrategyGetter =
301 Box::new(move |host| ServerStrategy::Tagged(Box::new(bind_type(host)), tag));
302 Ok((
303 strategy_getter,
304 ServerConfig::TaggedUnwrap(Box::new(base_config)),
305 ))
306 }
307
308 SourcePath::Null => {
309 let strategy_getter: HostStrategyGetter = Box::new(|_| ServerStrategy::Null);
310 Ok((strategy_getter, ServerConfig::Null))
311 }
312 }
313 }
314}
315
316impl RustCrateSink for RustCratePortConfig {
317 fn as_any(&self) -> &dyn Any {
318 self
319 }
320
321 fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
322 let server = self.service.upgrade().unwrap();
323 let server_read = server.try_read().unwrap();
324
325 let server_host = server_read.on.clone();
326
327 let (bind_type, base_config) = client_path.plan(self, server_host.deref())?;
328
329 let server = server.clone();
330 let merge = self.merge;
331 let port = self.port.clone();
332 Ok(Box::new(move || {
333 let mut server_write = server.try_write().unwrap();
334 let bind_type = (bind_type)(server_write.on.as_any());
335
336 if merge {
337 let merge_config = server_write
338 .port_to_bind
339 .entry(port.clone())
340 .or_insert(ServerStrategy::Merge(vec![]));
341 let merge_index = if let ServerStrategy::Merge(merge) = merge_config {
342 merge.push(bind_type);
343 merge.len() - 1
344 } else {
345 panic!("Expected a merge connection definition")
346 };
347
348 ServerConfig::MergeSelect(Box::new(base_config), merge_index)
349 } else {
350 assert!(!server_write.port_to_bind.contains_key(&port));
351 server_write.port_to_bind.insert(port.clone(), bind_type);
352 base_config
353 }
354 }))
355 }
356
357 fn instantiate_reverse(
358 &self,
359 server_host: &Arc<dyn Host>,
360 server_sink: Arc<dyn RustCrateServer>,
361 wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
362 ) -> Result<ReverseSinkInstantiator> {
363 let client = self.service.upgrade().unwrap();
364 let client_read = client.try_read().unwrap();
365
366 let server_host = server_host.clone();
367
368 let (conn_type, bind_type) = server_host.strategy_as_server(client_read.on.deref())?;
369 let client_port = wrap_client_port(ServerConfig::from_strategy(&conn_type, server_sink));
370
371 let client = client.clone();
372 let merge = self.merge;
373 let port = self.port.clone();
374 Ok(Box::new(move |_| {
375 let mut client_write = client.try_write().unwrap();
376
377 if merge {
378 let merge_config = client_write
379 .port_to_server
380 .entry(port.clone())
381 .or_insert(ServerConfig::Merge(vec![]));
382
383 if let ServerConfig::Merge(merge) = merge_config {
384 merge.push(client_port);
385 } else {
386 panic!()
387 };
388 } else {
389 assert!(!client_write.port_to_server.contains_key(&port));
390 client_write
391 .port_to_server
392 .insert(port.clone(), client_port);
393 };
394
395 (bind_type)(client_write.on.as_any())
396 }))
397 }
398}
399
400#[derive(Clone)]
401pub enum ServerConfig {
402 Direct(Arc<dyn RustCrateServer>),
403 Forwarded(Arc<dyn RustCrateServer>),
404 Demux(HashMap<u32, ServerConfig>),
406 DemuxSelect(Box<ServerConfig>, u32),
408 Merge(Vec<ServerConfig>),
410 MergeSelect(Box<ServerConfig>, usize),
412 Tagged(Box<ServerConfig>, u32),
413 TaggedUnwrap(Box<ServerConfig>),
414 Null,
415}
416
417impl ServerConfig {
418 pub fn from_strategy(
419 strategy: &ClientStrategy,
420 server: Arc<dyn RustCrateServer>,
421 ) -> ServerConfig {
422 match strategy {
423 ClientStrategy::UnixSocket(_) | ClientStrategy::InternalTcpPort(_) => {
424 ServerConfig::Direct(server)
425 }
426 ClientStrategy::ForwardedTcpPort(_) => ServerConfig::Forwarded(server),
427 }
428 }
429}
430
431#[async_recursion]
432async fn forward_connection(conn: &ServerPort, target: &dyn LaunchedHost) -> ServerPort {
433 match conn {
434 ServerPort::UnixSocket(_) => panic!("Expected a TCP port to be forwarded"),
435 ServerPort::TcpPort(addr) => ServerPort::TcpPort(target.forward_port(addr).await.unwrap()),
436 ServerPort::Demux(demux) => {
437 let mut forwarded_map = HashMap::new();
438 for (key, conn) in demux {
439 forwarded_map.insert(*key, forward_connection(conn, target).await);
440 }
441 ServerPort::Demux(forwarded_map)
442 }
443 ServerPort::Merge(merge) => {
444 let mut forwarded_vec = Vec::new();
445 for conn in merge {
446 forwarded_vec.push(forward_connection(conn, target).await);
447 }
448 ServerPort::Merge(forwarded_vec)
449 }
450 ServerPort::Tagged(underlying, id) => {
451 ServerPort::Tagged(Box::new(forward_connection(underlying, target).await), *id)
452 }
453 ServerPort::Null => ServerPort::Null,
454 }
455}
456
457impl ServerConfig {
458 #[async_recursion]
459 pub async fn load_instantiated(
460 &self,
461 select: &(dyn Fn(ServerPort) -> ServerPort + Send + Sync),
462 ) -> ServerPort {
463 match self {
464 ServerConfig::Direct(server) => select(server.get_port()),
465
466 ServerConfig::Forwarded(server) => {
467 let selected = select(server.get_port());
468 forward_connection(&selected, server.launched_host().as_ref()).await
469 }
470
471 ServerConfig::Demux(demux) => {
472 let mut demux_map = HashMap::new();
473 for (key, conn) in demux {
474 demux_map.insert(*key, conn.load_instantiated(select).await);
475 }
476 ServerPort::Demux(demux_map)
477 }
478
479 ServerConfig::DemuxSelect(underlying, key) => {
480 let key = *key;
481 underlying
482 .load_instantiated(
483 &(move |p| {
484 if let ServerPort::Demux(mut mapping) = p {
485 select(mapping.remove(&key).unwrap())
486 } else {
487 panic!("Expected a demux connection definition")
488 }
489 }),
490 )
491 .await
492 }
493
494 ServerConfig::Merge(merge) => {
495 let mut merge_vec = Vec::new();
496 for conn in merge {
497 merge_vec.push(conn.load_instantiated(select).await);
498 }
499 ServerPort::Merge(merge_vec)
500 }
501
502 ServerConfig::MergeSelect(underlying, key) => {
503 let key = *key;
504 underlying
505 .load_instantiated(
506 &(move |p| {
507 if let ServerPort::Merge(mut mapping) = p {
508 select(mapping.remove(key))
509 } else {
510 panic!("Expected a merge connection definition")
511 }
512 }),
513 )
514 .await
515 }
516
517 ServerConfig::Tagged(underlying, id) => {
518 ServerPort::Tagged(Box::new(underlying.load_instantiated(select).await), *id)
519 }
520
521 ServerConfig::TaggedUnwrap(underlying) => {
522 let loaded = underlying.load_instantiated(select).await;
523 if let ServerPort::Tagged(underlying, _) = loaded {
524 *underlying
525 } else {
526 panic!("Expected a tagged connection definition")
527 }
528 }
529
530 ServerConfig::Null => ServerPort::Null,
531 }
532 }
533}