hydro_lang/
singleton_ref.rs1use std::cell::RefCell;
4use std::marker::PhantomData;
5use std::rc::Rc;
6
7use proc_macro2::Span;
8use quote::quote;
9use stageleft::runtime_support::{FreeVariableWithContextWithProps, QuoteTokens};
10
11use crate::compile::ir::{HydroNode, SharedNode};
12use crate::location::Location;
13
14pub struct SingletonRef<'a, T, L> {
22 pub(crate) node: *const RefCell<HydroNode>,
23 _phantom: PhantomData<(&'a (), T, L)>,
24}
25impl<T, L> SingletonRef<'_, T, L> {
26 pub(crate) fn new(rc_ptr: Rc<RefCell<HydroNode>>) -> Self {
31 let node = Rc::into_raw(rc_ptr);
33 Self {
34 node,
35 _phantom: PhantomData,
36 }
37 }
38}
39
40impl<T, L> Copy for SingletonRef<'_, T, L> {}
41impl<T, L> Clone for SingletonRef<'_, T, L> {
42 fn clone(&self) -> Self {
43 *self
44 }
45}
46
47thread_local! {
50 static SINGLETON_REFS: RefCell<Option<Vec<(syn::Ident, HydroNode)>>> = const { RefCell::new(None) };
51}
52
53pub fn with_singleton_capture(
57 f: impl FnOnce() -> crate::compile::ir::DebugExpr,
58) -> crate::compile::ir::ClosureExpr {
59 SINGLETON_REFS.with(|cell| {
60 let prev = cell.borrow_mut().replace(Vec::new());
61 assert!(
62 prev.is_none(),
63 "nested singleton capture scopes are not supported"
64 );
65 });
66 let expr = f();
67 let singleton_refs = SINGLETON_REFS.with(|cell| cell.borrow_mut().take().unwrap());
68 crate::compile::ir::ClosureExpr::new(expr, singleton_refs)
69}
70
71static SINGLETON_REF_COUNTER: std::sync::atomic::AtomicUsize =
72 std::sync::atomic::AtomicUsize::new(0);
73
74impl<'a, T: 'a, L> FreeVariableWithContextWithProps<L, ()> for SingletonRef<'a, T, L>
75where
76 L: Location<'a>,
77{
78 type O = &'a T;
79
80 fn to_tokens(self, _ctx: &L) -> (QuoteTokens, ()) {
81 let id = SINGLETON_REF_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
82 let ident = syn::Ident::new(&format!("__hydro_singleton_ref_{}", id), Span::call_site());
83
84 SINGLETON_REFS.with(|cell| {
85 let mut guard = cell.borrow_mut();
86 let refs = guard.as_mut().expect(
87 "SingletonRef used inside q!() but no singleton capture scope is active. \
88 This is a bug — singleton capture should be set up by the operator that uses q!().",
89 );
90 let rc = unsafe { Rc::from_raw(self.node) };
94 let cloned = rc.clone();
95 std::mem::forget(rc); let metadata = cloned.borrow().metadata().clone(); refs.push((
99 ident.clone(),
100 HydroNode::Singleton {
101 inner: SharedNode(cloned),
102 metadata,
103 },
104 ));
105 });
106
107 (
108 QuoteTokens {
109 prelude: None,
110 expr: Some(quote!(#ident)),
111 },
112 (),
113 )
114 }
115}
116
117#[cfg(test)]
118#[cfg(feature = "build")]
119mod tests {
120 use stageleft::q;
121
122 use crate::compile::builder::FlowBuilder;
123 use crate::location::Location;
124
125 struct P1 {}
126
127 #[test]
130 fn singleton_by_ref_compiles() {
131 let mut flow = FlowBuilder::new();
132 let node = flow.process::<P1>();
133
134 let my_count = node
135 .source_iter(q!(0..5i32))
136 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
137 let count_ref = my_count.by_ref();
138
139 node.source_iter(q!(1..=3i32))
140 .map(q!(|x| x + *count_ref))
141 .for_each(q!(|_| {}));
142
143 my_count.into_stream().for_each(q!(|_| {}));
145
146 let _built = flow.finalize();
148 }
149
150 #[test]
152 fn singleton_by_ref_non_copy() {
153 let mut flow = FlowBuilder::new();
154 let node = flow.process::<P1>();
155
156 let my_vec = node.source_iter(q!(0..5i32)).fold(
157 q!(|| Vec::<i32>::new()),
158 q!(|acc: &mut Vec<i32>, x| acc.push(x)),
159 );
160 let vec_ref = my_vec.by_ref();
161
162 node.source_iter(q!(1..=3i32))
163 .map(q!(|x| x + vec_ref.len() as i32))
164 .for_each(q!(|_| {}));
165
166 my_vec.into_stream().for_each(q!(|_| {}));
168
169 let _built = flow.finalize();
170 }
171
172 #[test]
174 fn singleton_by_ref_filter() {
175 let mut flow = FlowBuilder::new();
176 let node = flow.process::<P1>();
177
178 let threshold = node
179 .source_iter(q!(0..5i32))
180 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
181 let threshold_ref = threshold.by_ref();
182
183 node.source_iter(q!(1..=10i32))
184 .filter(q!(|x| *x > *threshold_ref))
185 .for_each(q!(|_| {}));
186
187 threshold.into_stream().for_each(q!(|_| {}));
188 let _built = flow.finalize();
189 }
190
191 #[test]
193 fn singleton_by_ref_flat_map() {
194 let mut flow = FlowBuilder::new();
195 let node = flow.process::<P1>();
196
197 let count = node
198 .source_iter(q!(0..3i32))
199 .fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
200 let count_ref = count.by_ref();
201
202 node.source_iter(q!(1..=2i32))
203 .flat_map_ordered(q!(|x| (0..*count_ref).map(move |i| x + i)))
204 .for_each(q!(|_| {}));
205
206 count.into_stream().for_each(q!(|_| {}));
207 let _built = flow.finalize();
208 }
209
210 #[test]
212 fn singleton_by_ref_inspect() {
213 let mut flow = FlowBuilder::new();
214 let node = flow.process::<P1>();
215
216 let count = node
217 .source_iter(q!(0..5i32))
218 .fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
219 let count_ref = count.by_ref();
220
221 node.source_iter(q!(1..=3i32))
222 .inspect(q!(|x| println!("count={}, x={}", *count_ref, x)))
223 .for_each(q!(|_| {}));
224
225 count.into_stream().for_each(q!(|_| {}));
226 let _built = flow.finalize();
227 }
228
229 #[test]
231 fn singleton_by_ref_partition() {
232 let mut flow = FlowBuilder::new();
233 let node = flow.process::<P1>();
234
235 let threshold = node
236 .source_iter(q!(0..5i32))
237 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
238 let threshold_ref = threshold.by_ref();
239
240 let (above, below) = node
241 .source_iter(q!(1..=10i32))
242 .partition(q!(|x| *x > *threshold_ref));
243
244 above.for_each(q!(|_| {}));
245 below.for_each(q!(|_| {}));
246 threshold.into_stream().for_each(q!(|_| {}));
247 let _built = flow.finalize();
248 }
249
250 #[test]
256 fn singleton_by_ref_partition_with_downstream_ops() {
257 let mut flow = FlowBuilder::new();
258 let node = flow.process::<P1>();
259
260 let threshold = node
261 .source_iter(q!(0..5i32))
262 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
263 let threshold_ref = threshold.by_ref();
264
265 let (above, below) = node
266 .source_iter(q!(1..=10i32))
267 .partition(q!(|x| *x > *threshold_ref));
268
269 above.map(q!(|x| x * 2)).for_each(q!(|_| {}));
271 below.map(q!(|x| x + 100)).for_each(q!(|_| {}));
272 threshold.into_stream().for_each(q!(|_| {}));
273 let _built = flow.finalize();
274 }
275}