dfir_rs/util/
tcp.rs

1#![cfg(not(target_arch = "wasm32"))]
2
3use std::collections::HashMap;
4use std::collections::hash_map::Entry::{Occupied, Vacant};
5use std::fmt::Debug;
6use std::net::SocketAddr;
7
8use futures::{SinkExt, StreamExt};
9use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
10use tokio::net::{TcpListener, TcpSocket, TcpStream};
11use tokio::select;
12use tokio::task::spawn_local;
13use tokio_stream::StreamMap;
14use tokio_util::codec::{
15    BytesCodec, Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec, LinesCodec,
16};
17
18use super::unsync::mpsc::{Receiver, Sender};
19use super::unsync_channel;
20
21/// Helper creates a TCP `Stream` and `Sink` from the given socket, using the given `Codec` to
22/// handle delineation between inputs/outputs.
23pub fn tcp_framed<Codec>(
24    stream: TcpStream,
25    codec: Codec,
26) -> (
27    FramedWrite<OwnedWriteHalf, Codec>,
28    FramedRead<OwnedReadHalf, Codec>,
29)
30where
31    Codec: Clone + Decoder,
32{
33    let (recv, send) = stream.into_split();
34    let send = FramedWrite::new(send, codec.clone());
35    let recv = FramedRead::new(recv, codec);
36    (send, recv)
37}
38
39/// Helper creates a TCP `Stream` and `Sink` for `Bytes` strings where each string is
40/// length-delimited.
41pub fn tcp_bytes(
42    stream: TcpStream,
43) -> (
44    FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
45    FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
46) {
47    tcp_framed(stream, LengthDelimitedCodec::new())
48}
49
50/// Helper creates a TCP `Stream` and `Sink` for undelimited streams of `Bytes`.
51pub fn tcp_bytestream(
52    stream: TcpStream,
53) -> (
54    FramedWrite<OwnedWriteHalf, BytesCodec>,
55    FramedRead<OwnedReadHalf, BytesCodec>,
56) {
57    tcp_framed(stream, BytesCodec::new())
58}
59
60/// Helper creates a TCP `Stream` and `Sink` for `str`ings delimited by newlines.
61pub fn tcp_lines(
62    stream: TcpStream,
63) -> (
64    FramedWrite<OwnedWriteHalf, LinesCodec>,
65    FramedRead<OwnedReadHalf, LinesCodec>,
66) {
67    tcp_framed(stream, LinesCodec::new())
68}
69
70/// A framed TCP `Sink` (sending).
71pub type TcpFramedSink<T> = Sender<(T, SocketAddr)>;
72/// A framed TCP `Stream` (receiving).
73#[expect(type_alias_bounds, reason = "code readability")]
74pub type TcpFramedStream<Codec: Decoder> =
75    Receiver<Result<(<Codec as Decoder>::Item, SocketAddr), <Codec as Decoder>::Error>>;
76
77// TODO(mingwei): this temporary code should be replaced with a properly thought out networking system.
78/// Create a listening tcp socket, and then as new connections come in, receive their data and forward it to a queue.
79pub async fn bind_tcp<Item, Codec>(
80    endpoint: SocketAddr,
81    codec: Codec,
82) -> Result<(TcpFramedSink<Item>, TcpFramedStream<Codec>, SocketAddr), std::io::Error>
83where
84    Item: 'static,
85    Codec: 'static + Clone + Decoder + Encoder<Item>,
86    <Codec as Encoder<Item>>::Error: Debug,
87{
88    let listener = TcpListener::bind(endpoint).await?;
89
90    let bound_endpoint = listener.local_addr()?;
91
92    let (send_egress, mut recv_egress) = unsync_channel::<(Item, SocketAddr)>(None);
93    let (send_ingres, recv_ingres) = unsync_channel(None);
94
95    spawn_local(async move {
96        let send_ingress = send_ingres;
97        // Map of `addr -> peers`, to send messages to.
98        let mut peers_send = HashMap::new();
99        // `StreamMap` of `addr -> peers`, to receive messages from. Automatically removes streams
100        // when they disconnect.
101        let mut peers_recv = StreamMap::<SocketAddr, FramedRead<OwnedReadHalf, Codec>>::new();
102
103        loop {
104            // Calling methods in a loop, futures must be cancel-safe.
105            select! {
106                // `biased` means the cases will be prioritized in the order they are listed.
107                // First we accept any new connections
108                // This is not strictly neccessary, but lets us do our internal work (send outgoing
109                // messages) before accepting more work (receiving more messages, accepting new
110                // clients).
111                biased;
112                // Send outgoing messages.
113                msg_send = recv_egress.next() => {
114                    let Some((payload, peer_addr)) = msg_send else {
115                        // `None` if the send side has been dropped (no more send messages will ever come).
116                        continue;
117                    };
118                    let Some(stream) = peers_send.get_mut(&peer_addr) else {
119                        tracing::warn!("Dropping message to non-connected peer: {}", peer_addr);
120                        continue;
121                    };
122                    if let Err(err) = SinkExt::send(stream, payload).await {
123                        tracing::error!("IO or codec error sending message to peer {}, disconnecting: {:?}", peer_addr, err);
124                        peers_send.remove(&peer_addr); // `Drop` disconnects.
125                    };
126                }
127                // Receive incoming messages.
128                msg_recv = peers_recv.next(), if !peers_recv.is_empty() => {
129                    // If `peers_recv` is empty then `next()` will immediately return `None` which
130                    // would cause the loop to spin.
131                    let Some((peer_addr, payload_result)) = msg_recv else {
132                        continue; // => `peers_recv.is_empty()`.
133                    };
134                    if let Err(err) = send_ingress.send(payload_result.map(|payload| (payload, peer_addr))).await {
135                        tracing::error!("Error passing along received message: {:?}", err);
136                    }
137                }
138                // Accept new clients.
139                new_peer = listener.accept() => {
140                    let Ok((stream, _addr)) = new_peer else {
141                        continue;
142                    };
143                    let Ok(peer_addr) = stream.peer_addr() else {
144                        continue;
145                    };
146                    let (peer_send, peer_recv) = tcp_framed(stream, codec.clone());
147
148                    // TODO: Using peer_addr here as the key is a little bit sketchy.
149                    // It's possible that a peer could send a message, disconnect, then another peer connects from the
150                    // same IP address (and the same src port), and then the response could be sent to that new client.
151                    // This can be solved by using monotonically increasing IDs for each new peer, but would break the
152                    // similarity with the UDP versions of this function.
153                    peers_send.insert(peer_addr, peer_send);
154                    peers_recv.insert(peer_addr, peer_recv);
155                }
156            }
157        }
158    });
159
160    Ok((send_egress, recv_ingres, bound_endpoint))
161}
162
163/// The inverse of [`bind_tcp`].
164///
165/// When messages enqueued into the returned sender, tcp sockets will be created and connected as
166/// necessary to send out the requests. As the responses come back, they will be forwarded to the
167/// returned receiver.
168pub fn connect_tcp<Item, Codec>(codec: Codec) -> (TcpFramedSink<Item>, TcpFramedStream<Codec>)
169where
170    Item: 'static,
171    Codec: 'static + Clone + Decoder + Encoder<Item>,
172    <Codec as Encoder<Item>>::Error: Debug,
173{
174    let (send_egress, mut recv_egress) = unsync_channel(None);
175    let (send_ingres, recv_ingres) = unsync_channel(None);
176
177    spawn_local(async move {
178        let send_ingres = send_ingres;
179        // Map of `addr -> peers`, to send messages to.
180        let mut peers_send = HashMap::new();
181        // `StreamMap` of `addr -> peers`, to receive messages from. Automatically removes streams
182        // when they disconnect.
183        let mut peers_recv = StreamMap::new();
184
185        loop {
186            // Calling methods in a loop, futures must be cancel-safe.
187            select! {
188                // `biased` means the cases will be prioritized in the order they are listed.
189                // This is not strictly neccessary, but lets us do our internal work (send outgoing
190                // messages) before accepting more work (receiving more messages).
191                biased;
192                // Send outgoing messages.
193                msg_send = recv_egress.next() => {
194                    let Some((payload, peer_addr)) = msg_send else {
195                        // `None` if the send side has been dropped (no more send messages will ever come).
196                        continue;
197                    };
198
199                    let stream = match peers_send.entry(peer_addr) {
200                        Occupied(entry) => entry.into_mut(),
201                        Vacant(entry) => {
202                            let socket = TcpSocket::new_v4().unwrap();
203                            let stream = socket.connect(peer_addr).await.unwrap();
204
205                            let (peer_send, peer_recv) = tcp_framed(stream, codec.clone());
206
207                            peers_recv.insert(peer_addr, peer_recv);
208                            entry.insert(peer_send)
209                        }
210                    };
211
212                    if let Err(err) = stream.send(payload).await {
213                        tracing::error!("IO or codec error sending message to peer {}, disconnecting: {:?}", peer_addr, err);
214                        peers_send.remove(&peer_addr); // `Drop` disconnects.
215                    }
216                }
217                // Receive incoming messages.
218                msg_recv = peers_recv.next(), if !peers_recv.is_empty() => {
219                    // If `peers_recv` is empty then `next()` will immediately return `None` which
220                    // would cause the loop to spin.
221                    let Some((peer_addr, payload_result)) = msg_recv else {
222                        continue; // => `peers_recv.is_empty()`.
223                    };
224                    if let Err(err) = send_ingres.send(payload_result.map(|payload| (payload, peer_addr))).await {
225                        tracing::error!("Error passing along received message: {:?}", err);
226                    }
227                }
228            }
229        }
230    });
231
232    (send_egress, recv_ingres)
233}