pn_delta/
pn_delta.rs

1use std::cell::RefCell;
2use std::collections::HashMap;
3use std::ops::Deref;
4use std::rc::Rc;
5
6use dfir_rs::dfir_syntax;
7use dfir_rs::scheduled::ticks::TickInstant;
8use dfir_rs::serde::{Deserialize, Serialize};
9use dfir_rs::util::deploy::{
10    ConnectedDemux, ConnectedDirect, ConnectedSink, ConnectedSource, ConnectedTagged,
11};
12use dfir_rs::util::{deserialize_from_bytes, serialize_to_bytes};
13
14mod protocol;
15use protocol::*;
16
17#[derive(Serialize, Deserialize, Clone, Debug)]
18enum GossipOrIncrement {
19    Gossip(Vec<(u64, (usize, u64, u64))>),
20    Increment(u64, i64),
21}
22
23type NextStateType = (u64, bool, Rc<RefCell<(Vec<u64>, Vec<u64>)>>);
24
25#[dfir_rs::main]
26async fn main() {
27    let ports = dfir_rs::util::deploy::init::<()>().await;
28
29    let my_id: Vec<usize> = serde_json::from_str(&std::env::args().nth(1).unwrap()).unwrap();
30    let my_id = my_id[0];
31    let num_replicas: Vec<usize> = serde_json::from_str(&std::env::args().nth(2).unwrap()).unwrap();
32    let num_replicas = num_replicas[0];
33
34    let increment_requests = ports
35        .port("increment_requests")
36        .connect::<ConnectedDirect>()
37        .await
38        .into_source();
39
40    let query_responses = ports
41        .port("query_responses")
42        .connect::<ConnectedDirect>()
43        .await
44        .into_sink();
45
46    let to_peer = ports
47        .port("to_peer")
48        .connect::<ConnectedDemux<ConnectedDirect>>()
49        .await
50        .into_sink();
51
52    let from_peer = ports
53        .port("from_peer")
54        .connect::<ConnectedTagged<ConnectedDirect>>()
55        .await
56        .into_source();
57
58    let f1 = async move {
59        #[cfg(target_os = "linux")]
60        loop {
61            let x = procinfo::pid::stat_self().unwrap();
62            let bytes = x.rss * 1024 * 4;
63            println!("memory,{}", bytes);
64            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
65        }
66    };
67
68    let df = dfir_syntax! {
69        next_state = union()
70            -> fold::<'static>(|| (HashMap::<u64, Rc<RefCell<(Vec<u64>, Vec<u64>)>>>::new(), HashMap::new(), TickInstant::default()), |(cur_state, modified_tweets, last_tick): &mut (HashMap<_, _>, HashMap<_, _>, _), goi| {
71                if context.current_tick() != *last_tick {
72                    modified_tweets.clear();
73                }
74
75                match goi {
76                    GossipOrIncrement::Gossip(gossip) => {
77                        for (counter_id, (gossip_i, gossip_pos, gossip_neg)) in gossip.iter() {
78                            let gossip_i = *gossip_i;
79                            let cur_value = cur_state.entry(*counter_id).or_insert(Rc::new(RefCell::new((
80                                vec![0; num_replicas], vec![0; num_replicas]
81                            ))));
82                            let mut cur_value = cur_value.as_ref().borrow_mut();
83
84                            if *gossip_pos > cur_value.0[gossip_i] {
85                                cur_value.0[gossip_i] = *gossip_pos;
86                                modified_tweets.entry(*counter_id).or_insert(false);
87                            }
88
89                            if *gossip_neg > cur_value.1[gossip_i] {
90                                cur_value.1[gossip_i] = *gossip_neg;
91                                modified_tweets.entry(*counter_id).or_insert(false);
92                            }
93                        }
94                    }
95                    GossipOrIncrement::Increment(counter_id, delta) => {
96                        let cur_value = cur_state.entry(counter_id).or_insert(Rc::new(RefCell::new((
97                            vec![0; num_replicas], vec![0; num_replicas]
98                        ))));
99                        let mut cur_value = cur_value.as_ref().borrow_mut();
100
101                        if delta > 0 {
102                            cur_value.0[my_id] += delta as u64;
103                        } else {
104                            cur_value.1[my_id] += (-delta) as u64;
105                        }
106
107                        *modified_tweets.entry(counter_id).or_insert(false) |= true;
108                    }
109                }
110
111                *last_tick = context.current_tick();
112            })
113            -> filter(|(_, _, tick)| *tick == context.current_tick())
114            -> filter(|(_, modified_tweets, _)| !modified_tweets.is_empty())
115            -> map(|(state, modified_tweets, _)| modified_tweets.iter().map(|(t, is_local)| (*t, *is_local, state.get(t).unwrap().clone())).collect::<Vec<_>>())
116            -> tee();
117
118        source_stream(from_peer)
119            -> map(|x| deserialize_from_bytes::<GossipOrIncrement>(&x.unwrap().1).unwrap())
120            -> next_state;
121
122        source_stream(increment_requests)
123            -> map(|x| deserialize_from_bytes::<OperationPayload>(&x.unwrap()).unwrap())
124            -> map(|t| GossipOrIncrement::Increment(t.key, t.change))
125            -> next_state;
126
127        all_peers = source_iter(0..num_replicas)
128            -> filter(|x| *x != my_id);
129
130        all_peers -> [0] broadcaster;
131        next_state -> [1] broadcaster;
132        broadcaster = cross_join::<'static, 'tick>()
133            -> map(|(peer, state): (_, Vec<NextStateType>)| {
134                (peer as u32, state.iter().filter(|t| t.1).map(|(k, _, v)| (*k, (my_id, v.as_ref().borrow().0[my_id], v.as_ref().borrow().1[my_id]))).collect())
135            })
136            -> filter(|(_, gossip): &(_, Vec<_>)| !gossip.is_empty())
137            -> map(|(peer, gossip): (_, _)| {
138                (peer, serialize_to_bytes(GossipOrIncrement::Gossip(gossip)))
139            })
140            -> dest_sink(to_peer);
141
142        next_state
143            -> flat_map(|a: Vec<NextStateType>| {
144                a.into_iter().map(|(k, _, rc_array)| {
145                    let rc_borrowed = rc_array.as_ref().borrow();
146                    let (pos, neg) = rc_borrowed.deref();
147                    QueryResponse {
148                        key: k,
149                        value: pos.iter().sum::<u64>() as i64 - neg.iter().sum::<u64>() as i64
150                    }
151                }).collect::<Vec<_>>()
152            })
153            -> map(serialize_to_bytes::<QueryResponse>)
154            -> dest_sink(query_responses);
155    };
156
157    // initial memory
158    #[cfg(target_os = "linux")]
159    {
160        let x = procinfo::pid::stat_self().unwrap();
161        let bytes = x.rss * 1024 * 4;
162        println!("memory,{}", bytes);
163    }
164
165    let f1_handle = tokio::spawn(f1);
166    dfir_rs::util::deploy::launch_flow(df).await;
167    f1_handle.abort();
168}