hydro_deploy_integration/
lib.rs

1use std::cell::RefCell;
2use std::collections::HashMap;
3use std::marker::PhantomData;
4use std::net::SocketAddr;
5use std::path::PathBuf;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use async_recursion::async_recursion;
11use async_trait::async_trait;
12use bytes::{Bytes, BytesMut};
13use futures::sink::Buffer;
14use futures::stream::{FuturesUnordered, SplitSink, SplitStream};
15use futures::{Future, Sink, SinkExt, Stream, StreamExt, ready, stream};
16use pin_project_lite::pin_project;
17use serde::{Deserialize, Serialize};
18use tempfile::TempDir;
19use tokio::io;
20use tokio::net::{TcpListener, TcpStream};
21#[cfg(unix)]
22use tokio::net::{UnixListener, UnixStream};
23use tokio_stream::wrappers::TcpListenerStream;
24use tokio_util::codec::{Framed, LengthDelimitedCodec};
25
26pub mod multi_connection;
27
28pub type InitConfig = (HashMap<String, ServerBindConfig>, Option<String>);
29
30/// Contains runtime information passed by Hydro Deploy to a program,
31/// describing how to connect to other services and metadata about them.
32pub struct DeployPorts<T = Option<()>> {
33    pub ports: RefCell<HashMap<String, Connection>>,
34    pub meta: T,
35}
36
37impl<T> DeployPorts<T> {
38    pub fn port(&self, name: &str) -> Connection {
39        self.ports
40            .try_borrow_mut()
41            .unwrap()
42            .remove(name)
43            .unwrap_or_else(|| panic!("port {} not found", name))
44    }
45}
46
47#[cfg(not(unix))]
48type UnixStream = std::convert::Infallible;
49
50#[cfg(not(unix))]
51type UnixListener = std::convert::Infallible;
52
53/// Describes how to connect to a service which is listening on some port.
54#[derive(Serialize, Deserialize, Clone, Debug)]
55pub enum ServerPort {
56    UnixSocket(PathBuf),
57    TcpPort(SocketAddr),
58    Demux(HashMap<u32, ServerPort>),
59    Merge(Vec<ServerPort>),
60    Tagged(Box<ServerPort>, u32),
61    Null,
62}
63
64impl ServerPort {
65    #[async_recursion]
66    pub async fn connect(&self) -> ClientConnection {
67        match self {
68            ServerPort::UnixSocket(path) => {
69                #[cfg(unix)]
70                {
71                    let bound = UnixStream::connect(path.clone());
72                    ClientConnection::UnixSocket(bound.await.unwrap())
73                }
74
75                #[cfg(not(unix))]
76                {
77                    let _ = path;
78                    panic!("Unix sockets are not supported on this platform")
79                }
80            }
81            ServerPort::TcpPort(addr) => {
82                let addr_clone = *addr;
83                let stream = async_retry(
84                    move || TcpStream::connect(addr_clone),
85                    10,
86                    Duration::from_secs(1),
87                )
88                .await
89                .unwrap();
90                ClientConnection::TcpPort(stream)
91            }
92            ServerPort::Demux(bindings) => ClientConnection::Demux(
93                bindings
94                    .iter()
95                    .map(|(k, v)| async move { (*k, v.connect().await) })
96                    .collect::<FuturesUnordered<_>>()
97                    .collect::<Vec<_>>()
98                    .await
99                    .into_iter()
100                    .collect(),
101            ),
102            ServerPort::Merge(ports) => ClientConnection::Merge(
103                ports
104                    .iter()
105                    .map(|p| p.connect())
106                    .collect::<FuturesUnordered<_>>()
107                    .collect::<Vec<_>>()
108                    .await
109                    .into_iter()
110                    .collect(),
111            ),
112            ServerPort::Tagged(port, tag) => {
113                ClientConnection::Tagged(Box::new(port.as_ref().connect().await), *tag)
114            }
115            ServerPort::Null => ClientConnection::Null,
116        }
117    }
118
119    pub async fn instantiate(&self) -> Connection {
120        Connection::AsClient(self.connect().await)
121    }
122}
123
124#[derive(Debug)]
125pub enum ClientConnection {
126    UnixSocket(UnixStream),
127    TcpPort(TcpStream),
128    Demux(HashMap<u32, ClientConnection>),
129    Merge(Vec<ClientConnection>),
130    Tagged(Box<ClientConnection>, u32),
131    Null,
132}
133
134#[derive(Serialize, Deserialize, Clone, Debug)]
135pub enum ServerBindConfig {
136    UnixSocket,
137    TcpPort(
138        /// The host the port should be bound on.
139        String,
140        /// The port the service should listen on.
141        ///
142        /// If `None`, the port will be chosen automatically.
143        Option<u16>,
144    ),
145    Demux(HashMap<u32, ServerBindConfig>),
146    Merge(Vec<ServerBindConfig>),
147    Tagged(Box<ServerBindConfig>, u32),
148    MultiConnection(Box<ServerBindConfig>),
149    Null,
150}
151
152impl ServerBindConfig {
153    #[async_recursion]
154    pub async fn bind(self) -> BoundServer {
155        match self {
156            ServerBindConfig::UnixSocket => {
157                #[cfg(unix)]
158                {
159                    let dir = tempfile::tempdir().unwrap();
160                    let socket_path = dir.path().join("socket");
161                    let bound = UnixListener::bind(socket_path).unwrap();
162                    BoundServer::UnixSocket(bound, dir)
163                }
164
165                #[cfg(not(unix))]
166                {
167                    panic!("Unix sockets are not supported on this platform")
168                }
169            }
170            ServerBindConfig::TcpPort(host, port) => {
171                let listener = TcpListener::bind((host, port.unwrap_or(0))).await.unwrap();
172                let addr = listener.local_addr().unwrap();
173                BoundServer::TcpPort(TcpListenerStream::new(listener), addr)
174            }
175            ServerBindConfig::Demux(bindings) => {
176                let mut demux = HashMap::new();
177                for (key, bind) in bindings {
178                    demux.insert(key, bind.bind().await);
179                }
180                BoundServer::Demux(demux)
181            }
182            ServerBindConfig::Merge(bindings) => {
183                let mut merge = Vec::new();
184                for bind in bindings {
185                    merge.push(bind.bind().await);
186                }
187                BoundServer::Merge(merge)
188            }
189            ServerBindConfig::Tagged(underlying, id) => {
190                BoundServer::Tagged(Box::new(underlying.bind().await), id)
191            }
192            ServerBindConfig::MultiConnection(underlying) => {
193                BoundServer::MultiConnection(Box::new(underlying.bind().await))
194            }
195            ServerBindConfig::Null => BoundServer::Null,
196        }
197    }
198}
199
200#[derive(Debug)]
201pub enum Connection {
202    AsClient(ClientConnection),
203    AsServer(AcceptedServer),
204}
205
206impl Connection {
207    pub fn connect<T: Connected>(self) -> T {
208        T::from_defn(self)
209    }
210}
211
212pub type DynStream = Pin<Box<dyn Stream<Item = Result<BytesMut, io::Error>> + Send + Sync>>;
213
214pub type DynSink<Input> = Pin<Box<dyn Sink<Input, Error = io::Error> + Send + Sync>>;
215
216pub trait StreamSink:
217    Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>
218{
219}
220impl<T: Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>> StreamSink
221    for T
222{
223}
224
225pub type DynStreamSink = Pin<Box<dyn StreamSink + Send + Sync>>;
226
227pub trait Connected: Send {
228    fn from_defn(pipe: Connection) -> Self;
229}
230
231pub trait ConnectedSink {
232    type Input: Send;
233    type Sink: Sink<Self::Input, Error = io::Error> + Send + Sync;
234
235    fn into_sink(self) -> Self::Sink;
236}
237
238pub trait ConnectedSource {
239    type Output: Send;
240    type Stream: Stream<Item = Result<Self::Output, io::Error>> + Send + Sync;
241    fn into_source(self) -> Self::Stream;
242}
243
244#[derive(Debug)]
245pub enum BoundServer {
246    UnixSocket(UnixListener, TempDir),
247    TcpPort(TcpListenerStream, SocketAddr),
248    Demux(HashMap<u32, BoundServer>),
249    Merge(Vec<BoundServer>),
250    Tagged(Box<BoundServer>, u32),
251    MultiConnection(Box<BoundServer>),
252    Null,
253}
254
255#[derive(Debug)]
256pub enum AcceptedServer {
257    UnixSocket(UnixStream, TempDir),
258    TcpPort(TcpStream),
259    Demux(HashMap<u32, AcceptedServer>),
260    Merge(Vec<AcceptedServer>),
261    Tagged(Box<AcceptedServer>, u32),
262    MultiConnection(Box<BoundServer>),
263    Null,
264}
265
266#[async_recursion]
267pub async fn accept_bound(bound: BoundServer) -> AcceptedServer {
268    match bound {
269        BoundServer::UnixSocket(listener, dir) => {
270            #[cfg(unix)]
271            {
272                let stream = listener.accept().await.unwrap().0;
273                AcceptedServer::UnixSocket(stream, dir)
274            }
275
276            #[cfg(not(unix))]
277            {
278                let _ = listener;
279                let _ = dir;
280                panic!("Unix sockets are not supported on this platform")
281            }
282        }
283        BoundServer::TcpPort(mut listener, _) => {
284            let stream = listener.next().await.unwrap().unwrap();
285            AcceptedServer::TcpPort(stream)
286        }
287        BoundServer::Demux(bindings) => AcceptedServer::Demux(
288            bindings
289                .into_iter()
290                .map(|(k, b)| async move { (k, accept_bound(b).await) })
291                .collect::<FuturesUnordered<_>>()
292                .collect::<Vec<_>>()
293                .await
294                .into_iter()
295                .collect(),
296        ),
297        BoundServer::Merge(merge) => AcceptedServer::Merge(
298            merge
299                .into_iter()
300                .map(|b| async move { accept_bound(b).await })
301                .collect::<FuturesUnordered<_>>()
302                .collect::<Vec<_>>()
303                .await,
304        ),
305        BoundServer::Tagged(underlying, id) => {
306            AcceptedServer::Tagged(Box::new(accept_bound(*underlying).await), id)
307        }
308        BoundServer::MultiConnection(underlying) => AcceptedServer::MultiConnection(underlying),
309        BoundServer::Null => AcceptedServer::Null,
310    }
311}
312
313impl BoundServer {
314    pub fn server_port(&self) -> ServerPort {
315        match self {
316            BoundServer::UnixSocket(_, tempdir) => {
317                #[cfg(unix)]
318                {
319                    ServerPort::UnixSocket(tempdir.path().join("socket"))
320                }
321
322                #[cfg(not(unix))]
323                {
324                    let _ = tempdir;
325                    panic!("Unix sockets are not supported on this platform")
326                }
327            }
328            BoundServer::TcpPort(_, addr) => {
329                ServerPort::TcpPort(SocketAddr::new(addr.ip(), addr.port()))
330            }
331
332            BoundServer::Demux(bindings) => {
333                let mut demux = HashMap::new();
334                for (key, bind) in bindings {
335                    demux.insert(*key, bind.server_port());
336                }
337                ServerPort::Demux(demux)
338            }
339
340            BoundServer::Merge(bindings) => {
341                let mut merge = Vec::new();
342                for bind in bindings {
343                    merge.push(bind.server_port());
344                }
345                ServerPort::Merge(merge)
346            }
347
348            BoundServer::Tagged(underlying, id) => {
349                ServerPort::Tagged(Box::new(underlying.server_port()), *id)
350            }
351
352            BoundServer::MultiConnection(underlying) => underlying.server_port(),
353
354            BoundServer::Null => ServerPort::Null,
355        }
356    }
357}
358
359fn accept(bound: AcceptedServer) -> ConnectedDirect {
360    match bound {
361        AcceptedServer::UnixSocket(stream, _dir) => {
362            #[cfg(unix)]
363            {
364                ConnectedDirect {
365                    stream_sink: Some(Box::pin(unix_bytes(stream))),
366                    source_only: None,
367                    sink_only: None,
368                }
369            }
370
371            #[cfg(not(unix))]
372            {
373                let _ = stream;
374                panic!("Unix sockets are not supported on this platform")
375            }
376        }
377        AcceptedServer::TcpPort(stream) => ConnectedDirect {
378            stream_sink: Some(Box::pin(tcp_bytes(stream))),
379            source_only: None,
380            sink_only: None,
381        },
382        AcceptedServer::Merge(merge) => {
383            let mut sources = vec![];
384            for bound in merge {
385                sources.push(Some(Box::pin(accept(bound).into_source())));
386            }
387
388            let merge_source: DynStream = Box::pin(MergeSource {
389                marker: PhantomData,
390                sources,
391                poll_cursor: 0,
392            });
393
394            ConnectedDirect {
395                stream_sink: None,
396                source_only: Some(merge_source),
397                sink_only: None,
398            }
399        }
400        AcceptedServer::Demux(_) => panic!("Cannot connect to a demux pipe directly"),
401        AcceptedServer::Tagged(_, _) => panic!("Cannot connect to a tagged pipe directly"),
402        AcceptedServer::MultiConnection(_) => {
403            panic!("Cannot connect to a multi-connection pipe directly")
404        }
405        AcceptedServer::Null => {
406            ConnectedDirect::from_defn(Connection::AsClient(ClientConnection::Null))
407        }
408    }
409}
410
411fn tcp_bytes(stream: TcpStream) -> impl StreamSink {
412    Framed::new(stream, LengthDelimitedCodec::new())
413}
414
415#[cfg(unix)]
416fn unix_bytes(stream: UnixStream) -> impl StreamSink {
417    Framed::new(stream, LengthDelimitedCodec::new())
418}
419
420struct IoErrorDrain<T> {
421    marker: PhantomData<T>,
422}
423
424impl<T> Sink<T> for IoErrorDrain<T> {
425    type Error = io::Error;
426
427    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
428        Poll::Ready(Ok(()))
429    }
430
431    fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
432        Ok(())
433    }
434
435    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
436        Poll::Ready(Ok(()))
437    }
438
439    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
440        Poll::Ready(Ok(()))
441    }
442}
443
444async fn async_retry<T, E, F: Future<Output = Result<T, E>>>(
445    thunk: impl Fn() -> F,
446    count: usize,
447    delay: Duration,
448) -> Result<T, E> {
449    for _ in 1..count {
450        let result = thunk().await;
451        if result.is_ok() {
452            return result;
453        } else {
454            tokio::time::sleep(delay).await;
455        }
456    }
457
458    thunk().await
459}
460
461pub struct ConnectedDirect {
462    stream_sink: Option<DynStreamSink>,
463    source_only: Option<DynStream>,
464    sink_only: Option<DynSink<Bytes>>,
465}
466
467impl ConnectedDirect {
468    pub fn into_source_sink(self) -> (SplitStream<DynStreamSink>, SplitSink<DynStreamSink, Bytes>) {
469        let (sink, stream) = self.stream_sink.unwrap().split();
470        (stream, sink)
471    }
472}
473
474impl Connected for ConnectedDirect {
475    fn from_defn(pipe: Connection) -> Self {
476        match pipe {
477            Connection::AsClient(ClientConnection::UnixSocket(stream)) => {
478                #[cfg(unix)]
479                {
480                    ConnectedDirect {
481                        stream_sink: Some(Box::pin(unix_bytes(stream))),
482                        source_only: None,
483                        sink_only: None,
484                    }
485                }
486
487                #[cfg(not(unix))]
488                {
489                    let _ = stream;
490                    panic!("Unix sockets are not supported on this platform");
491                }
492            }
493            Connection::AsClient(ClientConnection::TcpPort(stream)) => {
494                stream.set_nodelay(true).unwrap();
495                ConnectedDirect {
496                    stream_sink: Some(Box::pin(tcp_bytes(stream))),
497                    source_only: None,
498                    sink_only: None,
499                }
500            }
501            Connection::AsClient(ClientConnection::Merge(merge)) => {
502                let sources = merge
503                    .into_iter()
504                    .map(|port| {
505                        Some(Box::pin(
506                            ConnectedDirect::from_defn(Connection::AsClient(port)).into_source(),
507                        ))
508                    })
509                    .collect::<Vec<_>>();
510
511                let merged = MergeSource {
512                    marker: PhantomData,
513                    sources,
514                    poll_cursor: 0,
515                };
516
517                ConnectedDirect {
518                    stream_sink: None,
519                    source_only: Some(Box::pin(merged)),
520                    sink_only: None,
521                }
522            }
523            Connection::AsClient(ClientConnection::Demux(_)) => {
524                panic!("Cannot connect to a demux pipe directly")
525            }
526
527            Connection::AsClient(ClientConnection::Tagged(_, _)) => {
528                panic!("Cannot connect to a tagged pipe directly")
529            }
530
531            Connection::AsClient(ClientConnection::Null) => ConnectedDirect {
532                stream_sink: None,
533                source_only: Some(Box::pin(stream::empty())),
534                sink_only: Some(Box::pin(IoErrorDrain {
535                    marker: PhantomData,
536                })),
537            },
538
539            Connection::AsServer(bound) => accept(bound),
540        }
541    }
542}
543
544impl ConnectedSource for ConnectedDirect {
545    type Output = BytesMut;
546    type Stream = DynStream;
547
548    fn into_source(mut self) -> DynStream {
549        if let Some(s) = self.stream_sink.take() {
550            Box::pin(s)
551        } else {
552            self.source_only.take().unwrap()
553        }
554    }
555}
556
557impl ConnectedSink for ConnectedDirect {
558    type Input = Bytes;
559    type Sink = DynSink<Bytes>;
560
561    fn into_sink(mut self) -> DynSink<Self::Input> {
562        if let Some(s) = self.stream_sink.take() {
563            Box::pin(s)
564        } else {
565            self.sink_only.take().unwrap()
566        }
567    }
568}
569
570pub type BufferedDrain<S, I> = DemuxDrain<I, Buffer<S, I>>;
571
572pub struct ConnectedDemux<T: ConnectedSink>
573where
574    <T as ConnectedSink>::Input: Sync,
575{
576    pub keys: Vec<u32>,
577    sink: Option<BufferedDrain<T::Sink, T::Input>>,
578}
579
580trait DrainableSink<T>: Sink<T, Error = io::Error> + Send + Sync {}
581
582impl<T, S: Sink<T, Error = io::Error> + Send + Sync + ?Sized> DrainableSink<T> for S {}
583
584pin_project! {
585    pub struct DemuxDrain<T, S: DrainableSink<T>> {
586        marker: PhantomData<T>,
587        #[pin]
588        sinks: HashMap<u32, Pin<Box<S>>>,
589    }
590}
591
592impl<T, S: Sink<T, Error = io::Error> + Send + Sync> Sink<(u32, T)> for DemuxDrain<T, S> {
593    type Error = io::Error;
594
595    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
596        for sink in self.project().sinks.values_mut() {
597            ready!(Sink::poll_ready(sink.as_mut(), _cx))?;
598        }
599
600        Poll::Ready(Ok(()))
601    }
602
603    fn start_send(self: Pin<&mut Self>, item: (u32, T)) -> Result<(), Self::Error> {
604        Sink::start_send(
605            self.project()
606                .sinks
607                .get_mut()
608                .get_mut(&item.0)
609                .unwrap_or_else(|| panic!("No sink in this demux for key {}", item.0))
610                .as_mut(),
611            item.1,
612        )
613    }
614
615    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
616        for sink in self.project().sinks.values_mut() {
617            ready!(Sink::poll_flush(sink.as_mut(), _cx))?;
618        }
619
620        Poll::Ready(Ok(()))
621    }
622
623    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
624        for sink in self.project().sinks.values_mut() {
625            ready!(Sink::poll_close(sink.as_mut(), _cx))?;
626        }
627
628        Poll::Ready(Ok(()))
629    }
630}
631
632#[async_trait]
633impl<T: Connected + ConnectedSink> Connected for ConnectedDemux<T>
634where
635    <T as ConnectedSink>::Input: 'static + Sync,
636{
637    fn from_defn(pipe: Connection) -> Self {
638        match pipe {
639            Connection::AsClient(ClientConnection::Demux(demux)) => {
640                let mut connected_demux = HashMap::new();
641                let keys = demux.keys().cloned().collect();
642                for (id, pipe) in demux {
643                    connected_demux.insert(
644                        id,
645                        Box::pin(
646                            T::from_defn(Connection::AsClient(pipe))
647                                .into_sink()
648                                .buffer(1024),
649                        ),
650                    );
651                }
652
653                let demuxer = DemuxDrain {
654                    marker: PhantomData,
655                    sinks: connected_demux,
656                };
657
658                ConnectedDemux {
659                    keys,
660                    sink: Some(demuxer),
661                }
662            }
663
664            Connection::AsServer(AcceptedServer::Demux(demux)) => {
665                let mut connected_demux = HashMap::new();
666                let keys = demux.keys().cloned().collect();
667                for (id, bound) in demux {
668                    connected_demux.insert(
669                        id,
670                        Box::pin(
671                            T::from_defn(Connection::AsServer(bound))
672                                .into_sink()
673                                .buffer(1024),
674                        ),
675                    );
676                }
677
678                let demuxer = DemuxDrain {
679                    marker: PhantomData,
680                    sinks: connected_demux,
681                };
682
683                ConnectedDemux {
684                    keys,
685                    sink: Some(demuxer),
686                }
687            }
688            _ => panic!("Cannot connect to a non-demux pipe as a demux"),
689        }
690    }
691}
692
693impl<T: ConnectedSink> ConnectedSink for ConnectedDemux<T>
694where
695    <T as ConnectedSink>::Input: 'static + Sync,
696{
697    type Input = (u32, T::Input);
698    type Sink = BufferedDrain<T::Sink, T::Input>;
699
700    fn into_sink(mut self) -> Self::Sink {
701        self.sink.take().unwrap()
702    }
703}
704
705pub struct MergeSource<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> {
706    marker: PhantomData<T>,
707    /// Ordered list for fair polling, will never be `None` at the beginning of a poll
708    sources: Vec<Option<Pin<Box<S>>>>,
709    /// Cursor for fair round-robin polling
710    poll_cursor: usize,
711}
712
713impl<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> Stream for MergeSource<T, S> {
714    type Item = T;
715
716    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
717        let me = self.get_mut();
718        let mut out = Poll::Pending;
719        let mut any_removed = false;
720
721        if !me.sources.is_empty() {
722            let start_cursor = me.poll_cursor;
723
724            loop {
725                let current_length = me.sources.len();
726                let source = &mut me.sources[me.poll_cursor];
727
728                // Move cursor to next source for next poll
729                me.poll_cursor = (me.poll_cursor + 1) % current_length;
730
731                match source.as_mut().unwrap().as_mut().poll_next(cx) {
732                    Poll::Ready(Some(data)) => {
733                        out = Poll::Ready(Some(data));
734                        break;
735                    }
736                    Poll::Ready(None) => {
737                        *source = None; // Mark source as removed
738                        any_removed = true;
739                    }
740                    Poll::Pending => {}
741                }
742
743                // Check if we've completed a full round
744                if me.poll_cursor == start_cursor {
745                    break;
746                }
747            }
748        }
749
750        // Clean up None entries and adjust cursor
751        let mut current_index = 0;
752        let original_cursor = me.poll_cursor;
753
754        if any_removed {
755            me.sources.retain(|source| {
756                if source.is_none() && current_index < original_cursor {
757                    me.poll_cursor -= 1;
758                }
759                current_index += 1;
760                source.is_some()
761            });
762        }
763
764        if me.poll_cursor == me.sources.len() {
765            me.poll_cursor = 0;
766        }
767
768        if me.sources.is_empty() {
769            Poll::Ready(None)
770        } else {
771            out
772        }
773    }
774}
775
776pub struct TaggedSource<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> {
777    marker: PhantomData<T>,
778    id: u32,
779    source: Pin<Box<S>>,
780}
781
782impl<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> Stream
783    for TaggedSource<T, S>
784{
785    type Item = Result<(u32, T), io::Error>;
786
787    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
788        let id = self.as_ref().id;
789        let source = &mut self.get_mut().source;
790        match source.as_mut().poll_next(cx) {
791            Poll::Ready(Some(v)) => Poll::Ready(Some(v.map(|d| (id, d)))),
792            Poll::Ready(None) => Poll::Ready(None),
793            Poll::Pending => Poll::Pending,
794        }
795    }
796}
797
798type MergedMux<T> = MergeSource<
799    Result<(u32, <T as ConnectedSource>::Output), io::Error>,
800    TaggedSource<<T as ConnectedSource>::Output, <T as ConnectedSource>::Stream>,
801>;
802
803pub struct ConnectedTagged<T: ConnectedSource>
804where
805    <T as ConnectedSource>::Output: 'static + Sync + Unpin,
806{
807    source: MergedMux<T>,
808}
809
810#[async_trait]
811impl<T: Connected + ConnectedSource> Connected for ConnectedTagged<T>
812where
813    <T as ConnectedSource>::Output: 'static + Sync + Unpin,
814{
815    fn from_defn(pipe: Connection) -> Self {
816        let sources = match pipe {
817            Connection::AsClient(ClientConnection::Tagged(pipe, id)) => {
818                vec![(
819                    Box::pin(T::from_defn(Connection::AsClient(*pipe)).into_source()),
820                    id,
821                )]
822            }
823
824            Connection::AsClient(ClientConnection::Merge(m)) => {
825                let mut sources = Vec::new();
826                for port in m {
827                    if let ClientConnection::Tagged(pipe, id) = port {
828                        sources.push((
829                            Box::pin(T::from_defn(Connection::AsClient(*pipe)).into_source()),
830                            id,
831                        ));
832                    } else {
833                        panic!("Merge port must be tagged");
834                    }
835                }
836
837                sources
838            }
839
840            Connection::AsServer(AcceptedServer::Tagged(pipe, id)) => {
841                vec![(
842                    Box::pin(T::from_defn(Connection::AsServer(*pipe)).into_source()),
843                    id,
844                )]
845            }
846
847            Connection::AsServer(AcceptedServer::Merge(m)) => {
848                let mut sources = Vec::new();
849                for port in m {
850                    if let AcceptedServer::Tagged(pipe, id) = port {
851                        sources.push((
852                            Box::pin(T::from_defn(Connection::AsServer(*pipe)).into_source()),
853                            id,
854                        ));
855                    } else {
856                        panic!("Merge port must be tagged");
857                    }
858                }
859
860                sources
861            }
862
863            _ => panic!("Cannot connect to a non-tagged pipe as a tagged"),
864        };
865
866        let mut connected_mux = Vec::new();
867        for (pipe, id) in sources {
868            connected_mux.push(Some(Box::pin(TaggedSource {
869                marker: PhantomData,
870                id,
871                source: pipe,
872            })));
873        }
874
875        let muxer = MergeSource {
876            marker: PhantomData,
877            sources: connected_mux,
878            poll_cursor: 0,
879        };
880
881        ConnectedTagged { source: muxer }
882    }
883}
884
885impl<T: ConnectedSource> ConnectedSource for ConnectedTagged<T>
886where
887    <T as ConnectedSource>::Output: 'static + Sync + Unpin,
888{
889    type Output = (u32, T::Output);
890    type Stream = MergeSource<Result<Self::Output, io::Error>, TaggedSource<T::Output, T::Stream>>;
891
892    fn into_source(self) -> Self::Stream {
893        self.source
894    }
895}
896
897#[cfg(test)]
898mod tests {
899    use std::sync::Arc;
900    use std::task::{Context, Poll};
901
902    use futures::stream;
903
904    use super::*;
905
906    struct TestWaker;
907    impl std::task::Wake for TestWaker {
908        fn wake(self: Arc<Self>) {}
909    }
910
911    #[test]
912    fn test_merge_source_fair_polling() {
913        // Create test streams that yield values in a predictable pattern
914        let stream1 = Box::pin(stream::iter(vec![1, 4, 7]));
915        let stream2 = Box::pin(stream::iter(vec![2, 5, 8]));
916        let stream3 = Box::pin(stream::iter(vec![3, 6, 9]));
917
918        let mut merge_source = MergeSource {
919            marker: PhantomData,
920            sources: vec![Some(stream1), Some(stream2), Some(stream3)],
921            poll_cursor: 0,
922        };
923
924        let waker = Arc::new(TestWaker).into();
925        let mut cx = Context::from_waker(&waker);
926
927        let mut results = Vec::new();
928
929        // Poll until all streams are exhausted
930        loop {
931            match Pin::new(&mut merge_source).poll_next(&mut cx) {
932                Poll::Ready(Some(value)) => results.push(value),
933                Poll::Ready(None) => break,
934                Poll::Pending => break, // Shouldn't happen with our test streams
935            }
936        }
937
938        // With fair polling, we should get values in round-robin order: 1, 2, 3, 4, 5, 6, 7, 8, 9
939        assert_eq!(results, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
940    }
941}