1use std::borrow::Cow;
2use std::cell::RefCell;
3use std::collections::{BTreeMap, HashMap};
4use std::marker::PhantomData;
5use std::net::SocketAddr;
6use std::path::PathBuf;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11use async_recursion::async_recursion;
12use bytes::{Bytes, BytesMut};
13use futures::sink::Buffer;
14use futures::stream::{FuturesUnordered, SplitSink, SplitStream};
15use futures::{Future, Sink, SinkExt, Stream, StreamExt, stream};
16use serde::{Deserialize, Serialize};
17use tempfile::TempDir;
18use tokio::io;
19use tokio::net::{TcpListener, TcpStream};
20#[cfg(unix)]
21use tokio::net::{UnixListener, UnixStream};
22use tokio_stream::wrappers::TcpListenerStream;
23use tokio_util::codec::{Framed, LengthDelimitedCodec};
24
25pub mod multi_connection;
26pub mod single_connection;
27
28pub type InitConfig<'a> = (HashMap<String, ServerBindConfig>, Option<Cow<'a, str>>);
29
30pub struct DeployPorts<T = Option<()>> {
33 pub ports: RefCell<HashMap<String, Connection>>,
34 pub meta: T,
35}
36
37impl<T> DeployPorts<T> {
38 pub fn port(&self, name: &str) -> Connection {
39 self.ports
40 .try_borrow_mut()
41 .unwrap()
42 .remove(name)
43 .unwrap_or_else(|| panic!("port {} not found", name))
44 }
45}
46
47#[cfg(not(unix))]
48type UnixStream = std::convert::Infallible;
49
50#[cfg(not(unix))]
51type UnixListener = std::convert::Infallible;
52
53#[derive(Serialize, Deserialize, Clone, Debug)]
55pub enum ServerPort {
56 UnixSocket(PathBuf),
57 TcpPort(SocketAddr),
58 Demux(BTreeMap<u32, ServerPort>),
59 Merge(Vec<ServerPort>),
60 Tagged(Box<ServerPort>, u32),
61 Null,
62}
63
64impl ServerPort {
65 #[async_recursion]
66 pub async fn connect(&self) -> ClientConnection {
67 match self {
68 ServerPort::UnixSocket(path) => {
69 #[cfg(unix)]
70 {
71 let bound = UnixStream::connect(path.clone());
72 ClientConnection::UnixSocket(bound.await.unwrap())
73 }
74
75 #[cfg(not(unix))]
76 {
77 let _ = path;
78 panic!("Unix sockets are not supported on this platform")
79 }
80 }
81 ServerPort::TcpPort(addr) => {
82 let addr_clone = *addr;
83 let stream = async_retry(
84 move || TcpStream::connect(addr_clone),
85 10,
86 Duration::from_secs(1),
87 )
88 .await
89 .unwrap();
90 ClientConnection::TcpPort(stream)
91 }
92 ServerPort::Demux(bindings) => ClientConnection::Demux(
93 bindings
94 .iter()
95 .map(|(k, v)| async move { (*k, v.connect().await) })
96 .collect::<FuturesUnordered<_>>()
97 .collect::<BTreeMap<_, _>>()
98 .await,
99 ),
100 ServerPort::Merge(ports) => ClientConnection::Merge(
101 ports
102 .iter()
103 .map(|p| p.connect())
104 .collect::<FuturesUnordered<_>>()
105 .collect::<Vec<_>>()
106 .await,
107 ),
108 ServerPort::Tagged(port, tag) => {
109 ClientConnection::Tagged(Box::new(port.as_ref().connect().await), *tag)
110 }
111 ServerPort::Null => ClientConnection::Null,
112 }
113 }
114
115 pub async fn instantiate(&self) -> Connection {
116 Connection::AsClient(self.connect().await)
117 }
118}
119
120#[derive(Debug)]
121pub enum ClientConnection {
122 UnixSocket(UnixStream),
123 TcpPort(TcpStream),
124 Demux(BTreeMap<u32, ClientConnection>),
125 Merge(Vec<ClientConnection>),
126 Tagged(Box<ClientConnection>, u32),
127 Null,
128}
129
130#[derive(Serialize, Deserialize, Clone, Debug)]
131pub enum ServerBindConfig {
132 UnixSocket,
133 TcpPort(
134 String,
136 Option<u16>,
140 ),
141 Demux(BTreeMap<u32, ServerBindConfig>),
142 Merge(Vec<ServerBindConfig>),
143 Tagged(Box<ServerBindConfig>, u32),
144 MultiConnection(Box<ServerBindConfig>),
145 Null,
146}
147
148impl ServerBindConfig {
149 #[async_recursion]
150 pub async fn bind(self) -> BoundServer {
151 match self {
152 ServerBindConfig::UnixSocket => {
153 #[cfg(unix)]
154 {
155 let dir = tempfile::tempdir().unwrap();
156 let socket_path = dir.path().join("socket");
157 let bound = UnixListener::bind(socket_path).unwrap();
158 BoundServer::UnixSocket(bound, dir)
159 }
160
161 #[cfg(not(unix))]
162 {
163 panic!("Unix sockets are not supported on this platform")
164 }
165 }
166 ServerBindConfig::TcpPort(host, port) => {
167 let listener = TcpListener::bind((host, port.unwrap_or(0)))
168 .await
169 .unwrap_or_else(|e| panic!("Failed to bind port {:?}: {}", port, e));
170 let addr = listener.local_addr().unwrap();
171 BoundServer::TcpPort(TcpListenerStream::new(listener), addr)
172 }
173 ServerBindConfig::Demux(bindings) => {
174 let mut demux = BTreeMap::new();
175 for (key, bind) in bindings {
176 demux.insert(key, bind.bind().await); }
178 BoundServer::Demux(demux)
179 }
180 ServerBindConfig::Merge(bindings) => {
181 let mut merge = Vec::new();
182 for bind in bindings {
183 merge.push(bind.bind().await); }
185 BoundServer::Merge(merge)
186 }
187 ServerBindConfig::Tagged(underlying, id) => {
188 BoundServer::Tagged(Box::new(underlying.bind().await), id)
189 }
190 ServerBindConfig::MultiConnection(underlying) => {
191 BoundServer::MultiConnection(Box::new(underlying.bind().await))
192 }
193 ServerBindConfig::Null => BoundServer::Null,
194 }
195 }
196}
197
198#[derive(Debug)]
199pub enum Connection {
200 AsClient(ClientConnection),
201 AsServer(AcceptedServer),
202}
203
204impl Connection {
205 pub fn connect<T: Connected>(self) -> T {
206 T::from_defn(self)
207 }
208}
209
210pub type DynStream = Pin<Box<dyn Stream<Item = Result<BytesMut, io::Error>> + Send + Sync>>;
211
212pub type DynSink<Input> = Pin<Box<dyn Sink<Input, Error = io::Error> + Send + Sync>>;
213
214pub trait StreamSink:
215 Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>
216{
217}
218impl<T: Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>> StreamSink
219 for T
220{
221}
222
223pub type DynStreamSink = Pin<Box<dyn StreamSink + Send + Sync>>;
224
225pub trait Connected: Send {
226 fn from_defn(pipe: Connection) -> Self;
227}
228
229pub trait ConnectedSink {
230 type Input: Send;
231 type Sink: Sink<Self::Input, Error = io::Error> + Send + Sync;
232
233 fn into_sink(self) -> Self::Sink;
234}
235
236pub trait ConnectedSource {
237 type Output: Send;
238 type Stream: Stream<Item = Result<Self::Output, io::Error>> + Send + Sync;
239 fn into_source(self) -> Self::Stream;
240}
241
242#[derive(Debug)]
243pub enum BoundServer {
244 UnixSocket(UnixListener, TempDir),
245 TcpPort(TcpListenerStream, SocketAddr),
246 Demux(BTreeMap<u32, BoundServer>),
247 Merge(Vec<BoundServer>),
248 Tagged(Box<BoundServer>, u32),
249 MultiConnection(Box<BoundServer>),
250 Null,
251}
252
253#[derive(Debug)]
254pub enum AcceptedServer {
255 UnixSocket(UnixStream, TempDir),
256 TcpPort(TcpStream),
257 Demux(BTreeMap<u32, AcceptedServer>),
258 Merge(Vec<AcceptedServer>),
259 Tagged(Box<AcceptedServer>, u32),
260 MultiConnection(Box<BoundServer>),
261 Null,
262}
263
264#[async_recursion]
265pub async fn accept_bound(bound: BoundServer) -> AcceptedServer {
266 match bound {
267 BoundServer::UnixSocket(listener, dir) => {
268 #[cfg(unix)]
269 {
270 let stream = listener.accept().await.unwrap().0;
271 AcceptedServer::UnixSocket(stream, dir)
272 }
273
274 #[cfg(not(unix))]
275 {
276 let _ = listener;
277 let _ = dir;
278 panic!("Unix sockets are not supported on this platform")
279 }
280 }
281 BoundServer::TcpPort(mut listener, _) => {
282 let stream = listener.next().await.unwrap().unwrap();
283 AcceptedServer::TcpPort(stream)
284 }
285 BoundServer::Demux(bindings) => AcceptedServer::Demux(
286 bindings
287 .into_iter()
288 .map(|(k, b)| async move { (k, accept_bound(b).await) })
289 .collect::<FuturesUnordered<_>>()
290 .collect::<Vec<_>>()
291 .await
292 .into_iter()
293 .collect(),
294 ),
295 BoundServer::Merge(merge) => AcceptedServer::Merge(
296 merge
297 .into_iter()
298 .map(|b| async move { accept_bound(b).await })
299 .collect::<FuturesUnordered<_>>()
300 .collect::<Vec<_>>()
301 .await,
302 ),
303 BoundServer::Tagged(underlying, id) => {
304 AcceptedServer::Tagged(Box::new(accept_bound(*underlying).await), id)
305 }
306 BoundServer::MultiConnection(underlying) => AcceptedServer::MultiConnection(underlying),
307 BoundServer::Null => AcceptedServer::Null,
308 }
309}
310
311impl BoundServer {
312 pub fn server_port(&self) -> ServerPort {
313 match self {
314 BoundServer::UnixSocket(_, tempdir) => {
315 #[cfg(unix)]
316 {
317 ServerPort::UnixSocket(tempdir.path().join("socket"))
318 }
319
320 #[cfg(not(unix))]
321 {
322 let _ = tempdir;
323 panic!("Unix sockets are not supported on this platform")
324 }
325 }
326 BoundServer::TcpPort(_, addr) => {
327 ServerPort::TcpPort(SocketAddr::new(addr.ip(), addr.port()))
328 }
329
330 BoundServer::Demux(bindings) => {
331 let mut demux = BTreeMap::new();
332 for (key, bind) in bindings {
333 demux.insert(*key, bind.server_port());
334 }
335 ServerPort::Demux(demux)
336 }
337
338 BoundServer::Merge(bindings) => {
339 let mut merge = Vec::new();
340 for bind in bindings {
341 merge.push(bind.server_port());
342 }
343 ServerPort::Merge(merge)
344 }
345
346 BoundServer::Tagged(underlying, id) => {
347 ServerPort::Tagged(Box::new(underlying.server_port()), *id)
348 }
349
350 BoundServer::MultiConnection(underlying) => underlying.server_port(),
351
352 BoundServer::Null => ServerPort::Null,
353 }
354 }
355}
356
357fn accept(bound: AcceptedServer) -> ConnectedDirect {
358 match bound {
359 AcceptedServer::UnixSocket(stream, _dir) => {
360 #[cfg(unix)]
361 {
362 ConnectedDirect {
363 stream_sink: Some(Box::pin(unix_bytes(stream))),
364 source_only: None,
365 sink_only: None,
366 }
367 }
368
369 #[cfg(not(unix))]
370 {
371 let _ = stream;
372 panic!("Unix sockets are not supported on this platform")
373 }
374 }
375 AcceptedServer::TcpPort(stream) => ConnectedDirect {
376 stream_sink: Some(Box::pin(tcp_bytes(stream))),
377 source_only: None,
378 sink_only: None,
379 },
380 AcceptedServer::Merge(merge) => {
381 let mut sources = vec![];
382 for bound in merge {
383 sources.push(Some(Box::pin(accept(bound).into_source())));
384 }
385
386 let merge_source: DynStream = Box::pin(MergeSource {
387 marker: PhantomData,
388 sources,
389 poll_cursor: 0,
390 });
391
392 ConnectedDirect {
393 stream_sink: None,
394 source_only: Some(merge_source),
395 sink_only: None,
396 }
397 }
398 AcceptedServer::Demux(_) => panic!("Cannot connect to a demux pipe directly"),
399 AcceptedServer::Tagged(_, _) => panic!("Cannot connect to a tagged pipe directly"),
400 AcceptedServer::MultiConnection(_) => {
401 panic!("Cannot connect to a multi-connection pipe directly")
402 }
403 AcceptedServer::Null => {
404 ConnectedDirect::from_defn(Connection::AsClient(ClientConnection::Null))
405 }
406 }
407}
408
409fn tcp_bytes(stream: TcpStream) -> impl StreamSink {
410 Framed::new(stream, LengthDelimitedCodec::new())
411}
412
413#[cfg(unix)]
414fn unix_bytes(stream: UnixStream) -> impl StreamSink {
415 Framed::new(stream, LengthDelimitedCodec::new())
416}
417
418struct IoErrorDrain<T> {
419 marker: PhantomData<T>,
420}
421
422impl<T> Sink<T> for IoErrorDrain<T> {
423 type Error = io::Error;
424
425 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
426 Poll::Ready(Ok(()))
427 }
428
429 fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
430 Ok(())
431 }
432
433 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
434 Poll::Ready(Ok(()))
435 }
436
437 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
438 Poll::Ready(Ok(()))
439 }
440}
441
442async fn async_retry<T, E, F: Future<Output = Result<T, E>>>(
443 thunk: impl Fn() -> F,
444 count: usize,
445 delay: Duration,
446) -> Result<T, E> {
447 for _ in 1..count {
448 let result = thunk().await;
449 if result.is_ok() {
450 return result;
451 } else {
452 tokio::time::sleep(delay).await;
453 }
454 }
455
456 thunk().await
457}
458
459pub struct ConnectedDirect {
460 stream_sink: Option<DynStreamSink>,
461 source_only: Option<DynStream>,
462 sink_only: Option<DynSink<Bytes>>,
463}
464
465impl ConnectedDirect {
466 pub fn into_source_sink(self) -> (SplitStream<DynStreamSink>, SplitSink<DynStreamSink, Bytes>) {
467 let (sink, stream) = self.stream_sink.unwrap().split();
468 (stream, sink)
469 }
470}
471
472impl Connected for ConnectedDirect {
473 fn from_defn(pipe: Connection) -> Self {
474 match pipe {
475 Connection::AsClient(ClientConnection::UnixSocket(stream)) => {
476 #[cfg(unix)]
477 {
478 ConnectedDirect {
479 stream_sink: Some(Box::pin(unix_bytes(stream))),
480 source_only: None,
481 sink_only: None,
482 }
483 }
484
485 #[cfg(not(unix))]
486 {
487 let _ = stream;
488 panic!("Unix sockets are not supported on this platform");
489 }
490 }
491 Connection::AsClient(ClientConnection::TcpPort(stream)) => {
492 stream.set_nodelay(true).unwrap();
493 ConnectedDirect {
494 stream_sink: Some(Box::pin(tcp_bytes(stream))),
495 source_only: None,
496 sink_only: None,
497 }
498 }
499 Connection::AsClient(ClientConnection::Merge(merge)) => {
500 let sources = merge
501 .into_iter()
502 .map(|port| {
503 Some(Box::pin(
504 ConnectedDirect::from_defn(Connection::AsClient(port)).into_source(),
505 ))
506 })
507 .collect::<Vec<_>>();
508
509 let merged = MergeSource {
510 marker: PhantomData,
511 sources,
512 poll_cursor: 0,
513 };
514
515 ConnectedDirect {
516 stream_sink: None,
517 source_only: Some(Box::pin(merged)),
518 sink_only: None,
519 }
520 }
521 Connection::AsClient(ClientConnection::Demux(_)) => {
522 panic!("Cannot connect to a demux pipe directly")
523 }
524
525 Connection::AsClient(ClientConnection::Tagged(_, _)) => {
526 panic!("Cannot connect to a tagged pipe directly")
527 }
528
529 Connection::AsClient(ClientConnection::Null) => ConnectedDirect {
530 stream_sink: None,
531 source_only: Some(Box::pin(stream::empty())),
532 sink_only: Some(Box::pin(IoErrorDrain {
533 marker: PhantomData,
534 })),
535 },
536
537 Connection::AsServer(bound) => accept(bound),
538 }
539 }
540}
541
542impl ConnectedSource for ConnectedDirect {
543 type Output = BytesMut;
544 type Stream = DynStream;
545
546 fn into_source(mut self) -> DynStream {
547 if let Some(s) = self.stream_sink.take() {
548 Box::pin(s)
549 } else {
550 self.source_only.take().unwrap()
551 }
552 }
553}
554
555impl ConnectedSink for ConnectedDirect {
556 type Input = Bytes;
557 type Sink = DynSink<Bytes>;
558
559 fn into_sink(mut self) -> DynSink<Self::Input> {
560 if let Some(s) = self.stream_sink.take() {
561 Box::pin(s)
562 } else {
563 self.sink_only.take().unwrap()
564 }
565 }
566}
567
568pub type BufferedDrain<S, I> = sinktools::demux_map::DemuxMap<u32, Pin<Box<Buffer<S, I>>>>;
569
570pub struct ConnectedDemux<T: ConnectedSink>
571where
572 <T as ConnectedSink>::Input: Sync,
573{
574 pub keys: Vec<u32>,
575 sink: Option<BufferedDrain<T::Sink, T::Input>>,
576}
577
578impl<T: Connected + ConnectedSink> Connected for ConnectedDemux<T>
579where
580 <T as ConnectedSink>::Input: 'static + Sync,
581{
582 fn from_defn(pipe: Connection) -> Self {
583 match pipe {
584 Connection::AsClient(ClientConnection::Demux(demux)) => {
585 let mut connected_demux = HashMap::new();
586 let keys = demux.keys().cloned().collect();
587 for (id, pipe) in demux {
588 connected_demux.insert(
589 id,
590 Box::pin(
591 T::from_defn(Connection::AsClient(pipe))
592 .into_sink()
593 .buffer(1024),
594 ),
595 );
596 }
597
598 let demuxer = sinktools::demux_map(connected_demux);
599
600 ConnectedDemux {
601 keys,
602 sink: Some(demuxer),
603 }
604 }
605
606 Connection::AsServer(AcceptedServer::Demux(demux)) => {
607 let mut connected_demux = HashMap::new();
608 let keys = demux.keys().cloned().collect();
609 for (id, bound) in demux {
610 connected_demux.insert(
611 id,
612 Box::pin(
613 T::from_defn(Connection::AsServer(bound))
614 .into_sink()
615 .buffer(1024),
616 ),
617 );
618 }
619
620 let demuxer = sinktools::demux_map(connected_demux);
621
622 ConnectedDemux {
623 keys,
624 sink: Some(demuxer),
625 }
626 }
627 _ => panic!("Cannot connect to a non-demux pipe as a demux"),
628 }
629 }
630}
631
632impl<T: ConnectedSink> ConnectedSink for ConnectedDemux<T>
633where
634 <T as ConnectedSink>::Input: 'static + Sync,
635{
636 type Input = (u32, T::Input);
637 type Sink = BufferedDrain<T::Sink, T::Input>;
638
639 fn into_sink(mut self) -> Self::Sink {
640 self.sink.take().unwrap()
641 }
642}
643
644pub struct MergeSource<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> {
645 marker: PhantomData<T>,
646 sources: Vec<Option<Pin<Box<S>>>>,
648 poll_cursor: usize,
650}
651
652impl<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> Stream for MergeSource<T, S> {
653 type Item = T;
654
655 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
656 let me = self.get_mut();
657 let mut out = Poll::Pending;
658 let mut any_removed = false;
659
660 if !me.sources.is_empty() {
661 let start_cursor = me.poll_cursor;
662
663 loop {
664 let current_length = me.sources.len();
665 let source = &mut me.sources[me.poll_cursor];
666
667 me.poll_cursor = (me.poll_cursor + 1) % current_length;
669
670 match source.as_mut().unwrap().as_mut().poll_next(cx) {
671 Poll::Ready(Some(data)) => {
672 out = Poll::Ready(Some(data));
673 break;
674 }
675 Poll::Ready(None) => {
676 *source = None; any_removed = true;
678 }
679 Poll::Pending => {}
680 }
681
682 if me.poll_cursor == start_cursor {
684 break;
685 }
686 }
687 }
688
689 let mut current_index = 0;
691 let original_cursor = me.poll_cursor;
692
693 if any_removed {
694 me.sources.retain(|source| {
695 if source.is_none() && current_index < original_cursor {
696 me.poll_cursor -= 1;
697 }
698 current_index += 1;
699 source.is_some()
700 });
701 }
702
703 if me.poll_cursor == me.sources.len() {
704 me.poll_cursor = 0;
705 }
706
707 if me.sources.is_empty() {
708 Poll::Ready(None)
709 } else {
710 out
711 }
712 }
713}
714
715pub struct TaggedSource<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> {
716 marker: PhantomData<T>,
717 id: u32,
718 source: Pin<Box<S>>,
719}
720
721impl<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> Stream
722 for TaggedSource<T, S>
723{
724 type Item = Result<(u32, T), io::Error>;
725
726 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
727 let id = self.as_ref().id;
728 let source = &mut self.get_mut().source;
729 match source.as_mut().poll_next(cx) {
730 Poll::Ready(Some(v)) => Poll::Ready(Some(v.map(|d| (id, d)))),
731 Poll::Ready(None) => Poll::Ready(None),
732 Poll::Pending => Poll::Pending,
733 }
734 }
735}
736
737type MergedMux<T> = MergeSource<
738 Result<(u32, <T as ConnectedSource>::Output), io::Error>,
739 TaggedSource<<T as ConnectedSource>::Output, <T as ConnectedSource>::Stream>,
740>;
741
742pub struct ConnectedTagged<T: ConnectedSource>
743where
744 <T as ConnectedSource>::Output: 'static + Sync + Unpin,
745{
746 source: MergedMux<T>,
747}
748
749impl<T: Connected + ConnectedSource> Connected for ConnectedTagged<T>
750where
751 <T as ConnectedSource>::Output: 'static + Sync + Unpin,
752{
753 fn from_defn(pipe: Connection) -> Self {
754 let sources = match pipe {
755 Connection::AsClient(ClientConnection::Tagged(pipe, id)) => {
756 vec![(
757 Box::pin(T::from_defn(Connection::AsClient(*pipe)).into_source()),
758 id,
759 )]
760 }
761
762 Connection::AsClient(ClientConnection::Merge(m)) => {
763 let mut sources = Vec::new();
764 for port in m {
765 if let ClientConnection::Tagged(pipe, id) = port {
766 sources.push((
767 Box::pin(T::from_defn(Connection::AsClient(*pipe)).into_source()),
768 id,
769 ));
770 } else {
771 panic!("Merge port must be tagged");
772 }
773 }
774
775 sources
776 }
777
778 Connection::AsServer(AcceptedServer::Tagged(pipe, id)) => {
779 vec![(
780 Box::pin(T::from_defn(Connection::AsServer(*pipe)).into_source()),
781 id,
782 )]
783 }
784
785 Connection::AsServer(AcceptedServer::Merge(m)) => {
786 let mut sources = Vec::new();
787 for port in m {
788 if let AcceptedServer::Tagged(pipe, id) = port {
789 sources.push((
790 Box::pin(T::from_defn(Connection::AsServer(*pipe)).into_source()),
791 id,
792 ));
793 } else {
794 panic!("Merge port must be tagged");
795 }
796 }
797
798 sources
799 }
800
801 _ => panic!("Cannot connect to a non-tagged pipe as a tagged"),
802 };
803
804 let mut connected_mux = Vec::new();
805 for (pipe, id) in sources {
806 connected_mux.push(Some(Box::pin(TaggedSource {
807 marker: PhantomData,
808 id,
809 source: pipe,
810 })));
811 }
812
813 let muxer = MergeSource {
814 marker: PhantomData,
815 sources: connected_mux,
816 poll_cursor: 0,
817 };
818
819 ConnectedTagged { source: muxer }
820 }
821}
822
823impl<T: ConnectedSource> ConnectedSource for ConnectedTagged<T>
824where
825 <T as ConnectedSource>::Output: 'static + Sync + Unpin,
826{
827 type Output = (u32, T::Output);
828 type Stream = MergeSource<Result<Self::Output, io::Error>, TaggedSource<T::Output, T::Stream>>;
829
830 fn into_source(self) -> Self::Stream {
831 self.source
832 }
833}
834
835#[cfg(test)]
836mod tests {
837 use std::sync::Arc;
838 use std::task::{Context, Poll};
839
840 use futures::stream;
841
842 use super::*;
843
844 struct TestWaker;
845 impl std::task::Wake for TestWaker {
846 fn wake(self: Arc<Self>) {}
847 }
848
849 #[test]
850 fn test_merge_source_fair_polling() {
851 let stream1 = Box::pin(stream::iter(vec![1, 4, 7]));
853 let stream2 = Box::pin(stream::iter(vec![2, 5, 8]));
854 let stream3 = Box::pin(stream::iter(vec![3, 6, 9]));
855
856 let mut merge_source = MergeSource {
857 marker: PhantomData,
858 sources: vec![Some(stream1), Some(stream2), Some(stream3)],
859 poll_cursor: 0,
860 };
861
862 let waker = Arc::new(TestWaker).into();
863 let mut cx = Context::from_waker(&waker);
864
865 let mut results = Vec::new();
866
867 loop {
869 match Pin::new(&mut merge_source).poll_next(&mut cx) {
870 Poll::Ready(Some(value)) => results.push(value),
871 Poll::Ready(None) => break,
872 Poll::Pending => break, }
874 }
875
876 assert_eq!(results, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
878 }
879}