Skip to main content

dfir_rs/scheduled/
context.rs

1//! Module for the inline DFIR runtime context and execution engine.
2//!
3//! Provides [`Context`] (the lightweight operator context) and
4//! [`Dfir`] (the dataflow execution wrapper).
5
6use std::future::Future;
7use std::pin::Pin;
8use std::rc::Rc;
9use std::sync::Arc;
10use std::sync::atomic::Ordering;
11use std::task::Wake;
12
13#[cfg(feature = "meta")]
14use dfir_lang::diagnostic::{Diagnostic, Diagnostics, SerdeSpan};
15#[cfg(feature = "meta")]
16use dfir_lang::graph::DfirGraph;
17
18use super::metrics::{DfirMetrics, DfirMetricsIntervals};
19use crate::scheduled::ticks::TickInstant;
20
21/// Coordinates waking between [`Context`] (inside the tick closure) and [`Dfir`]
22/// (the external runner). Shared via `Arc` between both.
23///
24/// When external data arrives (e.g., a tokio stream receives a message), the [`Context::waker`]
25/// fires, which sets `can_start_tick` and wakes the [`Dfir::run`](Dfir::run) task so it starts a new tick.
26/// Implements [`Wake`] directly so it can be used as a `Waker` without an extra wrapper.
27#[doc(hidden)]
28pub struct WakeState {
29    /// Set to `true` when external data arrives, signaling that a new tick should run.
30    /// Checked by [`Dfir::run_tick`](Dfir::run_tick) and [`Dfir::run_available`](Dfir::run_available).
31    can_start_tick: std::sync::atomic::AtomicBool,
32    /// Wakes the [`Dfir::run`](Dfir::run) task from its idle `poll_fn` sleep.
33    task_waker: futures::task::AtomicWaker,
34}
35
36impl Default for WakeState {
37    fn default() -> Self {
38        Self {
39            can_start_tick: std::sync::atomic::AtomicBool::new(false),
40            task_waker: futures::task::AtomicWaker::new(),
41        }
42    }
43}
44
45impl Wake for WakeState {
46    fn wake(self: Arc<Self>) {
47        self.wake_by_ref();
48    }
49
50    fn wake_by_ref(self: &Arc<Self>) {
51        self.can_start_tick.store(true, Ordering::Relaxed);
52        self.task_waker.wake();
53    }
54}
55
56/// A lightweight context for inline codegen that avoids the overhead of the full
57/// [`Context`] (no tokio channels, no scheduler queues, no loop machinery).
58///
59/// Exposes method names that operator-generated code calls on
60/// `context` (for iterators: `is_first_run_this_tick`, `current_tick`, etc.).
61#[doc(hidden)]
62#[derive(Default)]
63pub struct Context {
64    /// Counter for number of ticks run.
65    current_tick: TickInstant,
66    /// Coordinates waking between [`Context`] (inside the tick closure) and [`Dfir`]
67    /// (the external runner). Shared via `Arc` between both. Implements [`Wake`].
68    wake_state: Arc<WakeState>,
69    /// Live-updating DFIR runtime metrics via interior mutability.
70    metrics: Rc<DfirMetrics>,
71    /// Tasks buffered via [`Self::request_task`], spawned by [`Dfir::spawn_tasks`]
72    /// once the runtime is running inside a tokio `LocalSet`.
73    tasks_to_spawn: Vec<Pin<Box<dyn Future<Output = ()> + 'static>>>,
74}
75
76impl Context {
77    /// Create a new inline context with shared wake state and metrics.
78    pub fn new(wake_state: Arc<WakeState>, metrics: Rc<DfirMetrics>) -> Self {
79        Self {
80            current_tick: TickInstant::default(),
81            wake_state,
82            metrics,
83            tasks_to_spawn: Vec::new(),
84        }
85    }
86
87    // --- Methods called as `df.xxx()` in operator prologues ---
88
89    /// Buffers an async task to be spawned later by [`Dfir::spawn_tasks`].
90    ///
91    /// Tasks are deferred because `write_prologue` runs during graph construction,
92    /// which may occur before a tokio `LocalSet` is entered. Buffered tasks are
93    /// drained and spawned via `tokio::task::spawn_local` at the start of
94    /// [`Dfir::run_tick`]. Tasks requested after that point remain buffered until
95    /// the next call to [`Dfir::run_tick`].
96    pub fn request_task<Fut>(&mut self, future: Fut)
97    where
98        Fut: Future<Output = ()> + 'static,
99    {
100        self.tasks_to_spawn.push(Box::pin(future));
101    }
102
103    // --- Methods called as `context.xxx()` in operator iterators ---
104
105    /// Always returns `true` in inline mode. The inline codegen runs the entire DAG
106    /// once per tick with no re-execution, so every subgraph is always on its first
107    /// (and only) run within each tick.
108    pub fn is_first_run_this_tick(&self) -> bool {
109        true
110    }
111
112    /// Gets the current tick count.
113    pub fn current_tick(&self) -> TickInstant {
114        self.current_tick
115    }
116
117    /// Returns a reference to the runtime metrics.
118    pub fn metrics(&self) -> &Rc<DfirMetrics> {
119        &self.metrics
120    }
121
122    /// Signals that external data has arrived and a new tick should be started.
123    pub fn schedule_subgraph(&self, is_external: bool) {
124        if is_external {
125            self.wake_state.wake_by_ref();
126        }
127    }
128
129    /// Returns a waker that signals external data has arrived.
130    pub fn waker(&self) -> std::task::Waker {
131        std::task::Waker::from(self.wake_state.clone())
132    }
133
134    /// Increments the tick counter.
135    /// Called by the generated tick closure at the end of each tick.
136    #[doc(hidden)]
137    pub fn __end_tick(&mut self) {
138        self.current_tick += crate::scheduled::ticks::TickDuration::SINGLE_TICK;
139    }
140}
141
142/// A wrapper around an inline-codegen tick closure that provides [`Self::run`],
143/// [`Self::run_available`], and [`Self::run_tick`] methods — mirroring the [`Dfir`](super::context::Dfir)
144/// API.
145///
146/// # Design
147///
148/// The inline codegen generates an `async move |df: &mut Context|` closure that captures
149/// dataflow-specific state (handoff buffers, source iterators) and receives the [`Context`]
150/// (tick counter, metrics) by reference each tick. `Dfir` owns both the
151/// closure and the context, and coordinates tick lifecycle and idle/wake behavior.
152///
153/// We use a single opaque closure rather than generating a bespoke struct per dataflow because:
154/// - The closure naturally captures exactly the state it needs with correct lifetimes
155/// - No codegen needed for struct definitions, field accessors, or initialization
156/// - Rust's async closure machinery handles the complex state machine (suspend/resume across
157///   `.await` points) that would be very difficult to replicate in a generated struct
158///
159/// The `Tick` type parameter is bounded by [`TickClosure`] (not `AsyncFnMut` directly) to
160/// support type erasure via [`TickClosureErased`] / [`DfirErased`] for heterogeneous
161/// collections (e.g., the sim runtime storing multiple locations in a `Vec`). The concrete
162/// (non-erased) path used by trybuild and embedded has zero overhead.
163#[doc(hidden)]
164pub struct Dfir<Tick> {
165    /// Async closure which runs a single tick when called.
166    tick_closure: Tick,
167    /// Coordinates waking between [`Context`] (inside the tick closure) and [`Dfir`]
168    /// (the external runner). Shared via `Arc` between both. Implements [`Wake`].
169    wake_state: Arc<WakeState>,
170    /// The inline context, owned by `Dfir` and passed to the tick closure by reference.
171    context: Context,
172    /// See [`Self::meta_graph()`].
173    #[cfg(feature = "meta")]
174    meta_graph: Option<DfirGraph>,
175    /// See [`Self::diagnostics()`].
176    #[cfg(feature = "meta")]
177    diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
178}
179
180/// Trait for tick closures — abstracts over both concrete async closures
181/// and type-erased boxed versions ([`TickClosureErased`]).
182///
183/// The `&mut Context` parameter is owned by [`Dfir`] and lent to the
184/// closure each tick, avoiding shared-ownership overhead for the context.
185#[doc(hidden)]
186pub trait TickClosure {
187    /// Call the tick closure. Returns `true` if any subgraph received input data.
188    fn call_tick<'a>(&'a mut self, ctx: &'a mut Context) -> impl Future<Output = bool> + 'a;
189}
190
191impl<F: for<'a> AsyncFnMut(&'a mut Context) -> bool> TickClosure for F {
192    fn call_tick<'a>(&'a mut self, ctx: &'a mut Context) -> impl Future<Output = bool> + 'a {
193        self(ctx)
194    }
195}
196
197/// No-op `TickClosure`.
198#[doc(hidden)]
199pub struct NullTickClosure;
200
201impl TickClosure for NullTickClosure {
202    fn call_tick<'a>(&'a mut self, _ctx: &'a mut Context) -> impl Future<Output = bool> + 'a {
203        std::future::ready(false)
204    }
205}
206
207/// Type-erased tick function for use in heterogeneous collections (e.g., the sim runtime).
208#[doc(hidden)]
209pub struct TickClosureErased(Box<dyn TickClosureErasedInner>);
210
211/// Object-safe inner trait for [`TickClosureErased`]. Needed because `AsyncFnMut` is not
212/// object-safe (GAT return type), but a trait with `&mut self -> Pin<Box<dyn Future + '_>>`
213/// is — the returned future borrows from the trait object which owns the closure.
214trait TickClosureErasedInner {
215    fn call_tick<'a>(
216        &'a mut self,
217        ctx: &'a mut Context,
218    ) -> Pin<Box<dyn Future<Output = bool> + 'a>>;
219}
220
221impl<F: for<'a> AsyncFnMut(&'a mut Context) -> bool> TickClosureErasedInner for F {
222    fn call_tick<'a>(
223        &'a mut self,
224        ctx: &'a mut Context,
225    ) -> Pin<Box<dyn Future<Output = bool> + 'a>> {
226        Box::pin(self(ctx))
227    }
228}
229
230impl TickClosure for TickClosureErased {
231    fn call_tick<'a>(&'a mut self, ctx: &'a mut Context) -> impl Future<Output = bool> + 'a {
232        self.0.call_tick(ctx)
233    }
234}
235
236/// Type alias for a type-erased [`Dfir`] that can be stored in heterogeneous collections.
237/// Created via [`Dfir::into_erased`].
238pub type DfirErased = Dfir<TickClosureErased>;
239
240impl<Tick: TickClosure> Dfir<Tick> {
241    /// Create a new `Dfir` from a tick closure, inline context,
242    /// and meta graph / diagnostics JSON strings.
243    #[doc(hidden)]
244    pub fn new(
245        tick_closure: Tick,
246        context: Context,
247        meta_graph_json: Option<&str>,
248        diagnostics_json: Option<&str>,
249    ) -> Self {
250        #[cfg(not(feature = "meta"))]
251        let _ = (meta_graph_json, diagnostics_json);
252        Self {
253            tick_closure,
254            wake_state: context.wake_state.clone(),
255            context,
256            #[cfg(feature = "meta")]
257            meta_graph: meta_graph_json.map(|json| {
258                let mut meta_graph: DfirGraph =
259                    serde_json::from_str(json).expect("Failed to deserialize graph.");
260                let mut op_inst_diagnostics = Diagnostics::new();
261                meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
262                assert!(
263                    op_inst_diagnostics.is_empty(),
264                    "Expected no diagnostics, got: {:#?}",
265                    op_inst_diagnostics
266                );
267                meta_graph
268            }),
269            #[cfg(feature = "meta")]
270            diagnostics: diagnostics_json.map(|json| {
271                serde_json::from_str(json).expect("Failed to deserialize diagnostics.")
272            }),
273        }
274    }
275
276    /// Return a handle to the meta graph, if set.
277    #[cfg(feature = "meta")]
278    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
279    pub fn meta_graph(&self) -> Option<&DfirGraph> {
280        self.meta_graph.as_ref()
281    }
282
283    /// Returns any diagnostics generated by the surface syntax macro.
284    #[cfg(feature = "meta")]
285    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
286    pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
287        self.diagnostics.as_deref()
288    }
289
290    /// Returns a reference-counted handle to the continually-updated runtime metrics for this DFIR instance.
291    pub fn metrics(&self) -> Rc<DfirMetrics> {
292        Rc::clone(self.context.metrics())
293    }
294
295    /// Gets the current tick (local time) count.
296    pub fn current_tick(&self) -> TickInstant {
297        self.context.current_tick()
298    }
299
300    /// Returns a [`DfirMetricsIntervals`] handle where each call to
301    /// [`DfirMetricsIntervals::take_interval`] ends the current interval and returns its metrics.
302    ///
303    /// The first call to `take_interval` returns metrics since this DFIR instance was created. Each subsequent call to
304    /// `take_interval` returns metrics since the previous call.
305    ///
306    /// Cloning the handle "forks" it from the original, as afterwards each interval may return different metrics
307    /// depending on when exactly `take_interval` is called.
308    pub fn metrics_intervals(&self) -> DfirMetricsIntervals {
309        DfirMetricsIntervals {
310            curr: self.metrics(),
311            prev: None,
312        }
313    }
314}
315
316impl<Tick: TickClosure> Dfir<Tick> {
317    /// Spawns all tasks buffered via [`Context::request_task`].
318    ///
319    /// This drains the buffer, so subsequent calls are no-ops until new tasks are requested.
320    fn spawn_tasks(&mut self) {
321        for task in self.context.tasks_to_spawn.drain(..) {
322            tokio::task::spawn_local(task);
323        }
324    }
325
326    /// Run a single tick. Returns `true` if any subgraph received input data.
327    ///
328    /// Checks both handoff buffers (via `work_done` flag set in generated recv port code)
329    /// and external events (via `can_start_tick` set by wakers/schedule_subgraph).
330    pub async fn run_tick(&mut self) -> bool {
331        self.spawn_tasks();
332        let had_external = self
333            .wake_state
334            .can_start_tick
335            .swap(false, Ordering::Relaxed);
336        let tick_had_work = self.tick_closure.call_tick(&mut self.context).await;
337        had_external || tick_had_work || self.wake_state.can_start_tick.load(Ordering::Relaxed)
338    }
339
340    /// Run a single tick synchronously. Panics if the tick yields (async suspension).
341    /// Returns `true` if work was done (see [`Self::run_tick`]).
342    pub fn run_tick_sync(&mut self) -> bool {
343        let mut fut = std::pin::pin!(self.run_tick());
344        let mut ctx = std::task::Context::from_waker(std::task::Waker::noop());
345        match fut.as_mut().poll(&mut ctx) {
346            std::task::Poll::Ready(result) => result,
347            std::task::Poll::Pending => {
348                panic!("Dfir::run_tick_sync: tick yielded asynchronously.")
349            }
350        }
351    }
352
353    /// Run ticks as long as work is available, then return.
354    pub async fn run_available(&mut self) {
355        // Always run at least one tick.
356        self.wake_state
357            .can_start_tick
358            .store(false, Ordering::Relaxed);
359        loop {
360            self.run_tick().await;
361            let can_start_tick = self
362                .wake_state
363                .can_start_tick
364                .swap(false, Ordering::Relaxed);
365            if !can_start_tick {
366                break;
367            }
368            // Yield between each tick to receive more events.
369            tokio::task::yield_now().await;
370        }
371    }
372
373    /// [`Self::run_available`] but panics if any tick yields asynchronously.
374    pub fn run_available_sync(&mut self) {
375        self.wake_state
376            .can_start_tick
377            .store(false, Ordering::Relaxed);
378        loop {
379            self.run_tick_sync();
380            let can_start_tick = self
381                .wake_state
382                .can_start_tick
383                .swap(false, Ordering::Relaxed);
384            if !can_start_tick {
385                break;
386            }
387        }
388    }
389
390    /// Run forever, processing ticks when work is available and yielding when idle.
391    pub async fn run(&mut self) -> crate::Never {
392        loop {
393            self.run_available().await;
394            // Wait for an external event to wake us.
395            std::future::poll_fn(|cx| {
396                // Register waker first to avoid race: if an event fires between
397                // the check and the register, the waker is already in place.
398                self.wake_state.task_waker.register(cx.waker());
399                if self.wake_state.can_start_tick.load(Ordering::Relaxed) {
400                    std::task::Poll::Ready(())
401                } else {
402                    std::task::Poll::Pending
403                }
404            })
405            .await;
406        }
407    }
408}
409
410impl<Tick: 'static + for<'a> AsyncFnMut(&'a mut Context) -> bool> Dfir<Tick> {
411    /// Type-erase the tick closure for use in heterogeneous collections.
412    ///
413    /// Wraps the concrete async closure in [`TickClosureErased`], which boxes the future
414    /// returned by each tick call. This adds one heap allocation per tick, but enables
415    /// storing multiple `Dfir`s with different closure types in a single `Vec`.
416    ///
417    /// Only needed for the sim runtime path. The trybuild and embedded paths keep the
418    /// concrete type and pay no erasure cost.
419    pub fn into_erased(self) -> DfirErased {
420        Dfir {
421            tick_closure: TickClosureErased(Box::new(self.tick_closure)),
422            wake_state: self.wake_state,
423            context: self.context,
424            #[cfg(feature = "meta")]
425            meta_graph: self.meta_graph,
426            #[cfg(feature = "meta")]
427            diagnostics: self.diagnostics,
428        }
429    }
430}