1#![cfg(not(target_arch = "wasm32"))]
2
3use std::collections::HashMap;
4use std::collections::hash_map::Entry::{Occupied, Vacant};
5use std::fmt::Debug;
6use std::net::SocketAddr;
7
8use futures::{SinkExt, StreamExt};
9use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
10use tokio::net::{TcpListener, TcpSocket, TcpStream};
11use tokio::select;
12use tokio::task::spawn_local;
13use tokio_stream::StreamMap;
14use tokio_util::codec::{
15 BytesCodec, Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec, LinesCodec,
16};
17
18use super::unsync::mpsc::{Receiver, Sender};
19use super::unsync_channel;
20
21pub fn tcp_framed<Codec>(
24 stream: TcpStream,
25 codec: Codec,
26) -> (
27 FramedWrite<OwnedWriteHalf, Codec>,
28 FramedRead<OwnedReadHalf, Codec>,
29)
30where
31 Codec: Clone + Decoder,
32{
33 let (recv, send) = stream.into_split();
34 let send = FramedWrite::new(send, codec.clone());
35 let recv = FramedRead::new(recv, codec);
36 (send, recv)
37}
38
39pub fn tcp_bytes(
42 stream: TcpStream,
43) -> (
44 FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
45 FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
46) {
47 tcp_framed(stream, LengthDelimitedCodec::new())
48}
49
50pub fn tcp_bytestream(
52 stream: TcpStream,
53) -> (
54 FramedWrite<OwnedWriteHalf, BytesCodec>,
55 FramedRead<OwnedReadHalf, BytesCodec>,
56) {
57 tcp_framed(stream, BytesCodec::new())
58}
59
60pub fn tcp_lines(
62 stream: TcpStream,
63) -> (
64 FramedWrite<OwnedWriteHalf, LinesCodec>,
65 FramedRead<OwnedReadHalf, LinesCodec>,
66) {
67 tcp_framed(stream, LinesCodec::new())
68}
69
70pub type TcpFramedSink<T> = Sender<(T, SocketAddr)>;
72#[expect(type_alias_bounds, reason = "code readability")]
74pub type TcpFramedStream<Codec: Decoder> =
75 Receiver<Result<(<Codec as Decoder>::Item, SocketAddr), <Codec as Decoder>::Error>>;
76
77pub async fn bind_tcp<Item, Codec>(
80 endpoint: SocketAddr,
81 codec: Codec,
82) -> Result<(TcpFramedSink<Item>, TcpFramedStream<Codec>, SocketAddr), std::io::Error>
83where
84 Item: 'static,
85 Codec: 'static + Clone + Decoder + Encoder<Item>,
86 <Codec as Encoder<Item>>::Error: Debug,
87{
88 let listener = TcpListener::bind(endpoint).await?;
89
90 let bound_endpoint = listener.local_addr()?;
91
92 let (send_egress, mut recv_egress) = unsync_channel::<(Item, SocketAddr)>(None);
93 let (send_ingres, recv_ingres) = unsync_channel(None);
94
95 spawn_local(async move {
96 let send_ingress = send_ingres;
97 let mut peers_send = HashMap::new();
99 let mut peers_recv = StreamMap::<SocketAddr, FramedRead<OwnedReadHalf, Codec>>::new();
102
103 loop {
104 select! {
106 biased;
112 msg_send = recv_egress.next() => {
114 let Some((payload, peer_addr)) = msg_send else {
115 continue;
117 };
118 let Some(stream) = peers_send.get_mut(&peer_addr) else {
119 tracing::warn!("Dropping message to non-connected peer: {}", peer_addr);
120 continue;
121 };
122 if let Err(err) = SinkExt::send(stream, payload).await {
123 tracing::error!("IO or codec error sending message to peer {}, disconnecting: {:?}", peer_addr, err);
124 peers_send.remove(&peer_addr); };
126 }
127 msg_recv = peers_recv.next(), if !peers_recv.is_empty() => {
129 let Some((peer_addr, payload_result)) = msg_recv else {
132 continue; };
134 if let Err(err) = send_ingress.send(payload_result.map(|payload| (payload, peer_addr))).await {
135 tracing::error!("Error passing along received message: {:?}", err);
136 }
137 }
138 new_peer = listener.accept() => {
140 let Ok((stream, _addr)) = new_peer else {
141 continue;
142 };
143 let Ok(peer_addr) = stream.peer_addr() else {
144 continue;
145 };
146 let (peer_send, peer_recv) = tcp_framed(stream, codec.clone());
147
148 peers_send.insert(peer_addr, peer_send);
154 peers_recv.insert(peer_addr, peer_recv);
155 }
156 }
157 }
158 });
159
160 Ok((send_egress, recv_ingres, bound_endpoint))
161}
162
163pub fn connect_tcp<Item, Codec>(codec: Codec) -> (TcpFramedSink<Item>, TcpFramedStream<Codec>)
169where
170 Item: 'static,
171 Codec: 'static + Clone + Decoder + Encoder<Item>,
172 <Codec as Encoder<Item>>::Error: Debug,
173{
174 let (send_egress, mut recv_egress) = unsync_channel(None);
175 let (send_ingres, recv_ingres) = unsync_channel(None);
176
177 spawn_local(async move {
178 let send_ingres = send_ingres;
179 let mut peers_send = HashMap::new();
181 let mut peers_recv = StreamMap::new();
184
185 loop {
186 select! {
188 biased;
192 msg_send = recv_egress.next() => {
194 let Some((payload, peer_addr)) = msg_send else {
195 continue;
197 };
198
199 let stream = match peers_send.entry(peer_addr) {
200 Occupied(entry) => entry.into_mut(),
201 Vacant(entry) => {
202 let socket = TcpSocket::new_v4().unwrap();
203 let stream = socket.connect(peer_addr).await.unwrap();
204
205 let (peer_send, peer_recv) = tcp_framed(stream, codec.clone());
206
207 peers_recv.insert(peer_addr, peer_recv);
208 entry.insert(peer_send)
209 }
210 };
211
212 if let Err(err) = stream.send(payload).await {
213 tracing::error!("IO or codec error sending message to peer {}, disconnecting: {:?}", peer_addr, err);
214 peers_send.remove(&peer_addr); }
216 }
217 msg_recv = peers_recv.next(), if !peers_recv.is_empty() => {
219 let Some((peer_addr, payload_result)) = msg_recv else {
222 continue; };
224 if let Err(err) = send_ingres.send(payload_result.map(|payload| (payload, peer_addr))).await {
225 tracing::error!("Error passing along received message: {:?}", err);
226 }
227 }
228 }
229 }
230 });
231
232 (send_egress, recv_ingres)
233}