Skip to main content

hydro_deploy_integration/
lib.rs

1use std::borrow::Cow;
2use std::cell::RefCell;
3use std::collections::{BTreeMap, HashMap};
4use std::marker::PhantomData;
5use std::net::SocketAddr;
6use std::path::PathBuf;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11use async_recursion::async_recursion;
12use bytes::{Bytes, BytesMut};
13use futures::sink::Buffer;
14use futures::stream::{FuturesUnordered, SplitSink, SplitStream};
15use futures::{Future, Sink, SinkExt, Stream, StreamExt, stream};
16use serde::{Deserialize, Serialize};
17use tempfile::TempDir;
18use tokio::io;
19use tokio::net::{TcpListener, TcpStream};
20#[cfg(unix)]
21use tokio::net::{UnixListener, UnixStream};
22use tokio_stream::wrappers::TcpListenerStream;
23use tokio_util::codec::{Framed, LengthDelimitedCodec};
24
25pub mod multi_connection;
26pub mod single_connection;
27
28pub type InitConfig<'a> = (HashMap<String, ServerBindConfig>, Option<Cow<'a, str>>);
29
30/// Contains runtime information passed by Hydro Deploy to a program,
31/// describing how to connect to other services and metadata about them.
32pub struct DeployPorts<T = Option<()>> {
33    pub ports: RefCell<HashMap<String, Connection>>,
34    pub meta: T,
35}
36
37impl<T> DeployPorts<T> {
38    pub fn port(&self, name: &str) -> Connection {
39        self.ports
40            .try_borrow_mut()
41            .unwrap()
42            .remove(name)
43            .unwrap_or_else(|| panic!("port {} not found", name))
44    }
45}
46
47#[cfg(not(unix))]
48type UnixStream = std::convert::Infallible;
49
50#[cfg(not(unix))]
51type UnixListener = std::convert::Infallible;
52
53/// Describes how to connect to a service which is listening on some port.
54#[derive(Serialize, Deserialize, Clone, Debug)]
55pub enum ServerPort {
56    UnixSocket(PathBuf),
57    TcpPort(SocketAddr),
58    Demux(BTreeMap<u32, ServerPort>),
59    Merge(Vec<ServerPort>),
60    Tagged(Box<ServerPort>, u32),
61    Null,
62}
63
64impl ServerPort {
65    #[async_recursion]
66    pub async fn connect(&self) -> ClientConnection {
67        match self {
68            ServerPort::UnixSocket(path) => {
69                #[cfg(unix)]
70                {
71                    let bound = UnixStream::connect(path.clone());
72                    ClientConnection::UnixSocket(bound.await.unwrap())
73                }
74
75                #[cfg(not(unix))]
76                {
77                    let _ = path;
78                    panic!("Unix sockets are not supported on this platform")
79                }
80            }
81            ServerPort::TcpPort(addr) => {
82                let addr_clone = *addr;
83                let stream = async_retry(
84                    move || TcpStream::connect(addr_clone),
85                    10,
86                    Duration::from_secs(1),
87                )
88                .await
89                .unwrap();
90                ClientConnection::TcpPort(stream)
91            }
92            ServerPort::Demux(bindings) => ClientConnection::Demux(
93                bindings
94                    .iter()
95                    .map(|(k, v)| async move { (*k, v.connect().await) })
96                    .collect::<FuturesUnordered<_>>()
97                    .collect::<BTreeMap<_, _>>()
98                    .await,
99            ),
100            ServerPort::Merge(ports) => ClientConnection::Merge(
101                ports
102                    .iter()
103                    .map(|p| p.connect())
104                    .collect::<FuturesUnordered<_>>()
105                    .collect::<Vec<_>>()
106                    .await,
107            ),
108            ServerPort::Tagged(port, tag) => {
109                ClientConnection::Tagged(Box::new(port.as_ref().connect().await), *tag)
110            }
111            ServerPort::Null => ClientConnection::Null,
112        }
113    }
114
115    pub async fn instantiate(&self) -> Connection {
116        Connection::AsClient(self.connect().await)
117    }
118}
119
120#[derive(Debug)]
121pub enum ClientConnection {
122    UnixSocket(UnixStream),
123    TcpPort(TcpStream),
124    Demux(BTreeMap<u32, ClientConnection>),
125    Merge(Vec<ClientConnection>),
126    Tagged(Box<ClientConnection>, u32),
127    Null,
128}
129
130#[derive(Serialize, Deserialize, Clone, Debug)]
131pub enum ServerBindConfig {
132    UnixSocket,
133    TcpPort(
134        /// The host the port should be bound on.
135        String,
136        /// The port the service should listen on.
137        ///
138        /// If `None`, the port will be chosen automatically.
139        Option<u16>,
140    ),
141    Demux(BTreeMap<u32, ServerBindConfig>),
142    Merge(Vec<ServerBindConfig>),
143    Tagged(Box<ServerBindConfig>, u32),
144    MultiConnection(Box<ServerBindConfig>),
145    Null,
146}
147
148impl ServerBindConfig {
149    #[async_recursion]
150    pub async fn bind(self) -> BoundServer {
151        match self {
152            ServerBindConfig::UnixSocket => {
153                #[cfg(unix)]
154                {
155                    let dir = tempfile::tempdir().unwrap();
156                    let socket_path = dir.path().join("socket");
157                    let bound = UnixListener::bind(socket_path).unwrap();
158                    BoundServer::UnixSocket(bound, dir)
159                }
160
161                #[cfg(not(unix))]
162                {
163                    panic!("Unix sockets are not supported on this platform")
164                }
165            }
166            ServerBindConfig::TcpPort(host, port) => {
167                let listener = TcpListener::bind((host, port.unwrap_or(0)))
168                    .await
169                    .unwrap_or_else(|e| panic!("Failed to bind port {:?}: {}", port, e));
170                let addr = listener.local_addr().unwrap();
171                BoundServer::TcpPort(TcpListenerStream::new(listener), addr)
172            }
173            ServerBindConfig::Demux(bindings) => {
174                let mut demux = BTreeMap::new();
175                for (key, bind) in bindings {
176                    demux.insert(key, bind.bind().await); // TODO(mingwei): Do in parallel.
177                }
178                BoundServer::Demux(demux)
179            }
180            ServerBindConfig::Merge(bindings) => {
181                let mut merge = Vec::new();
182                for bind in bindings {
183                    merge.push(bind.bind().await); // TODO(mingwei): Do in parallel.
184                }
185                BoundServer::Merge(merge)
186            }
187            ServerBindConfig::Tagged(underlying, id) => {
188                BoundServer::Tagged(Box::new(underlying.bind().await), id)
189            }
190            ServerBindConfig::MultiConnection(underlying) => {
191                BoundServer::MultiConnection(Box::new(underlying.bind().await))
192            }
193            ServerBindConfig::Null => BoundServer::Null,
194        }
195    }
196}
197
198#[derive(Debug)]
199pub enum Connection {
200    AsClient(ClientConnection),
201    AsServer(AcceptedServer),
202}
203
204impl Connection {
205    pub fn connect<T: Connected>(self) -> T {
206        T::from_defn(self)
207    }
208}
209
210pub type DynStream = Pin<Box<dyn Stream<Item = Result<BytesMut, io::Error>> + Send + Sync>>;
211
212pub type DynSink<Input> = Pin<Box<dyn Sink<Input, Error = io::Error> + Send + Sync>>;
213
214pub trait StreamSink:
215    Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>
216{
217}
218impl<T: Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>> StreamSink
219    for T
220{
221}
222
223pub type DynStreamSink = Pin<Box<dyn StreamSink + Send + Sync>>;
224
225pub trait Connected: Send {
226    fn from_defn(pipe: Connection) -> Self;
227}
228
229pub trait ConnectedSink {
230    type Input: Send;
231    type Sink: Sink<Self::Input, Error = io::Error> + Send + Sync;
232
233    fn into_sink(self) -> Self::Sink;
234}
235
236pub trait ConnectedSource {
237    type Output: Send;
238    type Stream: Stream<Item = Result<Self::Output, io::Error>> + Send + Sync;
239    fn into_source(self) -> Self::Stream;
240}
241
242#[derive(Debug)]
243pub enum BoundServer {
244    UnixSocket(UnixListener, TempDir),
245    TcpPort(TcpListenerStream, SocketAddr),
246    Demux(BTreeMap<u32, BoundServer>),
247    Merge(Vec<BoundServer>),
248    Tagged(Box<BoundServer>, u32),
249    MultiConnection(Box<BoundServer>),
250    Null,
251}
252
253#[derive(Debug)]
254pub enum AcceptedServer {
255    UnixSocket(UnixStream, TempDir),
256    TcpPort(TcpStream),
257    Demux(BTreeMap<u32, AcceptedServer>),
258    Merge(Vec<AcceptedServer>),
259    Tagged(Box<AcceptedServer>, u32),
260    MultiConnection(Box<BoundServer>),
261    Null,
262}
263
264#[async_recursion]
265pub async fn accept_bound(bound: BoundServer) -> AcceptedServer {
266    match bound {
267        BoundServer::UnixSocket(listener, dir) => {
268            #[cfg(unix)]
269            {
270                let stream = listener.accept().await.unwrap().0;
271                AcceptedServer::UnixSocket(stream, dir)
272            }
273
274            #[cfg(not(unix))]
275            {
276                let _ = listener;
277                let _ = dir;
278                panic!("Unix sockets are not supported on this platform")
279            }
280        }
281        BoundServer::TcpPort(mut listener, _) => {
282            let stream = listener.next().await.unwrap().unwrap();
283            AcceptedServer::TcpPort(stream)
284        }
285        BoundServer::Demux(bindings) => AcceptedServer::Demux(
286            bindings
287                .into_iter()
288                .map(|(k, b)| async move { (k, accept_bound(b).await) })
289                .collect::<FuturesUnordered<_>>()
290                .collect::<Vec<_>>()
291                .await
292                .into_iter()
293                .collect(),
294        ),
295        BoundServer::Merge(merge) => AcceptedServer::Merge(
296            merge
297                .into_iter()
298                .map(|b| async move { accept_bound(b).await })
299                .collect::<FuturesUnordered<_>>()
300                .collect::<Vec<_>>()
301                .await,
302        ),
303        BoundServer::Tagged(underlying, id) => {
304            AcceptedServer::Tagged(Box::new(accept_bound(*underlying).await), id)
305        }
306        BoundServer::MultiConnection(underlying) => AcceptedServer::MultiConnection(underlying),
307        BoundServer::Null => AcceptedServer::Null,
308    }
309}
310
311impl BoundServer {
312    pub fn server_port(&self) -> ServerPort {
313        match self {
314            BoundServer::UnixSocket(_, tempdir) => {
315                #[cfg(unix)]
316                {
317                    ServerPort::UnixSocket(tempdir.path().join("socket"))
318                }
319
320                #[cfg(not(unix))]
321                {
322                    let _ = tempdir;
323                    panic!("Unix sockets are not supported on this platform")
324                }
325            }
326            BoundServer::TcpPort(_, addr) => {
327                ServerPort::TcpPort(SocketAddr::new(addr.ip(), addr.port()))
328            }
329
330            BoundServer::Demux(bindings) => {
331                let mut demux = BTreeMap::new();
332                for (key, bind) in bindings {
333                    demux.insert(*key, bind.server_port());
334                }
335                ServerPort::Demux(demux)
336            }
337
338            BoundServer::Merge(bindings) => {
339                let mut merge = Vec::new();
340                for bind in bindings {
341                    merge.push(bind.server_port());
342                }
343                ServerPort::Merge(merge)
344            }
345
346            BoundServer::Tagged(underlying, id) => {
347                ServerPort::Tagged(Box::new(underlying.server_port()), *id)
348            }
349
350            BoundServer::MultiConnection(underlying) => underlying.server_port(),
351
352            BoundServer::Null => ServerPort::Null,
353        }
354    }
355}
356
357fn accept(bound: AcceptedServer) -> ConnectedDirect {
358    match bound {
359        AcceptedServer::UnixSocket(stream, _dir) => {
360            #[cfg(unix)]
361            {
362                ConnectedDirect {
363                    stream_sink: Some(Box::pin(unix_bytes(stream))),
364                    source_only: None,
365                    sink_only: None,
366                }
367            }
368
369            #[cfg(not(unix))]
370            {
371                let _ = stream;
372                panic!("Unix sockets are not supported on this platform")
373            }
374        }
375        AcceptedServer::TcpPort(stream) => ConnectedDirect {
376            stream_sink: Some(Box::pin(tcp_bytes(stream))),
377            source_only: None,
378            sink_only: None,
379        },
380        AcceptedServer::Merge(merge) => {
381            let mut sources = vec![];
382            for bound in merge {
383                sources.push(Some(Box::pin(accept(bound).into_source())));
384            }
385
386            let merge_source: DynStream = Box::pin(MergeSource {
387                marker: PhantomData,
388                sources,
389                poll_cursor: 0,
390            });
391
392            ConnectedDirect {
393                stream_sink: None,
394                source_only: Some(merge_source),
395                sink_only: None,
396            }
397        }
398        AcceptedServer::Demux(_) => panic!("Cannot connect to a demux pipe directly"),
399        AcceptedServer::Tagged(_, _) => panic!("Cannot connect to a tagged pipe directly"),
400        AcceptedServer::MultiConnection(_) => {
401            panic!("Cannot connect to a multi-connection pipe directly")
402        }
403        AcceptedServer::Null => {
404            ConnectedDirect::from_defn(Connection::AsClient(ClientConnection::Null))
405        }
406    }
407}
408
409fn tcp_bytes(stream: TcpStream) -> impl StreamSink {
410    Framed::new(stream, LengthDelimitedCodec::new())
411}
412
413#[cfg(unix)]
414fn unix_bytes(stream: UnixStream) -> impl StreamSink {
415    Framed::new(stream, LengthDelimitedCodec::new())
416}
417
418struct IoErrorDrain<T> {
419    marker: PhantomData<T>,
420}
421
422impl<T> Sink<T> for IoErrorDrain<T> {
423    type Error = io::Error;
424
425    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
426        Poll::Ready(Ok(()))
427    }
428
429    fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
430        Ok(())
431    }
432
433    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
434        Poll::Ready(Ok(()))
435    }
436
437    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
438        Poll::Ready(Ok(()))
439    }
440}
441
442async fn async_retry<T, E, F: Future<Output = Result<T, E>>>(
443    thunk: impl Fn() -> F,
444    count: usize,
445    delay: Duration,
446) -> Result<T, E> {
447    for _ in 1..count {
448        let result = thunk().await;
449        if result.is_ok() {
450            return result;
451        } else {
452            tokio::time::sleep(delay).await;
453        }
454    }
455
456    thunk().await
457}
458
459pub struct ConnectedDirect {
460    stream_sink: Option<DynStreamSink>,
461    source_only: Option<DynStream>,
462    sink_only: Option<DynSink<Bytes>>,
463}
464
465impl ConnectedDirect {
466    pub fn into_source_sink(self) -> (SplitStream<DynStreamSink>, SplitSink<DynStreamSink, Bytes>) {
467        let (sink, stream) = self.stream_sink.unwrap().split();
468        (stream, sink)
469    }
470}
471
472impl Connected for ConnectedDirect {
473    fn from_defn(pipe: Connection) -> Self {
474        match pipe {
475            Connection::AsClient(ClientConnection::UnixSocket(stream)) => {
476                #[cfg(unix)]
477                {
478                    ConnectedDirect {
479                        stream_sink: Some(Box::pin(unix_bytes(stream))),
480                        source_only: None,
481                        sink_only: None,
482                    }
483                }
484
485                #[cfg(not(unix))]
486                {
487                    let _ = stream;
488                    panic!("Unix sockets are not supported on this platform");
489                }
490            }
491            Connection::AsClient(ClientConnection::TcpPort(stream)) => {
492                stream.set_nodelay(true).unwrap();
493                ConnectedDirect {
494                    stream_sink: Some(Box::pin(tcp_bytes(stream))),
495                    source_only: None,
496                    sink_only: None,
497                }
498            }
499            Connection::AsClient(ClientConnection::Merge(merge)) => {
500                let sources = merge
501                    .into_iter()
502                    .map(|port| {
503                        Some(Box::pin(
504                            ConnectedDirect::from_defn(Connection::AsClient(port)).into_source(),
505                        ))
506                    })
507                    .collect::<Vec<_>>();
508
509                let merged = MergeSource {
510                    marker: PhantomData,
511                    sources,
512                    poll_cursor: 0,
513                };
514
515                ConnectedDirect {
516                    stream_sink: None,
517                    source_only: Some(Box::pin(merged)),
518                    sink_only: None,
519                }
520            }
521            Connection::AsClient(ClientConnection::Demux(_)) => {
522                panic!("Cannot connect to a demux pipe directly")
523            }
524
525            Connection::AsClient(ClientConnection::Tagged(_, _)) => {
526                panic!("Cannot connect to a tagged pipe directly")
527            }
528
529            Connection::AsClient(ClientConnection::Null) => ConnectedDirect {
530                stream_sink: None,
531                source_only: Some(Box::pin(stream::empty())),
532                sink_only: Some(Box::pin(IoErrorDrain {
533                    marker: PhantomData,
534                })),
535            },
536
537            Connection::AsServer(bound) => accept(bound),
538        }
539    }
540}
541
542impl ConnectedSource for ConnectedDirect {
543    type Output = BytesMut;
544    type Stream = DynStream;
545
546    fn into_source(mut self) -> DynStream {
547        if let Some(s) = self.stream_sink.take() {
548            Box::pin(s)
549        } else {
550            self.source_only.take().unwrap()
551        }
552    }
553}
554
555impl ConnectedSink for ConnectedDirect {
556    type Input = Bytes;
557    type Sink = DynSink<Bytes>;
558
559    fn into_sink(mut self) -> DynSink<Self::Input> {
560        if let Some(s) = self.stream_sink.take() {
561            Box::pin(s)
562        } else {
563            self.sink_only.take().unwrap()
564        }
565    }
566}
567
568pub type BufferedDrain<S, I> = sinktools::demux_map::DemuxMap<u32, Pin<Box<Buffer<S, I>>>>;
569
570pub struct ConnectedDemux<T: ConnectedSink>
571where
572    <T as ConnectedSink>::Input: Sync,
573{
574    pub keys: Vec<u32>,
575    sink: Option<BufferedDrain<T::Sink, T::Input>>,
576}
577
578impl<T: Connected + ConnectedSink> Connected for ConnectedDemux<T>
579where
580    <T as ConnectedSink>::Input: 'static + Sync,
581{
582    fn from_defn(pipe: Connection) -> Self {
583        match pipe {
584            Connection::AsClient(ClientConnection::Demux(demux)) => {
585                let mut connected_demux = HashMap::new();
586                let keys = demux.keys().cloned().collect();
587                for (id, pipe) in demux {
588                    connected_demux.insert(
589                        id,
590                        Box::pin(
591                            T::from_defn(Connection::AsClient(pipe))
592                                .into_sink()
593                                .buffer(1024),
594                        ),
595                    );
596                }
597
598                let demuxer = sinktools::demux_map(connected_demux);
599
600                ConnectedDemux {
601                    keys,
602                    sink: Some(demuxer),
603                }
604            }
605
606            Connection::AsServer(AcceptedServer::Demux(demux)) => {
607                let mut connected_demux = HashMap::new();
608                let keys = demux.keys().cloned().collect();
609                for (id, bound) in demux {
610                    connected_demux.insert(
611                        id,
612                        Box::pin(
613                            T::from_defn(Connection::AsServer(bound))
614                                .into_sink()
615                                .buffer(1024),
616                        ),
617                    );
618                }
619
620                let demuxer = sinktools::demux_map(connected_demux);
621
622                ConnectedDemux {
623                    keys,
624                    sink: Some(demuxer),
625                }
626            }
627            _ => panic!("Cannot connect to a non-demux pipe as a demux"),
628        }
629    }
630}
631
632impl<T: ConnectedSink> ConnectedSink for ConnectedDemux<T>
633where
634    <T as ConnectedSink>::Input: 'static + Sync,
635{
636    type Input = (u32, T::Input);
637    type Sink = BufferedDrain<T::Sink, T::Input>;
638
639    fn into_sink(mut self) -> Self::Sink {
640        self.sink.take().unwrap()
641    }
642}
643
644pub struct MergeSource<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> {
645    marker: PhantomData<T>,
646    /// Ordered list for fair polling, will never be `None` at the beginning of a poll
647    sources: Vec<Option<Pin<Box<S>>>>,
648    /// Cursor for fair round-robin polling
649    poll_cursor: usize,
650}
651
652impl<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> Stream for MergeSource<T, S> {
653    type Item = T;
654
655    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
656        let me = self.get_mut();
657        let mut out = Poll::Pending;
658        let mut any_removed = false;
659
660        if !me.sources.is_empty() {
661            let start_cursor = me.poll_cursor;
662
663            loop {
664                let current_length = me.sources.len();
665                let source = &mut me.sources[me.poll_cursor];
666
667                // Move cursor to next source for next poll
668                me.poll_cursor = (me.poll_cursor + 1) % current_length;
669
670                match source.as_mut().unwrap().as_mut().poll_next(cx) {
671                    Poll::Ready(Some(data)) => {
672                        out = Poll::Ready(Some(data));
673                        break;
674                    }
675                    Poll::Ready(None) => {
676                        *source = None; // Mark source as removed
677                        any_removed = true;
678                    }
679                    Poll::Pending => {}
680                }
681
682                // Check if we've completed a full round
683                if me.poll_cursor == start_cursor {
684                    break;
685                }
686            }
687        }
688
689        // Clean up None entries and adjust cursor
690        let mut current_index = 0;
691        let original_cursor = me.poll_cursor;
692
693        if any_removed {
694            me.sources.retain(|source| {
695                if source.is_none() && current_index < original_cursor {
696                    me.poll_cursor -= 1;
697                }
698                current_index += 1;
699                source.is_some()
700            });
701        }
702
703        if me.poll_cursor == me.sources.len() {
704            me.poll_cursor = 0;
705        }
706
707        if me.sources.is_empty() {
708            Poll::Ready(None)
709        } else {
710            out
711        }
712    }
713}
714
715pub struct TaggedSource<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> {
716    marker: PhantomData<T>,
717    id: u32,
718    source: Pin<Box<S>>,
719}
720
721impl<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> Stream
722    for TaggedSource<T, S>
723{
724    type Item = Result<(u32, T), io::Error>;
725
726    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
727        let id = self.as_ref().id;
728        let source = &mut self.get_mut().source;
729        match source.as_mut().poll_next(cx) {
730            Poll::Ready(Some(v)) => Poll::Ready(Some(v.map(|d| (id, d)))),
731            Poll::Ready(None) => Poll::Ready(None),
732            Poll::Pending => Poll::Pending,
733        }
734    }
735}
736
737type MergedMux<T> = MergeSource<
738    Result<(u32, <T as ConnectedSource>::Output), io::Error>,
739    TaggedSource<<T as ConnectedSource>::Output, <T as ConnectedSource>::Stream>,
740>;
741
742pub struct ConnectedTagged<T: ConnectedSource>
743where
744    <T as ConnectedSource>::Output: 'static + Sync + Unpin,
745{
746    source: MergedMux<T>,
747}
748
749impl<T: Connected + ConnectedSource> Connected for ConnectedTagged<T>
750where
751    <T as ConnectedSource>::Output: 'static + Sync + Unpin,
752{
753    fn from_defn(pipe: Connection) -> Self {
754        let sources = match pipe {
755            Connection::AsClient(ClientConnection::Tagged(pipe, id)) => {
756                vec![(
757                    Box::pin(T::from_defn(Connection::AsClient(*pipe)).into_source()),
758                    id,
759                )]
760            }
761
762            Connection::AsClient(ClientConnection::Merge(m)) => {
763                let mut sources = Vec::new();
764                for port in m {
765                    if let ClientConnection::Tagged(pipe, id) = port {
766                        sources.push((
767                            Box::pin(T::from_defn(Connection::AsClient(*pipe)).into_source()),
768                            id,
769                        ));
770                    } else {
771                        panic!("Merge port must be tagged");
772                    }
773                }
774
775                sources
776            }
777
778            Connection::AsServer(AcceptedServer::Tagged(pipe, id)) => {
779                vec![(
780                    Box::pin(T::from_defn(Connection::AsServer(*pipe)).into_source()),
781                    id,
782                )]
783            }
784
785            Connection::AsServer(AcceptedServer::Merge(m)) => {
786                let mut sources = Vec::new();
787                for port in m {
788                    if let AcceptedServer::Tagged(pipe, id) = port {
789                        sources.push((
790                            Box::pin(T::from_defn(Connection::AsServer(*pipe)).into_source()),
791                            id,
792                        ));
793                    } else {
794                        panic!("Merge port must be tagged");
795                    }
796                }
797
798                sources
799            }
800
801            _ => panic!("Cannot connect to a non-tagged pipe as a tagged"),
802        };
803
804        let mut connected_mux = Vec::new();
805        for (pipe, id) in sources {
806            connected_mux.push(Some(Box::pin(TaggedSource {
807                marker: PhantomData,
808                id,
809                source: pipe,
810            })));
811        }
812
813        let muxer = MergeSource {
814            marker: PhantomData,
815            sources: connected_mux,
816            poll_cursor: 0,
817        };
818
819        ConnectedTagged { source: muxer }
820    }
821}
822
823impl<T: ConnectedSource> ConnectedSource for ConnectedTagged<T>
824where
825    <T as ConnectedSource>::Output: 'static + Sync + Unpin,
826{
827    type Output = (u32, T::Output);
828    type Stream = MergeSource<Result<Self::Output, io::Error>, TaggedSource<T::Output, T::Stream>>;
829
830    fn into_source(self) -> Self::Stream {
831        self.source
832    }
833}
834
835#[cfg(test)]
836mod tests {
837    use std::sync::Arc;
838    use std::task::{Context, Poll};
839
840    use futures::stream;
841
842    use super::*;
843
844    struct TestWaker;
845    impl std::task::Wake for TestWaker {
846        fn wake(self: Arc<Self>) {}
847    }
848
849    #[test]
850    fn test_merge_source_fair_polling() {
851        // Create test streams that yield values in a predictable pattern
852        let stream1 = Box::pin(stream::iter(vec![1, 4, 7]));
853        let stream2 = Box::pin(stream::iter(vec![2, 5, 8]));
854        let stream3 = Box::pin(stream::iter(vec![3, 6, 9]));
855
856        let mut merge_source = MergeSource {
857            marker: PhantomData,
858            sources: vec![Some(stream1), Some(stream2), Some(stream3)],
859            poll_cursor: 0,
860        };
861
862        let waker = Arc::new(TestWaker).into();
863        let mut cx = Context::from_waker(&waker);
864
865        let mut results = Vec::new();
866
867        // Poll until all streams are exhausted
868        loop {
869            match Pin::new(&mut merge_source).poll_next(&mut cx) {
870                Poll::Ready(Some(value)) => results.push(value),
871                Poll::Ready(None) => break,
872                Poll::Pending => break, // Shouldn't happen with our test streams
873            }
874        }
875
876        // With fair polling, we should get values in round-robin order: 1, 2, 3, 4, 5, 6, 7, 8, 9
877        assert_eq!(results, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
878    }
879}