1use std::convert::Infallible;
2use std::num::{NonZeroU32, ParseFloatError};
3use std::thread::sleep;
4use std::time::Duration;
5
6use clap::Parser;
7use dfir_rs::util::{unbounded_channel, unsync_channel};
8use gossip_kv::membership::{MemberDataBuilder, Protocol};
9use gossip_kv::{ClientRequest, GossipMessage};
10use governor::{Quota, RateLimiter};
11use prometheus::{Encoder, TextEncoder, gather};
12use tokio::sync::mpsc::UnboundedSender;
13use tokio::task;
14use tracing::{error, info, trace};
15use warp::Filter;
16
17type LoadTestAddress = u64;
18
19use dfir_rs::futures::sink::drain;
20use dfir_rs::futures::stream;
21use dfir_rs::tokio_stream::StreamExt;
22use dfir_rs::tokio_stream::wrappers::UnboundedReceiverStream;
23use gossip_kv::server::{SeedNode, server};
24use lattices::cc_traits::Iter;
25
26const UNKNOWN_ADDRESS: LoadTestAddress = 9999999999;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Parser)]
29struct Opts {
30 #[clap(short, long, default_value = "5")]
32 thread_count: usize,
33
34 #[clap(short, long, default_value = "10", value_parser = clap_duration_from_secs)]
36 gossip_frequency: Duration,
37
38 #[clap(short, long, default_value = "1")]
40 max_set_throughput: u32,
41}
42
43fn clap_duration_from_secs(arg: &str) -> Result<Duration, ParseFloatError> {
45 arg.parse().map(Duration::from_secs_f32)
46}
47
48fn run_server(
49 server_name: String,
50 gossip_address: LoadTestAddress,
51 gossip_input_rx: UnboundedReceiverStream<(GossipMessage, LoadTestAddress)>,
52 switchboard: Switchboard,
53 seed_nodes: Vec<SeedNode<LoadTestAddress>>,
54 opts: Opts,
55) {
56 std::thread::spawn(move || {
57 let rt = tokio::runtime::Builder::new_current_thread()
58 .enable_all()
59 .build()
60 .unwrap();
61
62 let (gossip_output_tx, mut gossip_output_rx) = unsync_channel(None);
63
64 let (gossip_trigger_tx, gossip_trigger_rx) = unbounded_channel();
65
66 let member_data = MemberDataBuilder::new(server_name.clone())
67 .add_protocol(Protocol::new("gossip".into(), gossip_address))
68 .build();
69
70 rt.block_on(async {
71 let local = task::LocalSet::new();
72
73 let (client_input_tx, client_input_rx) = unbounded_channel();
74
75 let put_throughput = opts.max_set_throughput;
76 local.spawn_local(async move {
77 let rate_limiter = RateLimiter::direct(Quota::per_second(
78 NonZeroU32::new(put_throughput).unwrap(),
79 ));
80 loop {
81 rate_limiter.until_ready().await;
82 let key = "/usr/table/key".parse().unwrap();
83 let request = ClientRequest::Set {
84 key,
85 value: "FOOBAR".to_string(),
86 };
87 client_input_tx.send((request, UNKNOWN_ADDRESS)).unwrap();
88 }
89 });
90
91 let gossip_frequency = opts.gossip_frequency;
92 local.spawn_local(async move {
93 loop {
94 tokio::time::sleep(gossip_frequency).await;
95 gossip_trigger_tx.send(()).unwrap();
96 }
97 });
98
99 local.spawn_local(async move {
101 while let Some((msg, addr)) = gossip_output_rx.next().await {
102 trace!("Sending gossip message: {:?} to {}", msg, addr);
103 let outbox = switchboard.gossip_outboxes.get(addr as usize).unwrap();
104 if let Err(e) = outbox.send((msg, gossip_address)) {
105 error!("Failed to send gossip message: {:?}", e);
106 }
107 }
108 });
109
110 local.spawn_local(async {
111 let mut server = server(
112 client_input_rx,
113 drain(), gossip_input_rx,
115 gossip_output_tx,
116 gossip_trigger_rx,
117 member_data,
118 seed_nodes,
119 stream::empty(),
120 );
121
122 server.run_async().await
123 });
124
125 local.await
126 });
127 });
128}
129
130struct Switchboard {
131 gossip_outboxes: Vec<UnboundedSender<(GossipMessage, LoadTestAddress)>>,
132}
133
134impl Clone for Switchboard {
135 fn clone(&self) -> Self {
136 Self {
137 gossip_outboxes: self.gossip_outboxes.clone(),
138 }
139 }
140}
141
142impl Switchboard {
143 fn new() -> Self {
144 Self {
145 gossip_outboxes: Vec::new(),
146 }
147 }
148 fn new_outbox(
149 &mut self,
150 ) -> (
151 LoadTestAddress,
152 UnboundedReceiverStream<(GossipMessage, LoadTestAddress)>,
153 ) {
154 let addr: LoadTestAddress = self.gossip_outboxes.len() as LoadTestAddress;
155 let (tx, rx) = unbounded_channel();
156 self.gossip_outboxes.push(tx);
157 (addr, rx)
158 }
159}
160
161async fn metrics_handler() -> Result<impl warp::Reply, Infallible> {
162 let encoder = TextEncoder::new();
163 let metric_families = gather();
164 let mut buffer = Vec::new();
165 encoder.encode(&metric_families, &mut buffer).unwrap();
166
167 Ok(warp::reply::with_header(
168 buffer,
169 "Content-Type",
170 encoder.format_type(),
171 ))
172}
173
174fn main() {
175 tracing_subscriber::fmt::init();
176
177 let opts: Opts = Opts::parse();
178
179 std::thread::spawn(move || {
180 let metrics_route = warp::path("metrics").and_then(metrics_handler);
181
182 let rt = tokio::runtime::Builder::new_current_thread()
183 .enable_all()
184 .build()
185 .unwrap();
186
187 rt.block_on(async move {
188 info!("Starting metrics server on port 4003");
189 warp::serve(metrics_route).run(([0, 0, 0, 0], 4003)).await;
190 });
191 });
192
193 info!("Starting load test with with {} threads", opts.thread_count);
194
195 let mut switchboard = Switchboard::new();
196
197 let outboxes: Vec<_> = (0..opts.thread_count)
198 .map(|_| {
199 let (addr, rx) = switchboard.new_outbox();
200 (format!("SERVER-{}", addr), addr, rx)
201 })
202 .collect();
203
204 let seed_nodes: Vec<_> = outboxes
205 .iter()
206 .map(|(name, addr, _)| SeedNode {
207 id: name.clone(),
208 address: *addr,
209 })
210 .collect();
211
212 outboxes.into_iter().for_each(|(name, addr, outbox)| {
213 run_server(
214 name,
215 addr,
216 outbox,
217 switchboard.clone(),
218 seed_nodes.clone(),
219 opts,
220 );
221 });
222
223 loop {
224 sleep(Duration::from_secs(1));
225 }
226}