Skip to main content

hydro_deploy_integration/
multi_connection.rs

1use std::collections::HashMap;
2use std::io;
3use std::ops::DerefMut;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use futures::{Sink, SinkExt, Stream, StreamExt};
8#[cfg(unix)]
9use tempfile::TempDir;
10use tokio::net::TcpListener;
11#[cfg(unix)]
12use tokio::net::UnixListener;
13use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
14use tokio::sync::mpsc;
15use tokio_stream::wrappers::UnboundedReceiverStream;
16use tokio_util::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite};
17
18use crate::{AcceptedServer, BoundServer, Connected, Connection};
19
20pub struct ConnectedMultiConnection<I, O, C: Decoder<Item = I> + Encoder<O>> {
21    pub source: MultiConnectionSource<I, O, C>,
22    pub sink: MultiConnectionSink<O, C>,
23    pub membership: UnboundedReceiverStream<(u64, bool)>,
24}
25
26impl<
27    I: 'static,
28    O: Send + Sync + 'static,
29    C: Decoder<Item = I> + Encoder<O> + Send + Sync + Default + 'static,
30> Connected for ConnectedMultiConnection<I, O, C>
31{
32    fn from_defn(pipe: Connection) -> Self {
33        match pipe {
34            Connection::AsServer(AcceptedServer::MultiConnection(bound_server)) => {
35                let (new_sink_sender, new_sink_receiver) = mpsc::unbounded_channel();
36                let (membership_sender, membership_receiver) = mpsc::unbounded_channel();
37
38                let source = match *bound_server {
39                    #[cfg(unix)]
40                    BoundServer::UnixSocket(listener, dir) => MultiConnectionSource {
41                        unix_listener: Some(listener),
42                        tcp_listener: None,
43                        _dir_holder: Some(dir),
44                        next_connection_id: 0,
45                        active_connections: Vec::new(),
46                        poll_cursor: 0,
47                        new_sink_sender,
48                        membership_sender,
49                    },
50                    BoundServer::TcpPort(listener, _) => MultiConnectionSource {
51                        #[cfg(unix)]
52                        unix_listener: None,
53                        tcp_listener: Some(listener.into_inner()),
54                        #[cfg(unix)]
55                        _dir_holder: None,
56                        next_connection_id: 0,
57                        active_connections: Vec::new(),
58                        poll_cursor: 0,
59                        new_sink_sender,
60                        membership_sender,
61                    },
62                    _ => panic!("MultiConnection only supports UnixSocket and TcpPort"),
63                };
64
65                let sink = MultiConnectionSink::<O, C> {
66                    connection_sinks: HashMap::new(),
67                    new_sink_receiver,
68                };
69
70                ConnectedMultiConnection {
71                    source,
72                    sink,
73                    membership: UnboundedReceiverStream::new(membership_receiver),
74                }
75            }
76            _ => panic!("Cannot connect to a non-multi-connection pipe as a multi-connection"),
77        }
78    }
79}
80
81type DynDecodedStream<I, C> =
82    Pin<Box<dyn Stream<Item = Result<I, <C as Decoder>::Error>> + Send + Sync>>;
83type DynEncodedSink<O, C> = Pin<Box<dyn Sink<O, Error = <C as Encoder<O>>::Error> + Send + Sync>>;
84
85pub struct MultiConnectionSource<I, O, C: Decoder<Item = I> + Encoder<O>> {
86    #[cfg(unix)]
87    unix_listener: Option<UnixListener>,
88    tcp_listener: Option<TcpListener>,
89    #[cfg(unix)]
90    _dir_holder: Option<TempDir>, // keeps the folder containing the socket alive
91    next_connection_id: u64,
92    /// Ordered list for fair polling, will never be `None` at the beginning of a poll
93    active_connections: Vec<Option<(u64, DynDecodedStream<I, C>)>>,
94    /// Cursor for fair round-robin polling
95    poll_cursor: usize,
96    new_sink_sender: mpsc::UnboundedSender<(u64, DynEncodedSink<O, C>)>,
97    membership_sender: mpsc::UnboundedSender<(u64, bool)>,
98}
99
100pub struct MultiConnectionSink<O, C: Encoder<O>> {
101    connection_sinks: HashMap<u64, DynEncodedSink<O, C>>,
102    new_sink_receiver: mpsc::UnboundedReceiver<(u64, DynEncodedSink<O, C>)>,
103}
104
105impl<
106    I,
107    O: Send + Sync + 'static,
108    C: Decoder<Item = I> + Encoder<O> + Send + Sync + Default + 'static,
109> Stream for MultiConnectionSource<I, O, C>
110{
111    type Item = Result<(u64, I), <C as Decoder>::Error>;
112
113    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
114        let me = self.deref_mut();
115        // Handle Unix socket accepts
116        #[cfg(unix)]
117        if let Some(listener) = me.unix_listener.as_mut() {
118            loop {
119                match listener.poll_accept(cx) {
120                    Poll::Ready(Ok((stream, _))) => {
121                        use futures::{SinkExt, StreamExt};
122                        use tokio_util::codec::Framed;
123
124                        let connection_id = me.next_connection_id;
125                        me.next_connection_id += 1;
126
127                        let framed = Framed::new(stream, C::default());
128                        let (sink, stream) = framed.split();
129
130                        let boxed_stream: Pin<
131                            Box<dyn Stream<Item = Result<I, <C as Decoder>::Error>> + Send + Sync>,
132                        > = Box::pin(stream);
133
134                        // Buffer so that a stalled output does not prevent sending to others
135                        let boxed_sink: Pin<
136                            Box<dyn Sink<O, Error = <C as Encoder<O>>::Error> + Send + Sync>,
137                        > = Box::pin(sink.buffer(1024));
138
139                        me.active_connections
140                            .push(Some((connection_id, boxed_stream)));
141
142                        let _ = me.new_sink_sender.send((connection_id, boxed_sink));
143                        let _ = me.membership_sender.send((connection_id, true));
144                    }
145                    Poll::Ready(Err(e)) => {
146                        if !me.active_connections.iter().any(|conn| conn.is_some()) {
147                            return Poll::Ready(Some(Err(e.into())));
148                        } else {
149                            break;
150                        }
151                    }
152                    Poll::Pending => {
153                        break;
154                    }
155                }
156            }
157        }
158
159        // Handle TCP socket accepts
160        if let Some(listener) = me.tcp_listener.as_mut() {
161            loop {
162                match listener.poll_accept(cx) {
163                    Poll::Ready(Ok((stream, _))) => {
164                        let connection_id = me.next_connection_id;
165                        me.next_connection_id += 1;
166
167                        let framed = Framed::new(stream, C::default());
168                        let (sink, stream) = framed.split();
169
170                        let boxed_stream: Pin<
171                            Box<dyn Stream<Item = Result<I, <C as Decoder>::Error>> + Send + Sync>,
172                        > = Box::pin(stream);
173
174                        // Buffer so that a stalled output does not prevent sending to others
175                        let boxed_sink: Pin<
176                            Box<dyn Sink<O, Error = <C as Encoder<O>>::Error> + Send + Sync>,
177                        > = Box::pin(sink.buffer(1024));
178
179                        me.active_connections
180                            .push(Some((connection_id, boxed_stream)));
181
182                        let _ = me.new_sink_sender.send((connection_id, boxed_sink));
183                        let _ = me.membership_sender.send((connection_id, true));
184                    }
185                    Poll::Ready(Err(e)) => {
186                        if !me.active_connections.iter().any(|conn| conn.is_some()) {
187                            return Poll::Ready(Some(Err(e.into())));
188                        } else {
189                            break;
190                        }
191                    }
192                    Poll::Pending => {
193                        break;
194                    }
195                }
196            }
197        }
198
199        // Poll all active connections for data using fair round-robin cursor
200        let mut out = Poll::Pending;
201        let mut any_removed = false;
202
203        if !me.active_connections.is_empty() {
204            let start_cursor = me.poll_cursor;
205
206            loop {
207                let current_length = me.active_connections.len();
208                let id_and_stream = &mut me.active_connections[me.poll_cursor];
209                let (connection_id, stream) = id_and_stream.as_mut().unwrap();
210                let connection_id = *connection_id; // Copy the ID before borrowing stream
211
212                // Move cursor to next source for next poll
213                me.poll_cursor = (me.poll_cursor + 1) % current_length;
214
215                match stream.as_mut().poll_next(cx) {
216                    Poll::Ready(Some(Ok(data))) => {
217                        out = Poll::Ready(Some(Ok((connection_id, data))));
218                        break;
219                    }
220                    Poll::Ready(Some(Err(_))) | Poll::Ready(None) => {
221                        let _ = me.membership_sender.send((connection_id, false));
222                        *id_and_stream = None; // Mark connection as removed
223                        any_removed = true;
224                    }
225                    Poll::Pending => {}
226                }
227
228                // Check if we've completed a full round
229                if me.poll_cursor == start_cursor {
230                    break;
231                }
232            }
233        }
234
235        // Clean up None entries and adjust cursor
236        let mut current_index = 0;
237        let original_cursor = me.poll_cursor;
238
239        if any_removed {
240            me.active_connections.retain(|conn| {
241                if conn.is_none() && current_index < original_cursor {
242                    me.poll_cursor -= 1;
243                }
244                current_index += 1;
245                conn.is_some()
246            });
247        }
248
249        if me.poll_cursor == me.active_connections.len() {
250            me.poll_cursor = 0;
251        }
252
253        out
254    }
255}
256
257impl<O, C: Encoder<O>> Sink<(u64, O)> for MultiConnectionSink<O, C> {
258    type Error = <C as Encoder<O>>::Error;
259
260    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
261        loop {
262            match self.new_sink_receiver.poll_recv(cx) {
263                Poll::Ready(Some((connection_id, sink))) => {
264                    self.connection_sinks.insert(connection_id, sink);
265                }
266                Poll::Ready(None) => {
267                    if self.connection_sinks.is_empty() {
268                        return Poll::Ready(Err(io::Error::new(
269                            io::ErrorKind::BrokenPipe,
270                            "No additional sinks are available (was the stream dropped)?",
271                        )
272                        .into()));
273                    } else {
274                        break;
275                    }
276                }
277                Poll::Pending => {
278                    break;
279                }
280            }
281        }
282
283        // Check if all sinks are ready, removing any that are closed
284        let mut any_pending = false;
285        self.connection_sinks
286            .retain(|_, sink| match sink.as_mut().poll_ready(cx) {
287                Poll::Ready(Ok(())) => true,
288                Poll::Ready(Err(_)) => false,
289                Poll::Pending => {
290                    any_pending = true;
291                    true
292                }
293            });
294
295        if any_pending {
296            Poll::Pending
297        } else {
298            Poll::Ready(Ok(())) // always ready, because we drop messages if there is no sink
299        }
300    }
301
302    fn start_send(mut self: Pin<&mut Self>, item: (u64, O)) -> Result<(), Self::Error> {
303        if let Some(sink) = self.connection_sinks.get_mut(&item.0) {
304            let _ = sink.as_mut().start_send(item.1); // TODO(shadaj): log errors when we have principled logging
305        }
306        // If connection doesn't exist, silently drop (connection may have closed)
307        Ok(())
308    }
309
310    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
311        let mut any_pending = false;
312
313        self.connection_sinks
314            .retain(|_, sink| match sink.as_mut().poll_flush(cx) {
315                Poll::Ready(Ok(())) => true,
316                Poll::Ready(Err(_)) => false,
317                Poll::Pending => {
318                    any_pending = true;
319                    true
320                }
321            });
322
323        if any_pending {
324            Poll::Pending
325        } else {
326            Poll::Ready(Ok(()))
327        }
328    }
329
330    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
331        let mut any_pending = false;
332
333        self.connection_sinks.retain(|_, sink| {
334            match sink.as_mut().poll_close(cx) {
335                Poll::Ready(Ok(()) | Err(_)) => false, // Remove regardless of ok/err
336                Poll::Pending => {
337                    any_pending = true;
338                    true
339                }
340            }
341        });
342
343        if any_pending {
344            Poll::Pending
345        } else {
346            Poll::Ready(Ok(()))
347        }
348    }
349}
350
351/// TCP-only concrete type versions for use in containerized deployments
352pub struct TcpMultiConnectionSource<C: Decoder> {
353    /// The TCP listener accepting new connections
354    pub listener: TcpListener,
355    /// Counter for assigning unique connection IDs
356    pub next_connection_id: u64,
357    /// Active connections with their IDs and framed readers
358    pub active_connections: Vec<Option<(u64, FramedRead<OwnedReadHalf, C>)>>,
359    /// Cursor for fair round-robin polling
360    pub poll_cursor: usize,
361    /// Channel to send new sinks to the TcpMultiConnectionSink
362    pub new_sink_sender: mpsc::UnboundedSender<(u64, FramedWrite<OwnedWriteHalf, C>)>,
363    /// Channel to send membership events
364    pub membership_sender: mpsc::UnboundedSender<(u64, bool)>,
365}
366
367impl<C: Decoder + Default + Unpin> Stream for TcpMultiConnectionSource<C>
368where
369    C::Error: From<io::Error>,
370{
371    type Item = Result<(u64, C::Item), C::Error>;
372
373    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
374        let me = self.deref_mut();
375
376        // Accept new connections
377        loop {
378            match me.listener.poll_accept(cx) {
379                Poll::Ready(Ok((stream, _peer))) => {
380                    let connection_id = me.next_connection_id;
381                    me.next_connection_id += 1;
382
383                    let (rx, tx) = stream.into_split();
384                    let fr = FramedRead::new(rx, C::default());
385                    let fw = FramedWrite::new(tx, C::default());
386
387                    me.active_connections.push(Some((connection_id, fr)));
388                    let _ = me.new_sink_sender.send((connection_id, fw));
389                    let _ = me.membership_sender.send((connection_id, true));
390                }
391                Poll::Ready(Err(e)) => {
392                    if !me.active_connections.iter().any(|c| c.is_some()) {
393                        return Poll::Ready(Some(Err(e.into())));
394                    } else {
395                        break;
396                    }
397                }
398                Poll::Pending => {
399                    break;
400                }
401            }
402        }
403
404        // Poll all active connections for data using fair round-robin cursor
405        let mut out = Poll::Pending;
406        let mut any_removed = false;
407
408        if !me.active_connections.is_empty() {
409            let start_cursor = me.poll_cursor;
410
411            loop {
412                let current_length = me.active_connections.len();
413                let id_and_stream = &mut me.active_connections[me.poll_cursor];
414                let (connection_id, stream) = id_and_stream.as_mut().unwrap();
415                let connection_id = *connection_id; // Copy the ID before borrowing stream
416
417                // Move cursor to next source for next poll
418                me.poll_cursor = (me.poll_cursor + 1) % current_length;
419
420                match Pin::new(stream).poll_next(cx) {
421                    Poll::Ready(Some(Ok(data))) => {
422                        out = Poll::Ready(Some(Ok((connection_id, data))));
423                        break;
424                    }
425                    Poll::Ready(Some(Err(_))) | Poll::Ready(None) => {
426                        let _ = me.membership_sender.send((connection_id, false));
427                        *id_and_stream = None; // Mark connection as removed
428                        any_removed = true;
429                    }
430                    Poll::Pending => {}
431                }
432
433                // Check if we've completed a full round
434                if me.poll_cursor == start_cursor {
435                    break;
436                }
437            }
438        }
439
440        // Clean up None entries and adjust cursor
441        let mut current_index = 0;
442        let original_cursor = me.poll_cursor;
443
444        if any_removed {
445            me.active_connections.retain(|conn| {
446                if conn.is_none() && current_index < original_cursor {
447                    me.poll_cursor -= 1;
448                }
449                current_index += 1;
450                conn.is_some()
451            });
452        }
453
454        if me.poll_cursor == me.active_connections.len() {
455            me.poll_cursor = 0;
456        }
457
458        out
459    }
460}
461
462/// TCP-only multi-connection sink using concrete types (no boxing).
463/// Routes (connection_id, data) to the appropriate connection.
464pub struct TcpMultiConnectionSink<I, C: Encoder<I>> {
465    /// Map of connection IDs to their framed writers
466    pub connection_sinks: HashMap<u64, FramedWrite<OwnedWriteHalf, C>>,
467    /// Channel to receive new sinks from TcpMultiConnectionSource
468    pub new_sink_receiver: mpsc::UnboundedReceiver<(u64, FramedWrite<OwnedWriteHalf, C>)>,
469    _marker: std::marker::PhantomData<fn(I) -> I>, /* fn(I) -> I instead of just I to keep the struct invariant over I, which keeps it Unpin. */
470}
471
472impl<I, C: Encoder<I> + Unpin> Sink<(u64, I)> for TcpMultiConnectionSink<I, C> {
473    type Error = C::Error;
474
475    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
476        let me = self.get_mut();
477        // Receive any new sinks
478        while let Poll::Ready(Some((id, sink))) = me.new_sink_receiver.poll_recv(cx) {
479            me.connection_sinks.insert(id, sink);
480        }
481        Poll::Ready(Ok(()))
482    }
483
484    fn start_send(self: Pin<&mut Self>, item: (u64, I)) -> Result<(), Self::Error> {
485        let me = self.get_mut();
486        if let Some(sink) = me.connection_sinks.get_mut(&item.0) {
487            let _ = Pin::new(sink).start_send(item.1);
488        }
489        Ok(())
490    }
491
492    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
493        let me = self.get_mut();
494        let mut any_pending = false;
495
496        me.connection_sinks
497            .retain(|_id, sink| match Pin::new(sink).poll_flush(cx) {
498                Poll::Ready(Ok(())) => true,
499                Poll::Ready(Err(_)) => false,
500                Poll::Pending => {
501                    any_pending = true;
502                    true
503                }
504            });
505
506        if any_pending {
507            Poll::Pending
508        } else {
509            Poll::Ready(Ok(()))
510        }
511    }
512
513    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
514        let me = self.get_mut();
515        let mut any_pending = false;
516
517        me.connection_sinks
518            .retain(|_id, sink| match Pin::new(sink).poll_close(cx) {
519                Poll::Ready(Ok(())) => false,
520                Poll::Ready(Err(_)) => false,
521                Poll::Pending => {
522                    any_pending = true;
523                    true
524                }
525            });
526
527        if any_pending {
528            Poll::Pending
529        } else {
530            Poll::Ready(Ok(()))
531        }
532    }
533}
534
535type TcpMultiConnectionParts<I, C> = (
536    TcpMultiConnectionSource<C>,
537    TcpMultiConnectionSink<I, C>,
538    UnboundedReceiverStream<(u64, bool)>,
539);
540
541pub fn tcp_multi_connection<I, C>(listener: TcpListener) -> TcpMultiConnectionParts<I, C>
542where
543    C: Decoder + Encoder<I> + Default,
544{
545    let (new_sink_sender, new_sink_receiver) = mpsc::unbounded_channel();
546    let (membership_sender, membership_receiver) = mpsc::unbounded_channel();
547
548    let source = TcpMultiConnectionSource {
549        listener,
550        next_connection_id: 0,
551        active_connections: Vec::new(),
552        poll_cursor: 0,
553        new_sink_sender,
554        membership_sender,
555    };
556
557    let sink = TcpMultiConnectionSink {
558        connection_sinks: HashMap::new(),
559        new_sink_receiver,
560        _marker: std::marker::PhantomData,
561    };
562
563    let membership = UnboundedReceiverStream::new(membership_receiver);
564
565    (source, sink, membership)
566}