gossip_server/
main.rs

1use std::convert::Infallible;
2use std::fmt::Debug;
3use std::future::ready;
4use std::hash::Hash;
5use std::io::Error;
6use std::net::{IpAddr, Ipv4Addr, SocketAddr};
7use std::num::ParseFloatError;
8use std::time::Duration;
9
10use clap::Parser;
11use dfir_rs::futures::{Sink, SinkExt, StreamExt};
12use dfir_rs::tokio_stream::wrappers::IntervalStream;
13use dfir_rs::util::{bind_udp_bytes, ipv4_resolve};
14use dfir_rs::{bincode, bytes, tokio};
15use gossip_kv::membership::{MemberDataBuilder, Protocol};
16use gossip_kv::server::{SeedNode, server};
17use prometheus::{Encoder, TextEncoder, gather};
18use serde::Serialize;
19use tracing::{error, info, trace};
20use warp::Filter;
21
22use crate::config::{SeedNodeSettings, setup_settings_watch};
23use crate::membership::member_name;
24
25mod config;
26
27mod membership;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Parser)]
30struct Opts {
31    /// Port to listen for gossip messages.
32    #[clap(short, long, default_value = "3000")]
33    gossip_port: u16,
34
35    /// Port to listen for client requests.
36    #[clap(short, long, default_value = "3001")]
37    client_port: u16,
38
39    /// The duration (in seconds) between gossip rounds.
40    #[clap(short, long, default_value = "5", value_parser = clap_duration_from_secs)]
41    gossip_frequency: Duration,
42
43    #[clap(env = "GOSSIP_MEMBER_SUFFIX_LEN", default_value = "4")]
44    member_suffix_len: usize,
45}
46
47/// Parse duration from float string for clap args.
48fn clap_duration_from_secs(arg: &str) -> Result<Duration, ParseFloatError> {
49    arg.parse().map(Duration::from_secs_f32)
50}
51
52/// Create a SeedNode from a SeedNodeSettings.
53/// Performs a DNS lookup on the address.
54fn make_seed_node(settings: &SeedNodeSettings) -> SeedNode<SocketAddr> {
55    SeedNode {
56        id: settings.id.clone(),
57        address: ipv4_resolve(&settings.address).unwrap(),
58    }
59}
60
61/// Handler for the /metrics route. Used to expose prometheus metrics for the server.
62async fn metrics_handler() -> Result<impl warp::Reply, Infallible> {
63    let encoder = TextEncoder::new();
64    let metric_families = gather();
65    let mut buffer = Vec::new();
66    encoder.encode(&metric_families, &mut buffer).unwrap();
67
68    Ok(warp::reply::with_header(
69        buffer,
70        "Content-Type",
71        encoder.format_type(),
72    ))
73}
74
75/// Setup serialization for outbound networking messages.
76fn setup_outbound_serialization<Outbound, Message>(
77    outbound: Outbound,
78) -> impl Sink<(Message, SocketAddr), Error = Error>
79where
80    Outbound: Sink<(bytes::Bytes, SocketAddr), Error = Error>,
81    Message: Serialize + Debug + Send + 'static,
82{
83    outbound.with(|(msg, addr): (Message, SocketAddr)| {
84        ready(Ok::<(bytes::Bytes, SocketAddr), Error>((
85            dfir_rs::util::serialize_to_bytes(msg),
86            addr,
87        )))
88    })
89}
90
91/// Setup deserialization for inbound networking messages.
92fn setup_inbound_deserialization<Inbound, Message>(
93    inbound: Inbound,
94) -> impl dfir_rs::futures::Stream<Item = (Message, SocketAddr)>
95where
96    Inbound: dfir_rs::futures::Stream<Item = Result<(bytes::BytesMut, SocketAddr), Error>>,
97    Message: for<'de> serde::Deserialize<'de> + Debug + Send + 'static,
98{
99    inbound.filter_map(|input| {
100        let mapped = match input {
101            Ok((bytes, addr)) => {
102                let msg: bincode::Result<Message> = dfir_rs::util::deserialize_from_bytes(&bytes);
103                match msg {
104                    Ok(msg) => Some((msg, addr)),
105                    Err(e) => {
106                        error!("Error deserializing message: {:?}", e);
107                        None
108                    }
109                }
110            }
111            Err(e) => {
112                error!("Error receiving message: {:?}", e);
113                None
114            }
115        };
116        ready(mapped)
117    })
118}
119
120#[dfir_rs::main]
121async fn main() {
122    tracing_subscriber::fmt::init();
123
124    let opts: Opts = Opts::parse();
125
126    // Setup metrics server
127    let metrics_route = warp::path("metrics").and_then(metrics_handler);
128    tokio::spawn(async move {
129        info!("Starting metrics server on port 4003");
130        warp::serve(metrics_route).run(([0, 0, 0, 0], 4003)).await;
131    });
132
133    // Setup protocol information for this member
134    let client_protocol_address =
135        SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), opts.client_port);
136    let gossip_protocol_address =
137        SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), opts.gossip_port);
138
139    let member_data = MemberDataBuilder::new(member_name(opts.member_suffix_len).clone())
140        .add_protocol(Protocol::new("gossip".into(), gossip_protocol_address))
141        .add_protocol(Protocol::new("client".into(), client_protocol_address))
142        .build();
143
144    // Bind to the UDP ports
145    let (client_outbound, client_inbound, _) = bind_udp_bytes(client_protocol_address).await;
146    let (gossip_outbound, gossip_inbound, _) = bind_udp_bytes(gossip_protocol_address).await;
147
148    info!(
149        "Server {:?} listening for client requests on: {:?}",
150        member_data.id, client_protocol_address
151    );
152
153    // Setup serde for client requests
154    let client_ob = setup_outbound_serialization(client_outbound);
155    let client_ib = setup_inbound_deserialization(client_inbound);
156
157    // Setup serde for gossip messages
158    let gossip_ob = setup_outbound_serialization(gossip_outbound);
159    let gossip_ib = setup_inbound_deserialization(gossip_inbound);
160
161    // Setup regular gossip triggers
162    let gossip_rx = IntervalStream::new(tokio::time::interval(opts.gossip_frequency)).map(|_| ());
163
164    // Setup watcher for setting changes
165    let (_watcher, server_settings, settings_stream) = setup_settings_watch();
166
167    let seed_nodes = server_settings
168        .seed_nodes
169        .iter()
170        .map(make_seed_node)
171        .collect::<Vec<_>>();
172
173    let seed_node_stream = settings_stream.map(|settings| {
174        trace!("Settings updated. Reloading seed nodes");
175        settings.seed_nodes.iter().map(make_seed_node).collect()
176    });
177
178    // Create and run the server
179    let mut server = server(
180        client_ib,
181        client_ob,
182        gossip_ib,
183        gossip_ob,
184        gossip_rx,
185        member_data,
186        seed_nodes,
187        seed_node_stream,
188    );
189
190    server.run_async().await;
191}