topolotree/
main.rs

1#[cfg(test)]
2mod tests;
3
4use std::cell::RefCell;
5use std::collections::HashMap;
6use std::fmt::{Debug, Display};
7use std::io;
8use std::rc::Rc;
9use std::time::Duration;
10
11use dfir_rs::bytes::{Bytes, BytesMut};
12use dfir_rs::dfir_syntax;
13use dfir_rs::scheduled::graph::Dfir;
14use dfir_rs::util::deploy::{
15    ConnectedDemux, ConnectedDirect, ConnectedSink, ConnectedSource, ConnectedTagged,
16};
17use futures::{SinkExt, Stream};
18
19mod protocol;
20use dfir_rs::scheduled::ticks::TickInstant;
21use dfir_rs::util::{deserialize_from_bytes, serialize_to_bytes};
22use protocol::*;
23use tokio::time::Instant;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26struct NodeId(pub u32);
27
28impl Display for NodeId {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        Display::fmt(&self.0, f)
31    }
32}
33
34type PostNeighborJoin = (((u64, Option<NodeId>), (i64, TickInstant)), NodeId);
35
36type ContributionAgg =
37    Rc<RefCell<HashMap<u64, HashMap<Option<NodeId>, (Timestamped<i64>, TickInstant)>>>>;
38
39fn run_topolotree(
40    neighbors: Vec<u32>,
41    input_recv: impl Stream<Item = Result<(u32, BytesMut), io::Error>> + Unpin + 'static,
42    increment_requests: impl Stream<Item = Result<BytesMut, io::Error>> + Unpin + 'static,
43    output_send: tokio::sync::mpsc::UnboundedSender<(u32, Bytes)>,
44    query_send: tokio::sync::mpsc::UnboundedSender<Bytes>,
45) -> Dfir<'static> {
46    fn merge(x: &mut i64, y: i64) {
47        *x += y;
48    }
49
50    // Timestamp stuff is a bit complicated, there is a proper data-flowy way to do it
51    // but it would require at least one more join and one more cross join just specifically for the local timestamps
52    // Until we need it to be proper then we can take a shortcut and use rc refcell
53    let self_timestamp = Rc::new(RefCell::new(HashMap::<u64, isize>::new()));
54
55    let self_timestamp1 = Rc::clone(&self_timestamp);
56    let self_timestamp2 = Rc::clone(&self_timestamp);
57    let self_timestamp3 = Rc::clone(&self_timestamp);
58
59    // we use current tick to keep track of which *keys* have been modified
60
61    dfir_syntax! {
62        parsed_input = source_stream(input_recv)
63            -> map(Result::unwrap)
64            -> map(|(src, x)| (NodeId(src), deserialize_from_bytes::<TopolotreeMessage>(&x).unwrap()))
65            -> demux(|(src, msg), var_args!(payload, ping, pong)| {
66                match msg {
67                    TopolotreeMessage::Payload(p) => payload.give((src, p)),
68                    TopolotreeMessage::Ping() => ping.give((src, ())),
69                    TopolotreeMessage::Pong() => pong.give((src, ())),
70                }
71            });
72
73        from_neighbors = parsed_input[payload] -> tee();
74        pings = parsed_input[ping] -> tee();
75        pongs = parsed_input[pong];
76
77        pings -> map(|(src, _)| (src, TopolotreeMessage::Pong())) -> output;
78
79        // generate a ping every second
80        neighbors -> [0]ping_generator;
81        source_interval(Duration::from_secs(1)) -> [1]ping_generator;
82        ping_generator = cross_join_multiset()
83            -> map(|(src, _)| (src, TopolotreeMessage::Ping()))
84            -> output;
85
86        pongs -> dead_neighbors;
87        pings -> dead_neighbors;
88        new_neighbors -> map(|neighbor| (neighbor, ())) -> dead_neighbors; // fake pong
89        dead_neighbors = union() -> fold_keyed::<'static>(Instant::now, |acc: &mut Instant, _| {
90                *acc = Instant::now();
91            })
92            -> filter_map(|(node_id, acc)| {
93                if acc.elapsed().as_secs() > 5 {
94                    Some(node_id)
95                } else {
96                    None
97                }
98            });
99
100        from_neighbors
101            -> map(|(_, payload): (NodeId, Payload<i64>)| payload.key)
102            -> touched_keys;
103
104        operations
105            -> map(|op| op.key)
106            -> touched_keys;
107
108        touched_keys = union() -> unique() -> [0]from_neighbors_unfiltered;
109
110        from_neighbors
111            -> map(|(src, payload): (NodeId, Payload<i64>)| (src, (payload.key, payload.contents)))
112            -> fold::<'static>(|| Rc::new(RefCell::new(HashMap::new())), |acc: &mut ContributionAgg, (source, (key, val)): (NodeId, (u64, Timestamped<i64>))| {
113                let mut acc = acc.borrow_mut();
114                let key_entry = acc.entry(key).or_default();
115                let src_entry = key_entry.entry(Some(source)).or_insert((Timestamped { timestamp: -1, data: 0 }, TickInstant::default()));
116                if val.timestamp > src_entry.0.timestamp {
117                    src_entry.0 = val;
118                    *self_timestamp1.borrow_mut().entry(key).or_insert(0) += 1;
119                    src_entry.1 = context.current_tick();
120                }
121            })
122            -> from_neighbors_to_filter;
123
124        from_neighbors_to_filter = union() -> [1]from_neighbors_unfiltered;
125        from_neighbors_unfiltered =
126            cross_join() ->
127            flat_map(|(key, hashmap)| {
128                let hashmap = hashmap.borrow();
129                hashmap.get(&key).iter().flat_map(|v| v.iter()).map(|t| ((key, *t.0), (t.1.0.data, t.1.1))).collect::<Vec<_>>().into_iter()
130            }) ->
131            from_neighbors_or_local;
132
133        operations = source_stream(increment_requests)
134            -> map(|x| deserialize_from_bytes::<OperationPayload>(&x.unwrap()).unwrap())
135            -> tee();
136        local_values = operations
137            -> inspect(|change| {
138                *self_timestamp2.borrow_mut().entry(change.key).or_insert(0) += 1;
139            })
140            -> map(|change_payload: OperationPayload| (change_payload.key, (change_payload.change, context.current_tick())))
141            -> fold::<'static>(|| Rc::new(RefCell::new(HashMap::new())), |agg: &mut ContributionAgg, change: (u64, (i64, TickInstant))| {
142                let mut agg = agg.borrow_mut();
143                let agg_key = agg.entry(change.0).or_default();
144                let agg_key = agg_key.entry(None).or_insert((Timestamped { timestamp: 0, data: 0 }, TickInstant::default()));
145
146                agg_key.0.data += change.1.0;
147                agg_key.1 = change.1.1;
148            });
149
150        local_values -> from_neighbors_to_filter;
151
152        from_neighbors_or_local = tee();
153        from_neighbors_or_local -> [0]all_neighbor_data;
154
155        new_neighbors = source_iter(neighbors)
156            -> map(NodeId)
157            -> tee();
158
159        new_neighbors
160            -> persist::<'static>()
161            -> [pos]neighbors;
162        dead_neighbors -> [neg]neighbors;
163        neighbors = difference()
164            -> tee();
165
166        neighbors -> [1]all_neighbor_data;
167
168        query_result = from_neighbors_or_local
169            -> map(|((key, _), payload): ((u64, _), (i64, TickInstant))| {
170                (key, payload)
171            })
172            -> reduce_keyed(|acc: &mut (i64, TickInstant), (data, change_tick): (i64, TickInstant)| {
173                merge(&mut acc.0, data);
174                acc.1 = std::cmp::max(acc.1, change_tick);
175            })
176            -> filter(|(_, (_, change_tick))| *change_tick == context.current_tick())
177            -> for_each(|(key, (data, _))| {
178                let serialized = serialize_to_bytes(QueryResponse {
179                    key,
180                    value: data
181                });
182                query_send.send(serialized).unwrap();
183            });
184
185        all_neighbor_data = cross_join_multiset()
186            -> filter(|(((_, aggregate_from_this_guy), _), target_neighbor): &PostNeighborJoin| {
187                aggregate_from_this_guy.iter().all(|source| source != target_neighbor)
188            })
189            -> map(|(((key, _), payload), target_neighbor)| {
190                ((key, target_neighbor), payload)
191            })
192            -> reduce_keyed(|acc: &mut (i64, TickInstant), (data, change_tick): (i64, TickInstant)| {
193                merge(&mut acc.0, data);
194                acc.1 = std::cmp::max(acc.1, change_tick);
195            })
196            -> filter(|(_, (_, change_tick))| *change_tick == context.current_tick())
197            -> map(|((key, target_neighbor), (data, _))| (target_neighbor, Payload {
198                key,
199                contents: Timestamped {
200                    timestamp: self_timestamp3.borrow().get(&key).copied().unwrap_or(0),
201                    data,
202                }
203            }))
204            -> map(|(target_neighbor, payload)| (target_neighbor, TopolotreeMessage::Payload(payload)))
205            -> output;
206
207        output = union() -> for_each(|(target_neighbor, output): (NodeId, TopolotreeMessage)| {
208            let serialized = serialize_to_bytes(output);
209            output_send.send((target_neighbor.0, serialized)).unwrap();
210        });
211    }
212}
213
214#[dfir_rs::main]
215async fn main() {
216    let mut args = std::env::args().skip(1);
217    let _self_id: u32 = args.next().unwrap().parse().unwrap();
218    let neighbors: Vec<u32> = args.map(|x| x.parse().unwrap()).collect();
219
220    let ports = dfir_rs::util::deploy::init::<()>().await;
221
222    let input_recv = ports
223        .port("from_peer")
224        // connect to the port with a single recipient
225        .connect::<ConnectedTagged<ConnectedDirect>>()
226        .await
227        .into_source();
228
229    let mut output_send = ports
230        .port("to_peer")
231        .connect::<ConnectedDemux<ConnectedDirect>>()
232        .await
233        .into_sink();
234
235    let operations_send = ports
236        .port("increment_requests")
237        // connect to the port with a single recipient
238        .connect::<ConnectedDirect>()
239        .await
240        .into_source();
241
242    let mut query_responses = ports
243        .port("query_responses")
244        .connect::<ConnectedDirect>()
245        .await
246        .into_sink();
247
248    let (chan_tx, mut chan_rx) = tokio::sync::mpsc::unbounded_channel();
249
250    tokio::task::spawn_local(async move {
251        while let Some(msg) = chan_rx.recv().await {
252            output_send.feed(msg).await.unwrap();
253            while let Ok(msg) = chan_rx.try_recv() {
254                output_send.feed(msg).await.unwrap();
255            }
256            output_send.flush().await.unwrap();
257        }
258    });
259
260    let (query_tx, mut query_rx) = tokio::sync::mpsc::unbounded_channel();
261    tokio::task::spawn_local(async move {
262        while let Some(msg) = query_rx.recv().await {
263            query_responses.feed(msg).await.unwrap();
264            while let Ok(msg) = query_rx.try_recv() {
265                query_responses.feed(msg).await.unwrap();
266            }
267            query_responses.flush().await.unwrap();
268        }
269    });
270
271    let flow = run_topolotree(neighbors, input_recv, operations_send, chan_tx, query_tx);
272
273    let f1 = async move {
274        #[cfg(target_os = "linux")]
275        loop {
276            let x = procinfo::pid::stat_self().unwrap();
277            let bytes = x.rss * 1024 * 4;
278            println!("memory,{}", bytes);
279            tokio::time::sleep(Duration::from_secs(1)).await;
280        }
281    };
282
283    // initial memory
284    #[cfg(target_os = "linux")]
285    {
286        let x = procinfo::pid::stat_self().unwrap();
287        let bytes = x.rss * 1024 * 4;
288        println!("memory,{}", bytes);
289    }
290
291    let f1_handle = tokio::spawn(f1);
292    dfir_rs::util::deploy::launch_flow(flow).await;
293    f1_handle.abort();
294}