1use std::any::Any;
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::ops::Deref;
5use std::sync::{Arc, Weak};
6
7use anyhow::{Result, bail};
8use async_recursion::async_recursion;
9use async_trait::async_trait;
10use dyn_clone::DynClone;
11use hydro_deploy_integration::ServerPort;
12use tokio::sync::RwLock;
13
14use super::RustCrateService;
15use crate::{ClientStrategy, Host, LaunchedHost, PortNetworkHint, ServerStrategy};
16
17pub trait RustCrateSource: Send + Sync {
18 fn source_path(&self) -> SourcePath;
19 fn record_server_config(&self, config: ServerConfig);
20
21 fn host(&self) -> Arc<dyn Host>;
22 fn server(&self) -> Arc<dyn RustCrateServer>;
23 fn record_server_strategy(&self, config: ServerStrategy);
24
25 fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
26 config
27 }
28
29 fn send_to(&self, sink: &dyn RustCrateSink) {
30 let forward_res = sink.instantiate(&self.source_path());
31 if let Ok(instantiated) = forward_res {
32 self.record_server_config(instantiated());
33 } else {
34 drop(forward_res);
35 let instantiated = sink
36 .instantiate_reverse(&self.host(), self.server(), &|p| {
37 self.wrap_reverse_server_config(p)
38 })
39 .unwrap();
40 self.record_server_strategy(instantiated(sink));
41 }
42 }
43}
44
45pub trait RustCrateServer: DynClone + Debug + Send + Sync {
46 fn get_port(&self) -> ServerPort;
47 fn launched_host(&self) -> Arc<dyn LaunchedHost>;
48}
49
50pub type ReverseSinkInstantiator = Box<dyn FnOnce(&dyn Any) -> ServerStrategy>;
51
52pub trait RustCrateSink: Any + Send + Sync {
53 fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>>;
56
57 fn instantiate_reverse(
60 &self,
61 server_host: &Arc<dyn Host>,
62 server_sink: Arc<dyn RustCrateServer>,
63 wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
64 ) -> Result<ReverseSinkInstantiator>;
65}
66
67pub struct TaggedSource {
68 pub source: Arc<dyn RustCrateSource>,
69 pub tag: u32,
70}
71
72impl RustCrateSource for TaggedSource {
73 fn source_path(&self) -> SourcePath {
74 SourcePath::Tagged(Box::new(self.source.source_path()), self.tag)
75 }
76
77 fn record_server_config(&self, config: ServerConfig) {
78 self.source.record_server_config(config);
79 }
80
81 fn host(&self) -> Arc<dyn Host> {
82 self.source.host()
83 }
84
85 fn server(&self) -> Arc<dyn RustCrateServer> {
86 self.source.server()
87 }
88
89 fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
90 ServerConfig::Tagged(Box::new(config), self.tag)
91 }
92
93 fn record_server_strategy(&self, config: ServerStrategy) {
94 self.source.record_server_strategy(config);
95 }
96}
97
98pub struct NullSourceSink;
99
100impl RustCrateSource for NullSourceSink {
101 fn source_path(&self) -> SourcePath {
102 SourcePath::Null
103 }
104
105 fn host(&self) -> Arc<dyn Host> {
106 panic!("null source has no host")
107 }
108
109 fn server(&self) -> Arc<dyn RustCrateServer> {
110 panic!("null source has no server")
111 }
112
113 fn record_server_config(&self, _config: ServerConfig) {}
114 fn record_server_strategy(&self, _config: ServerStrategy) {}
115}
116
117impl RustCrateSink for NullSourceSink {
118 fn instantiate(&self, _client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
119 Ok(Box::new(|| ServerConfig::Null))
120 }
121
122 fn instantiate_reverse(
123 &self,
124 _server_host: &Arc<dyn Host>,
125 _server_sink: Arc<dyn RustCrateServer>,
126 _wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
127 ) -> Result<ReverseSinkInstantiator> {
128 Ok(Box::new(|_| ServerStrategy::Null))
129 }
130}
131
132pub struct DemuxSink {
133 pub demux: HashMap<u32, Arc<dyn RustCrateSink>>,
134}
135
136impl RustCrateSink for DemuxSink {
137 fn instantiate(&self, client_host: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
138 let mut thunk_map = HashMap::new();
139 for (key, target) in &self.demux {
140 thunk_map.insert(*key, target.instantiate(client_host)?);
141 }
142
143 Ok(Box::new(move || {
144 let instantiated_map = thunk_map
145 .into_iter()
146 .map(|(key, thunk)| (key, thunk()))
147 .collect();
148
149 ServerConfig::Demux(instantiated_map)
150 }))
151 }
152
153 fn instantiate_reverse(
154 &self,
155 server_host: &Arc<dyn Host>,
156 server_sink: Arc<dyn RustCrateServer>,
157 wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
158 ) -> Result<ReverseSinkInstantiator> {
159 let mut thunk_map = HashMap::new();
160 for (key, target) in &self.demux {
161 thunk_map.insert(
162 *key,
163 target.instantiate_reverse(
164 server_host,
165 server_sink.clone(),
166 &|p| ServerConfig::DemuxSelect(Box::new(wrap_client_port(p)), *key),
168 )?,
169 );
170 }
171
172 Ok(Box::new(move |me| {
173 let me = me.downcast_ref::<DemuxSink>().unwrap();
174 let instantiated_map = thunk_map
175 .into_iter()
176 .map(|(key, thunk)| (key, (thunk)(me.demux.get(&key).unwrap())))
177 .collect();
178
179 ServerStrategy::Demux(instantiated_map)
180 }))
181 }
182}
183
184#[derive(Clone, Debug)]
185pub struct RustCratePortConfig {
186 pub service: Weak<RwLock<RustCrateService>>,
187 pub service_host: Arc<dyn Host>,
188 pub service_server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
189 pub network_hint: PortNetworkHint,
190 pub port: String,
191 pub merge: bool,
192}
193
194impl RustCratePortConfig {
195 pub fn merge(&self) -> Self {
196 Self {
197 service: self.service.clone(),
198 service_host: self.service_host.clone(),
199 service_server_defns: self.service_server_defns.clone(),
200 network_hint: self.network_hint,
201 port: self.port.clone(),
202 merge: true,
203 }
204 }
205}
206
207impl RustCrateSource for RustCratePortConfig {
208 fn source_path(&self) -> SourcePath {
209 SourcePath::Direct(
210 self.service
211 .upgrade()
212 .unwrap()
213 .try_read()
214 .unwrap()
215 .on
216 .clone(),
217 )
218 }
219
220 fn host(&self) -> Arc<dyn Host> {
221 self.service_host.clone()
222 }
223
224 fn server(&self) -> Arc<dyn RustCrateServer> {
225 let from = self.service.upgrade().unwrap();
226 let from_read = from.try_read().unwrap();
227
228 Arc::new(RustCratePortConfig {
229 service: Arc::downgrade(&from),
230 service_host: from_read.on.clone(),
231 service_server_defns: from_read.server_defns.clone(),
232 network_hint: self.network_hint,
233 port: self.port.clone(),
234 merge: false,
235 })
236 }
237
238 fn record_server_config(&self, config: ServerConfig) {
239 let from = self.service.upgrade().unwrap();
240 let mut from_write = from.try_write().unwrap();
241
242 assert!(
244 !from_write.port_to_server.contains_key(&self.port),
245 "The port configuration is incorrect, for example, are you using a ConnectedDirect instead of a ConnectedDemux?"
246 );
247 from_write.port_to_server.insert(self.port.clone(), config);
248 }
249
250 fn record_server_strategy(&self, config: ServerStrategy) {
251 let from = self.service.upgrade().unwrap();
252 let mut from_write = from.try_write().unwrap();
253
254 assert!(!from_write.port_to_bind.contains_key(&self.port));
255 from_write.port_to_bind.insert(self.port.clone(), config);
256 }
257}
258
259#[async_trait]
260impl RustCrateServer for RustCratePortConfig {
261 fn get_port(&self) -> ServerPort {
262 let server_defns = self.service_server_defns.try_read().unwrap();
264 server_defns.get(&self.port).unwrap().clone()
265 }
266
267 fn launched_host(&self) -> Arc<dyn LaunchedHost> {
268 self.service_host.launched().unwrap()
269 }
270}
271
272pub enum SourcePath {
273 Null,
274 Direct(Arc<dyn Host>),
275 Many(Arc<dyn Host>),
276 Tagged(Box<SourcePath>, u32),
277}
278
279impl SourcePath {
280 #[expect(
281 clippy::type_complexity,
282 reason = "internals (dyn Fn to defer instantiation)"
283 )]
284 fn plan<T: RustCrateServer + Clone + 'static>(
285 &self,
286 server: &T,
287 server_host: &dyn Host,
288 network_hint: PortNetworkHint,
289 ) -> Result<(Box<dyn FnOnce(&dyn Any) -> ServerStrategy>, ServerConfig)> {
290 match self {
291 SourcePath::Direct(client_host) => {
292 let (conn_type, bind_type) =
293 server_host.strategy_as_server(client_host.deref(), network_hint)?;
294 let base_config = ServerConfig::from_strategy(&conn_type, Arc::new(server.clone()));
295 Ok((
296 Box::new(|host| ServerStrategy::Direct(bind_type(host))),
297 base_config,
298 ))
299 }
300
301 SourcePath::Many(client_host) => {
302 let (conn_type, bind_type) =
303 server_host.strategy_as_server(client_host.deref(), network_hint)?;
304 let base_config = ServerConfig::from_strategy(&conn_type, Arc::new(server.clone()));
305 Ok((
306 Box::new(|host| ServerStrategy::Many(bind_type(host))),
307 base_config,
308 ))
309 }
310
311 SourcePath::Tagged(underlying, tag) => {
312 let (bind_type, base_config) =
313 underlying.plan(server, server_host, network_hint)?;
314 let tag = *tag;
315 Ok((
316 Box::new(move |host| ServerStrategy::Tagged(Box::new(bind_type(host)), tag)),
317 ServerConfig::TaggedUnwrap(Box::new(base_config)),
318 ))
319 }
320
321 SourcePath::Null => Ok((Box::new(|_| ServerStrategy::Null), ServerConfig::Null)),
322 }
323 }
324}
325
326impl RustCrateSink for RustCratePortConfig {
327 fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
328 let server = self.service.upgrade().unwrap();
329 let server_read = server.try_read().unwrap();
330
331 let server_host = server_read.on.clone();
332
333 let (bind_type, base_config) =
334 client_path.plan(self, server_host.deref(), self.network_hint)?;
335
336 let server = server.clone();
337 let merge = self.merge;
338 let port = self.port.clone();
339 Ok(Box::new(move || {
340 let mut server_write = server.try_write().unwrap();
341 let bind_type = (bind_type)(&*server_write.on);
342
343 if merge {
344 let merge_config = server_write
345 .port_to_bind
346 .entry(port.clone())
347 .or_insert(ServerStrategy::Merge(vec![]));
348 let merge_index = if let ServerStrategy::Merge(merge) = merge_config {
349 merge.push(bind_type);
350 merge.len() - 1
351 } else {
352 panic!("Expected a merge connection definition")
353 };
354
355 ServerConfig::MergeSelect(Box::new(base_config), merge_index)
356 } else {
357 assert!(!server_write.port_to_bind.contains_key(&port));
358 server_write.port_to_bind.insert(port.clone(), bind_type);
359 base_config
360 }
361 }))
362 }
363
364 fn instantiate_reverse(
365 &self,
366 server_host: &Arc<dyn Host>,
367 server_sink: Arc<dyn RustCrateServer>,
368 wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
369 ) -> Result<ReverseSinkInstantiator> {
370 if !matches!(self.network_hint, PortNetworkHint::Auto) {
371 bail!("Trying to form collection where I am the client, but I have server hint")
372 }
373
374 let client = self.service.upgrade().unwrap();
375 let client_read = client.try_read().unwrap();
376
377 let server_host = server_host.clone();
378
379 let (conn_type, bind_type) =
380 server_host.strategy_as_server(client_read.on.deref(), PortNetworkHint::Auto)?;
381 let client_port = wrap_client_port(ServerConfig::from_strategy(&conn_type, server_sink));
382
383 let client = client.clone();
384 let merge = self.merge;
385 let port = self.port.clone();
386 Ok(Box::new(move |_| {
387 let mut client_write = client.try_write().unwrap();
388
389 if merge {
390 let merge_config = client_write
391 .port_to_server
392 .entry(port.clone())
393 .or_insert(ServerConfig::Merge(vec![]));
394
395 if let ServerConfig::Merge(merge) = merge_config {
396 merge.push(client_port);
397 } else {
398 panic!()
399 };
400 } else {
401 assert!(!client_write.port_to_server.contains_key(&port));
402 client_write
403 .port_to_server
404 .insert(port.clone(), client_port);
405 };
406
407 ServerStrategy::Direct((bind_type)(&*client_write.on))
408 }))
409 }
410}
411
412#[derive(Clone, Debug)]
413pub enum ServerConfig {
414 Direct(Arc<dyn RustCrateServer>),
415 Forwarded(Arc<dyn RustCrateServer>),
416 Demux(HashMap<u32, ServerConfig>),
418 DemuxSelect(Box<ServerConfig>, u32),
420 Merge(Vec<ServerConfig>),
422 MergeSelect(Box<ServerConfig>, usize),
424 Tagged(Box<ServerConfig>, u32),
425 TaggedUnwrap(Box<ServerConfig>),
426 Null,
427}
428
429impl ServerConfig {
430 pub fn from_strategy(
431 strategy: &ClientStrategy,
432 server: Arc<dyn RustCrateServer>,
433 ) -> ServerConfig {
434 match strategy {
435 ClientStrategy::UnixSocket(_) | ClientStrategy::InternalTcpPort(_) => {
436 ServerConfig::Direct(server)
437 }
438 ClientStrategy::ForwardedTcpPort(_) => ServerConfig::Forwarded(server),
439 }
440 }
441}
442
443#[async_recursion]
444async fn forward_connection(conn: &ServerPort, target: &dyn LaunchedHost) -> ServerPort {
445 match conn {
446 ServerPort::UnixSocket(_) => panic!("Expected a TCP port to be forwarded"),
447 ServerPort::TcpPort(addr) => ServerPort::TcpPort(target.forward_port(addr).await.unwrap()),
448 ServerPort::Demux(demux) => {
449 let mut forwarded_map = HashMap::new();
450 for (key, conn) in demux {
451 forwarded_map.insert(*key, forward_connection(conn, target).await);
452 }
453 ServerPort::Demux(forwarded_map)
454 }
455 ServerPort::Merge(merge) => {
456 let mut forwarded_vec = Vec::new();
457 for conn in merge {
458 forwarded_vec.push(forward_connection(conn, target).await);
459 }
460 ServerPort::Merge(forwarded_vec)
461 }
462 ServerPort::Tagged(underlying, id) => {
463 ServerPort::Tagged(Box::new(forward_connection(underlying, target).await), *id)
464 }
465 ServerPort::Null => ServerPort::Null,
466 }
467}
468
469impl ServerConfig {
470 #[async_recursion]
471 pub async fn load_instantiated(
472 &self,
473 select: &(dyn Fn(ServerPort) -> ServerPort + Send + Sync),
474 ) -> ServerPort {
475 match self {
476 ServerConfig::Direct(server) => select(server.get_port()),
477
478 ServerConfig::Forwarded(server) => {
479 let selected = select(server.get_port());
480 forward_connection(&selected, server.launched_host().as_ref()).await
481 }
482
483 ServerConfig::Demux(demux) => {
484 let mut demux_map = HashMap::new();
485 for (key, conn) in demux {
486 demux_map.insert(*key, conn.load_instantiated(select).await);
487 }
488 ServerPort::Demux(demux_map)
489 }
490
491 ServerConfig::DemuxSelect(underlying, key) => {
492 let key = *key;
493 underlying
494 .load_instantiated(
495 &(move |p| {
496 if let ServerPort::Demux(mut mapping) = p {
497 select(mapping.remove(&key).unwrap())
498 } else {
499 panic!("Expected a demux connection definition")
500 }
501 }),
502 )
503 .await
504 }
505
506 ServerConfig::Merge(merge) => {
507 let mut merge_vec = Vec::new();
508 for conn in merge {
509 merge_vec.push(conn.load_instantiated(select).await);
510 }
511 ServerPort::Merge(merge_vec)
512 }
513
514 ServerConfig::MergeSelect(underlying, key) => {
515 let key = *key;
516 underlying
517 .load_instantiated(
518 &(move |p| {
519 if let ServerPort::Merge(mut mapping) = p {
520 select(mapping.remove(key))
521 } else {
522 panic!("Expected a merge connection definition")
523 }
524 }),
525 )
526 .await
527 }
528
529 ServerConfig::Tagged(underlying, id) => {
530 ServerPort::Tagged(Box::new(underlying.load_instantiated(select).await), *id)
531 }
532
533 ServerConfig::TaggedUnwrap(underlying) => {
534 let loaded = underlying.load_instantiated(select).await;
535 if let ServerPort::Tagged(underlying, _) = loaded {
536 *underlying
537 } else {
538 panic!("Expected a tagged connection definition")
539 }
540 }
541
542 ServerConfig::Null => ServerPort::Null,
543 }
544 }
545}