hydro_lang/location/
cluster.rs

1use std::fmt::{Debug, Formatter};
2use std::marker::PhantomData;
3
4use proc_macro2::Span;
5use quote::quote;
6use stageleft::runtime_support::{FreeVariableWithContextWithProps, QuoteTokens};
7use stageleft::{QuotedWithContextWithProps, quote_type};
8
9use super::dynamic::LocationId;
10use super::{Location, MemberId};
11use crate::compile::builder::FlowState;
12use crate::location::member_id::TaglessMemberId;
13use crate::staging_util::{Invariant, get_this_crate};
14
15pub struct Cluster<'a, ClusterTag> {
16    pub(crate) id: usize,
17    pub(crate) flow_state: FlowState,
18    pub(crate) _phantom: Invariant<'a, ClusterTag>,
19}
20
21impl<C> Debug for Cluster<'_, C> {
22    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23        write!(f, "Cluster({})", self.id)
24    }
25}
26
27impl<C> Eq for Cluster<'_, C> {}
28impl<C> PartialEq for Cluster<'_, C> {
29    fn eq(&self, other: &Self) -> bool {
30        self.id == other.id && FlowState::ptr_eq(&self.flow_state, &other.flow_state)
31    }
32}
33
34impl<C> Clone for Cluster<'_, C> {
35    fn clone(&self) -> Self {
36        Cluster {
37            id: self.id,
38            flow_state: self.flow_state.clone(),
39            _phantom: PhantomData,
40        }
41    }
42}
43
44impl<'a, C> super::dynamic::DynLocation for Cluster<'a, C> {
45    fn id(&self) -> LocationId {
46        LocationId::Cluster(self.id)
47    }
48
49    fn flow_state(&self) -> &FlowState {
50        &self.flow_state
51    }
52
53    fn is_top_level() -> bool {
54        true
55    }
56}
57
58impl<'a, C> Location<'a> for Cluster<'a, C> {
59    type Root = Cluster<'a, C>;
60
61    fn root(&self) -> Self::Root {
62        self.clone()
63    }
64}
65
66pub struct ClusterIds<'a> {
67    pub id: usize,
68    pub _phantom: PhantomData<&'a ()>,
69}
70
71impl<'a> Clone for ClusterIds<'a> {
72    fn clone(&self) -> Self {
73        Self {
74            id: self.id,
75            _phantom: Default::default(),
76        }
77    }
78}
79
80impl<'a, Ctx> FreeVariableWithContextWithProps<Ctx, ()> for ClusterIds<'a> {
81    type O = &'a [TaglessMemberId];
82
83    fn to_tokens(self, _ctx: &Ctx) -> (QuoteTokens, ())
84    where
85        Self: Sized,
86    {
87        let ident = syn::Ident::new(
88            &format!("__hydro_lang_cluster_ids_{}", self.id),
89            Span::call_site(),
90        );
91
92        (
93            QuoteTokens {
94                prelude: None,
95                expr: Some(quote! { #ident }),
96            },
97            (),
98        )
99    }
100}
101
102impl<'a, Ctx> QuotedWithContextWithProps<'a, &'a [TaglessMemberId], Ctx, ()> for ClusterIds<'a> {}
103
104pub trait IsCluster {
105    type Tag;
106}
107
108impl<C> IsCluster for Cluster<'_, C> {
109    type Tag = C;
110}
111
112/// A free variable representing the cluster's own ID. When spliced in
113/// a quoted snippet that will run on a cluster, this turns into a [`MemberId`].
114pub static CLUSTER_SELF_ID: ClusterSelfId = ClusterSelfId { _private: &() };
115
116#[derive(Clone, Copy)]
117pub struct ClusterSelfId<'a> {
118    _private: &'a (),
119}
120
121impl<'a, L> FreeVariableWithContextWithProps<L, ()> for ClusterSelfId<'a>
122where
123    L: Location<'a>,
124    <L as Location<'a>>::Root: IsCluster,
125{
126    type O = MemberId<<<L as Location<'a>>::Root as IsCluster>::Tag>;
127
128    fn to_tokens(self, ctx: &L) -> (QuoteTokens, ())
129    where
130        Self: Sized,
131    {
132        let cluster_id = if let LocationId::Cluster(id) = ctx.root().id() {
133            id
134        } else {
135            unreachable!()
136        };
137
138        let ident = syn::Ident::new(
139            &format!("__hydro_lang_cluster_self_id_{}", cluster_id),
140            Span::call_site(),
141        );
142        let root = get_this_crate();
143        let c_type: syn::Type = quote_type::<<<L as Location<'a>>::Root as IsCluster>::Tag>();
144
145        (
146            QuoteTokens {
147                prelude: None,
148                expr: Some(
149                    quote! { #root::__staged::location::MemberId::<#c_type>::from_tagless((#ident).clone()) },
150                ),
151            },
152            (),
153        )
154    }
155}
156
157impl<'a, L>
158    QuotedWithContextWithProps<'a, MemberId<<<L as Location<'a>>::Root as IsCluster>::Tag>, L, ()>
159    for ClusterSelfId<'a>
160where
161    L: Location<'a>,
162    <L as Location<'a>>::Root: IsCluster,
163{
164}
165
166#[cfg(test)]
167mod tests {
168    #[cfg(feature = "sim")]
169    use stageleft::q;
170
171    #[cfg(feature = "sim")]
172    use super::CLUSTER_SELF_ID;
173    #[cfg(feature = "sim")]
174    use crate::location::{Location, MemberId, MembershipEvent};
175    #[cfg(feature = "sim")]
176    use crate::networking::TCP;
177    #[cfg(feature = "sim")]
178    use crate::nondet::nondet;
179    #[cfg(feature = "sim")]
180    use crate::prelude::FlowBuilder;
181
182    #[cfg(feature = "sim")]
183    #[test]
184    fn sim_cluster_self_id() {
185        let flow = FlowBuilder::new();
186        let cluster1 = flow.cluster::<()>();
187        let cluster2 = flow.cluster::<()>();
188
189        let node = flow.process::<()>();
190
191        let out_recv = cluster1
192            .source_iter(q!(vec![CLUSTER_SELF_ID]))
193            .send(&node, TCP.bincode())
194            .values()
195            .interleave(
196                cluster2
197                    .source_iter(q!(vec![CLUSTER_SELF_ID]))
198                    .send(&node, TCP.bincode())
199                    .values(),
200            )
201            .sim_output();
202
203        flow.sim()
204            .with_cluster_size(&cluster1, 3)
205            .with_cluster_size(&cluster2, 4)
206            .exhaustive(async || {
207                out_recv
208                    .assert_yields_only_unordered([0, 1, 2, 0, 1, 2, 3].map(MemberId::from_raw_id))
209                    .await
210            });
211    }
212
213    #[cfg(feature = "sim")]
214    #[test]
215    fn sim_cluster_with_tick() {
216        use std::collections::HashMap;
217
218        let flow = FlowBuilder::new();
219        let cluster = flow.cluster::<()>();
220        let node = flow.process::<()>();
221
222        let out_recv = cluster
223            .source_iter(q!(vec![1, 2, 3]))
224            .batch(&cluster.tick(), nondet!(/** test */))
225            .count()
226            .all_ticks()
227            .send(&node, TCP.bincode())
228            .entries()
229            .map(q!(|(id, v)| (id, v)))
230            .sim_output();
231
232        let count = flow
233            .sim()
234            .with_cluster_size(&cluster, 2)
235            .exhaustive(async || {
236                let grouped = out_recv.collect_sorted::<Vec<_>>().await.into_iter().fold(
237                    HashMap::new(),
238                    |mut acc: HashMap<MemberId<()>, usize>, (id, v)| {
239                        *acc.entry(id).or_default() += v;
240                        acc
241                    },
242                );
243
244                assert!(grouped.len() == 2);
245                for (_id, v) in grouped {
246                    assert!(v == 3);
247                }
248            });
249
250        assert_eq!(count, 106);
251        // not a square because we simulate all interleavings of ticks across 2 cluster members
252        // eventually, we should be able to identify that the members are independent (because
253        // there are no dataflow cycles) and avoid simulating redundant interleavings
254    }
255
256    #[cfg(feature = "sim")]
257    #[test]
258    fn sim_cluster_membership() {
259        let flow = FlowBuilder::new();
260        let cluster = flow.cluster::<()>();
261        let node = flow.process::<()>();
262
263        let out_recv = node
264            .source_cluster_members(&cluster)
265            .entries()
266            .map(q!(|(id, v)| (id, v)))
267            .sim_output();
268
269        flow.sim()
270            .with_cluster_size(&cluster, 2)
271            .exhaustive(async || {
272                out_recv
273                    .assert_yields_only_unordered(vec![
274                        (MemberId::from_raw_id(0), MembershipEvent::Joined),
275                        (MemberId::from_raw_id(1), MembershipEvent::Joined),
276                    ])
277                    .await;
278            });
279    }
280}