1use std::cell::RefCell;
9use std::marker::PhantomData;
10use std::rc::Rc;
11
12use proc_macro2::Span;
13use quote::quote;
14use stageleft::runtime_support::{FreeVariableWithContextWithProps, QuoteTokens};
15
16use crate::compile::ir::{HydroNode, SharedNode};
17use crate::location::Location;
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
21pub enum HandoffRefKind {
22 Singleton,
24 Optional,
26 Vec,
28}
29
30thread_local! {
34 static CAPTURED_REFS: RefCell<Option<Vec<(HydroNode, bool)>>> = const { RefCell::new(None) };
35}
36
37pub(crate) fn handoff_ref_ident(index: usize) -> syn::Ident {
39 syn::Ident::new(
40 &format!("__hydro_singleton_ref_{}", index),
41 Span::call_site(),
42 )
43}
44
45pub fn with_ref_capture(
49 f: impl FnOnce() -> crate::compile::ir::DebugExpr,
50) -> crate::compile::ir::ClosureExpr {
51 CAPTURED_REFS.with(|cell| {
52 let prev = cell.borrow_mut().replace(Vec::new());
53 assert!(
54 prev.is_none(),
55 "nested handoff reference capture scopes are not supported"
56 );
57 });
58 let expr = (f)();
59 let captured_refs = CAPTURED_REFS.with(|cell| cell.borrow_mut().take().unwrap());
60 crate::compile::ir::ClosureExpr::new(expr, captured_refs)
61}
62
63fn register_handoff_ref(
66 ir_node: &RefCell<HydroNode>,
67 is_mut: bool,
68 kind: HandoffRefKind,
69) -> syn::Ident {
70 CAPTURED_REFS.with(|cell| {
71 let mut guard = cell.borrow_mut();
72 let refs = guard.as_mut().expect(
73 "HandoffRef used inside q!() but no reference capture scope is active. \
74 This is a bug — reference capture should be set up by the operator that uses q!().",
75 );
76
77 let index = refs.len();
78 let ident = handoff_ref_ident(index);
79
80 let metadata = ir_node.borrow().metadata().clone();
81
82 if !matches!(&*ir_node.borrow(), HydroNode::Reference { .. }) {
85 let orig = ir_node.replace(HydroNode::Placeholder);
86 *ir_node.borrow_mut() = HydroNode::Reference {
87 inner: SharedNode(Rc::new(RefCell::new(orig))),
88 kind,
89 metadata: metadata.clone(),
90 };
91 }
92
93 let borrow: std::cell::Ref<'_, HydroNode> = ir_node.borrow();
94 let HydroNode::Reference { inner, .. } = &*borrow else {
95 unreachable!()
96 };
97
98 refs.push((
99 HydroNode::Reference {
100 inner: SharedNode(Rc::clone(&inner.0)),
101 kind,
102 metadata,
103 },
104 is_mut,
105 ));
106
107 ident
108 })
109}
110
111macro_rules! define_handoff_ref {
113 (
114 $(#[$meta:meta])*
115 $name:ident, $is_mut:expr, $kind:expr, $output:ty
116 ) => {
117 $(#[$meta])*
118 pub struct $name<'a, 'slf, T, L> {
119 pub(crate) ir_node: &'slf RefCell<HydroNode>,
120 _phantom: PhantomData<(&'a T, L)>,
121 }
122
123 impl<'slf, T, L> $name<'_, 'slf, T, L> {
124 pub(crate) fn new(ir_node: &'slf RefCell<HydroNode>) -> Self {
126 Self {
127 ir_node,
128 _phantom: PhantomData,
129 }
130 }
131 }
132
133 impl<T, L> Copy for $name<'_, '_, T, L> {}
134 impl<T, L> Clone for $name<'_, '_, T, L> {
135 fn clone(&self) -> Self {
136 *self
137 }
138 }
139
140 impl<'a, 'slf, T: 'a, L> FreeVariableWithContextWithProps<L, ()> for $name<'a, 'slf, T, L>
141 where
142 L: Location<'a>,
143 {
144 type O = $output;
145
146 fn to_tokens(self, _ctx: &L) -> (QuoteTokens, ()) {
147 let ident = register_handoff_ref(self.ir_node, $is_mut, $kind);
148 (
149 QuoteTokens {
150 prelude: None,
151 expr: Some(quote!(#ident)),
152 },
153 (),
154 )
155 }
156 }
157 };
158}
159
160define_handoff_ref!(
161 SingletonRef, false, HandoffRefKind::Singleton, &'a T
165);
166
167define_handoff_ref!(
168 SingletonMut, true, HandoffRefKind::Singleton, &'a mut T
172);
173
174define_handoff_ref!(
175 OptionalRef, false, HandoffRefKind::Optional, &'a Option<T>
179);
180
181define_handoff_ref!(
182 OptionalMut, true, HandoffRefKind::Optional, &'a mut Option<T>
186);
187
188define_handoff_ref!(
189 StreamRef, false, HandoffRefKind::Vec, &'a Vec<T>
193);
194
195define_handoff_ref!(
196 StreamMut, true, HandoffRefKind::Vec, &'a mut Vec<T>
200);
201
202#[cfg(test)]
203#[cfg(feature = "build")]
204mod tests {
205 use stageleft::q;
206
207 use crate::compile::builder::FlowBuilder;
208 use crate::location::Location;
209
210 struct P1 {}
211
212 #[test]
214 fn singleton_by_ref_compiles() {
215 let mut flow = FlowBuilder::new();
216 let node = flow.process::<P1>();
217
218 let my_count = node
219 .source_iter(q!(0..5i32))
220 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
221 let count_ref = my_count.by_ref();
222
223 node.source_iter(q!(1..=3i32))
224 .map(q!(|x| x + *count_ref))
225 .for_each(q!(|_| {}));
226
227 my_count.into_stream().for_each(q!(|_| {}));
228 let _built = flow.finalize();
229 }
230
231 #[test]
233 fn singleton_by_ref_non_copy() {
234 let mut flow = FlowBuilder::new();
235 let node = flow.process::<P1>();
236
237 let my_vec = node.source_iter(q!(0..5i32)).fold(
238 q!(|| Vec::<i32>::new()),
239 q!(|acc: &mut Vec<i32>, x| acc.push(x)),
240 );
241 let vec_ref = my_vec.by_ref();
242
243 node.source_iter(q!(1..=3i32))
244 .map(q!(|x| x + vec_ref.len() as i32))
245 .for_each(q!(|_| {}));
246
247 my_vec.into_stream().for_each(q!(|_| {}));
248 let _built = flow.finalize();
249 }
250
251 #[test]
253 fn singleton_by_ref_filter() {
254 let mut flow = FlowBuilder::new();
255 let node = flow.process::<P1>();
256
257 let threshold = node
258 .source_iter(q!(0..5i32))
259 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
260 let threshold_ref = threshold.by_ref();
261
262 node.source_iter(q!(1..=10i32))
263 .filter(q!(|x| *x > *threshold_ref))
264 .for_each(q!(|_| {}));
265
266 threshold.into_stream().for_each(q!(|_| {}));
267 let _built = flow.finalize();
268 }
269
270 #[test]
272 fn singleton_by_ref_flat_map() {
273 let mut flow = FlowBuilder::new();
274 let node = flow.process::<P1>();
275
276 let count = node
277 .source_iter(q!(0..3i32))
278 .fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
279 let count_ref = count.by_ref();
280
281 node.source_iter(q!(1..=2i32))
282 .flat_map_ordered(q!(|x| (0..*count_ref).map(move |i| x + i)))
283 .for_each(q!(|_| {}));
284
285 count.into_stream().for_each(q!(|_| {}));
286 let _built = flow.finalize();
287 }
288
289 #[test]
291 fn singleton_by_ref_inspect() {
292 let mut flow = FlowBuilder::new();
293 let node = flow.process::<P1>();
294
295 let count = node
296 .source_iter(q!(0..5i32))
297 .fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
298 let count_ref = count.by_ref();
299
300 node.source_iter(q!(1..=3i32))
301 .inspect(q!(|x| println!("count={}, x={}", *count_ref, x)))
302 .for_each(q!(|_| {}));
303
304 count.into_stream().for_each(q!(|_| {}));
305 let _built = flow.finalize();
306 }
307
308 #[test]
310 fn singleton_by_ref_partition() {
311 let mut flow = FlowBuilder::new();
312 let node = flow.process::<P1>();
313
314 let threshold = node
315 .source_iter(q!(0..5i32))
316 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
317 let threshold_ref = threshold.by_ref();
318
319 let (above, below) = node
320 .source_iter(q!(1..=10i32))
321 .partition(q!(|x| *x > *threshold_ref));
322
323 above.for_each(q!(|_| {}));
324 below.for_each(q!(|_| {}));
325 threshold.into_stream().for_each(q!(|_| {}));
326 let _built = flow.finalize();
327 }
328
329 #[test]
331 fn singleton_by_ref_partition_with_downstream_ops() {
332 let mut flow = FlowBuilder::new();
333 let node = flow.process::<P1>();
334
335 let threshold = node
336 .source_iter(q!(0..5i32))
337 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
338 let threshold_ref = threshold.by_ref();
339
340 let (above, below) = node
341 .source_iter(q!(1..=10i32))
342 .partition(q!(|x| *x > *threshold_ref));
343
344 above.map(q!(|x| x * 2)).for_each(q!(|_| {}));
345 below.map(q!(|x| x + 100)).for_each(q!(|_| {}));
346 threshold.into_stream().for_each(q!(|_| {}));
347 let _built = flow.finalize();
348 }
349
350 #[test]
352 fn singleton_by_mut_compiles() {
353 let mut flow = FlowBuilder::new();
354 let node = flow.process::<P1>();
355
356 let my_count = node
357 .source_iter(q!(0..5i32))
358 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
359 let count_mut = my_count.by_mut();
360
361 node.source_iter(q!(1..=3i32))
362 .map(q!(|x| {
363 *count_mut += x;
364 x
365 }))
366 .for_each(q!(|_| {}));
367
368 my_count.into_stream().for_each(q!(|_| {}));
369 let _built = flow.finalize();
370 }
371
372 #[test]
374 fn optional_by_ref_compiles() {
375 let mut flow = FlowBuilder::new();
376 let node = flow.process::<P1>();
377
378 let my_opt = node.source_iter(q!(0..5i32)).reduce(q!(|a, b| *a += b));
379 let opt_ref = my_opt.by_ref();
380
381 node.source_iter(q!(1..=3i32))
382 .map(q!(|x| x + opt_ref.unwrap_or(0)))
383 .for_each(q!(|_| {}));
384
385 my_opt.into_stream().for_each(q!(|_| {}));
386 let _built = flow.finalize();
387 }
388
389 #[test]
391 fn stream_by_ref_compiles() {
392 let mut flow = FlowBuilder::new();
393 let node = flow.process::<P1>();
394
395 let my_stream = node.source_iter(q!(0..5i32));
396 let stream_ref = my_stream.by_ref();
397
398 node.source_iter(q!(1..=3i32))
399 .map(q!(|x| x + stream_ref.len() as i32))
400 .for_each(q!(|_| {}));
401
402 my_stream.for_each(q!(|_| {}));
403 let _built = flow.finalize();
404 }
405
406 #[test]
408 fn singleton_by_mut_filter() {
409 let mut flow = FlowBuilder::new();
410 let node = flow.process::<P1>();
411
412 let my_count = node
413 .source_iter(q!(0..5i32))
414 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
415 let count_mut = my_count.by_mut();
416
417 node.source_iter(q!(1..=3i32))
418 .filter(q!(|x| {
419 *count_mut += *x;
420 *count_mut > 0
421 }))
422 .for_each(q!(|_| {}));
423
424 my_count.into_stream().for_each(q!(|_| {}));
425 let _built = flow.finalize();
426 }
427
428 #[test]
430 fn singleton_by_mut_flat_map() {
431 let mut flow = FlowBuilder::new();
432 let node = flow.process::<P1>();
433
434 let my_count = node
435 .source_iter(q!(0..5i32))
436 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
437 let count_mut = my_count.by_mut();
438
439 node.source_iter(q!(1..=3i32))
440 .flat_map_ordered(q!(|x| {
441 *count_mut += x;
442 vec![*count_mut]
443 }))
444 .for_each(q!(|_| {}));
445
446 my_count.into_stream().for_each(q!(|_| {}));
447 let _built = flow.finalize();
448 }
449
450 #[test]
452 fn singleton_by_mut_filter_map() {
453 let mut flow = FlowBuilder::new();
454 let node = flow.process::<P1>();
455
456 let my_count = node
457 .source_iter(q!(0..5i32))
458 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
459 let count_mut = my_count.by_mut();
460
461 node.source_iter(q!(1..=3i32))
462 .filter_map(q!(|x| {
463 *count_mut += x;
464 Some(*count_mut)
465 }))
466 .for_each(q!(|_| {}));
467
468 my_count.into_stream().for_each(q!(|_| {}));
469 let _built = flow.finalize();
470 }
471
472 #[test]
474 fn singleton_by_mut_inspect() {
475 let mut flow = FlowBuilder::new();
476 let node = flow.process::<P1>();
477
478 let my_count = node
479 .source_iter(q!(0..5i32))
480 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
481 let count_mut = my_count.by_mut();
482
483 node.source_iter(q!(1..=3i32))
484 .inspect(q!(|x| {
485 *count_mut += *x;
486 }))
487 .for_each(q!(|_| {}));
488
489 my_count.into_stream().for_each(q!(|_| {}));
490 let _built = flow.finalize();
491 }
492
493 #[test]
495 fn singleton_by_ref_for_each() {
496 let mut flow = FlowBuilder::new();
497 let node = flow.process::<P1>();
498
499 let my_count = node
500 .source_iter(q!(0..5i32))
501 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
502 let count_ref = my_count.by_ref();
503
504 node.source_iter(q!(1..=3i32))
505 .for_each(q!(|x| println!("{}", x + *count_ref)));
506
507 my_count.into_stream().for_each(q!(|_| {}));
508 let _built = flow.finalize();
509 }
510
511 #[test]
513 fn singleton_by_mut_for_each() {
514 let mut flow = FlowBuilder::new();
515 let node = flow.process::<P1>();
516
517 let my_count = node
518 .source_iter(q!(0..5i32))
519 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
520 let count_mut = my_count.by_mut();
521
522 node.source_iter(q!(1..=3i32)).for_each(q!(|x| {
523 *count_mut += x;
524 }));
525
526 my_count.into_stream().for_each(q!(|_| {}));
527 let _built = flow.finalize();
528 }
529}