dfir_rs/scheduled/net/
network_vertex.rs

1#![cfg(not(target_arch = "wasm32"))]
2
3use std::collections::HashMap;
4
5use futures::{SinkExt, StreamExt};
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8use tokio::net::{TcpListener, TcpStream};
9use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
10
11use crate::scheduled::graph::Dfir;
12use crate::scheduled::graph_ext::GraphExt;
13use crate::scheduled::handoff::VecHandoff;
14use crate::scheduled::port::{RecvPort, SendPort};
15
16pub type Address = String;
17
18// These methods can't be wrapped up in a trait because async methods are not
19// allowed in traits (yet).
20
21impl Dfir<'_> {
22    // TODO(justin): document these, but they're derivatives of inbound_tcp_vertex_internal.
23    pub async fn inbound_tcp_vertex_port<T>(&mut self, port: u16) -> RecvPort<VecHandoff<T>>
24    where
25        T: 'static + DeserializeOwned + Send,
26    {
27        self.inbound_tcp_vertex_internal(Some(port)).await.1
28    }
29
30    pub async fn inbound_tcp_vertex<T>(&mut self) -> (u16, RecvPort<VecHandoff<T>>)
31    where
32        T: 'static + DeserializeOwned + Send,
33    {
34        self.inbound_tcp_vertex_internal(None).await
35    }
36    // TODO(justin): this needs to return a result/get rid of all the unwraps, I
37    // guess we need a custom error type?
38    /// Begins listening on some TCP port. Returns an [`RecvPort`] representing
39    /// the stream of messages received. Currently there is no notion of
40    /// identity to the connections received, if they are to be attached to some
41    /// participant in the system, that needs to be included in the message
42    /// directly.
43    ///
44    /// The messages will be interpreted to be bincode-encoded, length-delimited
45    /// messages, as produced by [Self::outbound_tcp_vertex].
46    async fn inbound_tcp_vertex_internal<T>(
47        &mut self,
48        port: Option<u16>,
49    ) -> (u16, RecvPort<VecHandoff<T>>)
50    where
51        T: 'static + DeserializeOwned + Send,
52    {
53        let listener = TcpListener::bind(format!("localhost:{}", port.unwrap_or(0)))
54            .await
55            .unwrap();
56        let port = listener.local_addr().unwrap().port();
57
58        // TODO(justin): figure out an appropriate buffer here.
59        let (incoming_send, incoming_messages) = futures::channel::mpsc::channel(1024);
60
61        // Listen to incoming connections and spawn a tokio task for each one,
62        // which feeds into the channel.
63        // TODO(justin): give some way to get a handle into this thing.
64        tokio::spawn(async move {
65            loop {
66                let (socket, _) = listener.accept().await.unwrap();
67                let (reader, _) = socket.into_split();
68                let mut reader = FramedRead::new(reader, LengthDelimitedCodec::new());
69                let mut incoming_send = incoming_send.clone();
70                tokio::spawn(async move {
71                    while let Some(msg) = reader.next().await {
72                        // TODO(justin): figure out error handling here.
73                        let msg = msg.unwrap();
74                        let out = bincode::deserialize(&msg).unwrap();
75                        incoming_send.send(out).await.unwrap();
76                    }
77                    // TODO(justin): The connection is closed, so we should
78                    // clean up its metadata.
79                });
80            }
81        });
82
83        let (send_port, recv_port) = self.make_edge("tcp ingress handoff");
84        self.add_input_from_stream("tcp ingress stream", send_port, incoming_messages.map(Some));
85
86        (port, recv_port)
87    }
88
89    pub async fn outbound_tcp_vertex<T>(&mut self) -> SendPort<VecHandoff<(Address, T)>>
90    where
91        T: 'static + Serialize + Send,
92    {
93        let (mut connection_reqs_send, mut connection_reqs_recv) =
94            futures::channel::mpsc::channel(1024);
95        let (mut connections_send, mut connections_recv) = futures::channel::mpsc::channel(1024);
96
97        // TODO(justin): handle errors here.
98        // Spawn an actor which establishes connections.
99        tokio::spawn(async move {
100            while let Some(addr) = connection_reqs_recv.next().await {
101                let addr: Address = addr;
102                connections_send
103                    .send((addr.clone(), TcpStream::connect(addr.clone()).await))
104                    .await
105                    .unwrap();
106            }
107        });
108
109        enum ConnStatus<T> {
110            Pending(Vec<T>),
111            Connected(FramedWrite<TcpStream, LengthDelimitedCodec>),
112        }
113
114        let (mut outbound_messages_send, mut outbound_messages_recv) =
115            futures::channel::mpsc::channel(1024);
116        tokio::spawn(async move {
117            // TODO(justin): this cache should be global to the entire
118            // instance so we can reuse connections from inbound connections.
119            let mut connections = HashMap::<Address, ConnStatus<T>>::new();
120
121            loop {
122                tokio::select! {
123                    Some((addr, msg)) = outbound_messages_recv.next() => {
124                        let addr: Address = addr;
125                        let msg: T = msg;
126                        match connections.get_mut(&addr) {
127                            None => {
128                                // We have not seen this address before, open a
129                                // connection to it and buffer the message to be
130                                // sent once it's open.
131
132                                // TODO(justin): what do we do if the buffer is full here?
133                                connection_reqs_send.try_send(addr.clone()).unwrap();
134                                connections.insert(addr, ConnStatus::Pending(vec![msg]));
135                            }
136                            Some(ConnStatus::Pending(msgs)) => {
137                                // We have seen this address before but we're
138                                // still trying to connect to it, so buffer this
139                                // message so that when we _do_ connect we will
140                                // send it.
141                                msgs.push(msg);
142                            }
143                            Some(ConnStatus::Connected(conn)) => {
144                                // TODO(justin): move the actual sending here
145                                // into a different task so we don't have to
146                                // wait for the send.
147                                let msg = bincode::serialize(&msg).unwrap();
148                                conn.send(msg.into()).await.unwrap();
149                            }
150                        }
151                    },
152
153                    Some((addr, conn)) = connections_recv.next() => {
154                        match conn {
155                            Ok(conn) => {
156                                match connections.get_mut(&addr) {
157                                    Some(ConnStatus::Pending(msgs)) => {
158                                        let mut conn = FramedWrite::new(conn, LengthDelimitedCodec::new());
159                                        for msg in msgs.drain(..) {
160                                            // TODO(justin): move the actual sending here
161                                            // into a different task so we don't have to
162                                            // wait for the send.
163                                            let msg = bincode::serialize(&msg).unwrap();
164                                            conn.send(msg.into()).await.unwrap();
165                                        }
166                                        connections.insert(addr, ConnStatus::Connected(conn));
167                                    }
168                                    None => {
169                                        // This means nobody ever requested this
170                                        // connection, so we shouldn't have initiated it
171                                        // in the first place.
172                                        unreachable!()
173                                    }
174                                    Some(ConnStatus::Connected(_tcp)) => {
175                                        // This means we were already connected, so we
176                                        // shouldn't have connected again. If the
177                                        // connection cache becomes shared this could
178                                        // become reachable.
179                                        unreachable!()
180                                    }
181                                }
182                            }
183                            Err(e) => {
184                                // We couldn't connect to the address for some
185                                // reason.
186                                // TODO(justin): once we have a clearer picture
187                                // of error handling, we could do something like
188                                // send this error along a pipe to be handled by
189                                // someone else. For now, just log it and drop
190                                // any pending messages.
191                                eprintln!("couldn't connect to {}: {}", addr, e);
192                                connections.remove(&addr);
193                            }
194                        }
195                    },
196                    else => break,
197                }
198            }
199        });
200
201        let mut buffered_messages = Vec::new();
202        let mut next_messages = Vec::new();
203        let (input_port, output_port) = self.make_edge("tcp egress handoff");
204        self.add_subgraph_sink("tcp egress stream", output_port, move |_ctx, recv| {
205            buffered_messages.extend(recv.take_inner());
206            for msg in buffered_messages.drain(..) {
207                if let Err(e) = outbound_messages_send.try_send(msg) {
208                    // If we weren't able to send a message (say, because the
209                    // buffer is full), we get handed it back in the error. If
210                    // this happens we hang onto the message to try sending it
211                    // again next time.
212                    next_messages.push(e.into_inner());
213                }
214            }
215
216            // NB. we don't need to flush the channel here due to the use of
217            // `try_send`.  It's guaranteed that there was space for the
218            // messages and that they were sent.
219
220            // TODO(justin): we do need to make sure we get rescheduled if
221            // next_messages is empty here.
222
223            std::mem::swap(&mut buffered_messages, &mut next_messages);
224        });
225
226        input_port
227    }
228}