Skip to main content

hydro_deploy_integration/
single_connection.rs

1use std::ops::DerefMut;
2use std::pin::Pin;
3#[cfg(unix)]
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use futures::{Sink, SinkExt, Stream, StreamExt, ready};
8#[cfg(unix)]
9use tempfile::TempDir;
10use tokio::sync::mpsc;
11use tokio_util::codec::{Decoder, Encoder, Framed};
12
13use crate::{AcceptedServer, BoundServer, Connected, Connection};
14
15/// A connected implementation which only allows a single live connection for the
16/// lifetime of the server. The first accepted connection is kept and its sink
17/// and stream are exposed; any subsequent connections are ignored.
18pub struct ConnectedSingleConnection<I, O, C: Decoder<Item = I> + Encoder<O>> {
19    pub source: SingleConnectionSource<I, C>,
20    pub sink: SingleConnectionSink<O, C>,
21}
22
23impl<
24    I: 'static,
25    O: Send + Sync + 'static,
26    C: Decoder<Item = I> + Encoder<O> + Send + Sync + Default + 'static,
27> Connected for ConnectedSingleConnection<I, O, C>
28{
29    fn from_defn(pipe: Connection) -> Self {
30        match pipe {
31            Connection::AsServer(AcceptedServer::MultiConnection(bound_server)) => {
32                let (new_stream_sender, new_stream_receiver) = mpsc::unbounded_channel();
33                let (new_sink_sender, new_sink_receiver) = mpsc::unbounded_channel();
34
35                #[cfg_attr(
36                    not(unix),
37                    expect(unused_variables, reason = "dir is only used on non-Unix")
38                )]
39                let dir = match *bound_server {
40                    #[cfg(unix)]
41                    BoundServer::UnixSocket(listener, dir) => {
42                        tokio::spawn(async move {
43                            tokio::task::yield_now().await;
44                            match listener.accept().await {
45                                Ok((stream, _)) => {
46                                    let framed = Framed::new(stream, C::default());
47                                    let (sink, stream) = framed.split();
48
49                                    let boxed_stream: DynDecodedStream<I, C> = Box::pin(stream);
50                                    let boxed_sink: DynEncodedSink<O, C> =
51                                        Box::pin(sink.buffer(1024));
52
53                                    let _ = new_stream_sender.send(boxed_stream);
54                                    let _ = new_sink_sender.send(boxed_sink);
55                                }
56                                Err(e) => {
57                                    eprintln!("Error accepting Unix connection: {}", e);
58                                }
59                            }
60                        });
61
62                        Some(dir)
63                    }
64                    BoundServer::TcpPort(listener, _) => {
65                        tokio::spawn(async move {
66                            tokio::task::yield_now().await;
67                            match listener.into_inner().accept().await {
68                                Ok((stream, _)) => {
69                                    let framed = Framed::new(stream, C::default());
70                                    let (sink, stream) = framed.split();
71
72                                    let boxed_stream: DynDecodedStream<I, C> = Box::pin(stream);
73                                    let boxed_sink: DynEncodedSink<O, C> =
74                                        Box::pin(sink.buffer(1024));
75
76                                    let _ = new_stream_sender.send(boxed_stream);
77                                    let _ = new_sink_sender.send(boxed_sink);
78                                }
79                                Err(e) => {
80                                    eprintln!("Error accepting TCP connection: {}", e);
81                                }
82                            }
83                        });
84
85                        #[cfg(unix)]
86                        {
87                            None
88                        }
89
90                        #[cfg(not(unix))]
91                        {
92                            None::<()>
93                        }
94                    }
95                    _ => panic!("SingleConnection only supports UnixSocket and TcpPort"),
96                };
97
98                #[cfg(unix)]
99                let dir_holder_arc = dir.map(Arc::new);
100
101                let source = SingleConnectionSource {
102                    new_stream_receiver,
103                    #[cfg(unix)]
104                    _dir_holder: dir_holder_arc.clone(),
105                    active_stream: None,
106                };
107
108                let sink = SingleConnectionSink::<O, C> {
109                    #[cfg(unix)]
110                    _dir_holder: dir_holder_arc,
111                    connection_sink: None,
112                    new_sink_receiver,
113                };
114
115                ConnectedSingleConnection { source, sink }
116            }
117            _ => panic!("Cannot connect to a non-multi-connection pipe as a single-connection"),
118        }
119    }
120}
121
122type DynDecodedStream<I, C> =
123    Pin<Box<dyn Stream<Item = Result<I, <C as Decoder>::Error>> + Send + Sync>>;
124type DynEncodedSink<O, C> = Pin<Box<dyn Sink<O, Error = <C as Encoder<O>>::Error> + Send + Sync>>;
125
126pub struct SingleConnectionSource<I, C: Decoder<Item = I>> {
127    new_stream_receiver: mpsc::UnboundedReceiver<DynDecodedStream<I, C>>,
128    #[cfg(unix)]
129    _dir_holder: Option<Arc<TempDir>>, // keeps the folder containing the socket alive
130    /// The active stream for the single connection, if taken
131    active_stream: Option<DynDecodedStream<I, C>>,
132}
133
134pub struct SingleConnectionSink<O, C: Encoder<O>> {
135    #[cfg(unix)]
136    _dir_holder: Option<Arc<TempDir>>, // keeps the folder containing the socket alive
137    connection_sink: Option<DynEncodedSink<O, C>>,
138    new_sink_receiver: mpsc::UnboundedReceiver<DynEncodedSink<O, C>>,
139}
140
141impl<I, C: Decoder<Item = I> + Send + Sync + Default + 'static> Stream
142    for SingleConnectionSource<I, C>
143{
144    type Item = Result<I, <C as Decoder>::Error>;
145
146    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
147        let me = self.deref_mut();
148
149        if me.active_stream.is_none() {
150            if let Some(stream) = ready!(me.new_stream_receiver.poll_recv(cx)) {
151                me.active_stream = Some(stream);
152            } else {
153                return Poll::Ready(None);
154            }
155        }
156
157        me.active_stream.as_mut().unwrap().as_mut().poll_next(cx)
158    }
159}
160
161impl<O, C: Encoder<O>> Sink<O> for SingleConnectionSink<O, C> {
162    type Error = <C as Encoder<O>>::Error;
163
164    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165        if self.connection_sink.is_none() {
166            match ready!(self.new_sink_receiver.poll_recv(cx)) {
167                Some(sink) => {
168                    self.connection_sink = Some(sink);
169                }
170                None => return Poll::Pending,
171            }
172        }
173
174        self.connection_sink
175            .as_mut()
176            .unwrap()
177            .as_mut()
178            .poll_ready(cx)
179    }
180
181    fn start_send(mut self: Pin<&mut Self>, item: O) -> Result<(), Self::Error> {
182        self.connection_sink
183            .as_mut()
184            .unwrap()
185            .as_mut()
186            .start_send(item)
187    }
188
189    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
190        if let Some(sink) = self.connection_sink.as_mut() {
191            sink.as_mut().poll_flush(cx)
192        } else {
193            Poll::Ready(Ok(()))
194        }
195    }
196
197    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
198        if let Some(sink) = self.connection_sink.as_mut() {
199            sink.as_mut().poll_close(cx)
200        } else {
201            Poll::Ready(Ok(()))
202        }
203    }
204}