hydro_test/cluster/
kv_replica.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use hydro_lang::*;
6use serde::de::DeserializeOwned;
7use serde::{Deserialize, Serialize};
8
9pub struct Replica {}
10
11pub trait KvKey: Serialize + DeserializeOwned + Hash + Eq + Clone + Debug {}
12impl<K: Serialize + DeserializeOwned + Hash + Eq + Clone + Debug> KvKey for K {}
13
14pub trait KvValue: Serialize + DeserializeOwned + Eq + Clone + Debug {}
15impl<V: Serialize + DeserializeOwned + Eq + Clone + Debug> KvValue for V {}
16
17#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
18pub struct KvPayload<K, V> {
19    pub key: K,
20    pub value: V,
21}
22
23#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
24pub struct SequencedKv<K, V> {
25    // Note: Important that seq is the first member of the struct for sorting
26    pub seq: usize,
27    pub kv: Option<KvPayload<K, V>>,
28}
29
30impl<K: KvKey, V: KvValue> Ord for SequencedKv<K, V> {
31    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
32        self.seq.cmp(&other.seq)
33    }
34}
35
36impl<K: KvKey, V: KvValue> PartialOrd for SequencedKv<K, V> {
37    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
38        Some(self.cmp(other))
39    }
40}
41
42// Replicas. All relations for replicas will be prefixed with r. Expects ReplicaPayload on p_to_replicas, outputs a stream of (client address, ReplicaPayload) after processing.
43#[expect(clippy::type_complexity, reason = "internal paxos code // TODO")]
44pub fn kv_replica<'a, K: KvKey, V: KvValue>(
45    replicas: &Cluster<'a, Replica>,
46    p_to_replicas: impl Into<
47        Stream<(usize, Option<KvPayload<K, V>>), Cluster<'a, Replica>, Unbounded, NoOrder>,
48    >,
49    checkpoint_frequency: usize,
50) -> (
51    Stream<usize, Cluster<'a, Replica>, Unbounded>,
52    Stream<KvPayload<K, V>, Cluster<'a, Replica>, Unbounded>,
53) {
54    let p_to_replicas: Stream<SequencedKv<K, V>, Cluster<'a, Replica>, Unbounded, NoOrder> =
55        p_to_replicas
56            .into()
57            .map(q!(|(slot, kv)| SequencedKv { seq: slot, kv }));
58
59    let replica_tick = replicas.tick();
60
61    let (r_buffered_payloads_complete_cycle, r_buffered_payloads) = replica_tick.cycle();
62    // p_to_replicas.inspect(q!(|payload: ReplicaPayload| println!("Replica received payload: {:?}", payload)));
63    let r_sorted_payloads = unsafe {
64        // SAFETY: because we fill slots one-by-one, we can safely batch
65        // because non-determinism is resolved when we sort by slots
66        p_to_replicas.tick_batch(&replica_tick)
67    }
68        .chain(r_buffered_payloads) // Combine with all payloads that we've received and not processed yet
69        .sort();
70    // Create a cycle since we'll use this seq before we define it
71    let (r_next_slot_complete_cycle, r_next_slot) =
72        replica_tick.cycle_with_initial(replica_tick.singleton(q!(0)));
73    // Find highest the sequence number of any payload that can be processed in this tick. This is the payload right before a hole.
74    let r_next_slot_after_processing_payloads = r_sorted_payloads
75        .clone()
76        .cross_singleton(r_next_slot.clone())
77        .fold(
78            q!(|| 0),
79            q!(|new_next_slot, (sorted_payload, next_slot)| {
80                if sorted_payload.seq == std::cmp::max(*new_next_slot, next_slot) {
81                    *new_next_slot = sorted_payload.seq + 1;
82                }
83            }),
84        );
85    // Find all payloads that can and cannot be processed in this tick.
86    let r_processable_payloads = r_sorted_payloads
87        .clone()
88        .cross_singleton(r_next_slot_after_processing_payloads.clone())
89        .filter(q!(
90            |(sorted_payload, highest_seq)| sorted_payload.seq < *highest_seq
91        ))
92        .map(q!(|(sorted_payload, _)| { sorted_payload }));
93    let r_new_non_processable_payloads = r_sorted_payloads
94        .clone()
95        .cross_singleton(r_next_slot_after_processing_payloads.clone())
96        .filter(q!(
97            |(sorted_payload, highest_seq)| sorted_payload.seq > *highest_seq
98        ))
99        .map(q!(|(sorted_payload, _)| { sorted_payload }));
100    // Save these, we can process them once the hole has been filled
101    r_buffered_payloads_complete_cycle.complete_next_tick(r_new_non_processable_payloads);
102
103    let r_kv_store = r_processable_payloads
104        .clone()
105        .persist() // Optimization: all_ticks() + fold() = fold<static>, where the state of the previous fold is saved and persisted values are deleted.
106        .fold(q!(|| (HashMap::new(), 0)), q!(|(kv_store, next_slot), payload| {
107            if let Some(kv) = payload.kv {
108                kv_store.insert(kv.key, kv.value);
109            }
110            *next_slot = payload.seq + 1;
111        }));
112    // Update the highest seq for the next tick
113    r_next_slot_complete_cycle
114        .complete_next_tick(r_kv_store.map(q!(|(_kv_store, next_slot)| next_slot)));
115
116    // Send checkpoints to the acceptors when we've processed enough payloads
117    let (r_checkpointed_seqs_complete_cycle, r_checkpointed_seqs) =
118        replica_tick.cycle::<Optional<usize, _, _>>();
119    let r_max_checkpointed_seq = r_checkpointed_seqs.persist().max().into_singleton();
120    let r_checkpoint_seq_new = r_max_checkpointed_seq
121        .zip(r_next_slot)
122        .filter_map(q!(
123            move |(max_checkpointed_seq, next_slot)| if max_checkpointed_seq
124                .map(|m| next_slot - m >= checkpoint_frequency)
125                .unwrap_or(true)
126            {
127                Some(next_slot)
128            } else {
129                None
130            }
131        ));
132    r_checkpointed_seqs_complete_cycle.complete_next_tick(r_checkpoint_seq_new.clone());
133
134    // Tell clients that the payload has been committed. All ReplicaPayloads contain the client's machine ID (to string) as value.
135    let r_to_clients = r_processable_payloads
136        .filter_map(q!(|payload| payload.kv))
137        .all_ticks();
138    (r_checkpoint_seq_new.all_ticks(), r_to_clients)
139}