1#[cfg(test)]
2mod tests;
3
4use std::cell::RefCell;
5use std::collections::HashMap;
6use std::fmt::{Debug, Display};
7use std::io;
8use std::rc::Rc;
9use std::time::Duration;
10
11use dfir_rs::bytes::{Bytes, BytesMut};
12use dfir_rs::dfir_syntax;
13use dfir_rs::scheduled::graph::Dfir;
14use dfir_rs::util::deploy::{
15 ConnectedDemux, ConnectedDirect, ConnectedSink, ConnectedSource, ConnectedTagged,
16};
17use futures::{SinkExt, Stream};
18
19mod protocol;
20use dfir_rs::scheduled::ticks::TickInstant;
21use dfir_rs::util::{deserialize_from_bytes, serialize_to_bytes};
22use protocol::*;
23use tokio::time::Instant;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26struct NodeId(pub u32);
27
28impl Display for NodeId {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 Display::fmt(&self.0, f)
31 }
32}
33
34type PostNeighborJoin = (((u64, Option<NodeId>), (i64, TickInstant)), NodeId);
35
36type ContributionAgg =
37 Rc<RefCell<HashMap<u64, HashMap<Option<NodeId>, (Timestamped<i64>, TickInstant)>>>>;
38
39fn run_topolotree(
40 neighbors: Vec<u32>,
41 input_recv: impl Stream<Item = Result<(u32, BytesMut), io::Error>> + Unpin + 'static,
42 increment_requests: impl Stream<Item = Result<BytesMut, io::Error>> + Unpin + 'static,
43 output_send: tokio::sync::mpsc::UnboundedSender<(u32, Bytes)>,
44 query_send: tokio::sync::mpsc::UnboundedSender<Bytes>,
45) -> Dfir<'static> {
46 fn merge(x: &mut i64, y: i64) {
47 *x += y;
48 }
49
50 let self_timestamp = Rc::new(RefCell::new(HashMap::<u64, isize>::new()));
54
55 let self_timestamp1 = Rc::clone(&self_timestamp);
56 let self_timestamp2 = Rc::clone(&self_timestamp);
57 let self_timestamp3 = Rc::clone(&self_timestamp);
58
59 dfir_syntax! {
62 parsed_input = source_stream(input_recv)
63 -> map(Result::unwrap)
64 -> map(|(src, x)| (NodeId(src), deserialize_from_bytes::<TopolotreeMessage>(&x).unwrap()))
65 -> demux(|(src, msg), var_args!(payload, ping, pong)| {
66 match msg {
67 TopolotreeMessage::Payload(p) => payload.give((src, p)),
68 TopolotreeMessage::Ping() => ping.give((src, ())),
69 TopolotreeMessage::Pong() => pong.give((src, ())),
70 }
71 });
72
73 from_neighbors = parsed_input[payload] -> tee();
74 pings = parsed_input[ping] -> tee();
75 pongs = parsed_input[pong];
76
77 pings -> map(|(src, _)| (src, TopolotreeMessage::Pong())) -> output;
78
79 neighbors -> [0]ping_generator;
81 source_interval(Duration::from_secs(1)) -> [1]ping_generator;
82 ping_generator = cross_join_multiset()
83 -> map(|(src, _)| (src, TopolotreeMessage::Ping()))
84 -> output;
85
86 pongs -> dead_neighbors;
87 pings -> dead_neighbors;
88 new_neighbors -> map(|neighbor| (neighbor, ())) -> dead_neighbors; dead_neighbors = union() -> fold_keyed::<'static>(Instant::now, |acc: &mut Instant, _| {
90 *acc = Instant::now();
91 })
92 -> filter_map(|(node_id, acc)| {
93 if acc.elapsed().as_secs() > 5 {
94 Some(node_id)
95 } else {
96 None
97 }
98 });
99
100 from_neighbors
101 -> map(|(_, payload): (NodeId, Payload<i64>)| payload.key)
102 -> touched_keys;
103
104 operations
105 -> map(|op| op.key)
106 -> touched_keys;
107
108 touched_keys = union() -> unique() -> [0]from_neighbors_unfiltered;
109
110 from_neighbors
111 -> map(|(src, payload): (NodeId, Payload<i64>)| (src, (payload.key, payload.contents)))
112 -> fold::<'static>(|| Rc::new(RefCell::new(HashMap::new())), |acc: &mut ContributionAgg, (source, (key, val)): (NodeId, (u64, Timestamped<i64>))| {
113 let mut acc = acc.borrow_mut();
114 let key_entry = acc.entry(key).or_default();
115 let src_entry = key_entry.entry(Some(source)).or_insert((Timestamped { timestamp: -1, data: 0 }, TickInstant::default()));
116 if val.timestamp > src_entry.0.timestamp {
117 src_entry.0 = val;
118 *self_timestamp1.borrow_mut().entry(key).or_insert(0) += 1;
119 src_entry.1 = context.current_tick();
120 }
121 })
122 -> from_neighbors_to_filter;
123
124 from_neighbors_to_filter = union() -> [1]from_neighbors_unfiltered;
125 from_neighbors_unfiltered =
126 cross_join() ->
127 flat_map(|(key, hashmap)| {
128 let hashmap = hashmap.borrow();
129 hashmap.get(&key).iter().flat_map(|v| v.iter()).map(|t| ((key, *t.0), (t.1.0.data, t.1.1))).collect::<Vec<_>>().into_iter()
130 }) ->
131 from_neighbors_or_local;
132
133 operations = source_stream(increment_requests)
134 -> map(|x| deserialize_from_bytes::<OperationPayload>(&x.unwrap()).unwrap())
135 -> tee();
136 local_values = operations
137 -> inspect(|change| {
138 *self_timestamp2.borrow_mut().entry(change.key).or_insert(0) += 1;
139 })
140 -> map(|change_payload: OperationPayload| (change_payload.key, (change_payload.change, context.current_tick())))
141 -> fold::<'static>(|| Rc::new(RefCell::new(HashMap::new())), |agg: &mut ContributionAgg, change: (u64, (i64, TickInstant))| {
142 let mut agg = agg.borrow_mut();
143 let agg_key = agg.entry(change.0).or_default();
144 let agg_key = agg_key.entry(None).or_insert((Timestamped { timestamp: 0, data: 0 }, TickInstant::default()));
145
146 agg_key.0.data += change.1.0;
147 agg_key.1 = change.1.1;
148 });
149
150 local_values -> from_neighbors_to_filter;
151
152 from_neighbors_or_local = tee();
153 from_neighbors_or_local -> [0]all_neighbor_data;
154
155 new_neighbors = source_iter(neighbors)
156 -> map(NodeId)
157 -> tee();
158
159 new_neighbors
160 -> persist::<'static>()
161 -> [pos]neighbors;
162 dead_neighbors -> [neg]neighbors;
163 neighbors = difference()
164 -> tee();
165
166 neighbors -> [1]all_neighbor_data;
167
168 query_result = from_neighbors_or_local
169 -> map(|((key, _), payload): ((u64, _), (i64, TickInstant))| {
170 (key, payload)
171 })
172 -> reduce_keyed(|acc: &mut (i64, TickInstant), (data, change_tick): (i64, TickInstant)| {
173 merge(&mut acc.0, data);
174 acc.1 = std::cmp::max(acc.1, change_tick);
175 })
176 -> filter(|(_, (_, change_tick))| *change_tick == context.current_tick())
177 -> for_each(|(key, (data, _))| {
178 let serialized = serialize_to_bytes(QueryResponse {
179 key,
180 value: data
181 });
182 query_send.send(serialized).unwrap();
183 });
184
185 all_neighbor_data = cross_join_multiset()
186 -> filter(|(((_, aggregate_from_this_guy), _), target_neighbor): &PostNeighborJoin| {
187 aggregate_from_this_guy.iter().all(|source| source != target_neighbor)
188 })
189 -> map(|(((key, _), payload), target_neighbor)| {
190 ((key, target_neighbor), payload)
191 })
192 -> reduce_keyed(|acc: &mut (i64, TickInstant), (data, change_tick): (i64, TickInstant)| {
193 merge(&mut acc.0, data);
194 acc.1 = std::cmp::max(acc.1, change_tick);
195 })
196 -> filter(|(_, (_, change_tick))| *change_tick == context.current_tick())
197 -> map(|((key, target_neighbor), (data, _))| (target_neighbor, Payload {
198 key,
199 contents: Timestamped {
200 timestamp: self_timestamp3.borrow().get(&key).copied().unwrap_or(0),
201 data,
202 }
203 }))
204 -> map(|(target_neighbor, payload)| (target_neighbor, TopolotreeMessage::Payload(payload)))
205 -> output;
206
207 output = union() -> for_each(|(target_neighbor, output): (NodeId, TopolotreeMessage)| {
208 let serialized = serialize_to_bytes(output);
209 output_send.send((target_neighbor.0, serialized)).unwrap();
210 });
211 }
212}
213
214#[dfir_rs::main]
215async fn main() {
216 let mut args = std::env::args().skip(1);
217 let _self_id: u32 = args.next().unwrap().parse().unwrap();
218 let neighbors: Vec<u32> = args.map(|x| x.parse().unwrap()).collect();
219
220 let ports = dfir_rs::util::deploy::init::<()>().await;
221
222 let input_recv = ports
223 .port("from_peer")
224 .connect::<ConnectedTagged<ConnectedDirect>>()
226 .await
227 .into_source();
228
229 let mut output_send = ports
230 .port("to_peer")
231 .connect::<ConnectedDemux<ConnectedDirect>>()
232 .await
233 .into_sink();
234
235 let operations_send = ports
236 .port("increment_requests")
237 .connect::<ConnectedDirect>()
239 .await
240 .into_source();
241
242 let mut query_responses = ports
243 .port("query_responses")
244 .connect::<ConnectedDirect>()
245 .await
246 .into_sink();
247
248 let (chan_tx, mut chan_rx) = tokio::sync::mpsc::unbounded_channel();
249
250 tokio::task::spawn_local(async move {
251 while let Some(msg) = chan_rx.recv().await {
252 output_send.feed(msg).await.unwrap();
253 while let Ok(msg) = chan_rx.try_recv() {
254 output_send.feed(msg).await.unwrap();
255 }
256 output_send.flush().await.unwrap();
257 }
258 });
259
260 let (query_tx, mut query_rx) = tokio::sync::mpsc::unbounded_channel();
261 tokio::task::spawn_local(async move {
262 while let Some(msg) = query_rx.recv().await {
263 query_responses.feed(msg).await.unwrap();
264 while let Ok(msg) = query_rx.try_recv() {
265 query_responses.feed(msg).await.unwrap();
266 }
267 query_responses.flush().await.unwrap();
268 }
269 });
270
271 let flow = run_topolotree(neighbors, input_recv, operations_send, chan_tx, query_tx);
272
273 let f1 = async move {
274 #[cfg(target_os = "linux")]
275 loop {
276 let x = procinfo::pid::stat_self().unwrap();
277 let bytes = x.rss * 1024 * 4;
278 println!("memory,{}", bytes);
279 tokio::time::sleep(Duration::from_secs(1)).await;
280 }
281 };
282
283 #[cfg(target_os = "linux")]
285 {
286 let x = procinfo::pid::stat_self().unwrap();
287 let bytes = x.rss * 1024 * 4;
288 println!("memory,{}", bytes);
289 }
290
291 let f1_handle = tokio::spawn(f1);
292 dfir_rs::util::deploy::launch_flow(flow).await;
293 f1_handle.abort();
294}