1use std::collections::HashMap;
2use std::io;
3use std::ops::DerefMut;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use futures::{Sink, SinkExt, Stream, StreamExt};
8#[cfg(unix)]
9use tempfile::TempDir;
10use tokio::net::TcpListener;
11#[cfg(unix)]
12use tokio::net::UnixListener;
13use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
14use tokio::sync::mpsc;
15use tokio_stream::wrappers::UnboundedReceiverStream;
16use tokio_util::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite};
17
18use crate::{AcceptedServer, BoundServer, Connected, Connection};
19
20pub struct ConnectedMultiConnection<I, O, C: Decoder<Item = I> + Encoder<O>> {
21 pub source: MultiConnectionSource<I, O, C>,
22 pub sink: MultiConnectionSink<O, C>,
23 pub membership: UnboundedReceiverStream<(u64, bool)>,
24}
25
26impl<
27 I: 'static,
28 O: Send + Sync + 'static,
29 C: Decoder<Item = I> + Encoder<O> + Send + Sync + Default + 'static,
30> Connected for ConnectedMultiConnection<I, O, C>
31{
32 fn from_defn(pipe: Connection) -> Self {
33 match pipe {
34 Connection::AsServer(AcceptedServer::MultiConnection(bound_server)) => {
35 let (new_sink_sender, new_sink_receiver) = mpsc::unbounded_channel();
36 let (membership_sender, membership_receiver) = mpsc::unbounded_channel();
37
38 let source = match *bound_server {
39 #[cfg(unix)]
40 BoundServer::UnixSocket(listener, dir) => MultiConnectionSource {
41 unix_listener: Some(listener),
42 tcp_listener: None,
43 _dir_holder: Some(dir),
44 next_connection_id: 0,
45 active_connections: Vec::new(),
46 poll_cursor: 0,
47 new_sink_sender,
48 membership_sender,
49 },
50 BoundServer::TcpPort(listener, _) => MultiConnectionSource {
51 #[cfg(unix)]
52 unix_listener: None,
53 tcp_listener: Some(listener.into_inner()),
54 #[cfg(unix)]
55 _dir_holder: None,
56 next_connection_id: 0,
57 active_connections: Vec::new(),
58 poll_cursor: 0,
59 new_sink_sender,
60 membership_sender,
61 },
62 _ => panic!("MultiConnection only supports UnixSocket and TcpPort"),
63 };
64
65 let sink = MultiConnectionSink::<O, C> {
66 connection_sinks: HashMap::new(),
67 new_sink_receiver,
68 };
69
70 ConnectedMultiConnection {
71 source,
72 sink,
73 membership: UnboundedReceiverStream::new(membership_receiver),
74 }
75 }
76 _ => panic!("Cannot connect to a non-multi-connection pipe as a multi-connection"),
77 }
78 }
79}
80
81type DynDecodedStream<I, C> =
82 Pin<Box<dyn Stream<Item = Result<I, <C as Decoder>::Error>> + Send + Sync>>;
83type DynEncodedSink<O, C> = Pin<Box<dyn Sink<O, Error = <C as Encoder<O>>::Error> + Send + Sync>>;
84
85pub struct MultiConnectionSource<I, O, C: Decoder<Item = I> + Encoder<O>> {
86 #[cfg(unix)]
87 unix_listener: Option<UnixListener>,
88 tcp_listener: Option<TcpListener>,
89 #[cfg(unix)]
90 _dir_holder: Option<TempDir>, next_connection_id: u64,
92 active_connections: Vec<Option<(u64, DynDecodedStream<I, C>)>>,
94 poll_cursor: usize,
96 new_sink_sender: mpsc::UnboundedSender<(u64, DynEncodedSink<O, C>)>,
97 membership_sender: mpsc::UnboundedSender<(u64, bool)>,
98}
99
100pub struct MultiConnectionSink<O, C: Encoder<O>> {
101 connection_sinks: HashMap<u64, DynEncodedSink<O, C>>,
102 new_sink_receiver: mpsc::UnboundedReceiver<(u64, DynEncodedSink<O, C>)>,
103}
104
105impl<
106 I,
107 O: Send + Sync + 'static,
108 C: Decoder<Item = I> + Encoder<O> + Send + Sync + Default + 'static,
109> Stream for MultiConnectionSource<I, O, C>
110{
111 type Item = Result<(u64, I), <C as Decoder>::Error>;
112
113 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
114 let me = self.deref_mut();
115 #[cfg(unix)]
117 if let Some(listener) = me.unix_listener.as_mut() {
118 loop {
119 match listener.poll_accept(cx) {
120 Poll::Ready(Ok((stream, _))) => {
121 use futures::{SinkExt, StreamExt};
122 use tokio_util::codec::Framed;
123
124 let connection_id = me.next_connection_id;
125 me.next_connection_id += 1;
126
127 let framed = Framed::new(stream, C::default());
128 let (sink, stream) = framed.split();
129
130 let boxed_stream: Pin<
131 Box<dyn Stream<Item = Result<I, <C as Decoder>::Error>> + Send + Sync>,
132 > = Box::pin(stream);
133
134 let boxed_sink: Pin<
136 Box<dyn Sink<O, Error = <C as Encoder<O>>::Error> + Send + Sync>,
137 > = Box::pin(sink.buffer(1024));
138
139 me.active_connections
140 .push(Some((connection_id, boxed_stream)));
141
142 let _ = me.new_sink_sender.send((connection_id, boxed_sink));
143 let _ = me.membership_sender.send((connection_id, true));
144 }
145 Poll::Ready(Err(e)) => {
146 if !me.active_connections.iter().any(|conn| conn.is_some()) {
147 return Poll::Ready(Some(Err(e.into())));
148 } else {
149 break;
150 }
151 }
152 Poll::Pending => {
153 break;
154 }
155 }
156 }
157 }
158
159 if let Some(listener) = me.tcp_listener.as_mut() {
161 loop {
162 match listener.poll_accept(cx) {
163 Poll::Ready(Ok((stream, _))) => {
164 let connection_id = me.next_connection_id;
165 me.next_connection_id += 1;
166
167 let framed = Framed::new(stream, C::default());
168 let (sink, stream) = framed.split();
169
170 let boxed_stream: Pin<
171 Box<dyn Stream<Item = Result<I, <C as Decoder>::Error>> + Send + Sync>,
172 > = Box::pin(stream);
173
174 let boxed_sink: Pin<
176 Box<dyn Sink<O, Error = <C as Encoder<O>>::Error> + Send + Sync>,
177 > = Box::pin(sink.buffer(1024));
178
179 me.active_connections
180 .push(Some((connection_id, boxed_stream)));
181
182 let _ = me.new_sink_sender.send((connection_id, boxed_sink));
183 let _ = me.membership_sender.send((connection_id, true));
184 }
185 Poll::Ready(Err(e)) => {
186 if !me.active_connections.iter().any(|conn| conn.is_some()) {
187 return Poll::Ready(Some(Err(e.into())));
188 } else {
189 break;
190 }
191 }
192 Poll::Pending => {
193 break;
194 }
195 }
196 }
197 }
198
199 let mut out = Poll::Pending;
201 let mut any_removed = false;
202
203 if !me.active_connections.is_empty() {
204 let start_cursor = me.poll_cursor;
205
206 loop {
207 let current_length = me.active_connections.len();
208 let id_and_stream = &mut me.active_connections[me.poll_cursor];
209 let (connection_id, stream) = id_and_stream.as_mut().unwrap();
210 let connection_id = *connection_id; me.poll_cursor = (me.poll_cursor + 1) % current_length;
214
215 match stream.as_mut().poll_next(cx) {
216 Poll::Ready(Some(Ok(data))) => {
217 out = Poll::Ready(Some(Ok((connection_id, data))));
218 break;
219 }
220 Poll::Ready(Some(Err(_))) | Poll::Ready(None) => {
221 let _ = me.membership_sender.send((connection_id, false));
222 *id_and_stream = None; any_removed = true;
224 }
225 Poll::Pending => {}
226 }
227
228 if me.poll_cursor == start_cursor {
230 break;
231 }
232 }
233 }
234
235 let mut current_index = 0;
237 let original_cursor = me.poll_cursor;
238
239 if any_removed {
240 me.active_connections.retain(|conn| {
241 if conn.is_none() && current_index < original_cursor {
242 me.poll_cursor -= 1;
243 }
244 current_index += 1;
245 conn.is_some()
246 });
247 }
248
249 if me.poll_cursor == me.active_connections.len() {
250 me.poll_cursor = 0;
251 }
252
253 out
254 }
255}
256
257impl<O, C: Encoder<O>> Sink<(u64, O)> for MultiConnectionSink<O, C> {
258 type Error = <C as Encoder<O>>::Error;
259
260 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
261 loop {
262 match self.new_sink_receiver.poll_recv(cx) {
263 Poll::Ready(Some((connection_id, sink))) => {
264 self.connection_sinks.insert(connection_id, sink);
265 }
266 Poll::Ready(None) => {
267 if self.connection_sinks.is_empty() {
268 return Poll::Ready(Err(io::Error::new(
269 io::ErrorKind::BrokenPipe,
270 "No additional sinks are available (was the stream dropped)?",
271 )
272 .into()));
273 } else {
274 break;
275 }
276 }
277 Poll::Pending => {
278 break;
279 }
280 }
281 }
282
283 let mut any_pending = false;
285 self.connection_sinks
286 .retain(|_, sink| match sink.as_mut().poll_ready(cx) {
287 Poll::Ready(Ok(())) => true,
288 Poll::Ready(Err(_)) => false,
289 Poll::Pending => {
290 any_pending = true;
291 true
292 }
293 });
294
295 if any_pending {
296 Poll::Pending
297 } else {
298 Poll::Ready(Ok(())) }
300 }
301
302 fn start_send(mut self: Pin<&mut Self>, item: (u64, O)) -> Result<(), Self::Error> {
303 if let Some(sink) = self.connection_sinks.get_mut(&item.0) {
304 let _ = sink.as_mut().start_send(item.1); }
306 Ok(())
308 }
309
310 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
311 let mut any_pending = false;
312
313 self.connection_sinks
314 .retain(|_, sink| match sink.as_mut().poll_flush(cx) {
315 Poll::Ready(Ok(())) => true,
316 Poll::Ready(Err(_)) => false,
317 Poll::Pending => {
318 any_pending = true;
319 true
320 }
321 });
322
323 if any_pending {
324 Poll::Pending
325 } else {
326 Poll::Ready(Ok(()))
327 }
328 }
329
330 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
331 let mut any_pending = false;
332
333 self.connection_sinks.retain(|_, sink| {
334 match sink.as_mut().poll_close(cx) {
335 Poll::Ready(Ok(()) | Err(_)) => false, Poll::Pending => {
337 any_pending = true;
338 true
339 }
340 }
341 });
342
343 if any_pending {
344 Poll::Pending
345 } else {
346 Poll::Ready(Ok(()))
347 }
348 }
349}
350
351pub struct TcpMultiConnectionSource<C: Decoder> {
353 pub listener: TcpListener,
355 pub next_connection_id: u64,
357 pub active_connections: Vec<Option<(u64, FramedRead<OwnedReadHalf, C>)>>,
359 pub poll_cursor: usize,
361 pub new_sink_sender: mpsc::UnboundedSender<(u64, FramedWrite<OwnedWriteHalf, C>)>,
363 pub membership_sender: mpsc::UnboundedSender<(u64, bool)>,
365}
366
367impl<C: Decoder + Default + Unpin> Stream for TcpMultiConnectionSource<C>
368where
369 C::Error: From<io::Error>,
370{
371 type Item = Result<(u64, C::Item), C::Error>;
372
373 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
374 let me = self.deref_mut();
375
376 loop {
378 match me.listener.poll_accept(cx) {
379 Poll::Ready(Ok((stream, _peer))) => {
380 let connection_id = me.next_connection_id;
381 me.next_connection_id += 1;
382
383 let (rx, tx) = stream.into_split();
384 let fr = FramedRead::new(rx, C::default());
385 let fw = FramedWrite::new(tx, C::default());
386
387 me.active_connections.push(Some((connection_id, fr)));
388 let _ = me.new_sink_sender.send((connection_id, fw));
389 let _ = me.membership_sender.send((connection_id, true));
390 }
391 Poll::Ready(Err(e)) => {
392 if !me.active_connections.iter().any(|c| c.is_some()) {
393 return Poll::Ready(Some(Err(e.into())));
394 } else {
395 break;
396 }
397 }
398 Poll::Pending => {
399 break;
400 }
401 }
402 }
403
404 let mut out = Poll::Pending;
406 let mut any_removed = false;
407
408 if !me.active_connections.is_empty() {
409 let start_cursor = me.poll_cursor;
410
411 loop {
412 let current_length = me.active_connections.len();
413 let id_and_stream = &mut me.active_connections[me.poll_cursor];
414 let (connection_id, stream) = id_and_stream.as_mut().unwrap();
415 let connection_id = *connection_id; me.poll_cursor = (me.poll_cursor + 1) % current_length;
419
420 match Pin::new(stream).poll_next(cx) {
421 Poll::Ready(Some(Ok(data))) => {
422 out = Poll::Ready(Some(Ok((connection_id, data))));
423 break;
424 }
425 Poll::Ready(Some(Err(_))) | Poll::Ready(None) => {
426 let _ = me.membership_sender.send((connection_id, false));
427 *id_and_stream = None; any_removed = true;
429 }
430 Poll::Pending => {}
431 }
432
433 if me.poll_cursor == start_cursor {
435 break;
436 }
437 }
438 }
439
440 let mut current_index = 0;
442 let original_cursor = me.poll_cursor;
443
444 if any_removed {
445 me.active_connections.retain(|conn| {
446 if conn.is_none() && current_index < original_cursor {
447 me.poll_cursor -= 1;
448 }
449 current_index += 1;
450 conn.is_some()
451 });
452 }
453
454 if me.poll_cursor == me.active_connections.len() {
455 me.poll_cursor = 0;
456 }
457
458 out
459 }
460}
461
462pub struct TcpMultiConnectionSink<I, C: Encoder<I>> {
465 pub connection_sinks: HashMap<u64, FramedWrite<OwnedWriteHalf, C>>,
467 pub new_sink_receiver: mpsc::UnboundedReceiver<(u64, FramedWrite<OwnedWriteHalf, C>)>,
469 _marker: std::marker::PhantomData<fn(I) -> I>, }
471
472impl<I, C: Encoder<I> + Unpin> Sink<(u64, I)> for TcpMultiConnectionSink<I, C> {
473 type Error = C::Error;
474
475 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
476 let me = self.get_mut();
477 while let Poll::Ready(Some((id, sink))) = me.new_sink_receiver.poll_recv(cx) {
479 me.connection_sinks.insert(id, sink);
480 }
481 Poll::Ready(Ok(()))
482 }
483
484 fn start_send(self: Pin<&mut Self>, item: (u64, I)) -> Result<(), Self::Error> {
485 let me = self.get_mut();
486 if let Some(sink) = me.connection_sinks.get_mut(&item.0) {
487 let _ = Pin::new(sink).start_send(item.1);
488 }
489 Ok(())
490 }
491
492 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
493 let me = self.get_mut();
494 let mut any_pending = false;
495
496 me.connection_sinks
497 .retain(|_id, sink| match Pin::new(sink).poll_flush(cx) {
498 Poll::Ready(Ok(())) => true,
499 Poll::Ready(Err(_)) => false,
500 Poll::Pending => {
501 any_pending = true;
502 true
503 }
504 });
505
506 if any_pending {
507 Poll::Pending
508 } else {
509 Poll::Ready(Ok(()))
510 }
511 }
512
513 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
514 let me = self.get_mut();
515 let mut any_pending = false;
516
517 me.connection_sinks
518 .retain(|_id, sink| match Pin::new(sink).poll_close(cx) {
519 Poll::Ready(Ok(())) => false,
520 Poll::Ready(Err(_)) => false,
521 Poll::Pending => {
522 any_pending = true;
523 true
524 }
525 });
526
527 if any_pending {
528 Poll::Pending
529 } else {
530 Poll::Ready(Ok(()))
531 }
532 }
533}
534
535type TcpMultiConnectionParts<I, C> = (
536 TcpMultiConnectionSource<C>,
537 TcpMultiConnectionSink<I, C>,
538 UnboundedReceiverStream<(u64, bool)>,
539);
540
541pub fn tcp_multi_connection<I, C>(listener: TcpListener) -> TcpMultiConnectionParts<I, C>
542where
543 C: Decoder + Encoder<I> + Default,
544{
545 let (new_sink_sender, new_sink_receiver) = mpsc::unbounded_channel();
546 let (membership_sender, membership_receiver) = mpsc::unbounded_channel();
547
548 let source = TcpMultiConnectionSource {
549 listener,
550 next_connection_id: 0,
551 active_connections: Vec::new(),
552 poll_cursor: 0,
553 new_sink_sender,
554 membership_sender,
555 };
556
557 let sink = TcpMultiConnectionSink {
558 connection_sinks: HashMap::new(),
559 new_sink_receiver,
560 _marker: std::marker::PhantomData,
561 };
562
563 let membership = UnboundedReceiverStream::new(membership_receiver);
564
565 (source, sink, membership)
566}