Skip to main content

dfir_rs/scheduled/
context.rs

1//! Module for the inline DFIR runtime context and execution engine.
2//!
3//! Provides [`Context`] (the lightweight operator context) and
4//! [`Dfir`] (the dataflow execution wrapper).
5
6use std::future::Future;
7use std::pin::Pin;
8use std::rc::Rc;
9use std::sync::Arc;
10use std::sync::atomic::Ordering;
11use std::task::Wake;
12
13#[cfg(feature = "meta")]
14use dfir_lang::diagnostic::{Diagnostic, Diagnostics, SerdeSpan};
15#[cfg(feature = "meta")]
16use dfir_lang::graph::DfirGraph;
17
18use super::metrics::{DfirMetrics, DfirMetricsIntervals};
19use crate::scheduled::ticks::TickInstant;
20
21/// Coordinates waking between [`Context`] (inside the tick closure) and [`Dfir`]
22/// (the external runner). Shared via `Arc` between both.
23///
24/// When external data arrives (e.g., a tokio stream receives a message), the [`Context::waker`]
25/// fires, which sets `can_start_tick` and wakes the [`Dfir::run`](Dfir::run) task so it starts a new tick.
26/// Implements [`Wake`] directly so it can be used as a `Waker` without an extra wrapper.
27#[doc(hidden)]
28pub struct WakeState {
29    /// Set to `true` when external data arrives, signaling that a new tick should run.
30    /// Checked by [`Dfir::run_tick`](Dfir::run_tick) and [`Dfir::run_available`](Dfir::run_available).
31    can_start_tick: std::sync::atomic::AtomicBool,
32    /// Wakes the [`Dfir::run`](Dfir::run) task from its idle `poll_fn` sleep.
33    task_waker: futures::task::AtomicWaker,
34}
35
36impl Default for WakeState {
37    fn default() -> Self {
38        Self {
39            can_start_tick: std::sync::atomic::AtomicBool::new(false),
40            task_waker: futures::task::AtomicWaker::new(),
41        }
42    }
43}
44
45impl Wake for WakeState {
46    fn wake(self: Arc<Self>) {
47        self.wake_by_ref();
48    }
49
50    fn wake_by_ref(self: &Arc<Self>) {
51        self.can_start_tick.store(true, Ordering::Relaxed);
52        self.task_waker.wake();
53    }
54}
55
56/// A lightweight context for inline codegen that avoids the overhead of the full
57/// scheduled graph (no tokio channels, no scheduler queues, no loop machinery).
58///
59/// Exposes methods that operator-generated code calls on both
60/// `df` (for prologues: `request_task`) and
61/// `context` (for iterators: `current_tick`, `schedule_subgraph`, etc.).
62#[doc(hidden)]
63#[derive(Default)]
64pub struct Context {
65    /// Counter for number of ticks run.
66    current_tick: TickInstant,
67    /// Coordinates waking between [`Context`] (inside the tick closure) and [`Dfir`]
68    /// (the external runner). Shared via `Arc` between both. Implements [`Wake`].
69    wake_state: Arc<WakeState>,
70    /// Live-updating DFIR runtime metrics via interior mutability.
71    metrics: Rc<DfirMetrics>,
72    /// Tasks buffered via [`Self::request_task`], spawned by [`Dfir::spawn_tasks`]
73    /// once the runtime is running inside a tokio `LocalSet`.
74    tasks_to_spawn: Vec<Pin<Box<dyn Future<Output = ()> + 'static>>>,
75}
76
77impl Context {
78    /// Create a new inline context with shared wake state and metrics.
79    pub fn new(wake_state: Arc<WakeState>, metrics: Rc<DfirMetrics>) -> Self {
80        Self {
81            current_tick: TickInstant::default(),
82            wake_state,
83            metrics,
84            tasks_to_spawn: Vec::new(),
85        }
86    }
87
88    // --- Methods called as `df.xxx()` in operator prologues ---
89
90    /// Buffers an async task to be spawned later by [`Dfir::spawn_tasks`].
91    ///
92    /// Tasks are deferred because `write_prologue` runs during graph construction,
93    /// which may occur before a tokio `LocalSet` is entered. Buffered tasks are
94    /// drained and spawned via `tokio::task::spawn_local` at the start of
95    /// [`Dfir::run_tick`]. Tasks requested after that point remain buffered until
96    /// the next call to [`Dfir::run_tick`].
97    pub fn request_task<Fut>(&mut self, future: Fut)
98    where
99        Fut: Future<Output = ()> + 'static,
100    {
101        self.tasks_to_spawn.push(Box::pin(future));
102    }
103
104    // --- Methods called as `context.xxx()` in operator iterators ---
105
106    /// Gets the current tick count.
107    pub fn current_tick(&self) -> TickInstant {
108        self.current_tick
109    }
110
111    /// Returns a reference to the runtime metrics.
112    pub fn metrics(&self) -> &Rc<DfirMetrics> {
113        &self.metrics
114    }
115
116    /// Signals that external data has arrived and a new tick should be started.
117    pub fn schedule_subgraph(&self, is_external: bool) {
118        if is_external {
119            self.wake_state.wake_by_ref();
120        }
121    }
122
123    /// Returns a waker that signals external data has arrived.
124    pub fn waker(&self) -> std::task::Waker {
125        std::task::Waker::from(self.wake_state.clone())
126    }
127
128    /// Increments the tick counter.
129    /// Called by the generated tick closure at the end of each tick.
130    #[doc(hidden)]
131    pub fn __end_tick(&mut self) {
132        self.current_tick += crate::scheduled::ticks::TickDuration::SINGLE_TICK;
133    }
134}
135
136/// A wrapper around an inline-codegen tick closure that provides [`Self::run`],
137/// [`Self::run_available`], and [`Self::run_tick`] methods — mirroring the [`Dfir`](super::context::Dfir)
138/// API.
139///
140/// # Design
141///
142/// The inline codegen generates an `async move |df: &mut Context|` closure that captures
143/// dataflow-specific state (handoff buffers, source iterators) and receives the [`Context`]
144/// (tick counter, metrics) by reference each tick. `Dfir` owns both the
145/// closure and the context, and coordinates tick lifecycle and idle/wake behavior.
146///
147/// We use a single opaque closure rather than generating a bespoke struct per dataflow because:
148/// - The closure naturally captures exactly the state it needs with correct lifetimes
149/// - No codegen needed for struct definitions, field accessors, or initialization
150/// - Rust's async closure machinery handles the complex state machine (suspend/resume across
151///   `.await` points) that would be very difficult to replicate in a generated struct
152///
153/// The `Tick` type parameter is bounded by [`TickClosure`] (not `AsyncFnMut` directly) to
154/// support type erasure via [`TickClosureErased`] / [`DfirErased`] for heterogeneous
155/// collections (e.g., the sim runtime storing multiple locations in a `Vec`). The concrete
156/// (non-erased) path used by trybuild and embedded has zero overhead.
157#[doc(hidden)]
158pub struct Dfir<Tick> {
159    /// Async closure which runs a single tick when called.
160    tick_closure: Tick,
161    /// Coordinates waking between [`Context`] (inside the tick closure) and [`Dfir`]
162    /// (the external runner). Shared via `Arc` between both. Implements [`Wake`].
163    wake_state: Arc<WakeState>,
164    /// The inline context, owned by `Dfir` and passed to the tick closure by reference.
165    context: Context,
166    /// See [`Self::meta_graph()`].
167    #[cfg(feature = "meta")]
168    meta_graph: Option<DfirGraph>,
169    /// See [`Self::diagnostics()`].
170    #[cfg(feature = "meta")]
171    diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
172}
173
174/// Trait for tick closures — abstracts over both concrete async closures
175/// and type-erased boxed versions ([`TickClosureErased`]).
176///
177/// The `&mut Context` parameter is owned by [`Dfir`] and lent to the
178/// closure each tick, avoiding shared-ownership overhead for the context.
179#[doc(hidden)]
180pub trait TickClosure {
181    /// Call the tick closure. Returns `true` if any subgraph received input data.
182    fn call_tick<'a>(&'a mut self, ctx: &'a mut Context) -> impl Future<Output = bool> + 'a;
183}
184
185impl<F: for<'a> AsyncFnMut(&'a mut Context) -> bool> TickClosure for F {
186    fn call_tick<'a>(&'a mut self, ctx: &'a mut Context) -> impl Future<Output = bool> + 'a {
187        self(ctx)
188    }
189}
190
191/// No-op `TickClosure`.
192#[doc(hidden)]
193pub struct NullTickClosure;
194
195impl TickClosure for NullTickClosure {
196    fn call_tick<'a>(&'a mut self, _ctx: &'a mut Context) -> impl Future<Output = bool> + 'a {
197        std::future::ready(false)
198    }
199}
200
201/// Type-erased tick function for use in heterogeneous collections (e.g., the sim runtime).
202#[doc(hidden)]
203pub struct TickClosureErased(Box<dyn TickClosureErasedInner>);
204
205/// Object-safe inner trait for [`TickClosureErased`]. Needed because `AsyncFnMut` is not
206/// object-safe (GAT return type), but a trait with `&mut self -> Pin<Box<dyn Future + '_>>`
207/// is — the returned future borrows from the trait object which owns the closure.
208trait TickClosureErasedInner {
209    fn call_tick<'a>(
210        &'a mut self,
211        ctx: &'a mut Context,
212    ) -> Pin<Box<dyn Future<Output = bool> + 'a>>;
213}
214
215impl<F: for<'a> AsyncFnMut(&'a mut Context) -> bool> TickClosureErasedInner for F {
216    fn call_tick<'a>(
217        &'a mut self,
218        ctx: &'a mut Context,
219    ) -> Pin<Box<dyn Future<Output = bool> + 'a>> {
220        Box::pin(self(ctx))
221    }
222}
223
224impl TickClosure for TickClosureErased {
225    fn call_tick<'a>(&'a mut self, ctx: &'a mut Context) -> impl Future<Output = bool> + 'a {
226        self.0.call_tick(ctx)
227    }
228}
229
230/// Type alias for a type-erased [`Dfir`] that can be stored in heterogeneous collections.
231/// Created via [`Dfir::into_erased`].
232pub type DfirErased = Dfir<TickClosureErased>;
233
234impl<Tick: TickClosure> Dfir<Tick> {
235    /// Create a new `Dfir` from a tick closure, inline context,
236    /// and meta graph / diagnostics JSON strings.
237    #[doc(hidden)]
238    pub fn new(
239        tick_closure: Tick,
240        context: Context,
241        meta_graph_json: Option<&str>,
242        diagnostics_json: Option<&str>,
243    ) -> Self {
244        #[cfg(not(feature = "meta"))]
245        let _ = (meta_graph_json, diagnostics_json);
246        Self {
247            tick_closure,
248            wake_state: context.wake_state.clone(),
249            context,
250            #[cfg(feature = "meta")]
251            meta_graph: meta_graph_json.map(|json| {
252                let mut meta_graph: DfirGraph =
253                    serde_json::from_str(json).expect("Failed to deserialize graph.");
254                let mut op_inst_diagnostics = Diagnostics::new();
255                meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
256                assert!(
257                    op_inst_diagnostics.is_empty(),
258                    "Expected no diagnostics, got: {:#?}",
259                    op_inst_diagnostics
260                );
261                meta_graph
262            }),
263            #[cfg(feature = "meta")]
264            diagnostics: diagnostics_json.map(|json| {
265                serde_json::from_str(json).expect("Failed to deserialize diagnostics.")
266            }),
267        }
268    }
269
270    /// Return a handle to the meta graph, if set.
271    #[cfg(feature = "meta")]
272    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
273    pub fn meta_graph(&self) -> Option<&DfirGraph> {
274        self.meta_graph.as_ref()
275    }
276
277    /// Returns any diagnostics generated by the surface syntax macro.
278    #[cfg(feature = "meta")]
279    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
280    pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
281        self.diagnostics.as_deref()
282    }
283
284    /// Returns a reference-counted handle to the continually-updated runtime metrics for this DFIR instance.
285    pub fn metrics(&self) -> Rc<DfirMetrics> {
286        Rc::clone(self.context.metrics())
287    }
288
289    /// Gets the current tick (local time) count.
290    pub fn current_tick(&self) -> TickInstant {
291        self.context.current_tick()
292    }
293
294    /// Returns a [`DfirMetricsIntervals`] handle where each call to
295    /// [`DfirMetricsIntervals::take_interval`] ends the current interval and returns its metrics.
296    ///
297    /// The first call to `take_interval` returns metrics since this DFIR instance was created. Each subsequent call to
298    /// `take_interval` returns metrics since the previous call.
299    ///
300    /// Cloning the handle "forks" it from the original, as afterwards each interval may return different metrics
301    /// depending on when exactly `take_interval` is called.
302    pub fn metrics_intervals(&self) -> DfirMetricsIntervals {
303        DfirMetricsIntervals {
304            curr: self.metrics(),
305            prev: None,
306        }
307    }
308}
309
310impl<Tick: TickClosure> Dfir<Tick> {
311    /// Spawns all tasks buffered via [`Context::request_task`].
312    ///
313    /// This drains the buffer, so subsequent calls are no-ops until new tasks are requested.
314    fn spawn_tasks(&mut self) {
315        for task in self.context.tasks_to_spawn.drain(..) {
316            tokio::task::spawn_local(task);
317        }
318    }
319
320    /// Run a single tick. Returns `true` if any subgraph received input data.
321    ///
322    /// Checks both handoff buffers (via `work_done` flag set in generated recv port code)
323    /// and external events (via `can_start_tick` set by wakers/schedule_subgraph).
324    pub async fn run_tick(&mut self) -> bool {
325        self.spawn_tasks();
326        let had_external = self
327            .wake_state
328            .can_start_tick
329            .swap(false, Ordering::Relaxed);
330        let tick_had_work = self.tick_closure.call_tick(&mut self.context).await;
331        had_external || tick_had_work || self.wake_state.can_start_tick.load(Ordering::Relaxed)
332    }
333
334    /// Run a single tick synchronously. Panics if the tick yields (async suspension).
335    /// Returns `true` if work was done (see [`Self::run_tick`]).
336    pub fn run_tick_sync(&mut self) -> bool {
337        let mut fut = std::pin::pin!(self.run_tick());
338        let mut ctx = std::task::Context::from_waker(std::task::Waker::noop());
339        match fut.as_mut().poll(&mut ctx) {
340            std::task::Poll::Ready(result) => result,
341            std::task::Poll::Pending => {
342                panic!("Dfir::run_tick_sync: tick yielded asynchronously.")
343            }
344        }
345    }
346
347    /// Run ticks as long as work is available, then return.
348    pub async fn run_available(&mut self) {
349        // Always run at least one tick.
350        self.wake_state
351            .can_start_tick
352            .store(false, Ordering::Relaxed);
353        loop {
354            self.run_tick().await;
355            let can_start_tick = self
356                .wake_state
357                .can_start_tick
358                .swap(false, Ordering::Relaxed);
359            if !can_start_tick {
360                break;
361            }
362            // Yield between each tick to receive more events.
363            tokio::task::yield_now().await;
364        }
365    }
366
367    /// [`Self::run_available`] but panics if any tick yields asynchronously.
368    pub fn run_available_sync(&mut self) {
369        self.wake_state
370            .can_start_tick
371            .store(false, Ordering::Relaxed);
372        loop {
373            self.run_tick_sync();
374            let can_start_tick = self
375                .wake_state
376                .can_start_tick
377                .swap(false, Ordering::Relaxed);
378            if !can_start_tick {
379                break;
380            }
381        }
382    }
383
384    /// Run forever, processing ticks when work is available and yielding when idle.
385    pub async fn run(&mut self) -> crate::Never {
386        loop {
387            self.run_available().await;
388            // Wait for an external event to wake us.
389            std::future::poll_fn(|cx| {
390                // Register waker first to avoid race: if an event fires between
391                // the check and the register, the waker is already in place.
392                self.wake_state.task_waker.register(cx.waker());
393                if self.wake_state.can_start_tick.load(Ordering::Relaxed) {
394                    std::task::Poll::Ready(())
395                } else {
396                    std::task::Poll::Pending
397                }
398            })
399            .await;
400        }
401    }
402}
403
404impl<Tick: 'static + for<'a> AsyncFnMut(&'a mut Context) -> bool> Dfir<Tick> {
405    /// Type-erase the tick closure for use in heterogeneous collections.
406    ///
407    /// Wraps the concrete async closure in [`TickClosureErased`], which boxes the future
408    /// returned by each tick call. This adds one heap allocation per tick, but enables
409    /// storing multiple `Dfir`s with different closure types in a single `Vec`.
410    ///
411    /// Only needed for the sim runtime path. The trybuild and embedded paths keep the
412    /// concrete type and pay no erasure cost.
413    pub fn into_erased(self) -> DfirErased {
414        Dfir {
415            tick_closure: TickClosureErased(Box::new(self.tick_closure)),
416            wake_state: self.wake_state,
417            context: self.context,
418            #[cfg(feature = "meta")]
419            meta_graph: self.meta_graph,
420            #[cfg(feature = "meta")]
421            diagnostics: self.diagnostics,
422        }
423    }
424}