1use std::cell::RefCell;
2use std::collections::HashMap;
3use std::marker::PhantomData;
4use std::net::SocketAddr;
5use std::path::PathBuf;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use async_recursion::async_recursion;
11use async_trait::async_trait;
12use bytes::{Bytes, BytesMut};
13use futures::sink::Buffer;
14use futures::stream::{FuturesUnordered, SplitSink, SplitStream};
15use futures::{Future, Sink, SinkExt, Stream, StreamExt, ready, stream};
16use pin_project_lite::pin_project;
17use serde::{Deserialize, Serialize};
18use tempfile::TempDir;
19use tokio::io;
20use tokio::net::{TcpListener, TcpStream};
21#[cfg(unix)]
22use tokio::net::{UnixListener, UnixStream};
23use tokio_stream::wrappers::TcpListenerStream;
24use tokio_util::codec::{Framed, LengthDelimitedCodec};
25
26pub mod multi_connection;
27
28pub type InitConfig = (HashMap<String, ServerBindConfig>, Option<String>);
29
30pub struct DeployPorts<T = Option<()>> {
33 pub ports: RefCell<HashMap<String, Connection>>,
34 pub meta: T,
35}
36
37impl<T> DeployPorts<T> {
38 pub fn port(&self, name: &str) -> Connection {
39 self.ports
40 .try_borrow_mut()
41 .unwrap()
42 .remove(name)
43 .unwrap_or_else(|| panic!("port {} not found", name))
44 }
45}
46
47#[cfg(not(unix))]
48type UnixStream = std::convert::Infallible;
49
50#[cfg(not(unix))]
51type UnixListener = std::convert::Infallible;
52
53#[derive(Serialize, Deserialize, Clone, Debug)]
55pub enum ServerPort {
56 UnixSocket(PathBuf),
57 TcpPort(SocketAddr),
58 Demux(HashMap<u32, ServerPort>),
59 Merge(Vec<ServerPort>),
60 Tagged(Box<ServerPort>, u32),
61 Null,
62}
63
64impl ServerPort {
65 #[async_recursion]
66 pub async fn connect(&self) -> ClientConnection {
67 match self {
68 ServerPort::UnixSocket(path) => {
69 #[cfg(unix)]
70 {
71 let bound = UnixStream::connect(path.clone());
72 ClientConnection::UnixSocket(bound.await.unwrap())
73 }
74
75 #[cfg(not(unix))]
76 {
77 let _ = path;
78 panic!("Unix sockets are not supported on this platform")
79 }
80 }
81 ServerPort::TcpPort(addr) => {
82 let addr_clone = *addr;
83 let stream = async_retry(
84 move || TcpStream::connect(addr_clone),
85 10,
86 Duration::from_secs(1),
87 )
88 .await
89 .unwrap();
90 ClientConnection::TcpPort(stream)
91 }
92 ServerPort::Demux(bindings) => ClientConnection::Demux(
93 bindings
94 .iter()
95 .map(|(k, v)| async move { (*k, v.connect().await) })
96 .collect::<FuturesUnordered<_>>()
97 .collect::<Vec<_>>()
98 .await
99 .into_iter()
100 .collect(),
101 ),
102 ServerPort::Merge(ports) => ClientConnection::Merge(
103 ports
104 .iter()
105 .map(|p| p.connect())
106 .collect::<FuturesUnordered<_>>()
107 .collect::<Vec<_>>()
108 .await
109 .into_iter()
110 .collect(),
111 ),
112 ServerPort::Tagged(port, tag) => {
113 ClientConnection::Tagged(Box::new(port.as_ref().connect().await), *tag)
114 }
115 ServerPort::Null => ClientConnection::Null,
116 }
117 }
118
119 pub async fn instantiate(&self) -> Connection {
120 Connection::AsClient(self.connect().await)
121 }
122}
123
124#[derive(Debug)]
125pub enum ClientConnection {
126 UnixSocket(UnixStream),
127 TcpPort(TcpStream),
128 Demux(HashMap<u32, ClientConnection>),
129 Merge(Vec<ClientConnection>),
130 Tagged(Box<ClientConnection>, u32),
131 Null,
132}
133
134#[derive(Serialize, Deserialize, Clone, Debug)]
135pub enum ServerBindConfig {
136 UnixSocket,
137 TcpPort(
138 String,
140 Option<u16>,
144 ),
145 Demux(HashMap<u32, ServerBindConfig>),
146 Merge(Vec<ServerBindConfig>),
147 Tagged(Box<ServerBindConfig>, u32),
148 MultiConnection(Box<ServerBindConfig>),
149 Null,
150}
151
152impl ServerBindConfig {
153 #[async_recursion]
154 pub async fn bind(self) -> BoundServer {
155 match self {
156 ServerBindConfig::UnixSocket => {
157 #[cfg(unix)]
158 {
159 let dir = tempfile::tempdir().unwrap();
160 let socket_path = dir.path().join("socket");
161 let bound = UnixListener::bind(socket_path).unwrap();
162 BoundServer::UnixSocket(bound, dir)
163 }
164
165 #[cfg(not(unix))]
166 {
167 panic!("Unix sockets are not supported on this platform")
168 }
169 }
170 ServerBindConfig::TcpPort(host, port) => {
171 let listener = TcpListener::bind((host, port.unwrap_or(0))).await.unwrap();
172 let addr = listener.local_addr().unwrap();
173 BoundServer::TcpPort(TcpListenerStream::new(listener), addr)
174 }
175 ServerBindConfig::Demux(bindings) => {
176 let mut demux = HashMap::new();
177 for (key, bind) in bindings {
178 demux.insert(key, bind.bind().await);
179 }
180 BoundServer::Demux(demux)
181 }
182 ServerBindConfig::Merge(bindings) => {
183 let mut merge = Vec::new();
184 for bind in bindings {
185 merge.push(bind.bind().await);
186 }
187 BoundServer::Merge(merge)
188 }
189 ServerBindConfig::Tagged(underlying, id) => {
190 BoundServer::Tagged(Box::new(underlying.bind().await), id)
191 }
192 ServerBindConfig::MultiConnection(underlying) => {
193 BoundServer::MultiConnection(Box::new(underlying.bind().await))
194 }
195 ServerBindConfig::Null => BoundServer::Null,
196 }
197 }
198}
199
200#[derive(Debug)]
201pub enum Connection {
202 AsClient(ClientConnection),
203 AsServer(AcceptedServer),
204}
205
206impl Connection {
207 pub fn connect<T: Connected>(self) -> T {
208 T::from_defn(self)
209 }
210}
211
212pub type DynStream = Pin<Box<dyn Stream<Item = Result<BytesMut, io::Error>> + Send + Sync>>;
213
214pub type DynSink<Input> = Pin<Box<dyn Sink<Input, Error = io::Error> + Send + Sync>>;
215
216pub trait StreamSink:
217 Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>
218{
219}
220impl<T: Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>> StreamSink
221 for T
222{
223}
224
225pub type DynStreamSink = Pin<Box<dyn StreamSink + Send + Sync>>;
226
227pub trait Connected: Send {
228 fn from_defn(pipe: Connection) -> Self;
229}
230
231pub trait ConnectedSink {
232 type Input: Send;
233 type Sink: Sink<Self::Input, Error = io::Error> + Send + Sync;
234
235 fn into_sink(self) -> Self::Sink;
236}
237
238pub trait ConnectedSource {
239 type Output: Send;
240 type Stream: Stream<Item = Result<Self::Output, io::Error>> + Send + Sync;
241 fn into_source(self) -> Self::Stream;
242}
243
244#[derive(Debug)]
245pub enum BoundServer {
246 UnixSocket(UnixListener, TempDir),
247 TcpPort(TcpListenerStream, SocketAddr),
248 Demux(HashMap<u32, BoundServer>),
249 Merge(Vec<BoundServer>),
250 Tagged(Box<BoundServer>, u32),
251 MultiConnection(Box<BoundServer>),
252 Null,
253}
254
255#[derive(Debug)]
256pub enum AcceptedServer {
257 UnixSocket(UnixStream, TempDir),
258 TcpPort(TcpStream),
259 Demux(HashMap<u32, AcceptedServer>),
260 Merge(Vec<AcceptedServer>),
261 Tagged(Box<AcceptedServer>, u32),
262 MultiConnection(Box<BoundServer>),
263 Null,
264}
265
266#[async_recursion]
267pub async fn accept_bound(bound: BoundServer) -> AcceptedServer {
268 match bound {
269 BoundServer::UnixSocket(listener, dir) => {
270 #[cfg(unix)]
271 {
272 let stream = listener.accept().await.unwrap().0;
273 AcceptedServer::UnixSocket(stream, dir)
274 }
275
276 #[cfg(not(unix))]
277 {
278 let _ = listener;
279 let _ = dir;
280 panic!("Unix sockets are not supported on this platform")
281 }
282 }
283 BoundServer::TcpPort(mut listener, _) => {
284 let stream = listener.next().await.unwrap().unwrap();
285 AcceptedServer::TcpPort(stream)
286 }
287 BoundServer::Demux(bindings) => AcceptedServer::Demux(
288 bindings
289 .into_iter()
290 .map(|(k, b)| async move { (k, accept_bound(b).await) })
291 .collect::<FuturesUnordered<_>>()
292 .collect::<Vec<_>>()
293 .await
294 .into_iter()
295 .collect(),
296 ),
297 BoundServer::Merge(merge) => AcceptedServer::Merge(
298 merge
299 .into_iter()
300 .map(|b| async move { accept_bound(b).await })
301 .collect::<FuturesUnordered<_>>()
302 .collect::<Vec<_>>()
303 .await,
304 ),
305 BoundServer::Tagged(underlying, id) => {
306 AcceptedServer::Tagged(Box::new(accept_bound(*underlying).await), id)
307 }
308 BoundServer::MultiConnection(underlying) => AcceptedServer::MultiConnection(underlying),
309 BoundServer::Null => AcceptedServer::Null,
310 }
311}
312
313impl BoundServer {
314 pub fn server_port(&self) -> ServerPort {
315 match self {
316 BoundServer::UnixSocket(_, tempdir) => {
317 #[cfg(unix)]
318 {
319 ServerPort::UnixSocket(tempdir.path().join("socket"))
320 }
321
322 #[cfg(not(unix))]
323 {
324 let _ = tempdir;
325 panic!("Unix sockets are not supported on this platform")
326 }
327 }
328 BoundServer::TcpPort(_, addr) => {
329 ServerPort::TcpPort(SocketAddr::new(addr.ip(), addr.port()))
330 }
331
332 BoundServer::Demux(bindings) => {
333 let mut demux = HashMap::new();
334 for (key, bind) in bindings {
335 demux.insert(*key, bind.server_port());
336 }
337 ServerPort::Demux(demux)
338 }
339
340 BoundServer::Merge(bindings) => {
341 let mut merge = Vec::new();
342 for bind in bindings {
343 merge.push(bind.server_port());
344 }
345 ServerPort::Merge(merge)
346 }
347
348 BoundServer::Tagged(underlying, id) => {
349 ServerPort::Tagged(Box::new(underlying.server_port()), *id)
350 }
351
352 BoundServer::MultiConnection(underlying) => underlying.server_port(),
353
354 BoundServer::Null => ServerPort::Null,
355 }
356 }
357}
358
359fn accept(bound: AcceptedServer) -> ConnectedDirect {
360 match bound {
361 AcceptedServer::UnixSocket(stream, _dir) => {
362 #[cfg(unix)]
363 {
364 ConnectedDirect {
365 stream_sink: Some(Box::pin(unix_bytes(stream))),
366 source_only: None,
367 sink_only: None,
368 }
369 }
370
371 #[cfg(not(unix))]
372 {
373 let _ = stream;
374 panic!("Unix sockets are not supported on this platform")
375 }
376 }
377 AcceptedServer::TcpPort(stream) => ConnectedDirect {
378 stream_sink: Some(Box::pin(tcp_bytes(stream))),
379 source_only: None,
380 sink_only: None,
381 },
382 AcceptedServer::Merge(merge) => {
383 let mut sources = vec![];
384 for bound in merge {
385 sources.push(Some(Box::pin(accept(bound).into_source())));
386 }
387
388 let merge_source: DynStream = Box::pin(MergeSource {
389 marker: PhantomData,
390 sources,
391 poll_cursor: 0,
392 });
393
394 ConnectedDirect {
395 stream_sink: None,
396 source_only: Some(merge_source),
397 sink_only: None,
398 }
399 }
400 AcceptedServer::Demux(_) => panic!("Cannot connect to a demux pipe directly"),
401 AcceptedServer::Tagged(_, _) => panic!("Cannot connect to a tagged pipe directly"),
402 AcceptedServer::MultiConnection(_) => {
403 panic!("Cannot connect to a multi-connection pipe directly")
404 }
405 AcceptedServer::Null => {
406 ConnectedDirect::from_defn(Connection::AsClient(ClientConnection::Null))
407 }
408 }
409}
410
411fn tcp_bytes(stream: TcpStream) -> impl StreamSink {
412 Framed::new(stream, LengthDelimitedCodec::new())
413}
414
415#[cfg(unix)]
416fn unix_bytes(stream: UnixStream) -> impl StreamSink {
417 Framed::new(stream, LengthDelimitedCodec::new())
418}
419
420struct IoErrorDrain<T> {
421 marker: PhantomData<T>,
422}
423
424impl<T> Sink<T> for IoErrorDrain<T> {
425 type Error = io::Error;
426
427 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
428 Poll::Ready(Ok(()))
429 }
430
431 fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
432 Ok(())
433 }
434
435 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
436 Poll::Ready(Ok(()))
437 }
438
439 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
440 Poll::Ready(Ok(()))
441 }
442}
443
444async fn async_retry<T, E, F: Future<Output = Result<T, E>>>(
445 thunk: impl Fn() -> F,
446 count: usize,
447 delay: Duration,
448) -> Result<T, E> {
449 for _ in 1..count {
450 let result = thunk().await;
451 if result.is_ok() {
452 return result;
453 } else {
454 tokio::time::sleep(delay).await;
455 }
456 }
457
458 thunk().await
459}
460
461pub struct ConnectedDirect {
462 stream_sink: Option<DynStreamSink>,
463 source_only: Option<DynStream>,
464 sink_only: Option<DynSink<Bytes>>,
465}
466
467impl ConnectedDirect {
468 pub fn into_source_sink(self) -> (SplitStream<DynStreamSink>, SplitSink<DynStreamSink, Bytes>) {
469 let (sink, stream) = self.stream_sink.unwrap().split();
470 (stream, sink)
471 }
472}
473
474impl Connected for ConnectedDirect {
475 fn from_defn(pipe: Connection) -> Self {
476 match pipe {
477 Connection::AsClient(ClientConnection::UnixSocket(stream)) => {
478 #[cfg(unix)]
479 {
480 ConnectedDirect {
481 stream_sink: Some(Box::pin(unix_bytes(stream))),
482 source_only: None,
483 sink_only: None,
484 }
485 }
486
487 #[cfg(not(unix))]
488 {
489 let _ = stream;
490 panic!("Unix sockets are not supported on this platform");
491 }
492 }
493 Connection::AsClient(ClientConnection::TcpPort(stream)) => {
494 stream.set_nodelay(true).unwrap();
495 ConnectedDirect {
496 stream_sink: Some(Box::pin(tcp_bytes(stream))),
497 source_only: None,
498 sink_only: None,
499 }
500 }
501 Connection::AsClient(ClientConnection::Merge(merge)) => {
502 let sources = merge
503 .into_iter()
504 .map(|port| {
505 Some(Box::pin(
506 ConnectedDirect::from_defn(Connection::AsClient(port)).into_source(),
507 ))
508 })
509 .collect::<Vec<_>>();
510
511 let merged = MergeSource {
512 marker: PhantomData,
513 sources,
514 poll_cursor: 0,
515 };
516
517 ConnectedDirect {
518 stream_sink: None,
519 source_only: Some(Box::pin(merged)),
520 sink_only: None,
521 }
522 }
523 Connection::AsClient(ClientConnection::Demux(_)) => {
524 panic!("Cannot connect to a demux pipe directly")
525 }
526
527 Connection::AsClient(ClientConnection::Tagged(_, _)) => {
528 panic!("Cannot connect to a tagged pipe directly")
529 }
530
531 Connection::AsClient(ClientConnection::Null) => ConnectedDirect {
532 stream_sink: None,
533 source_only: Some(Box::pin(stream::empty())),
534 sink_only: Some(Box::pin(IoErrorDrain {
535 marker: PhantomData,
536 })),
537 },
538
539 Connection::AsServer(bound) => accept(bound),
540 }
541 }
542}
543
544impl ConnectedSource for ConnectedDirect {
545 type Output = BytesMut;
546 type Stream = DynStream;
547
548 fn into_source(mut self) -> DynStream {
549 if let Some(s) = self.stream_sink.take() {
550 Box::pin(s)
551 } else {
552 self.source_only.take().unwrap()
553 }
554 }
555}
556
557impl ConnectedSink for ConnectedDirect {
558 type Input = Bytes;
559 type Sink = DynSink<Bytes>;
560
561 fn into_sink(mut self) -> DynSink<Self::Input> {
562 if let Some(s) = self.stream_sink.take() {
563 Box::pin(s)
564 } else {
565 self.sink_only.take().unwrap()
566 }
567 }
568}
569
570pub type BufferedDrain<S, I> = DemuxDrain<I, Buffer<S, I>>;
571
572pub struct ConnectedDemux<T: ConnectedSink>
573where
574 <T as ConnectedSink>::Input: Sync,
575{
576 pub keys: Vec<u32>,
577 sink: Option<BufferedDrain<T::Sink, T::Input>>,
578}
579
580trait DrainableSink<T>: Sink<T, Error = io::Error> + Send + Sync {}
581
582impl<T, S: Sink<T, Error = io::Error> + Send + Sync + ?Sized> DrainableSink<T> for S {}
583
584pin_project! {
585 pub struct DemuxDrain<T, S: DrainableSink<T>> {
586 marker: PhantomData<T>,
587 #[pin]
588 sinks: HashMap<u32, Pin<Box<S>>>,
589 }
590}
591
592impl<T, S: Sink<T, Error = io::Error> + Send + Sync> Sink<(u32, T)> for DemuxDrain<T, S> {
593 type Error = io::Error;
594
595 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
596 for sink in self.project().sinks.values_mut() {
597 ready!(Sink::poll_ready(sink.as_mut(), _cx))?;
598 }
599
600 Poll::Ready(Ok(()))
601 }
602
603 fn start_send(self: Pin<&mut Self>, item: (u32, T)) -> Result<(), Self::Error> {
604 Sink::start_send(
605 self.project()
606 .sinks
607 .get_mut()
608 .get_mut(&item.0)
609 .unwrap_or_else(|| panic!("No sink in this demux for key {}", item.0))
610 .as_mut(),
611 item.1,
612 )
613 }
614
615 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
616 for sink in self.project().sinks.values_mut() {
617 ready!(Sink::poll_flush(sink.as_mut(), _cx))?;
618 }
619
620 Poll::Ready(Ok(()))
621 }
622
623 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
624 for sink in self.project().sinks.values_mut() {
625 ready!(Sink::poll_close(sink.as_mut(), _cx))?;
626 }
627
628 Poll::Ready(Ok(()))
629 }
630}
631
632#[async_trait]
633impl<T: Connected + ConnectedSink> Connected for ConnectedDemux<T>
634where
635 <T as ConnectedSink>::Input: 'static + Sync,
636{
637 fn from_defn(pipe: Connection) -> Self {
638 match pipe {
639 Connection::AsClient(ClientConnection::Demux(demux)) => {
640 let mut connected_demux = HashMap::new();
641 let keys = demux.keys().cloned().collect();
642 for (id, pipe) in demux {
643 connected_demux.insert(
644 id,
645 Box::pin(
646 T::from_defn(Connection::AsClient(pipe))
647 .into_sink()
648 .buffer(1024),
649 ),
650 );
651 }
652
653 let demuxer = DemuxDrain {
654 marker: PhantomData,
655 sinks: connected_demux,
656 };
657
658 ConnectedDemux {
659 keys,
660 sink: Some(demuxer),
661 }
662 }
663
664 Connection::AsServer(AcceptedServer::Demux(demux)) => {
665 let mut connected_demux = HashMap::new();
666 let keys = demux.keys().cloned().collect();
667 for (id, bound) in demux {
668 connected_demux.insert(
669 id,
670 Box::pin(
671 T::from_defn(Connection::AsServer(bound))
672 .into_sink()
673 .buffer(1024),
674 ),
675 );
676 }
677
678 let demuxer = DemuxDrain {
679 marker: PhantomData,
680 sinks: connected_demux,
681 };
682
683 ConnectedDemux {
684 keys,
685 sink: Some(demuxer),
686 }
687 }
688 _ => panic!("Cannot connect to a non-demux pipe as a demux"),
689 }
690 }
691}
692
693impl<T: ConnectedSink> ConnectedSink for ConnectedDemux<T>
694where
695 <T as ConnectedSink>::Input: 'static + Sync,
696{
697 type Input = (u32, T::Input);
698 type Sink = BufferedDrain<T::Sink, T::Input>;
699
700 fn into_sink(mut self) -> Self::Sink {
701 self.sink.take().unwrap()
702 }
703}
704
705pub struct MergeSource<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> {
706 marker: PhantomData<T>,
707 sources: Vec<Option<Pin<Box<S>>>>,
709 poll_cursor: usize,
711}
712
713impl<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> Stream for MergeSource<T, S> {
714 type Item = T;
715
716 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
717 let me = self.get_mut();
718 let mut out = Poll::Pending;
719 let mut any_removed = false;
720
721 if !me.sources.is_empty() {
722 let start_cursor = me.poll_cursor;
723
724 loop {
725 let current_length = me.sources.len();
726 let source = &mut me.sources[me.poll_cursor];
727
728 me.poll_cursor = (me.poll_cursor + 1) % current_length;
730
731 match source.as_mut().unwrap().as_mut().poll_next(cx) {
732 Poll::Ready(Some(data)) => {
733 out = Poll::Ready(Some(data));
734 break;
735 }
736 Poll::Ready(None) => {
737 *source = None; any_removed = true;
739 }
740 Poll::Pending => {}
741 }
742
743 if me.poll_cursor == start_cursor {
745 break;
746 }
747 }
748 }
749
750 let mut current_index = 0;
752 let original_cursor = me.poll_cursor;
753
754 if any_removed {
755 me.sources.retain(|source| {
756 if source.is_none() && current_index < original_cursor {
757 me.poll_cursor -= 1;
758 }
759 current_index += 1;
760 source.is_some()
761 });
762 }
763
764 if me.poll_cursor == me.sources.len() {
765 me.poll_cursor = 0;
766 }
767
768 if me.sources.is_empty() {
769 Poll::Ready(None)
770 } else {
771 out
772 }
773 }
774}
775
776pub struct TaggedSource<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> {
777 marker: PhantomData<T>,
778 id: u32,
779 source: Pin<Box<S>>,
780}
781
782impl<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> Stream
783 for TaggedSource<T, S>
784{
785 type Item = Result<(u32, T), io::Error>;
786
787 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
788 let id = self.as_ref().id;
789 let source = &mut self.get_mut().source;
790 match source.as_mut().poll_next(cx) {
791 Poll::Ready(Some(v)) => Poll::Ready(Some(v.map(|d| (id, d)))),
792 Poll::Ready(None) => Poll::Ready(None),
793 Poll::Pending => Poll::Pending,
794 }
795 }
796}
797
798type MergedMux<T> = MergeSource<
799 Result<(u32, <T as ConnectedSource>::Output), io::Error>,
800 TaggedSource<<T as ConnectedSource>::Output, <T as ConnectedSource>::Stream>,
801>;
802
803pub struct ConnectedTagged<T: ConnectedSource>
804where
805 <T as ConnectedSource>::Output: 'static + Sync + Unpin,
806{
807 source: MergedMux<T>,
808}
809
810#[async_trait]
811impl<T: Connected + ConnectedSource> Connected for ConnectedTagged<T>
812where
813 <T as ConnectedSource>::Output: 'static + Sync + Unpin,
814{
815 fn from_defn(pipe: Connection) -> Self {
816 let sources = match pipe {
817 Connection::AsClient(ClientConnection::Tagged(pipe, id)) => {
818 vec![(
819 Box::pin(T::from_defn(Connection::AsClient(*pipe)).into_source()),
820 id,
821 )]
822 }
823
824 Connection::AsClient(ClientConnection::Merge(m)) => {
825 let mut sources = Vec::new();
826 for port in m {
827 if let ClientConnection::Tagged(pipe, id) = port {
828 sources.push((
829 Box::pin(T::from_defn(Connection::AsClient(*pipe)).into_source()),
830 id,
831 ));
832 } else {
833 panic!("Merge port must be tagged");
834 }
835 }
836
837 sources
838 }
839
840 Connection::AsServer(AcceptedServer::Tagged(pipe, id)) => {
841 vec![(
842 Box::pin(T::from_defn(Connection::AsServer(*pipe)).into_source()),
843 id,
844 )]
845 }
846
847 Connection::AsServer(AcceptedServer::Merge(m)) => {
848 let mut sources = Vec::new();
849 for port in m {
850 if let AcceptedServer::Tagged(pipe, id) = port {
851 sources.push((
852 Box::pin(T::from_defn(Connection::AsServer(*pipe)).into_source()),
853 id,
854 ));
855 } else {
856 panic!("Merge port must be tagged");
857 }
858 }
859
860 sources
861 }
862
863 _ => panic!("Cannot connect to a non-tagged pipe as a tagged"),
864 };
865
866 let mut connected_mux = Vec::new();
867 for (pipe, id) in sources {
868 connected_mux.push(Some(Box::pin(TaggedSource {
869 marker: PhantomData,
870 id,
871 source: pipe,
872 })));
873 }
874
875 let muxer = MergeSource {
876 marker: PhantomData,
877 sources: connected_mux,
878 poll_cursor: 0,
879 };
880
881 ConnectedTagged { source: muxer }
882 }
883}
884
885impl<T: ConnectedSource> ConnectedSource for ConnectedTagged<T>
886where
887 <T as ConnectedSource>::Output: 'static + Sync + Unpin,
888{
889 type Output = (u32, T::Output);
890 type Stream = MergeSource<Result<Self::Output, io::Error>, TaggedSource<T::Output, T::Stream>>;
891
892 fn into_source(self) -> Self::Stream {
893 self.source
894 }
895}
896
897#[cfg(test)]
898mod tests {
899 use std::sync::Arc;
900 use std::task::{Context, Poll};
901
902 use futures::stream;
903
904 use super::*;
905
906 struct TestWaker;
907 impl std::task::Wake for TestWaker {
908 fn wake(self: Arc<Self>) {}
909 }
910
911 #[test]
912 fn test_merge_source_fair_polling() {
913 let stream1 = Box::pin(stream::iter(vec![1, 4, 7]));
915 let stream2 = Box::pin(stream::iter(vec![2, 5, 8]));
916 let stream3 = Box::pin(stream::iter(vec![3, 6, 9]));
917
918 let mut merge_source = MergeSource {
919 marker: PhantomData,
920 sources: vec![Some(stream1), Some(stream2), Some(stream3)],
921 poll_cursor: 0,
922 };
923
924 let waker = Arc::new(TestWaker).into();
925 let mut cx = Context::from_waker(&waker);
926
927 let mut results = Vec::new();
928
929 loop {
931 match Pin::new(&mut merge_source).poll_next(&mut cx) {
932 Poll::Ready(Some(value)) => results.push(value),
933 Poll::Ready(None) => break,
934 Poll::Pending => break, }
936 }
937
938 assert_eq!(results, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
940 }
941}