1use std::fmt::{Debug, Formatter};
2use std::marker::PhantomData;
3
4use proc_macro2::Span;
5use quote::quote;
6use stageleft::runtime_support::{FreeVariableWithContextWithProps, QuoteTokens};
7use stageleft::{QuotedWithContextWithProps, quote_type};
8
9use super::dynamic::LocationId;
10use super::{Location, MemberId};
11use crate::compile::builder::FlowState;
12use crate::location::member_id::TaglessMemberId;
13use crate::staging_util::{Invariant, get_this_crate};
14
15pub struct Cluster<'a, ClusterTag> {
16 pub(crate) id: usize,
17 pub(crate) flow_state: FlowState,
18 pub(crate) _phantom: Invariant<'a, ClusterTag>,
19}
20
21impl<C> Debug for Cluster<'_, C> {
22 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23 write!(f, "Cluster({})", self.id)
24 }
25}
26
27impl<C> Eq for Cluster<'_, C> {}
28impl<C> PartialEq for Cluster<'_, C> {
29 fn eq(&self, other: &Self) -> bool {
30 self.id == other.id && FlowState::ptr_eq(&self.flow_state, &other.flow_state)
31 }
32}
33
34impl<C> Clone for Cluster<'_, C> {
35 fn clone(&self) -> Self {
36 Cluster {
37 id: self.id,
38 flow_state: self.flow_state.clone(),
39 _phantom: PhantomData,
40 }
41 }
42}
43
44impl<'a, C> super::dynamic::DynLocation for Cluster<'a, C> {
45 fn id(&self) -> LocationId {
46 LocationId::Cluster(self.id)
47 }
48
49 fn flow_state(&self) -> &FlowState {
50 &self.flow_state
51 }
52
53 fn is_top_level() -> bool {
54 true
55 }
56}
57
58impl<'a, C> Location<'a> for Cluster<'a, C> {
59 type Root = Cluster<'a, C>;
60
61 fn root(&self) -> Self::Root {
62 self.clone()
63 }
64}
65
66pub struct ClusterIds<'a> {
67 pub id: usize,
68 pub _phantom: PhantomData<&'a ()>,
69}
70
71impl<'a> Clone for ClusterIds<'a> {
72 fn clone(&self) -> Self {
73 Self {
74 id: self.id,
75 _phantom: Default::default(),
76 }
77 }
78}
79
80impl<'a, Ctx> FreeVariableWithContextWithProps<Ctx, ()> for ClusterIds<'a> {
81 type O = &'a [TaglessMemberId];
82
83 fn to_tokens(self, _ctx: &Ctx) -> (QuoteTokens, ())
84 where
85 Self: Sized,
86 {
87 let ident = syn::Ident::new(
88 &format!("__hydro_lang_cluster_ids_{}", self.id),
89 Span::call_site(),
90 );
91
92 (
93 QuoteTokens {
94 prelude: None,
95 expr: Some(quote! { #ident }),
96 },
97 (),
98 )
99 }
100}
101
102impl<'a, Ctx> QuotedWithContextWithProps<'a, &'a [TaglessMemberId], Ctx, ()> for ClusterIds<'a> {}
103
104pub trait IsCluster {
105 type Tag;
106}
107
108impl<C> IsCluster for Cluster<'_, C> {
109 type Tag = C;
110}
111
112pub static CLUSTER_SELF_ID: ClusterSelfId = ClusterSelfId { _private: &() };
115
116#[derive(Clone, Copy)]
117pub struct ClusterSelfId<'a> {
118 _private: &'a (),
119}
120
121impl<'a, L> FreeVariableWithContextWithProps<L, ()> for ClusterSelfId<'a>
122where
123 L: Location<'a>,
124 <L as Location<'a>>::Root: IsCluster,
125{
126 type O = MemberId<<<L as Location<'a>>::Root as IsCluster>::Tag>;
127
128 fn to_tokens(self, ctx: &L) -> (QuoteTokens, ())
129 where
130 Self: Sized,
131 {
132 let cluster_id = if let LocationId::Cluster(id) = ctx.root().id() {
133 id
134 } else {
135 unreachable!()
136 };
137
138 let ident = syn::Ident::new(
139 &format!("__hydro_lang_cluster_self_id_{}", cluster_id),
140 Span::call_site(),
141 );
142 let root = get_this_crate();
143 let c_type: syn::Type = quote_type::<<<L as Location<'a>>::Root as IsCluster>::Tag>();
144
145 (
146 QuoteTokens {
147 prelude: None,
148 expr: Some(
149 quote! { #root::__staged::location::MemberId::<#c_type>::from_tagless((#ident).clone()) },
150 ),
151 },
152 (),
153 )
154 }
155}
156
157impl<'a, L>
158 QuotedWithContextWithProps<'a, MemberId<<<L as Location<'a>>::Root as IsCluster>::Tag>, L, ()>
159 for ClusterSelfId<'a>
160where
161 L: Location<'a>,
162 <L as Location<'a>>::Root: IsCluster,
163{
164}
165
166#[cfg(test)]
167mod tests {
168 #[cfg(feature = "sim")]
169 use stageleft::q;
170
171 #[cfg(feature = "sim")]
172 use super::CLUSTER_SELF_ID;
173 #[cfg(feature = "sim")]
174 use crate::location::{Location, MemberId, MembershipEvent};
175 #[cfg(feature = "sim")]
176 use crate::networking::TCP;
177 #[cfg(feature = "sim")]
178 use crate::nondet::nondet;
179 #[cfg(feature = "sim")]
180 use crate::prelude::FlowBuilder;
181
182 #[cfg(feature = "sim")]
183 #[test]
184 fn sim_cluster_self_id() {
185 let flow = FlowBuilder::new();
186 let cluster1 = flow.cluster::<()>();
187 let cluster2 = flow.cluster::<()>();
188
189 let node = flow.process::<()>();
190
191 let out_recv = cluster1
192 .source_iter(q!(vec![CLUSTER_SELF_ID]))
193 .send(&node, TCP.bincode())
194 .values()
195 .interleave(
196 cluster2
197 .source_iter(q!(vec![CLUSTER_SELF_ID]))
198 .send(&node, TCP.bincode())
199 .values(),
200 )
201 .sim_output();
202
203 flow.sim()
204 .with_cluster_size(&cluster1, 3)
205 .with_cluster_size(&cluster2, 4)
206 .exhaustive(async || {
207 out_recv
208 .assert_yields_only_unordered([0, 1, 2, 0, 1, 2, 3].map(MemberId::from_raw_id))
209 .await
210 });
211 }
212
213 #[cfg(feature = "sim")]
214 #[test]
215 fn sim_cluster_with_tick() {
216 use std::collections::HashMap;
217
218 let flow = FlowBuilder::new();
219 let cluster = flow.cluster::<()>();
220 let node = flow.process::<()>();
221
222 let out_recv = cluster
223 .source_iter(q!(vec![1, 2, 3]))
224 .batch(&cluster.tick(), nondet!())
225 .count()
226 .all_ticks()
227 .send(&node, TCP.bincode())
228 .entries()
229 .map(q!(|(id, v)| (id, v)))
230 .sim_output();
231
232 let count = flow
233 .sim()
234 .with_cluster_size(&cluster, 2)
235 .exhaustive(async || {
236 let grouped = out_recv.collect_sorted::<Vec<_>>().await.into_iter().fold(
237 HashMap::new(),
238 |mut acc: HashMap<MemberId<()>, usize>, (id, v)| {
239 *acc.entry(id).or_default() += v;
240 acc
241 },
242 );
243
244 assert!(grouped.len() == 2);
245 for (_id, v) in grouped {
246 assert!(v == 3);
247 }
248 });
249
250 assert_eq!(count, 106);
251 }
255
256 #[cfg(feature = "sim")]
257 #[test]
258 fn sim_cluster_membership() {
259 let flow = FlowBuilder::new();
260 let cluster = flow.cluster::<()>();
261 let node = flow.process::<()>();
262
263 let out_recv = node
264 .source_cluster_members(&cluster)
265 .entries()
266 .map(q!(|(id, v)| (id, v)))
267 .sim_output();
268
269 flow.sim()
270 .with_cluster_size(&cluster, 2)
271 .exhaustive(async || {
272 out_recv
273 .assert_yields_only_unordered(vec![
274 (MemberId::from_raw_id(0), MembershipEvent::Joined),
275 (MemberId::from_raw_id(1), MembershipEvent::Joined),
276 ])
277 .await;
278 });
279 }
280}