dfir_rs/scheduled/net/network_vertex.rs
1#![cfg(not(target_arch = "wasm32"))]
2
3use std::collections::HashMap;
4
5use futures::{SinkExt, StreamExt};
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8use tokio::net::{TcpListener, TcpStream};
9use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
10
11use crate::scheduled::graph::Dfir;
12use crate::scheduled::graph_ext::GraphExt;
13use crate::scheduled::handoff::VecHandoff;
14use crate::scheduled::port::{RecvPort, SendPort};
15
16pub type Address = String;
17
18// These methods can't be wrapped up in a trait because async methods are not
19// allowed in traits (yet).
20
21impl Dfir<'_> {
22 // TODO(justin): document these, but they're derivatives of inbound_tcp_vertex_internal.
23 pub async fn inbound_tcp_vertex_port<T>(&mut self, port: u16) -> RecvPort<VecHandoff<T>>
24 where
25 T: 'static + DeserializeOwned + Send,
26 {
27 self.inbound_tcp_vertex_internal(Some(port)).await.1
28 }
29
30 pub async fn inbound_tcp_vertex<T>(&mut self) -> (u16, RecvPort<VecHandoff<T>>)
31 where
32 T: 'static + DeserializeOwned + Send,
33 {
34 self.inbound_tcp_vertex_internal(None).await
35 }
36 // TODO(justin): this needs to return a result/get rid of all the unwraps, I
37 // guess we need a custom error type?
38 /// Begins listening on some TCP port. Returns an [`RecvPort`] representing
39 /// the stream of messages received. Currently there is no notion of
40 /// identity to the connections received, if they are to be attached to some
41 /// participant in the system, that needs to be included in the message
42 /// directly.
43 ///
44 /// The messages will be interpreted to be bincode-encoded, length-delimited
45 /// messages, as produced by [Self::outbound_tcp_vertex].
46 async fn inbound_tcp_vertex_internal<T>(
47 &mut self,
48 port: Option<u16>,
49 ) -> (u16, RecvPort<VecHandoff<T>>)
50 where
51 T: 'static + DeserializeOwned + Send,
52 {
53 let listener = TcpListener::bind(format!("localhost:{}", port.unwrap_or(0)))
54 .await
55 .unwrap();
56 let port = listener.local_addr().unwrap().port();
57
58 // TODO(justin): figure out an appropriate buffer here.
59 let (incoming_send, incoming_messages) = futures::channel::mpsc::channel(1024);
60
61 // Listen to incoming connections and spawn a tokio task for each one,
62 // which feeds into the channel.
63 // TODO(justin): give some way to get a handle into this thing.
64 tokio::spawn(async move {
65 loop {
66 let (socket, _) = listener.accept().await.unwrap();
67 let (reader, _) = socket.into_split();
68 let mut reader = FramedRead::new(reader, LengthDelimitedCodec::new());
69 let mut incoming_send = incoming_send.clone();
70 tokio::spawn(async move {
71 while let Some(msg) = reader.next().await {
72 // TODO(justin): figure out error handling here.
73 let msg = msg.unwrap();
74 let out = bincode::deserialize(&msg).unwrap();
75 incoming_send.send(out).await.unwrap();
76 }
77 // TODO(justin): The connection is closed, so we should
78 // clean up its metadata.
79 });
80 }
81 });
82
83 let (send_port, recv_port) = self.make_edge("tcp ingress handoff");
84 self.add_input_from_stream("tcp ingress stream", send_port, incoming_messages.map(Some));
85
86 (port, recv_port)
87 }
88
89 pub async fn outbound_tcp_vertex<T>(&mut self) -> SendPort<VecHandoff<(Address, T)>>
90 where
91 T: 'static + Serialize + Send,
92 {
93 let (mut connection_reqs_send, mut connection_reqs_recv) =
94 futures::channel::mpsc::channel(1024);
95 let (mut connections_send, mut connections_recv) = futures::channel::mpsc::channel(1024);
96
97 // TODO(justin): handle errors here.
98 // Spawn an actor which establishes connections.
99 tokio::spawn(async move {
100 while let Some(addr) = connection_reqs_recv.next().await {
101 let addr: Address = addr;
102 connections_send
103 .send((addr.clone(), TcpStream::connect(addr.clone()).await))
104 .await
105 .unwrap();
106 }
107 });
108
109 enum ConnStatus<T> {
110 Pending(Vec<T>),
111 Connected(FramedWrite<TcpStream, LengthDelimitedCodec>),
112 }
113
114 let (mut outbound_messages_send, mut outbound_messages_recv) =
115 futures::channel::mpsc::channel(1024);
116 tokio::spawn(async move {
117 // TODO(justin): this cache should be global to the entire
118 // instance so we can reuse connections from inbound connections.
119 let mut connections = HashMap::<Address, ConnStatus<T>>::new();
120
121 loop {
122 tokio::select! {
123 Some((addr, msg)) = outbound_messages_recv.next() => {
124 let addr: Address = addr;
125 let msg: T = msg;
126 match connections.get_mut(&addr) {
127 None => {
128 // We have not seen this address before, open a
129 // connection to it and buffer the message to be
130 // sent once it's open.
131
132 // TODO(justin): what do we do if the buffer is full here?
133 connection_reqs_send.try_send(addr.clone()).unwrap();
134 connections.insert(addr, ConnStatus::Pending(vec![msg]));
135 }
136 Some(ConnStatus::Pending(msgs)) => {
137 // We have seen this address before but we're
138 // still trying to connect to it, so buffer this
139 // message so that when we _do_ connect we will
140 // send it.
141 msgs.push(msg);
142 }
143 Some(ConnStatus::Connected(conn)) => {
144 // TODO(justin): move the actual sending here
145 // into a different task so we don't have to
146 // wait for the send.
147 let msg = bincode::serialize(&msg).unwrap();
148 conn.send(msg.into()).await.unwrap();
149 }
150 }
151 },
152
153 Some((addr, conn)) = connections_recv.next() => {
154 match conn {
155 Ok(conn) => {
156 match connections.get_mut(&addr) {
157 Some(ConnStatus::Pending(msgs)) => {
158 let mut conn = FramedWrite::new(conn, LengthDelimitedCodec::new());
159 for msg in msgs.drain(..) {
160 // TODO(justin): move the actual sending here
161 // into a different task so we don't have to
162 // wait for the send.
163 let msg = bincode::serialize(&msg).unwrap();
164 conn.send(msg.into()).await.unwrap();
165 }
166 connections.insert(addr, ConnStatus::Connected(conn));
167 }
168 None => {
169 // This means nobody ever requested this
170 // connection, so we shouldn't have initiated it
171 // in the first place.
172 unreachable!()
173 }
174 Some(ConnStatus::Connected(_tcp)) => {
175 // This means we were already connected, so we
176 // shouldn't have connected again. If the
177 // connection cache becomes shared this could
178 // become reachable.
179 unreachable!()
180 }
181 }
182 }
183 Err(e) => {
184 // We couldn't connect to the address for some
185 // reason.
186 // TODO(justin): once we have a clearer picture
187 // of error handling, we could do something like
188 // send this error along a pipe to be handled by
189 // someone else. For now, just log it and drop
190 // any pending messages.
191 eprintln!("couldn't connect to {}: {}", addr, e);
192 connections.remove(&addr);
193 }
194 }
195 },
196 else => break,
197 }
198 }
199 });
200
201 let mut buffered_messages = Vec::new();
202 let mut next_messages = Vec::new();
203 let (input_port, output_port) = self.make_edge("tcp egress handoff");
204 self.add_subgraph_sink("tcp egress stream", output_port, move |_ctx, recv| {
205 buffered_messages.extend(recv.take_inner());
206 for msg in buffered_messages.drain(..) {
207 if let Err(e) = outbound_messages_send.try_send(msg) {
208 // If we weren't able to send a message (say, because the
209 // buffer is full), we get handed it back in the error. If
210 // this happens we hang onto the message to try sending it
211 // again next time.
212 next_messages.push(e.into_inner());
213 }
214 }
215
216 // NB. we don't need to flush the channel here due to the use of
217 // `try_send`. It's guaranteed that there was space for the
218 // messages and that they were sent.
219
220 // TODO(justin): we do need to make sure we get rescheduled if
221 // next_messages is empty here.
222
223 std::mem::swap(&mut buffered_messages, &mut next_messages);
224 });
225
226 input_port
227 }
228}