hydro_deploy_integration/
multi_connection.rs1use std::collections::HashMap;
2use std::io;
3use std::marker::PhantomData;
4use std::ops::DerefMut;
5use std::pin::Pin;
6use std::task::{Context, Poll, ready};
7
8use futures::{Sink, SinkExt, Stream, StreamExt};
9#[cfg(unix)]
10use tempfile::TempDir;
11use tokio::net::TcpListener;
12#[cfg(unix)]
13use tokio::net::UnixListener;
14use tokio::sync::mpsc;
15use tokio_stream::wrappers::UnboundedReceiverStream;
16use tokio_util::codec::{Decoder, Encoder, Framed};
17
18use crate::{AcceptedServer, BoundServer, Connected, Connection};
19
20pub struct ConnectedMultiConnection<I, O, C: Decoder<Item = I> + Encoder<O>> {
21 pub source: MultiConnectionSource<I, O, C>,
22 pub sink: MultiConnectionSink<O, C>,
23 pub membership: UnboundedReceiverStream<(u64, bool)>,
24}
25
26impl<
27 I: 'static,
28 O: Send + Sync + 'static,
29 C: Decoder<Item = I> + Encoder<O> + Send + Sync + Default + 'static,
30> Connected for ConnectedMultiConnection<I, O, C>
31{
32 fn from_defn(pipe: Connection) -> Self {
33 match pipe {
34 Connection::AsServer(AcceptedServer::MultiConnection(bound_server)) => {
35 let (new_sink_sender, new_sink_receiver) = mpsc::unbounded_channel();
36 let (membership_sender, membership_receiver) = mpsc::unbounded_channel();
37
38 let source = match *bound_server {
39 #[cfg(unix)]
40 BoundServer::UnixSocket(listener, dir) => MultiConnectionSource {
41 unix_listener: Some(listener),
42 tcp_listener: None,
43 _dir_holder: Some(dir),
44 next_connection_id: 0,
45 active_connections: Vec::new(),
46 poll_cursor: 0,
47 new_sink_sender,
48 membership_sender,
49 _phantom: Default::default(),
50 },
51 BoundServer::TcpPort(listener, _) => MultiConnectionSource {
52 #[cfg(unix)]
53 unix_listener: None,
54 tcp_listener: Some(listener.into_inner()),
55 #[cfg(unix)]
56 _dir_holder: None,
57 next_connection_id: 0,
58 active_connections: Vec::new(),
59 poll_cursor: 0,
60 new_sink_sender,
61 membership_sender,
62 _phantom: Default::default(),
63 },
64 _ => panic!("MultiConnection only supports UnixSocket and TcpPort"),
65 };
66
67 let sink = MultiConnectionSink::<O, C> {
68 connection_sinks: HashMap::new(),
69 new_sink_receiver,
70 _phantom: Default::default(),
71 };
72
73 ConnectedMultiConnection {
74 source,
75 sink,
76 membership: UnboundedReceiverStream::new(membership_receiver),
77 }
78 }
79 _ => panic!("Cannot connect to a non-multi-connection pipe as a multi-connection"),
80 }
81 }
82}
83
84type DynDecodedStream<I, C> =
85 Pin<Box<dyn Stream<Item = Result<I, <C as Decoder>::Error>> + Send + Sync>>;
86type DynEncodedSink<O, C> = Pin<Box<dyn Sink<O, Error = <C as Encoder<O>>::Error> + Send + Sync>>;
87
88pub struct MultiConnectionSource<I, O, C: Decoder<Item = I> + Encoder<O>> {
89 #[cfg(unix)]
90 unix_listener: Option<UnixListener>,
91 tcp_listener: Option<TcpListener>,
92 #[cfg(unix)]
93 _dir_holder: Option<TempDir>, next_connection_id: u64,
95 active_connections: Vec<Option<(u64, DynDecodedStream<I, C>)>>,
97 poll_cursor: usize,
99 new_sink_sender: mpsc::UnboundedSender<(u64, DynEncodedSink<O, C>)>,
100 membership_sender: mpsc::UnboundedSender<(u64, bool)>,
101 _phantom: PhantomData<(Box<O>, Box<C>)>,
102}
103
104pub struct MultiConnectionSink<O, C: Encoder<O>> {
105 connection_sinks: HashMap<u64, DynEncodedSink<O, C>>,
106 new_sink_receiver: mpsc::UnboundedReceiver<(u64, DynEncodedSink<O, C>)>,
107 _phantom: PhantomData<(Box<O>, Box<C>)>,
108}
109
110impl<
111 I,
112 O: Send + Sync + 'static,
113 C: Decoder<Item = I> + Encoder<O> + Send + Sync + Default + 'static,
114> Stream for MultiConnectionSource<I, O, C>
115{
116 type Item = Result<(u64, I), <C as Decoder>::Error>;
117
118 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119 let me = self.deref_mut();
120 #[cfg(unix)]
122 if let Some(listener) = me.unix_listener.as_mut() {
123 loop {
124 match listener.poll_accept(cx) {
125 Poll::Ready(Ok((stream, _))) => {
126 use futures::{SinkExt, StreamExt};
127 use tokio_util::codec::Framed;
128
129 let connection_id = me.next_connection_id;
130 me.next_connection_id += 1;
131
132 let framed = Framed::new(stream, C::default());
133 let (sink, stream) = framed.split();
134
135 let boxed_stream: Pin<
136 Box<dyn Stream<Item = Result<I, <C as Decoder>::Error>> + Send + Sync>,
137 > = Box::pin(stream);
138
139 let boxed_sink: Pin<
141 Box<dyn Sink<O, Error = <C as Encoder<O>>::Error> + Send + Sync>,
142 > = Box::pin(sink.buffer(1024));
143
144 me.active_connections
145 .push(Some((connection_id, boxed_stream)));
146
147 let _ = me.new_sink_sender.send((connection_id, boxed_sink));
148 let _ = me.membership_sender.send((connection_id, true));
149 }
150 Poll::Ready(Err(e)) => {
151 if !me.active_connections.iter().any(|conn| conn.is_some()) {
152 return Poll::Ready(Some(Err(e.into())));
153 } else {
154 break;
155 }
156 }
157 Poll::Pending => {
158 break;
159 }
160 }
161 }
162 }
163
164 if let Some(listener) = me.tcp_listener.as_mut() {
166 loop {
167 match listener.poll_accept(cx) {
168 Poll::Ready(Ok((stream, _))) => {
169 let connection_id = me.next_connection_id;
170 me.next_connection_id += 1;
171
172 let framed = Framed::new(stream, C::default());
173 let (sink, stream) = framed.split();
174
175 let boxed_stream: Pin<
176 Box<dyn Stream<Item = Result<I, <C as Decoder>::Error>> + Send + Sync>,
177 > = Box::pin(stream);
178
179 let boxed_sink: Pin<
181 Box<dyn Sink<O, Error = <C as Encoder<O>>::Error> + Send + Sync>,
182 > = Box::pin(sink.buffer(1024));
183
184 me.active_connections
185 .push(Some((connection_id, boxed_stream)));
186
187 let _ = me.new_sink_sender.send((connection_id, boxed_sink));
188 let _ = me.membership_sender.send((connection_id, true));
189 }
190 Poll::Ready(Err(e)) => {
191 if !me.active_connections.iter().any(|conn| conn.is_some()) {
192 return Poll::Ready(Some(Err(e.into())));
193 } else {
194 break;
195 }
196 }
197 Poll::Pending => {
198 break;
199 }
200 }
201 }
202 }
203
204 let mut out = Poll::Pending;
206 let mut any_removed = false;
207
208 if !me.active_connections.is_empty() {
209 let start_cursor = me.poll_cursor;
210
211 loop {
212 let current_length = me.active_connections.len();
213 let id_and_stream = &mut me.active_connections[me.poll_cursor];
214 let (connection_id, stream) = id_and_stream.as_mut().unwrap();
215 let connection_id = *connection_id; me.poll_cursor = (me.poll_cursor + 1) % current_length;
219
220 match stream.as_mut().poll_next(cx) {
221 Poll::Ready(Some(Ok(data))) => {
222 out = Poll::Ready(Some(Ok((connection_id, data))));
223 break;
224 }
225 Poll::Ready(Some(Err(_))) | Poll::Ready(None) => {
226 let _ = me.membership_sender.send((connection_id, false));
227 *id_and_stream = None; any_removed = true;
229 }
230 Poll::Pending => {}
231 }
232
233 if me.poll_cursor == start_cursor {
235 break;
236 }
237 }
238 }
239
240 let mut current_index = 0;
242 let original_cursor = me.poll_cursor;
243
244 if any_removed {
245 me.active_connections.retain(|conn| {
246 if conn.is_none() && current_index < original_cursor {
247 me.poll_cursor -= 1;
248 }
249 current_index += 1;
250 conn.is_some()
251 });
252 }
253
254 if me.poll_cursor == me.active_connections.len() {
255 me.poll_cursor = 0;
256 }
257
258 out
259 }
260}
261
262impl<O, C: Encoder<O>> Sink<(u64, O)> for MultiConnectionSink<O, C> {
263 type Error = <C as Encoder<O>>::Error;
264
265 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
266 loop {
267 match self.new_sink_receiver.poll_recv(cx) {
268 Poll::Ready(Some((connection_id, sink))) => {
269 self.connection_sinks.insert(connection_id, sink);
270 }
271 Poll::Ready(None) => {
272 if self.connection_sinks.is_empty() {
273 return Poll::Ready(Err(io::Error::new(
274 io::ErrorKind::BrokenPipe,
275 "No additional sinks are available (was the stream dropped)?",
276 )
277 .into()));
278 } else {
279 break;
280 }
281 }
282 Poll::Pending => {
283 break;
284 }
285 }
286 }
287
288 let mut closed_connections = Vec::new();
290 for (&connection_id, sink) in self.connection_sinks.iter_mut() {
291 match ready!(sink.as_mut().poll_ready(cx)) {
292 Ok(()) => {}
293 Err(_) => {
294 closed_connections.push(connection_id);
295 }
296 }
297 }
298
299 for connection_id in closed_connections {
300 self.connection_sinks.remove(&connection_id);
301 }
302
303 Poll::Ready(Ok(())) }
305
306 fn start_send(mut self: Pin<&mut Self>, item: (u64, O)) -> Result<(), Self::Error> {
307 if let Some(sink) = self.connection_sinks.get_mut(&item.0) {
308 let _ = sink.as_mut().start_send(item.1); }
310 Ok(())
312 }
313
314 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
315 let mut closed_connections = Vec::new();
316 let mut any_pending = false;
317
318 for (&connection_id, sink) in self.connection_sinks.iter_mut() {
319 match sink.as_mut().poll_flush(cx) {
320 Poll::Ready(Ok(())) => {}
321 Poll::Ready(Err(_)) => {
322 closed_connections.push(connection_id);
323 }
324 Poll::Pending => {
325 any_pending = true;
326 }
327 }
328 }
329
330 for connection_id in closed_connections {
331 self.connection_sinks.remove(&connection_id);
332 }
333
334 if any_pending {
335 Poll::Pending
336 } else {
337 Poll::Ready(Ok(()))
338 }
339 }
340
341 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
342 let mut closed_connections = Vec::new();
343 let mut any_pending = false;
344
345 for (&connection_id, sink) in self.connection_sinks.iter_mut() {
346 match sink.as_mut().poll_close(cx) {
347 Poll::Ready(Ok(())) => {
348 closed_connections.push(connection_id);
349 }
350 Poll::Ready(Err(_)) => {
351 closed_connections.push(connection_id);
352 }
353 Poll::Pending => {
354 any_pending = true;
355 }
356 }
357 }
358
359 for connection_id in closed_connections {
360 self.connection_sinks.remove(&connection_id);
361 }
362
363 if any_pending {
364 Poll::Pending
365 } else {
366 Poll::Ready(Ok(()))
367 }
368 }
369}