1use std::collections::HashMap;
2use std::marker::PhantomData;
3use std::net::SocketAddr;
4use std::path::PathBuf;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::time::Duration;
8
9use async_recursion::async_recursion;
10use async_trait::async_trait;
11use bytes::{Bytes, BytesMut};
12use futures::sink::Buffer;
13use futures::{Future, Sink, SinkExt, Stream, ready, stream};
14use pin_project::pin_project;
15use serde::{Deserialize, Serialize};
16use tokio::io;
17use tokio::net::{TcpListener, TcpStream};
18#[cfg(unix)]
19use tokio::net::{UnixListener, UnixStream};
20use tokio::task::JoinHandle;
21use tokio_stream::StreamExt;
22use tokio_stream::wrappers::TcpListenerStream;
23use tokio_util::codec::{Framed, LengthDelimitedCodec};
24
25pub type InitConfig = (HashMap<String, ServerBindConfig>, Option<String>);
26
27#[cfg(not(unix))]
28type UnixStream = std::convert::Infallible;
29
30#[cfg(not(unix))]
31#[expect(dead_code, reason = "conditional compilation placeholder")]
32type UnixListener = std::convert::Infallible;
33
34#[derive(Serialize, Deserialize, Clone, Debug)]
36pub enum ServerPort {
37 UnixSocket(PathBuf),
38 TcpPort(SocketAddr),
39 Demux(HashMap<u32, ServerPort>),
40 Merge(Vec<ServerPort>),
41 Tagged(Box<ServerPort>, u32),
42 Null,
43}
44
45impl ServerPort {
46 pub fn connect(&self) -> ClientConnection {
47 match self {
48 ServerPort::UnixSocket(path) => {
49 #[cfg(unix)]
50 {
51 let bound = UnixStream::connect(path.clone());
52 ClientConnection::UnixSocket(tokio::spawn(bound))
53 }
54
55 #[cfg(not(unix))]
56 {
57 let _ = path;
58 panic!("Unix sockets are not supported on this platform")
59 }
60 }
61 ServerPort::TcpPort(addr) => {
62 let addr_clone = *addr;
63 let bound = async_retry(
64 move || TcpStream::connect(addr_clone),
65 10,
66 Duration::from_secs(1),
67 );
68 ClientConnection::TcpPort(tokio::spawn(bound))
69 }
70 ServerPort::Demux(bindings) => {
71 ClientConnection::Demux(bindings.iter().map(|(k, v)| (*k, v.connect())).collect())
72 }
73 ServerPort::Merge(ports) => {
74 ClientConnection::Merge(ports.iter().map(|p| p.connect()).collect())
75 }
76 ServerPort::Tagged(port, tag) => {
77 ClientConnection::Tagged(Box::new(port.as_ref().connect()), *tag)
78 }
79 ServerPort::Null => ClientConnection::Null,
80 }
81 }
82
83 pub fn instantiate(&self) -> Connection {
84 Connection::AsClient(self.connect())
85 }
86}
87
88#[derive(Debug)]
89pub enum ClientConnection {
90 UnixSocket(JoinHandle<io::Result<UnixStream>>),
91 TcpPort(JoinHandle<io::Result<TcpStream>>),
92 Demux(HashMap<u32, ClientConnection>),
93 Merge(Vec<ClientConnection>),
94 Tagged(Box<ClientConnection>, u32),
95 Null,
96}
97
98#[derive(Serialize, Deserialize, Clone, Debug)]
99pub enum ServerBindConfig {
100 UnixSocket,
101 TcpPort(
102 String,
104 ),
105 Demux(HashMap<u32, ServerBindConfig>),
106 Merge(Vec<ServerBindConfig>),
107 Tagged(Box<ServerBindConfig>, u32),
108 Null,
109}
110
111impl ServerBindConfig {
112 #[async_recursion]
113 pub async fn bind(self) -> BoundServer {
114 match self {
115 ServerBindConfig::UnixSocket => {
116 #[cfg(unix)]
117 {
118 let dir = tempfile::tempdir().unwrap();
119 let socket_path = dir.path().join("socket");
120 let bound = UnixListener::bind(socket_path).unwrap();
121 BoundServer::UnixSocket(
122 tokio::spawn(async move { Ok(bound.accept().await?.0) }),
123 dir,
124 )
125 }
126
127 #[cfg(not(unix))]
128 {
129 panic!("Unix sockets are not supported on this platform")
130 }
131 }
132 ServerBindConfig::TcpPort(host) => {
133 let listener = TcpListener::bind((host, 0)).await.unwrap();
134 let addr = listener.local_addr().unwrap();
135 BoundServer::TcpPort(TcpListenerStream::new(listener), addr)
136 }
137 ServerBindConfig::Demux(bindings) => {
138 let mut demux = HashMap::new();
139 for (key, bind) in bindings {
140 demux.insert(key, bind.bind().await);
141 }
142 BoundServer::Demux(demux)
143 }
144 ServerBindConfig::Merge(bindings) => {
145 let mut merge = Vec::new();
146 for bind in bindings {
147 merge.push(bind.bind().await);
148 }
149 BoundServer::Merge(merge)
150 }
151 ServerBindConfig::Tagged(underlying, id) => {
152 BoundServer::Tagged(Box::new(underlying.bind().await), id)
153 }
154 ServerBindConfig::Null => BoundServer::Null,
155 }
156 }
157}
158
159#[derive(Debug)]
160pub enum Connection {
161 AsClient(ClientConnection),
162 AsServer(BoundServer),
163}
164
165impl Connection {
166 pub async fn connect<T: Connected>(self) -> T {
167 T::from_defn(self).await
168 }
169
170 pub fn connect_local_blocking<T: Connected>(self) -> T {
171 let handle = tokio::runtime::Handle::current();
172 let _guard = handle.enter();
173 futures::executor::block_on(T::from_defn(self))
174 }
175
176 pub async fn accept_tcp(&mut self) -> TcpStream {
177 if let Connection::AsServer(BoundServer::TcpPort(handle, _)) = self {
178 handle.next().await.unwrap().unwrap()
179 } else {
180 panic!("Not a TCP port")
181 }
182 }
183}
184
185pub type DynStream = Pin<Box<dyn Stream<Item = Result<BytesMut, io::Error>> + Send + Sync>>;
186
187pub type DynSink<Input> = Pin<Box<dyn Sink<Input, Error = io::Error> + Send + Sync>>;
188
189pub trait StreamSink:
190 Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>
191{
192}
193impl<T: Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>> StreamSink
194 for T
195{
196}
197
198pub type DynStreamSink = Pin<Box<dyn StreamSink + Send + Sync>>;
199
200#[async_trait]
201pub trait Connected: Send {
202 async fn from_defn(pipe: Connection) -> Self;
203}
204
205pub trait ConnectedSink {
206 type Input: Send;
207 type Sink: Sink<Self::Input, Error = io::Error> + Send + Sync;
208
209 fn into_sink(self) -> Self::Sink;
210}
211
212pub trait ConnectedSource {
213 type Output: Send;
214 type Stream: Stream<Item = Result<Self::Output, io::Error>> + Send + Sync;
215 fn into_source(self) -> Self::Stream;
216}
217
218#[derive(Debug)]
219pub enum BoundServer {
220 UnixSocket(JoinHandle<io::Result<UnixStream>>, tempfile::TempDir),
221 TcpPort(TcpListenerStream, SocketAddr),
222 Demux(HashMap<u32, BoundServer>),
223 Merge(Vec<BoundServer>),
224 Tagged(Box<BoundServer>, u32),
225 Null,
226}
227
228impl BoundServer {
229 pub fn server_port(&self) -> ServerPort {
230 match self {
231 BoundServer::UnixSocket(_, tempdir) => {
232 #[cfg(unix)]
233 {
234 ServerPort::UnixSocket(tempdir.path().join("socket"))
235 }
236
237 #[cfg(not(unix))]
238 {
239 let _ = tempdir;
240 panic!("Unix sockets are not supported on this platform")
241 }
242 }
243 BoundServer::TcpPort(_, addr) => {
244 ServerPort::TcpPort(SocketAddr::new(addr.ip(), addr.port()))
245 }
246
247 BoundServer::Demux(bindings) => {
248 let mut demux = HashMap::new();
249 for (key, bind) in bindings {
250 demux.insert(*key, bind.server_port());
251 }
252 ServerPort::Demux(demux)
253 }
254
255 BoundServer::Merge(bindings) => {
256 let mut merge = Vec::new();
257 for bind in bindings {
258 merge.push(bind.server_port());
259 }
260 ServerPort::Merge(merge)
261 }
262
263 BoundServer::Tagged(underlying, id) => {
264 ServerPort::Tagged(Box::new(underlying.server_port()), *id)
265 }
266
267 BoundServer::Null => ServerPort::Null,
268 }
269 }
270}
271
272#[async_recursion]
273async fn accept(bound: BoundServer) -> ConnectedDirect {
274 match bound {
275 BoundServer::UnixSocket(listener, _) => {
276 #[cfg(unix)]
277 {
278 let stream = listener.await.unwrap().unwrap();
279 ConnectedDirect {
280 stream_sink: Some(Box::pin(unix_bytes(stream))),
281 source_only: None,
282 sink_only: None,
283 }
284 }
285
286 #[cfg(not(unix))]
287 {
288 drop(listener);
289 panic!("Unix sockets are not supported on this platform")
290 }
291 }
292 BoundServer::TcpPort(mut listener, _) => {
293 let stream = listener.next().await.unwrap().unwrap();
294 ConnectedDirect {
295 stream_sink: Some(Box::pin(tcp_bytes(stream))),
296 source_only: None,
297 sink_only: None,
298 }
299 }
300 BoundServer::Merge(merge) => {
301 let mut sources = vec![];
302 for bound in merge {
303 sources.push(accept(bound).await.into_source());
304 }
305
306 let merge_source: DynStream = Box::pin(MergeSource {
307 marker: PhantomData,
308 sources,
309 });
310
311 ConnectedDirect {
312 stream_sink: None,
313 source_only: Some(merge_source),
314 sink_only: None,
315 }
316 }
317 BoundServer::Demux(_) => panic!("Cannot connect to a demux pipe directly"),
318 BoundServer::Tagged(_, _) => panic!("Cannot connect to a tagged pipe directly"),
319 BoundServer::Null => {
320 ConnectedDirect::from_defn(Connection::AsClient(ClientConnection::Null)).await
321 }
322 }
323}
324
325fn tcp_bytes(stream: TcpStream) -> impl StreamSink {
326 Framed::new(stream, LengthDelimitedCodec::new())
327}
328
329#[cfg(unix)]
330fn unix_bytes(stream: UnixStream) -> impl StreamSink {
331 Framed::new(stream, LengthDelimitedCodec::new())
332}
333
334struct IoErrorDrain<T> {
335 marker: PhantomData<T>,
336}
337
338impl<T> Sink<T> for IoErrorDrain<T> {
339 type Error = io::Error;
340
341 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
342 Poll::Ready(Ok(()))
343 }
344
345 fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
346 Ok(())
347 }
348
349 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
350 Poll::Ready(Ok(()))
351 }
352
353 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
354 Poll::Ready(Ok(()))
355 }
356}
357
358async fn async_retry<T, E, F: Future<Output = Result<T, E>>>(
359 thunk: impl Fn() -> F,
360 count: usize,
361 delay: Duration,
362) -> Result<T, E> {
363 for _ in 1..count {
364 let result = thunk().await;
365 if result.is_ok() {
366 return result;
367 } else {
368 tokio::time::sleep(delay).await;
369 }
370 }
371
372 thunk().await
373}
374
375pub struct ConnectedDirect {
376 stream_sink: Option<DynStreamSink>,
377 source_only: Option<DynStream>,
378 sink_only: Option<DynSink<Bytes>>,
379}
380
381#[async_trait]
382impl Connected for ConnectedDirect {
383 async fn from_defn(pipe: Connection) -> Self {
384 match pipe {
385 Connection::AsClient(ClientConnection::UnixSocket(stream)) => {
386 #[cfg(unix)]
387 {
388 let stream = stream.await.unwrap().unwrap();
389 ConnectedDirect {
390 stream_sink: Some(Box::pin(unix_bytes(stream))),
391 source_only: None,
392 sink_only: None,
393 }
394 }
395
396 #[cfg(not(unix))]
397 {
398 drop(stream);
399 panic!("Unix sockets are not supported on this platform");
400 }
401 }
402 Connection::AsClient(ClientConnection::TcpPort(stream)) => {
403 let stream = stream.await.unwrap().unwrap();
404 stream.set_nodelay(true).unwrap();
405 ConnectedDirect {
406 stream_sink: Some(Box::pin(tcp_bytes(stream))),
407 source_only: None,
408 sink_only: None,
409 }
410 }
411 Connection::AsClient(ClientConnection::Merge(merge)) => {
412 let sources = futures::future::join_all(merge.into_iter().map(|port| async {
413 ConnectedDirect::from_defn(Connection::AsClient(port))
414 .await
415 .into_source()
416 }))
417 .await;
418
419 let merged = MergeSource {
420 marker: PhantomData,
421 sources,
422 };
423
424 ConnectedDirect {
425 stream_sink: None,
426 source_only: Some(Box::pin(merged)),
427 sink_only: None,
428 }
429 }
430 Connection::AsClient(ClientConnection::Demux(_)) => {
431 panic!("Cannot connect to a demux pipe directly")
432 }
433
434 Connection::AsClient(ClientConnection::Tagged(_, _)) => {
435 panic!("Cannot connect to a tagged pipe directly")
436 }
437
438 Connection::AsClient(ClientConnection::Null) => ConnectedDirect {
439 stream_sink: None,
440 source_only: Some(Box::pin(stream::empty())),
441 sink_only: Some(Box::pin(IoErrorDrain {
442 marker: PhantomData,
443 })),
444 },
445
446 Connection::AsServer(bound) => accept(bound).await,
447 }
448 }
449}
450
451impl ConnectedSource for ConnectedDirect {
452 type Output = BytesMut;
453 type Stream = DynStream;
454
455 fn into_source(mut self) -> DynStream {
456 if let Some(s) = self.stream_sink.take() {
457 Box::pin(s)
458 } else {
459 self.source_only.take().unwrap()
460 }
461 }
462}
463
464impl ConnectedSink for ConnectedDirect {
465 type Input = Bytes;
466 type Sink = DynSink<Bytes>;
467
468 fn into_sink(mut self) -> DynSink<Self::Input> {
469 if let Some(s) = self.stream_sink.take() {
470 Box::pin(s)
471 } else {
472 self.sink_only.take().unwrap()
473 }
474 }
475}
476
477pub type BufferedDrain<S, I> = DemuxDrain<I, Buffer<S, I>>;
478
479pub struct ConnectedDemux<T: ConnectedSink>
480where
481 <T as ConnectedSink>::Input: Sync,
482{
483 pub keys: Vec<u32>,
484 sink: Option<BufferedDrain<T::Sink, T::Input>>,
485}
486
487#[pin_project]
488pub struct DemuxDrain<T, S: Sink<T, Error = io::Error> + Send + Sync + ?Sized> {
489 marker: PhantomData<T>,
490 #[pin]
491 sinks: HashMap<u32, Pin<Box<S>>>,
492}
493
494impl<T, S: Sink<T, Error = io::Error> + Send + Sync> Sink<(u32, T)> for DemuxDrain<T, S> {
495 type Error = io::Error;
496
497 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
498 for sink in self.project().sinks.values_mut() {
499 ready!(Sink::poll_ready(sink.as_mut(), _cx))?;
500 }
501
502 Poll::Ready(Ok(()))
503 }
504
505 fn start_send(self: Pin<&mut Self>, item: (u32, T)) -> Result<(), Self::Error> {
506 Sink::start_send(
507 self.project()
508 .sinks
509 .get_mut()
510 .get_mut(&item.0)
511 .unwrap_or_else(|| panic!("No sink in this demux for key {}", item.0))
512 .as_mut(),
513 item.1,
514 )
515 }
516
517 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
518 for sink in self.project().sinks.values_mut() {
519 ready!(Sink::poll_flush(sink.as_mut(), _cx))?;
520 }
521
522 Poll::Ready(Ok(()))
523 }
524
525 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
526 for sink in self.project().sinks.values_mut() {
527 ready!(Sink::poll_close(sink.as_mut(), _cx))?;
528 }
529
530 Poll::Ready(Ok(()))
531 }
532}
533
534#[async_trait]
535impl<T: Connected + ConnectedSink> Connected for ConnectedDemux<T>
536where
537 <T as ConnectedSink>::Input: 'static + Sync,
538{
539 async fn from_defn(pipe: Connection) -> Self {
540 match pipe {
541 Connection::AsClient(ClientConnection::Demux(demux)) => {
542 let mut connected_demux = HashMap::new();
543 let keys = demux.keys().cloned().collect();
544 for (id, pipe) in demux {
545 connected_demux.insert(
546 id,
547 Box::pin(
548 T::from_defn(Connection::AsClient(pipe))
549 .await
550 .into_sink()
551 .buffer(1024),
552 ),
553 );
554 }
555
556 let demuxer = DemuxDrain {
557 marker: PhantomData,
558 sinks: connected_demux,
559 };
560
561 ConnectedDemux {
562 keys,
563 sink: Some(demuxer),
564 }
565 }
566
567 Connection::AsServer(BoundServer::Demux(demux)) => {
568 let mut connected_demux = HashMap::new();
569 let keys = demux.keys().cloned().collect();
570 for (id, bound) in demux {
571 connected_demux.insert(
572 id,
573 Box::pin(
574 T::from_defn(Connection::AsServer(bound))
575 .await
576 .into_sink()
577 .buffer(1024),
578 ),
579 );
580 }
581
582 let demuxer = DemuxDrain {
583 marker: PhantomData,
584 sinks: connected_demux,
585 };
586
587 ConnectedDemux {
588 keys,
589 sink: Some(demuxer),
590 }
591 }
592 _ => panic!("Cannot connect to a non-demux pipe as a demux"),
593 }
594 }
595}
596
597impl<T: ConnectedSink> ConnectedSink for ConnectedDemux<T>
598where
599 <T as ConnectedSink>::Input: 'static + Sync,
600{
601 type Input = (u32, T::Input);
602 type Sink = BufferedDrain<T::Sink, T::Input>;
603
604 fn into_sink(mut self) -> Self::Sink {
605 self.sink.take().unwrap()
606 }
607}
608
609pub struct MergeSource<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> {
610 marker: PhantomData<T>,
611 sources: Vec<Pin<Box<S>>>,
612}
613
614impl<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> Stream for MergeSource<T, S> {
615 type Item = T;
616
617 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
618 let sources = &mut self.get_mut().sources;
619 let mut next = None;
620
621 let mut i = 0;
622 while i < sources.len() {
623 match sources[i].as_mut().poll_next(cx) {
624 Poll::Ready(Some(v)) => {
625 next = Some(v);
626 break;
627 }
628 Poll::Ready(None) => {
629 sources.remove(i);
631 }
632 Poll::Pending => {
633 i += 1;
634 }
635 }
636 }
637
638 if sources.is_empty() {
639 Poll::Ready(None)
640 } else if next.is_none() {
641 Poll::Pending
642 } else {
643 Poll::Ready(next)
644 }
645 }
646}
647
648pub struct TaggedSource<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> {
649 marker: PhantomData<T>,
650 id: u32,
651 source: Pin<Box<S>>,
652}
653
654impl<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> Stream
655 for TaggedSource<T, S>
656{
657 type Item = Result<(u32, T), io::Error>;
658
659 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
660 let id = self.as_ref().id;
661 let source = &mut self.get_mut().source;
662 match source.as_mut().poll_next(cx) {
663 Poll::Ready(Some(v)) => Poll::Ready(Some(v.map(|d| (id, d)))),
664 Poll::Ready(None) => Poll::Ready(None),
665 Poll::Pending => Poll::Pending,
666 }
667 }
668}
669
670type MergedMux<T> = MergeSource<
671 Result<(u32, <T as ConnectedSource>::Output), io::Error>,
672 TaggedSource<<T as ConnectedSource>::Output, <T as ConnectedSource>::Stream>,
673>;
674
675pub struct ConnectedTagged<T: ConnectedSource>
676where
677 <T as ConnectedSource>::Output: 'static + Sync + Unpin,
678{
679 source: MergedMux<T>,
680}
681
682#[async_trait]
683impl<T: Connected + ConnectedSource> Connected for ConnectedTagged<T>
684where
685 <T as ConnectedSource>::Output: 'static + Sync + Unpin,
686{
687 async fn from_defn(pipe: Connection) -> Self {
688 let sources = match pipe {
689 Connection::AsClient(ClientConnection::Tagged(pipe, id)) => {
690 vec![(
691 Box::pin(
692 T::from_defn(Connection::AsClient(*pipe))
693 .await
694 .into_source(),
695 ),
696 id,
697 )]
698 }
699
700 Connection::AsClient(ClientConnection::Merge(m)) => {
701 let mut sources = Vec::new();
702 for port in m {
703 if let ClientConnection::Tagged(pipe, id) = port {
704 sources.push((
705 Box::pin(
706 T::from_defn(Connection::AsClient(*pipe))
707 .await
708 .into_source(),
709 ),
710 id,
711 ));
712 } else {
713 panic!("Merge port must be tagged");
714 }
715 }
716
717 sources
718 }
719
720 Connection::AsServer(BoundServer::Tagged(pipe, id)) => {
721 vec![(
722 Box::pin(
723 T::from_defn(Connection::AsServer(*pipe))
724 .await
725 .into_source(),
726 ),
727 id,
728 )]
729 }
730
731 Connection::AsServer(BoundServer::Merge(m)) => {
732 let mut sources = Vec::new();
733 for port in m {
734 if let BoundServer::Tagged(pipe, id) = port {
735 sources.push((
736 Box::pin(
737 T::from_defn(Connection::AsServer(*pipe))
738 .await
739 .into_source(),
740 ),
741 id,
742 ));
743 } else {
744 panic!("Merge port must be tagged");
745 }
746 }
747
748 sources
749 }
750
751 _ => panic!("Cannot connect to a non-tagged pipe as a tagged"),
752 };
753
754 let mut connected_mux = Vec::new();
755 for (pipe, id) in sources {
756 connected_mux.push(Box::pin(TaggedSource {
757 marker: PhantomData,
758 id,
759 source: pipe,
760 }));
761 }
762
763 let muxer = MergeSource {
764 marker: PhantomData,
765 sources: connected_mux,
766 };
767
768 ConnectedTagged { source: muxer }
769 }
770}
771
772impl<T: ConnectedSource> ConnectedSource for ConnectedTagged<T>
773where
774 <T as ConnectedSource>::Output: 'static + Sync + Unpin,
775{
776 type Output = (u32, T::Output);
777 type Stream = MergeSource<Result<Self::Output, io::Error>, TaggedSource<T::Output, T::Stream>>;
778
779 fn into_source(self) -> Self::Stream {
780 self.source
781 }
782}