hydro_deploy_integration/
multi_connection.rs

1use std::collections::HashMap;
2use std::io;
3use std::marker::PhantomData;
4use std::ops::DerefMut;
5use std::pin::Pin;
6use std::task::{Context, Poll, ready};
7
8use futures::{Sink, SinkExt, Stream, StreamExt};
9#[cfg(unix)]
10use tempfile::TempDir;
11use tokio::net::TcpListener;
12#[cfg(unix)]
13use tokio::net::UnixListener;
14use tokio::sync::mpsc;
15use tokio_stream::wrappers::UnboundedReceiverStream;
16use tokio_util::codec::{Decoder, Encoder, Framed};
17
18use crate::{AcceptedServer, BoundServer, Connected, Connection};
19
20pub struct ConnectedMultiConnection<I, O, C: Decoder<Item = I> + Encoder<O>> {
21    pub source: MultiConnectionSource<I, O, C>,
22    pub sink: MultiConnectionSink<O, C>,
23    pub membership: UnboundedReceiverStream<(u64, bool)>,
24}
25
26impl<
27    I: 'static,
28    O: Send + Sync + 'static,
29    C: Decoder<Item = I> + Encoder<O> + Send + Sync + Default + 'static,
30> Connected for ConnectedMultiConnection<I, O, C>
31{
32    fn from_defn(pipe: Connection) -> Self {
33        match pipe {
34            Connection::AsServer(AcceptedServer::MultiConnection(bound_server)) => {
35                let (new_sink_sender, new_sink_receiver) = mpsc::unbounded_channel();
36                let (membership_sender, membership_receiver) = mpsc::unbounded_channel();
37
38                let source = match *bound_server {
39                    #[cfg(unix)]
40                    BoundServer::UnixSocket(listener, dir) => MultiConnectionSource {
41                        unix_listener: Some(listener),
42                        tcp_listener: None,
43                        _dir_holder: Some(dir),
44                        next_connection_id: 0,
45                        active_connections: Vec::new(),
46                        poll_cursor: 0,
47                        new_sink_sender,
48                        membership_sender,
49                        _phantom: Default::default(),
50                    },
51                    BoundServer::TcpPort(listener, _) => MultiConnectionSource {
52                        #[cfg(unix)]
53                        unix_listener: None,
54                        tcp_listener: Some(listener.into_inner()),
55                        #[cfg(unix)]
56                        _dir_holder: None,
57                        next_connection_id: 0,
58                        active_connections: Vec::new(),
59                        poll_cursor: 0,
60                        new_sink_sender,
61                        membership_sender,
62                        _phantom: Default::default(),
63                    },
64                    _ => panic!("MultiConnection only supports UnixSocket and TcpPort"),
65                };
66
67                let sink = MultiConnectionSink::<O, C> {
68                    connection_sinks: HashMap::new(),
69                    new_sink_receiver,
70                    _phantom: Default::default(),
71                };
72
73                ConnectedMultiConnection {
74                    source,
75                    sink,
76                    membership: UnboundedReceiverStream::new(membership_receiver),
77                }
78            }
79            _ => panic!("Cannot connect to a non-multi-connection pipe as a multi-connection"),
80        }
81    }
82}
83
84type DynDecodedStream<I, C> =
85    Pin<Box<dyn Stream<Item = Result<I, <C as Decoder>::Error>> + Send + Sync>>;
86type DynEncodedSink<O, C> = Pin<Box<dyn Sink<O, Error = <C as Encoder<O>>::Error> + Send + Sync>>;
87
88pub struct MultiConnectionSource<I, O, C: Decoder<Item = I> + Encoder<O>> {
89    #[cfg(unix)]
90    unix_listener: Option<UnixListener>,
91    tcp_listener: Option<TcpListener>,
92    #[cfg(unix)]
93    _dir_holder: Option<TempDir>, // keeps the folder containing the socket alive
94    next_connection_id: u64,
95    /// Ordered list for fair polling, will never be `None` at the beginning of a poll
96    active_connections: Vec<Option<(u64, DynDecodedStream<I, C>)>>,
97    /// Cursor for fair round-robin polling
98    poll_cursor: usize,
99    new_sink_sender: mpsc::UnboundedSender<(u64, DynEncodedSink<O, C>)>,
100    membership_sender: mpsc::UnboundedSender<(u64, bool)>,
101    _phantom: PhantomData<(Box<O>, Box<C>)>,
102}
103
104pub struct MultiConnectionSink<O, C: Encoder<O>> {
105    connection_sinks: HashMap<u64, DynEncodedSink<O, C>>,
106    new_sink_receiver: mpsc::UnboundedReceiver<(u64, DynEncodedSink<O, C>)>,
107    _phantom: PhantomData<(Box<O>, Box<C>)>,
108}
109
110impl<
111    I,
112    O: Send + Sync + 'static,
113    C: Decoder<Item = I> + Encoder<O> + Send + Sync + Default + 'static,
114> Stream for MultiConnectionSource<I, O, C>
115{
116    type Item = Result<(u64, I), <C as Decoder>::Error>;
117
118    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119        let me = self.deref_mut();
120        // Handle Unix socket accepts
121        #[cfg(unix)]
122        if let Some(listener) = me.unix_listener.as_mut() {
123            loop {
124                match listener.poll_accept(cx) {
125                    Poll::Ready(Ok((stream, _))) => {
126                        use futures::{SinkExt, StreamExt};
127                        use tokio_util::codec::Framed;
128
129                        let connection_id = me.next_connection_id;
130                        me.next_connection_id += 1;
131
132                        let framed = Framed::new(stream, C::default());
133                        let (sink, stream) = framed.split();
134
135                        let boxed_stream: Pin<
136                            Box<dyn Stream<Item = Result<I, <C as Decoder>::Error>> + Send + Sync>,
137                        > = Box::pin(stream);
138
139                        // Buffer so that a stalled output does not prevent sending to others
140                        let boxed_sink: Pin<
141                            Box<dyn Sink<O, Error = <C as Encoder<O>>::Error> + Send + Sync>,
142                        > = Box::pin(sink.buffer(1024));
143
144                        me.active_connections
145                            .push(Some((connection_id, boxed_stream)));
146
147                        let _ = me.new_sink_sender.send((connection_id, boxed_sink));
148                        let _ = me.membership_sender.send((connection_id, true));
149                    }
150                    Poll::Ready(Err(e)) => {
151                        if !me.active_connections.iter().any(|conn| conn.is_some()) {
152                            return Poll::Ready(Some(Err(e.into())));
153                        } else {
154                            break;
155                        }
156                    }
157                    Poll::Pending => {
158                        break;
159                    }
160                }
161            }
162        }
163
164        // Handle TCP socket accepts
165        if let Some(listener) = me.tcp_listener.as_mut() {
166            loop {
167                match listener.poll_accept(cx) {
168                    Poll::Ready(Ok((stream, _))) => {
169                        let connection_id = me.next_connection_id;
170                        me.next_connection_id += 1;
171
172                        let framed = Framed::new(stream, C::default());
173                        let (sink, stream) = framed.split();
174
175                        let boxed_stream: Pin<
176                            Box<dyn Stream<Item = Result<I, <C as Decoder>::Error>> + Send + Sync>,
177                        > = Box::pin(stream);
178
179                        // Buffer so that a stalled output does not prevent sending to others
180                        let boxed_sink: Pin<
181                            Box<dyn Sink<O, Error = <C as Encoder<O>>::Error> + Send + Sync>,
182                        > = Box::pin(sink.buffer(1024));
183
184                        me.active_connections
185                            .push(Some((connection_id, boxed_stream)));
186
187                        let _ = me.new_sink_sender.send((connection_id, boxed_sink));
188                        let _ = me.membership_sender.send((connection_id, true));
189                    }
190                    Poll::Ready(Err(e)) => {
191                        if !me.active_connections.iter().any(|conn| conn.is_some()) {
192                            return Poll::Ready(Some(Err(e.into())));
193                        } else {
194                            break;
195                        }
196                    }
197                    Poll::Pending => {
198                        break;
199                    }
200                }
201            }
202        }
203
204        // Poll all active connections for data using fair round-robin cursor
205        let mut out = Poll::Pending;
206        let mut any_removed = false;
207
208        if !me.active_connections.is_empty() {
209            let start_cursor = me.poll_cursor;
210
211            loop {
212                let current_length = me.active_connections.len();
213                let id_and_stream = &mut me.active_connections[me.poll_cursor];
214                let (connection_id, stream) = id_and_stream.as_mut().unwrap();
215                let connection_id = *connection_id; // Copy the ID before borrowing stream
216
217                // Move cursor to next source for next poll
218                me.poll_cursor = (me.poll_cursor + 1) % current_length;
219
220                match stream.as_mut().poll_next(cx) {
221                    Poll::Ready(Some(Ok(data))) => {
222                        out = Poll::Ready(Some(Ok((connection_id, data))));
223                        break;
224                    }
225                    Poll::Ready(Some(Err(_))) | Poll::Ready(None) => {
226                        let _ = me.membership_sender.send((connection_id, false));
227                        *id_and_stream = None; // Mark connection as removed
228                        any_removed = true;
229                    }
230                    Poll::Pending => {}
231                }
232
233                // Check if we've completed a full round
234                if me.poll_cursor == start_cursor {
235                    break;
236                }
237            }
238        }
239
240        // Clean up None entries and adjust cursor
241        let mut current_index = 0;
242        let original_cursor = me.poll_cursor;
243
244        if any_removed {
245            me.active_connections.retain(|conn| {
246                if conn.is_none() && current_index < original_cursor {
247                    me.poll_cursor -= 1;
248                }
249                current_index += 1;
250                conn.is_some()
251            });
252        }
253
254        if me.poll_cursor == me.active_connections.len() {
255            me.poll_cursor = 0;
256        }
257
258        out
259    }
260}
261
262impl<O, C: Encoder<O>> Sink<(u64, O)> for MultiConnectionSink<O, C> {
263    type Error = <C as Encoder<O>>::Error;
264
265    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
266        loop {
267            match self.new_sink_receiver.poll_recv(cx) {
268                Poll::Ready(Some((connection_id, sink))) => {
269                    self.connection_sinks.insert(connection_id, sink);
270                }
271                Poll::Ready(None) => {
272                    if self.connection_sinks.is_empty() {
273                        return Poll::Ready(Err(io::Error::new(
274                            io::ErrorKind::BrokenPipe,
275                            "No additional sinks are available (was the stream dropped)?",
276                        )
277                        .into()));
278                    } else {
279                        break;
280                    }
281                }
282                Poll::Pending => {
283                    break;
284                }
285            }
286        }
287
288        // Check if all sinks are ready, removing any that are closed
289        let mut closed_connections = Vec::new();
290        for (&connection_id, sink) in self.connection_sinks.iter_mut() {
291            match ready!(sink.as_mut().poll_ready(cx)) {
292                Ok(()) => {}
293                Err(_) => {
294                    closed_connections.push(connection_id);
295                }
296            }
297        }
298
299        for connection_id in closed_connections {
300            self.connection_sinks.remove(&connection_id);
301        }
302
303        Poll::Ready(Ok(())) // always ready, because we drop messages if there is no sink
304    }
305
306    fn start_send(mut self: Pin<&mut Self>, item: (u64, O)) -> Result<(), Self::Error> {
307        if let Some(sink) = self.connection_sinks.get_mut(&item.0) {
308            let _ = sink.as_mut().start_send(item.1); // TODO(shadaj): log errors when we have principled logging
309        }
310        // If connection doesn't exist, silently drop (connection may have closed)
311        Ok(())
312    }
313
314    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
315        let mut closed_connections = Vec::new();
316        let mut any_pending = false;
317
318        for (&connection_id, sink) in self.connection_sinks.iter_mut() {
319            match sink.as_mut().poll_flush(cx) {
320                Poll::Ready(Ok(())) => {}
321                Poll::Ready(Err(_)) => {
322                    closed_connections.push(connection_id);
323                }
324                Poll::Pending => {
325                    any_pending = true;
326                }
327            }
328        }
329
330        for connection_id in closed_connections {
331            self.connection_sinks.remove(&connection_id);
332        }
333
334        if any_pending {
335            Poll::Pending
336        } else {
337            Poll::Ready(Ok(()))
338        }
339    }
340
341    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
342        let mut closed_connections = Vec::new();
343        let mut any_pending = false;
344
345        for (&connection_id, sink) in self.connection_sinks.iter_mut() {
346            match sink.as_mut().poll_close(cx) {
347                Poll::Ready(Ok(())) => {
348                    closed_connections.push(connection_id);
349                }
350                Poll::Ready(Err(_)) => {
351                    closed_connections.push(connection_id);
352                }
353                Poll::Pending => {
354                    any_pending = true;
355                }
356            }
357        }
358
359        for connection_id in closed_connections {
360            self.connection_sinks.remove(&connection_id);
361        }
362
363        if any_pending {
364            Poll::Pending
365        } else {
366            Poll::Ready(Ok(()))
367        }
368    }
369}