Skip to main content

hydro_lang/
handoff_ref.rs

1//! Reference handles for capturing singletons, optionals, and streams in `q!()` closures.
2//!
3//! Each handle type wraps a `&RefCell<HydroNode>` and, when captured inside a `q!()` closure,
4//! registers itself with the current capture scope. At codegen time, the IR node is lowered
5//! to the corresponding DFIR pseudo-operator (`singleton()`, `optional()`, or `handoff()`),
6//! and the reference resolves to the appropriate borrow type.
7
8use std::cell::RefCell;
9use std::marker::PhantomData;
10use std::rc::Rc;
11
12use proc_macro2::Span;
13use quote::quote;
14use stageleft::runtime_support::{FreeVariableWithContextWithProps, QuoteTokens};
15
16use crate::compile::ir::{HydroNode, SharedNode};
17use crate::location::Location;
18
19/// Determines which DFIR pseudo-operator a reference node lowers to.
20#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
21pub enum HandoffRefKind {
22    /// `-> singleton()` — exactly one item, `#var` gives `&T`.
23    Singleton,
24    /// `-> optional()` — zero or one item, `#var` gives `&Option<T>`.
25    Optional,
26    /// `-> handoff()` — zero or more items, `#var` gives `&Vec<T>`.
27    Vec,
28}
29
30// Thread-local storage for handoff references captured during `q!()` expansion.
31// Stores the HydroNode `(node, is_mut)` for each reference captured in the current closure.
32// The index determines the ident name via `handoff_ref_ident`.
33thread_local! {
34    static CAPTURED_REFS: RefCell<Option<Vec<(HydroNode, bool)>>> = const { RefCell::new(None) };
35}
36
37/// Returns the canonical ident for a captured ref at the given index within a closure.
38pub(crate) fn handoff_ref_ident(index: usize) -> syn::Ident {
39    syn::Ident::new(
40        &format!("__hydro_singleton_ref_{}", index),
41        Span::call_site(),
42    )
43}
44
45/// Activate the reference capture context. Must be called before `q!()` expansion
46/// that may capture handoff references. Returns a `ClosureExpr` bundling the expression with any
47/// captured references.
48pub fn with_ref_capture(
49    f: impl FnOnce() -> crate::compile::ir::DebugExpr,
50) -> crate::compile::ir::ClosureExpr {
51    CAPTURED_REFS.with(|cell| {
52        let prev = cell.borrow_mut().replace(Vec::new());
53        assert!(
54            prev.is_none(),
55            "nested handoff reference capture scopes are not supported"
56        );
57    });
58    let expr = (f)();
59    let captured_refs = CAPTURED_REFS.with(|cell| cell.borrow_mut().take().unwrap());
60    crate::compile::ir::ClosureExpr::new(expr, captured_refs)
61}
62
63/// Shared registration logic: wraps the IR node in `HydroNode::Reference` if needed,
64/// pushes it to the capture list, and returns the ident to use in the closure body.
65fn register_handoff_ref(
66    ir_node: &RefCell<HydroNode>,
67    is_mut: bool,
68    kind: HandoffRefKind,
69) -> syn::Ident {
70    CAPTURED_REFS.with(|cell| {
71        let mut guard = cell.borrow_mut();
72        let refs = guard.as_mut().expect(
73            "HandoffRef used inside q!() but no reference capture scope is active. \
74             This is a bug — reference capture should be set up by the operator that uses q!().",
75        );
76
77        let index = refs.len();
78        let ident = handoff_ref_ident(index);
79
80        let metadata = ir_node.borrow().metadata().clone();
81
82        // Wrap in HydroNode::Reference for materialization + identity tracking.
83        // If already a Reference node, reuse it.
84        if !matches!(&*ir_node.borrow(), HydroNode::Reference { .. }) {
85            let orig = ir_node.replace(HydroNode::Placeholder);
86            *ir_node.borrow_mut() = HydroNode::Reference {
87                inner: SharedNode(Rc::new(RefCell::new(orig))),
88                kind,
89                metadata: metadata.clone(),
90            };
91        }
92
93        let borrow: std::cell::Ref<'_, HydroNode> = ir_node.borrow();
94        let HydroNode::Reference { inner, .. } = &*borrow else {
95            unreachable!()
96        };
97
98        refs.push((
99            HydroNode::Reference {
100                inner: SharedNode(Rc::clone(&inner.0)),
101                kind,
102                metadata,
103            },
104            is_mut,
105        ));
106
107        ident
108    })
109}
110
111/// Macro to define a handoff reference struct with all necessary trait impls.
112macro_rules! define_handoff_ref {
113    (
114        $(#[$meta:meta])*
115        $name:ident, $is_mut:expr, $kind:expr, $output:ty
116    ) => {
117        $(#[$meta])*
118        pub struct $name<'a, 'slf, T, L> {
119            pub(crate) ir_node: &'slf RefCell<HydroNode>,
120            _phantom: PhantomData<(&'a T, L)>,
121        }
122
123        impl<'slf, T, L> $name<'_, 'slf, T, L> {
124            /// Creates a new reference handle from an IR node cell.
125            pub(crate) fn new(ir_node: &'slf RefCell<HydroNode>) -> Self {
126                Self {
127                    ir_node,
128                    _phantom: PhantomData,
129                }
130            }
131        }
132
133        impl<T, L> Copy for $name<'_, '_, T, L> {}
134        impl<T, L> Clone for $name<'_, '_, T, L> {
135            fn clone(&self) -> Self {
136                *self
137            }
138        }
139
140        impl<'a, 'slf, T: 'a, L> FreeVariableWithContextWithProps<L, ()> for $name<'a, 'slf, T, L>
141        where
142            L: Location<'a>,
143        {
144            type O = $output;
145
146            fn to_tokens(self, _ctx: &L) -> (QuoteTokens, ()) {
147                let ident = register_handoff_ref(self.ir_node, $is_mut, $kind);
148                (
149                    QuoteTokens {
150                        prelude: None,
151                        expr: Some(quote!(#ident)),
152                    },
153                    (),
154                )
155            }
156        }
157    };
158}
159
160define_handoff_ref!(
161    /// A shared reference handle to a singleton, resolves to `&T` at runtime.
162    ///
163    /// Created via [`Singleton::by_ref()`](crate::live_collections::Singleton::by_ref).
164    SingletonRef, false, HandoffRefKind::Singleton, &'a T
165);
166
167define_handoff_ref!(
168    /// A mutable reference handle to a singleton, resolves to `&mut T` at runtime.
169    ///
170    /// Created via [`Singleton::by_mut()`](crate::live_collections::Singleton::by_mut).
171    SingletonMut, true, HandoffRefKind::Singleton, &'a mut T
172);
173
174define_handoff_ref!(
175    /// A shared reference handle to an optional, resolves to `&Option<T>` at runtime.
176    ///
177    /// Created via [`Optional::by_ref()`](crate::live_collections::Optional::by_ref).
178    OptionalRef, false, HandoffRefKind::Optional, &'a Option<T>
179);
180
181define_handoff_ref!(
182    /// A mutable reference handle to an optional, resolves to `&mut Option<T>` at runtime.
183    ///
184    /// Created via [`Optional::by_mut()`](crate::live_collections::Optional::by_mut).
185    OptionalMut, true, HandoffRefKind::Optional, &'a mut Option<T>
186);
187
188define_handoff_ref!(
189    /// A shared reference handle to a stream's handoff buffer, resolves to `&Vec<T>` at runtime.
190    ///
191    /// Created via [`Stream::by_ref()`](crate::live_collections::Stream::by_ref).
192    StreamRef, false, HandoffRefKind::Vec, &'a Vec<T>
193);
194
195define_handoff_ref!(
196    /// A mutable reference handle to a stream's handoff buffer, resolves to `&mut Vec<T>` at runtime.
197    ///
198    /// Created via [`Stream::by_mut()`](crate::live_collections::Stream::by_mut).
199    StreamMut, true, HandoffRefKind::Vec, &'a mut Vec<T>
200);
201
202#[cfg(test)]
203#[cfg(feature = "build")]
204mod tests {
205    use stageleft::q;
206
207    use crate::compile::builder::FlowBuilder;
208    use crate::location::Location;
209
210    struct P1 {}
211
212    /// Compile-only test: verifies that `by_ref()` + `q!()` produces valid IR.
213    #[test]
214    fn singleton_by_ref_compiles() {
215        let mut flow = FlowBuilder::new();
216        let node = flow.process::<P1>();
217
218        let my_count = node
219            .source_iter(q!(0..5i32))
220            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
221        let count_ref = my_count.by_ref();
222
223        node.source_iter(q!(1..=3i32))
224            .map(q!(|x| x + *count_ref))
225            .for_each(q!(|_| {}));
226
227        my_count.into_stream().for_each(q!(|_| {}));
228        let _built = flow.finalize();
229    }
230
231    /// Test with a non-Copy type (Vec) to ensure we're borrowing, not copying.
232    #[test]
233    fn singleton_by_ref_non_copy() {
234        let mut flow = FlowBuilder::new();
235        let node = flow.process::<P1>();
236
237        let my_vec = node.source_iter(q!(0..5i32)).fold(
238            q!(|| Vec::<i32>::new()),
239            q!(|acc: &mut Vec<i32>, x| acc.push(x)),
240        );
241        let vec_ref = my_vec.by_ref();
242
243        node.source_iter(q!(1..=3i32))
244            .map(q!(|x| x + vec_ref.len() as i32))
245            .for_each(q!(|_| {}));
246
247        my_vec.into_stream().for_each(q!(|_| {}));
248        let _built = flow.finalize();
249    }
250
251    /// Compile-only: singleton ref inside filter closure.
252    #[test]
253    fn singleton_by_ref_filter() {
254        let mut flow = FlowBuilder::new();
255        let node = flow.process::<P1>();
256
257        let threshold = node
258            .source_iter(q!(0..5i32))
259            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
260        let threshold_ref = threshold.by_ref();
261
262        node.source_iter(q!(1..=10i32))
263            .filter(q!(|x| *x > *threshold_ref))
264            .for_each(q!(|_| {}));
265
266        threshold.into_stream().for_each(q!(|_| {}));
267        let _built = flow.finalize();
268    }
269
270    /// Compile-only: singleton ref inside flat_map closure.
271    #[test]
272    fn singleton_by_ref_flat_map() {
273        let mut flow = FlowBuilder::new();
274        let node = flow.process::<P1>();
275
276        let count = node
277            .source_iter(q!(0..3i32))
278            .fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
279        let count_ref = count.by_ref();
280
281        node.source_iter(q!(1..=2i32))
282            .flat_map_ordered(q!(|x| (0..*count_ref).map(move |i| x + i)))
283            .for_each(q!(|_| {}));
284
285        count.into_stream().for_each(q!(|_| {}));
286        let _built = flow.finalize();
287    }
288
289    /// Compile-only: singleton ref inside inspect closure.
290    #[test]
291    fn singleton_by_ref_inspect() {
292        let mut flow = FlowBuilder::new();
293        let node = flow.process::<P1>();
294
295        let count = node
296            .source_iter(q!(0..5i32))
297            .fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
298        let count_ref = count.by_ref();
299
300        node.source_iter(q!(1..=3i32))
301            .inspect(q!(|x| println!("count={}, x={}", *count_ref, x)))
302            .for_each(q!(|_| {}));
303
304        count.into_stream().for_each(q!(|_| {}));
305        let _built = flow.finalize();
306    }
307
308    /// Compile-only: singleton ref inside partition predicate.
309    #[test]
310    fn singleton_by_ref_partition() {
311        let mut flow = FlowBuilder::new();
312        let node = flow.process::<P1>();
313
314        let threshold = node
315            .source_iter(q!(0..5i32))
316            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
317        let threshold_ref = threshold.by_ref();
318
319        let (above, below) = node
320            .source_iter(q!(1..=10i32))
321            .partition(q!(|x| *x > *threshold_ref));
322
323        above.for_each(q!(|_| {}));
324        below.for_each(q!(|_| {}));
325        threshold.into_stream().for_each(q!(|_| {}));
326        let _built = flow.finalize();
327    }
328
329    /// Compile-only: singleton ref inside partition with downstream operators on both branches.
330    #[test]
331    fn singleton_by_ref_partition_with_downstream_ops() {
332        let mut flow = FlowBuilder::new();
333        let node = flow.process::<P1>();
334
335        let threshold = node
336            .source_iter(q!(0..5i32))
337            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
338        let threshold_ref = threshold.by_ref();
339
340        let (above, below) = node
341            .source_iter(q!(1..=10i32))
342            .partition(q!(|x| *x > *threshold_ref));
343
344        above.map(q!(|x| x * 2)).for_each(q!(|_| {}));
345        below.map(q!(|x| x + 100)).for_each(q!(|_| {}));
346        threshold.into_stream().for_each(q!(|_| {}));
347        let _built = flow.finalize();
348    }
349
350    /// Compile-only test: singleton by_mut.
351    #[test]
352    fn singleton_by_mut_compiles() {
353        let mut flow = FlowBuilder::new();
354        let node = flow.process::<P1>();
355
356        let my_count = node
357            .source_iter(q!(0..5i32))
358            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
359        let count_mut = my_count.by_mut();
360
361        node.source_iter(q!(1..=3i32))
362            .map(q!(|x| {
363                *count_mut += x;
364                x
365            }))
366            .for_each(q!(|_| {}));
367
368        my_count.into_stream().for_each(q!(|_| {}));
369        let _built = flow.finalize();
370    }
371
372    /// Compile-only test: optional by_ref.
373    #[test]
374    fn optional_by_ref_compiles() {
375        let mut flow = FlowBuilder::new();
376        let node = flow.process::<P1>();
377
378        let my_opt = node.source_iter(q!(0..5i32)).reduce(q!(|a, b| *a += b));
379        let opt_ref = my_opt.by_ref();
380
381        node.source_iter(q!(1..=3i32))
382            .map(q!(|x| x + opt_ref.unwrap_or(0)))
383            .for_each(q!(|_| {}));
384
385        my_opt.into_stream().for_each(q!(|_| {}));
386        let _built = flow.finalize();
387    }
388
389    /// Compile-only test: stream by_ref.
390    #[test]
391    fn stream_by_ref_compiles() {
392        let mut flow = FlowBuilder::new();
393        let node = flow.process::<P1>();
394
395        let my_stream = node.source_iter(q!(0..5i32));
396        let stream_ref = my_stream.by_ref();
397
398        node.source_iter(q!(1..=3i32))
399            .map(q!(|x| x + stream_ref.len() as i32))
400            .for_each(q!(|_| {}));
401
402        my_stream.for_each(q!(|_| {}));
403        let _built = flow.finalize();
404    }
405
406    /// Compile-only test: singleton by_mut in filter (TotalOrder).
407    #[test]
408    fn singleton_by_mut_filter() {
409        let mut flow = FlowBuilder::new();
410        let node = flow.process::<P1>();
411
412        let my_count = node
413            .source_iter(q!(0..5i32))
414            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
415        let count_mut = my_count.by_mut();
416
417        node.source_iter(q!(1..=3i32))
418            .filter(q!(|x| {
419                *count_mut += *x;
420                *count_mut > 0
421            }))
422            .for_each(q!(|_| {}));
423
424        my_count.into_stream().for_each(q!(|_| {}));
425        let _built = flow.finalize();
426    }
427
428    /// Compile-only test: singleton by_mut in flat_map_ordered (TotalOrder).
429    #[test]
430    fn singleton_by_mut_flat_map() {
431        let mut flow = FlowBuilder::new();
432        let node = flow.process::<P1>();
433
434        let my_count = node
435            .source_iter(q!(0..5i32))
436            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
437        let count_mut = my_count.by_mut();
438
439        node.source_iter(q!(1..=3i32))
440            .flat_map_ordered(q!(|x| {
441                *count_mut += x;
442                vec![*count_mut]
443            }))
444            .for_each(q!(|_| {}));
445
446        my_count.into_stream().for_each(q!(|_| {}));
447        let _built = flow.finalize();
448    }
449
450    /// Compile-only test: singleton by_mut in filter_map (TotalOrder).
451    #[test]
452    fn singleton_by_mut_filter_map() {
453        let mut flow = FlowBuilder::new();
454        let node = flow.process::<P1>();
455
456        let my_count = node
457            .source_iter(q!(0..5i32))
458            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
459        let count_mut = my_count.by_mut();
460
461        node.source_iter(q!(1..=3i32))
462            .filter_map(q!(|x| {
463                *count_mut += x;
464                Some(*count_mut)
465            }))
466            .for_each(q!(|_| {}));
467
468        my_count.into_stream().for_each(q!(|_| {}));
469        let _built = flow.finalize();
470    }
471
472    /// Compile-only test: singleton by_mut in inspect (TotalOrder).
473    #[test]
474    fn singleton_by_mut_inspect() {
475        let mut flow = FlowBuilder::new();
476        let node = flow.process::<P1>();
477
478        let my_count = node
479            .source_iter(q!(0..5i32))
480            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
481        let count_mut = my_count.by_mut();
482
483        node.source_iter(q!(1..=3i32))
484            .inspect(q!(|x| {
485                *count_mut += *x;
486            }))
487            .for_each(q!(|_| {}));
488
489        my_count.into_stream().for_each(q!(|_| {}));
490        let _built = flow.finalize();
491    }
492
493    /// Compile-only test: singleton by_ref in for_each.
494    #[test]
495    fn singleton_by_ref_for_each() {
496        let mut flow = FlowBuilder::new();
497        let node = flow.process::<P1>();
498
499        let my_count = node
500            .source_iter(q!(0..5i32))
501            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
502        let count_ref = my_count.by_ref();
503
504        node.source_iter(q!(1..=3i32))
505            .for_each(q!(|x| println!("{}", x + *count_ref)));
506
507        my_count.into_stream().for_each(q!(|_| {}));
508        let _built = flow.finalize();
509    }
510
511    /// Compile-only test: singleton by_mut in for_each.
512    #[test]
513    fn singleton_by_mut_for_each() {
514        let mut flow = FlowBuilder::new();
515        let node = flow.process::<P1>();
516
517        let my_count = node
518            .source_iter(q!(0..5i32))
519            .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
520        let count_mut = my_count.by_mut();
521
522        node.source_iter(q!(1..=3i32)).for_each(q!(|x| {
523            *count_mut += x;
524        }));
525
526        my_count.into_stream().for_each(q!(|_| {}));
527        let _built = flow.finalize();
528    }
529}