1use std::convert::Infallible;
2use std::fmt::Debug;
3use std::future::ready;
4use std::hash::Hash;
5use std::io::Error;
6use std::net::{IpAddr, Ipv4Addr, SocketAddr};
7use std::num::ParseFloatError;
8use std::time::Duration;
9
10use clap::Parser;
11use dfir_rs::futures::{Sink, SinkExt, StreamExt};
12use dfir_rs::tokio_stream::wrappers::IntervalStream;
13use dfir_rs::util::{bind_udp_bytes, ipv4_resolve};
14use dfir_rs::{bincode, bytes, tokio};
15use gossip_kv::membership::{MemberDataBuilder, Protocol};
16use gossip_kv::server::{SeedNode, server};
17use prometheus::{Encoder, TextEncoder, gather};
18use serde::Serialize;
19use tracing::{error, info, trace};
20use warp::Filter;
21
22use crate::config::{SeedNodeSettings, setup_settings_watch};
23use crate::membership::member_name;
24
25mod config;
26
27mod membership;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Parser)]
30struct Opts {
31 #[clap(short, long, default_value = "3000")]
33 gossip_port: u16,
34
35 #[clap(short, long, default_value = "3001")]
37 client_port: u16,
38
39 #[clap(short, long, default_value = "5", value_parser = clap_duration_from_secs)]
41 gossip_frequency: Duration,
42
43 #[clap(env = "GOSSIP_MEMBER_SUFFIX_LEN", default_value = "4")]
44 member_suffix_len: usize,
45}
46
47fn clap_duration_from_secs(arg: &str) -> Result<Duration, ParseFloatError> {
49 arg.parse().map(Duration::from_secs_f32)
50}
51
52fn make_seed_node(settings: &SeedNodeSettings) -> SeedNode<SocketAddr> {
55 SeedNode {
56 id: settings.id.clone(),
57 address: ipv4_resolve(&settings.address).unwrap(),
58 }
59}
60
61async fn metrics_handler() -> Result<impl warp::Reply, Infallible> {
63 let encoder = TextEncoder::new();
64 let metric_families = gather();
65 let mut buffer = Vec::new();
66 encoder.encode(&metric_families, &mut buffer).unwrap();
67
68 Ok(warp::reply::with_header(
69 buffer,
70 "Content-Type",
71 encoder.format_type(),
72 ))
73}
74
75fn setup_outbound_serialization<Outbound, Message>(
77 outbound: Outbound,
78) -> impl Sink<(Message, SocketAddr), Error = Error>
79where
80 Outbound: Sink<(bytes::Bytes, SocketAddr), Error = Error>,
81 Message: Serialize + Debug + Send + 'static,
82{
83 outbound.with(|(msg, addr): (Message, SocketAddr)| {
84 ready(Ok::<(bytes::Bytes, SocketAddr), Error>((
85 dfir_rs::util::serialize_to_bytes(msg),
86 addr,
87 )))
88 })
89}
90
91fn setup_inbound_deserialization<Inbound, Message>(
93 inbound: Inbound,
94) -> impl dfir_rs::futures::Stream<Item = (Message, SocketAddr)>
95where
96 Inbound: dfir_rs::futures::Stream<Item = Result<(bytes::BytesMut, SocketAddr), Error>>,
97 Message: for<'de> serde::Deserialize<'de> + Debug + Send + 'static,
98{
99 inbound.filter_map(|input| {
100 let mapped = match input {
101 Ok((bytes, addr)) => {
102 let msg: bincode::Result<Message> = dfir_rs::util::deserialize_from_bytes(&bytes);
103 match msg {
104 Ok(msg) => Some((msg, addr)),
105 Err(e) => {
106 error!("Error deserializing message: {:?}", e);
107 None
108 }
109 }
110 }
111 Err(e) => {
112 error!("Error receiving message: {:?}", e);
113 None
114 }
115 };
116 ready(mapped)
117 })
118}
119
120#[dfir_rs::main]
121async fn main() {
122 tracing_subscriber::fmt::init();
123
124 let opts: Opts = Opts::parse();
125
126 let metrics_route = warp::path("metrics").and_then(metrics_handler);
128 tokio::spawn(async move {
129 info!("Starting metrics server on port 4003");
130 warp::serve(metrics_route).run(([0, 0, 0, 0], 4003)).await;
131 });
132
133 let client_protocol_address =
135 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), opts.client_port);
136 let gossip_protocol_address =
137 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), opts.gossip_port);
138
139 let member_data = MemberDataBuilder::new(member_name(opts.member_suffix_len).clone())
140 .add_protocol(Protocol::new("gossip".into(), gossip_protocol_address))
141 .add_protocol(Protocol::new("client".into(), client_protocol_address))
142 .build();
143
144 let (client_outbound, client_inbound, _) = bind_udp_bytes(client_protocol_address).await;
146 let (gossip_outbound, gossip_inbound, _) = bind_udp_bytes(gossip_protocol_address).await;
147
148 info!(
149 "Server {:?} listening for client requests on: {:?}",
150 member_data.id, client_protocol_address
151 );
152
153 let client_ob = setup_outbound_serialization(client_outbound);
155 let client_ib = setup_inbound_deserialization(client_inbound);
156
157 let gossip_ob = setup_outbound_serialization(gossip_outbound);
159 let gossip_ib = setup_inbound_deserialization(gossip_inbound);
160
161 let gossip_rx = IntervalStream::new(tokio::time::interval(opts.gossip_frequency)).map(|_| ());
163
164 let (_watcher, server_settings, settings_stream) = setup_settings_watch();
166
167 let seed_nodes = server_settings
168 .seed_nodes
169 .iter()
170 .map(make_seed_node)
171 .collect::<Vec<_>>();
172
173 let seed_node_stream = settings_stream.map(|settings| {
174 trace!("Settings updated. Reloading seed nodes");
175 settings.seed_nodes.iter().map(make_seed_node).collect()
176 });
177
178 let mut server = server(
180 client_ib,
181 client_ob,
182 gossip_ib,
183 gossip_ob,
184 gossip_rx,
185 member_data,
186 seed_nodes,
187 seed_node_stream,
188 );
189
190 server.run_async().await;
191}