1use std::cell::RefCell;
2use std::collections::{HashMap, HashSet};
3use std::ops::Deref;
4use std::rc::Rc;
5
6use dfir_rs::dfir_syntax;
7use dfir_rs::scheduled::ticks::TickInstant;
8use dfir_rs::serde::{Deserialize, Serialize};
9use dfir_rs::util::deploy::{
10 ConnectedDemux, ConnectedDirect, ConnectedSink, ConnectedSource, ConnectedTagged,
11};
12use dfir_rs::util::{deserialize_from_bytes, serialize_to_bytes};
13
14mod protocol;
15use protocol::*;
16
17type NextStateType = (u64, Rc<RefCell<(Vec<u64>, Vec<u64>)>>);
18
19#[derive(Serialize, Deserialize, Clone, Debug)]
20enum GossipOrIncrement {
21 Gossip(Vec<NextStateType>),
22 Increment(u64, i64),
23}
24
25#[dfir_rs::main]
26async fn main() {
27 let ports = dfir_rs::util::deploy::init::<()>().await;
28
29 let my_id: Vec<usize> = serde_json::from_str(&std::env::args().nth(1).unwrap()).unwrap();
30 let my_id = my_id[0];
31 let num_replicas: Vec<usize> = serde_json::from_str(&std::env::args().nth(2).unwrap()).unwrap();
32 let num_replicas = num_replicas[0];
33
34 let increment_requests = ports
35 .port("increment_requests")
36 .connect::<ConnectedDirect>()
37 .await
38 .into_source();
39
40 let query_responses = ports
41 .port("query_responses")
42 .connect::<ConnectedDirect>()
43 .await
44 .into_sink();
45
46 let to_peer = ports
47 .port("to_peer")
48 .connect::<ConnectedDemux<ConnectedDirect>>()
49 .await
50 .into_sink();
51
52 let from_peer = ports
53 .port("from_peer")
54 .connect::<ConnectedTagged<ConnectedDirect>>()
55 .await
56 .into_source();
57
58 let f1 = async move {
59 #[cfg(target_os = "linux")]
60 loop {
61 let x = procinfo::pid::stat_self().unwrap();
62 let bytes = x.rss * 1024 * 4;
63 println!("memory,{}", bytes);
64 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
65 }
66 };
67
68 let df = dfir_syntax! {
69 next_state = union()
70 -> fold::<'static>(|| (HashMap::<u64, Rc<RefCell<(Vec<u64>, Vec<u64>)>>>::new(), HashSet::new(), TickInstant::default()), |(cur_state, modified_tweets, last_tick): &mut (HashMap<_, _>, HashSet<_>, _), goi| {
71 if context.current_tick() != *last_tick {
72 modified_tweets.clear();
73 }
74
75 match goi {
76 GossipOrIncrement::Gossip(gossip) => {
77 for (counter_id, gossip_rc) in gossip.iter() {
78 let gossip_borrowed = gossip_rc.as_ref().borrow();
79 let (pos, neg) = gossip_borrowed.deref();
80 let cur_value = cur_state.entry(*counter_id).or_insert(Rc::new(RefCell::new((
81 vec![0; num_replicas], vec![0; num_replicas]
82 ))));
83 let mut cur_value = cur_value.as_ref().borrow_mut();
84
85 for i in 0..num_replicas {
86 if pos[i] > cur_value.0[i] {
87 cur_value.0[i] = pos[i];
88 modified_tweets.insert(*counter_id);
89 }
90
91 if neg[i] > cur_value.1[i] {
92 cur_value.1[i] = neg[i];
93 modified_tweets.insert(*counter_id);
94 }
95 }
96 }
97 }
98 GossipOrIncrement::Increment(counter_id, delta) => {
99 let cur_value = cur_state.entry(counter_id).or_insert(Rc::new(RefCell::new((
100 vec![0; num_replicas], vec![0; num_replicas]
101 ))));
102 let mut cur_value = cur_value.as_ref().borrow_mut();
103
104 if delta > 0 {
105 cur_value.0[my_id] += delta as u64;
106 } else {
107 cur_value.1[my_id] += (-delta) as u64;
108 }
109
110 modified_tweets.insert(counter_id);
111 }
112 }
113
114 *last_tick = context.current_tick();
115 })
116 -> filter(|(_, _, tick)| *tick == context.current_tick())
117 -> filter(|(_, modified_tweets, _)| !modified_tweets.is_empty())
118 -> map(|(state, modified_tweets, _)| modified_tweets.iter().map(|t| (*t, state.get(t).unwrap().clone())).collect::<Vec<_>>())
119 -> tee();
120
121 source_stream(from_peer)
122 -> map(|x| deserialize_from_bytes::<GossipOrIncrement>(&x.unwrap().1).unwrap())
123 -> next_state;
124
125 source_stream(increment_requests)
126 -> map(|x| deserialize_from_bytes::<OperationPayload>(&x.unwrap()).unwrap())
127 -> map(|t| GossipOrIncrement::Increment(t.key, t.change))
128 -> next_state;
129
130 all_peers = source_iter(0..num_replicas)
131 -> filter(|x| *x != my_id);
132
133 all_peers -> [0] broadcaster;
134 next_state -> [1] broadcaster;
135 broadcaster = cross_join::<'static, 'tick>()
136 -> map(|(peer, state)| {
137 (peer as u32, serialize_to_bytes(GossipOrIncrement::Gossip(state)))
138 })
139 -> dest_sink(to_peer);
140
141 next_state
142 -> flat_map(|a: Vec<NextStateType>| {
143 a.into_iter().map(|(k, rc_array)| {
144 let rc_borrowed = rc_array.as_ref().borrow();
145 let (pos, neg) = rc_borrowed.deref();
146 QueryResponse {
147 key: k,
148 value: pos.iter().sum::<u64>() as i64 - neg.iter().sum::<u64>() as i64
149 }
150 }).collect::<Vec<_>>()
151 })
152 -> map(serialize_to_bytes::<QueryResponse>)
153 -> dest_sink(query_responses);
154 };
155
156 #[cfg(target_os = "linux")]
158 {
159 let x = procinfo::pid::stat_self().unwrap();
160 let bytes = x.rss * 1024 * 4;
161 println!("memory,{}", bytes);
162 }
163
164 let f1_handle = tokio::spawn(f1);
165 dfir_rs::util::deploy::launch_flow(df).await;
166 f1_handle.abort();
167}