1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
#![cfg(not(target_arch = "wasm32"))]

use std::collections::hash_map::Entry::{Occupied, Vacant};
use std::collections::HashMap;
use std::fmt::Debug;
use std::net::SocketAddr;

use futures::{SinkExt, StreamExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::select;
use tokio::task::spawn_local;
use tokio_stream::StreamMap;
use tokio_util::codec::{
    BytesCodec, Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec, LinesCodec,
};

use super::unsync::mpsc::{Receiver, Sender};
use super::unsync_channel;

/// Helper creates a TCP `Stream` and `Sink` from the given socket, using the given `Codec` to
/// handle delineation between inputs/outputs.
pub fn tcp_framed<Codec>(
    stream: TcpStream,
    codec: Codec,
) -> (
    FramedWrite<OwnedWriteHalf, Codec>,
    FramedRead<OwnedReadHalf, Codec>,
)
where
    Codec: Clone + Decoder,
{
    let (recv, send) = stream.into_split();
    let send = FramedWrite::new(send, codec.clone());
    let recv = FramedRead::new(recv, codec);
    (send, recv)
}

/// Helper creates a TCP `Stream` and `Sink` for `Bytes` strings where each string is
/// length-delimited.
pub fn tcp_bytes(
    stream: TcpStream,
) -> (
    FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
    FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
) {
    tcp_framed(stream, LengthDelimitedCodec::new())
}

/// Helper creates a TCP `Stream` and `Sink` for undelimited streams of `Bytes`.
pub fn tcp_bytestream(
    stream: TcpStream,
) -> (
    FramedWrite<OwnedWriteHalf, BytesCodec>,
    FramedRead<OwnedReadHalf, BytesCodec>,
) {
    tcp_framed(stream, BytesCodec::new())
}

/// Helper creates a TCP `Stream` and `Sink` for `str`ings delimited by newlines.
pub fn tcp_lines(
    stream: TcpStream,
) -> (
    FramedWrite<OwnedWriteHalf, LinesCodec>,
    FramedRead<OwnedReadHalf, LinesCodec>,
) {
    tcp_framed(stream, LinesCodec::new())
}

/// A framed TCP `Sink` (sending).
pub type TcpFramedSink<T> = Sender<(T, SocketAddr)>;
/// A framed TCP `Stream` (receiving).
#[expect(type_alias_bounds, reason = "code readability")]
pub type TcpFramedStream<Codec: Decoder> =
    Receiver<Result<(<Codec as Decoder>::Item, SocketAddr), <Codec as Decoder>::Error>>;

// TODO(mingwei): this temporary code should be replaced with a properly thought out networking system.
/// Create a listening tcp socket, and then as new connections come in, receive their data and forward it to a queue.
pub async fn bind_tcp<Item, Codec>(
    endpoint: SocketAddr,
    codec: Codec,
) -> Result<(TcpFramedSink<Item>, TcpFramedStream<Codec>, SocketAddr), std::io::Error>
where
    Item: 'static,
    Codec: 'static + Clone + Decoder + Encoder<Item>,
    <Codec as Encoder<Item>>::Error: Debug,
{
    let listener = TcpListener::bind(endpoint).await?;

    let bound_endpoint = listener.local_addr()?;

    let (send_egress, mut recv_egress) = unsync_channel::<(Item, SocketAddr)>(None);
    let (send_ingres, recv_ingres) = unsync_channel(None);

    spawn_local(async move {
        let send_ingress = send_ingres;
        // Map of `addr -> peers`, to send messages to.
        let mut peers_send = HashMap::new();
        // `StreamMap` of `addr -> peers`, to receive messages from. Automatically removes streams
        // when they disconnect.
        let mut peers_recv = StreamMap::<SocketAddr, FramedRead<OwnedReadHalf, Codec>>::new();

        loop {
            // Calling methods in a loop, futures must be cancel-safe.
            select! {
                // `biased` means the cases will be prioritized in the order they are listed.
                // First we accept any new connections
                // This is not strictly neccessary, but lets us do our internal work (send outgoing
                // messages) before accepting more work (receiving more messages, accepting new
                // clients).
                biased;
                // Send outgoing messages.
                msg_send = recv_egress.next() => {
                    let Some((payload, peer_addr)) = msg_send else {
                        // `None` if the send side has been dropped (no more send messages will ever come).
                        continue;
                    };
                    let Some(stream) = peers_send.get_mut(&peer_addr) else {
                        tracing::warn!("Dropping message to non-connected peer: {}", peer_addr);
                        continue;
                    };
                    if let Err(err) = SinkExt::send(stream, payload).await {
                        tracing::error!("IO or codec error sending message to peer {}, disconnecting: {:?}", peer_addr, err);
                        peers_send.remove(&peer_addr); // `Drop` disconnects.
                    };
                }
                // Receive incoming messages.
                msg_recv = peers_recv.next(), if !peers_recv.is_empty() => {
                    // If `peers_recv` is empty then `next()` will immediately return `None` which
                    // would cause the loop to spin.
                    let Some((peer_addr, payload_result)) = msg_recv else {
                        continue; // => `peers_recv.is_empty()`.
                    };
                    if let Err(err) = send_ingress.send(payload_result.map(|payload| (payload, peer_addr))).await {
                        tracing::error!("Error passing along received message: {:?}", err);
                    }
                }
                // Accept new clients.
                new_peer = listener.accept() => {
                    let Ok((stream, _addr)) = new_peer else {
                        continue;
                    };
                    let Ok(peer_addr) = stream.peer_addr() else {
                        continue;
                    };
                    let (peer_send, peer_recv) = tcp_framed(stream, codec.clone());

                    // TODO: Using peer_addr here as the key is a little bit sketchy.
                    // It's possible that a peer could send a message, disconnect, then another peer connects from the
                    // same IP address (and the same src port), and then the response could be sent to that new client.
                    // This can be solved by using monotonically increasing IDs for each new peer, but would break the
                    // similarity with the UDP versions of this function.
                    peers_send.insert(peer_addr, peer_send);
                    peers_recv.insert(peer_addr, peer_recv);
                }
            }
        }
    });

    Ok((send_egress, recv_ingres, bound_endpoint))
}

/// The inverse of [`bind_tcp`].
///
/// When messages enqueued into the returned sender, tcp sockets will be created and connected as
/// necessary to send out the requests. As the responses come back, they will be forwarded to the
/// returned receiver.
pub fn connect_tcp<Item, Codec>(codec: Codec) -> (TcpFramedSink<Item>, TcpFramedStream<Codec>)
where
    Item: 'static,
    Codec: 'static + Clone + Decoder + Encoder<Item>,
    <Codec as Encoder<Item>>::Error: Debug,
{
    let (send_egress, mut recv_egress) = unsync_channel(None);
    let (send_ingres, recv_ingres) = unsync_channel(None);

    spawn_local(async move {
        let send_ingres = send_ingres;
        // Map of `addr -> peers`, to send messages to.
        let mut peers_send = HashMap::new();
        // `StreamMap` of `addr -> peers`, to receive messages from. Automatically removes streams
        // when they disconnect.
        let mut peers_recv = StreamMap::new();

        loop {
            // Calling methods in a loop, futures must be cancel-safe.
            select! {
                // `biased` means the cases will be prioritized in the order they are listed.
                // This is not strictly neccessary, but lets us do our internal work (send outgoing
                // messages) before accepting more work (receiving more messages).
                biased;
                // Send outgoing messages.
                msg_send = recv_egress.next() => {
                    let Some((payload, peer_addr)) = msg_send else {
                        // `None` if the send side has been dropped (no more send messages will ever come).
                        continue;
                    };

                    let stream = match peers_send.entry(peer_addr) {
                        Occupied(entry) => entry.into_mut(),
                        Vacant(entry) => {
                            let socket = TcpSocket::new_v4().unwrap();
                            let stream = socket.connect(peer_addr).await.unwrap();

                            let (peer_send, peer_recv) = tcp_framed(stream, codec.clone());

                            peers_recv.insert(peer_addr, peer_recv);
                            entry.insert(peer_send)
                        }
                    };

                    if let Err(err) = stream.send(payload).await {
                        tracing::error!("IO or codec error sending message to peer {}, disconnecting: {:?}", peer_addr, err);
                        peers_send.remove(&peer_addr); // `Drop` disconnects.
                    }
                }
                // Receive incoming messages.
                msg_recv = peers_recv.next(), if !peers_recv.is_empty() => {
                    // If `peers_recv` is empty then `next()` will immediately return `None` which
                    // would cause the loop to spin.
                    let Some((peer_addr, payload_result)) = msg_recv else {
                        continue; // => `peers_recv.is_empty()`.
                    };
                    if let Err(err) = send_ingres.send(payload_result.map(|payload| (payload, peer_addr))).await {
                        tracing::error!("Error passing along received message: {:?}", err);
                    }
                }
            }
        }
    });

    (send_egress, recv_ingres)
}