dfir_rs/scheduled/
graph.rs

1//! Module for the [`Dfir`] struct and helper items.
2
3use std::any::Any;
4use std::borrow::Cow;
5use std::cell::Cell;
6use std::cmp::Ordering;
7use std::future::Future;
8use std::marker::PhantomData;
9
10#[cfg(feature = "meta")]
11use dfir_lang::diagnostic::{Diagnostic, SerdeSpan};
12#[cfg(feature = "meta")]
13use dfir_lang::graph::DfirGraph;
14use ref_cast::RefCast;
15use smallvec::SmallVec;
16use web_time::SystemTime;
17
18use super::context::Context;
19use super::handoff::handoff_list::PortList;
20use super::handoff::{Handoff, HandoffMeta, TeeingHandoff};
21use super::port::{RECV, RecvCtx, RecvPort, SEND, SendCtx, SendPort};
22use super::reactor::Reactor;
23use super::state::StateHandle;
24use super::subgraph::Subgraph;
25use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
26use crate::Never;
27use crate::scheduled::ticks::{TickDuration, TickInstant};
28use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
29
30/// A DFIR graph. Owns, schedules, and runs the compiled subgraphs.
31#[derive(Default)]
32pub struct Dfir<'a> {
33    pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
34
35    pub(super) loop_data: SecondarySlotVec<LoopTag, LoopData>,
36
37    pub(super) context: Context,
38
39    handoffs: SlotVec<HandoffTag, HandoffData>,
40
41    #[cfg(feature = "meta")]
42    /// See [`Self::meta_graph()`].
43    meta_graph: Option<DfirGraph>,
44
45    #[cfg(feature = "meta")]
46    /// See [`Self::diagnostics()`].
47    diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
48}
49
50/// Methods for [`TeeingHandoff`] teeing and dropping.
51impl Dfir<'_> {
52    /// Tees a [`TeeingHandoff`].
53    pub fn teeing_handoff_tee<T>(
54        &mut self,
55        tee_parent_port: &RecvPort<TeeingHandoff<T>>,
56    ) -> RecvPort<TeeingHandoff<T>>
57    where
58        T: Clone,
59    {
60        // If we're teeing from a child make sure to find root.
61        let tee_root = self.handoffs[tee_parent_port.handoff_id].pred_handoffs[0];
62
63        // Set up teeing metadata.
64        let tee_root_data = &mut self.handoffs[tee_root];
65        let tee_root_data_name = tee_root_data.name.clone();
66
67        // Insert new handoff output.
68        let teeing_handoff =
69            <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*tee_root_data.handoff).unwrap();
70        let new_handoff = teeing_handoff.tee();
71
72        // Handoff ID of new tee output.
73        let new_hoff_id = self.handoffs.insert_with_key(|new_hoff_id| {
74            let new_name = Cow::Owned(format!("{} tee {:?}", tee_root_data_name, new_hoff_id));
75            let mut new_handoff_data = HandoffData::new(new_name, new_handoff, new_hoff_id);
76            // Set self's predecessor as `tee_root`.
77            new_handoff_data.pred_handoffs = vec![tee_root];
78            new_handoff_data
79        });
80
81        // Go to `tee_root`'s successors and insert self (the new tee output).
82        let tee_root_data = &mut self.handoffs[tee_root];
83        tee_root_data.succ_handoffs.push(new_hoff_id);
84
85        // Add our new handoff id into the subgraph data if the send `tee_root` has already been
86        // used to add a subgraph.
87        assert!(
88            tee_root_data.preds.len() <= 1,
89            "Tee send side should only have one sender (or none set yet)."
90        );
91        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
92            self.subgraphs[pred_sg_id].succs.push(new_hoff_id);
93        }
94
95        let output_port = RecvPort {
96            handoff_id: new_hoff_id,
97            _marker: PhantomData,
98        };
99        output_port
100    }
101
102    /// Marks an output of a [`TeeingHandoff`] as dropped so that no more data will be sent to it.
103    ///
104    /// It is recommended to not not use this method and instead simply avoid teeing a
105    /// [`TeeingHandoff`] when it is not needed.
106    pub fn teeing_handoff_drop<T>(&mut self, tee_port: RecvPort<TeeingHandoff<T>>)
107    where
108        T: Clone,
109    {
110        let data = &self.handoffs[tee_port.handoff_id];
111        let teeing_handoff = <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*data.handoff).unwrap();
112        teeing_handoff.drop();
113
114        let tee_root = data.pred_handoffs[0];
115        let tee_root_data = &mut self.handoffs[tee_root];
116        // Remove this output from the send succ handoff list.
117        tee_root_data
118            .succ_handoffs
119            .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
120        // Remove from subgraph successors if send port was already connected.
121        assert!(
122            tee_root_data.preds.len() <= 1,
123            "Tee send side should only have one sender (or none set yet)."
124        );
125        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
126            self.subgraphs[pred_sg_id]
127                .succs
128                .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
129        }
130    }
131}
132
133impl<'a> Dfir<'a> {
134    /// Create a new empty graph.
135    pub fn new() -> Self {
136        Default::default()
137    }
138
139    /// Assign the meta graph via JSON string. Used internally by the [`dfir_syntax`] and other macros.
140    #[doc(hidden)]
141    pub fn __assign_meta_graph(&mut self, _meta_graph_json: &str) {
142        #[cfg(feature = "meta")]
143        {
144            let mut meta_graph: DfirGraph =
145                serde_json::from_str(_meta_graph_json).expect("Failed to deserialize graph.");
146
147            let mut op_inst_diagnostics = Vec::new();
148            meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
149            assert!(
150                op_inst_diagnostics.is_empty(),
151                "Expected no diagnostics, got: {:#?}",
152                op_inst_diagnostics
153            );
154
155            assert!(self.meta_graph.replace(meta_graph).is_none());
156        }
157    }
158    /// Assign the diagnostics via JSON string.
159    #[doc(hidden)]
160    pub fn __assign_diagnostics(&mut self, _diagnostics_json: &'static str) {
161        #[cfg(feature = "meta")]
162        {
163            let diagnostics: Vec<Diagnostic<SerdeSpan>> = serde_json::from_str(_diagnostics_json)
164                .expect("Failed to deserialize diagnostics.");
165
166            assert!(self.diagnostics.replace(diagnostics).is_none());
167        }
168    }
169
170    /// Return a handle to the meta graph, if set. The meta graph is a
171    /// representation of all the operators, subgraphs, and handoffs in this instance.
172    /// Will only be set if this graph was constructed using a surface syntax macro.
173    #[cfg(feature = "meta")]
174    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
175    pub fn meta_graph(&self) -> Option<&DfirGraph> {
176        self.meta_graph.as_ref()
177    }
178
179    /// Returns any diagnostics generated by the surface syntax macro. Each diagnostic is a pair of
180    /// (1) a `Diagnostic` with span info reset and (2) the `ToString` version of the diagnostic
181    /// with original span info.
182    /// Will only be set if this graph was constructed using a surface syntax macro.
183    #[cfg(feature = "meta")]
184    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
185    pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
186        self.diagnostics.as_deref()
187    }
188
189    /// Returns a reactor for externally scheduling subgraphs, possibly from another thread.
190    /// Reactor events are considered to be external events.
191    pub fn reactor(&self) -> Reactor {
192        Reactor::new(self.context.event_queue_send.clone())
193    }
194
195    /// Gets the current tick (local time) count.
196    pub fn current_tick(&self) -> TickInstant {
197        self.context.current_tick
198    }
199
200    /// Gets the current stratum nubmer.
201    pub fn current_stratum(&self) -> usize {
202        self.context.current_stratum
203    }
204
205    /// Runs the dataflow until the next tick begins.
206    ///
207    /// Returns `true` if any work was done.
208    ///
209    /// Will receive events if it is the start of a tick when called.
210    #[tracing::instrument(level = "trace", skip(self), ret)]
211    pub async fn run_tick(&mut self) -> bool {
212        let mut work_done = false;
213        // While work is immediately available *on the current tick*.
214        while self.next_stratum(true) {
215            work_done = true;
216            // Do any work.
217            self.run_stratum().await;
218        }
219        work_done
220    }
221
222    /// Runs the dataflow until no more (externally-triggered) work is immediately available.
223    /// Runs at least one tick of dataflow, even if no external events have been received.
224    /// If the dataflow contains loops this method may run forever.
225    ///
226    /// Returns `true` if any work was done.
227    ///
228    /// Yields repeatedly to allow external events to happen.
229    #[tracing::instrument(level = "trace", skip(self), ret)]
230    pub async fn run_available(&mut self) -> bool {
231        let mut work_done = false;
232        // While work is immediately available.
233        while self.next_stratum(false) {
234            work_done = true;
235            // Do any work.
236            self.run_stratum().await;
237
238            // Yield between each stratum to receive more events.
239            // TODO(mingwei): really only need to yield at start of ticks though.
240            tokio::task::yield_now().await;
241        }
242        work_done
243    }
244
245    /// Runs the current stratum of the dataflow until no more local work is available (does not receive events).
246    ///
247    /// Returns `true` if any work was done.
248    #[tracing::instrument(level = "trace", skip(self), fields(tick = u64::from(self.context.current_tick), stratum = self.context.current_stratum), ret)]
249    pub async fn run_stratum(&mut self) -> bool {
250        // Make sure to spawn tasks once dfir is running!
251        // This drains the task buffer, so becomes a no-op after first call.
252        self.context.spawn_tasks();
253
254        let mut work_done = false;
255
256        'pop: while let Some(sg_id) =
257            self.context.stratum_queues[self.context.current_stratum].pop_front()
258        {
259            {
260                let sg_data = &mut self.subgraphs[sg_id];
261                // This must be true for the subgraph to be enqueued.
262                assert!(sg_data.is_scheduled.take());
263
264                let _enter = tracing::info_span!(
265                    "run-subgraph",
266                    sg_id = sg_id.to_string(),
267                    sg_name = &*sg_data.name,
268                    sg_depth = sg_data.loop_depth,
269                    sg_loop_nonce = sg_data.last_loop_nonce.0,
270                    sg_iter_count = sg_data.last_loop_nonce.1,
271                )
272                .entered();
273
274                match sg_data.loop_depth.cmp(&self.context.loop_nonce_stack.len()) {
275                    Ordering::Greater => {
276                        // We have entered a loop.
277                        self.context.loop_nonce += 1;
278                        self.context.loop_nonce_stack.push(self.context.loop_nonce);
279                        tracing::trace!(loop_nonce = self.context.loop_nonce, "Entered loop.");
280                    }
281                    Ordering::Less => {
282                        // We have exited a loop.
283                        self.context.loop_nonce_stack.pop();
284                        tracing::trace!("Exited loop.");
285                    }
286                    Ordering::Equal => {}
287                }
288
289                self.context.subgraph_id = sg_id;
290                self.context.is_first_run_this_tick = sg_data
291                    .last_tick_run_in
292                    .is_none_or(|last_tick| last_tick < self.context.current_tick);
293
294                if let Some(loop_id) = sg_data.loop_id {
295                    // Loop execution - running loop block, from start to finish, containing
296                    // multiple iterations.
297                    // Loop iteration - a single iteration of a loop block, all subgraphs within
298                    // the loop should run (at most) once.
299
300                    // If the previous run of this subgraph had the same loop execution and
301                    // iteration count, then we need to increment the iteration count.
302                    let curr_loop_nonce = self.context.loop_nonce_stack.last().copied();
303
304                    let LoopData {
305                        iter_count: loop_iter_count,
306                        allow_another_iteration,
307                    } = &mut self.loop_data[loop_id];
308
309                    let (prev_loop_nonce, prev_iter_count) = sg_data.last_loop_nonce;
310
311                    // If the loop nonce is the same as the previous execution, then we are in
312                    // the same loop execution.
313                    // `curr_loop_nonce` is `None` for top-level loops, and top-level loops are
314                    // always in the same (singular) loop execution.
315                    let (curr_iter_count, new_loop_execution) =
316                        if curr_loop_nonce.is_none_or(|nonce| nonce == prev_loop_nonce) {
317                            // If the iteration count is the same as the previous execution, then
318                            // we are on the next iteration.
319                            if *loop_iter_count == prev_iter_count {
320                                // If not true, then we shall not run the next iteration.
321                                if !std::mem::take(allow_another_iteration) {
322                                    tracing::debug!(
323                                        "Loop will not continue to next iteration, skipping."
324                                    );
325                                    continue 'pop;
326                                }
327                                // Increment `loop_iter_count` or set it to 0.
328                                loop_iter_count.map_or((0, true), |n| (n + 1, false))
329                            } else {
330                                // Otherwise update the local iteration count to match the loop.
331                                debug_assert!(
332                                    prev_iter_count < *loop_iter_count,
333                                    "Expect loop iteration count to be increasing."
334                                );
335                                (loop_iter_count.unwrap(), false)
336                            }
337                        } else {
338                            // The loop execution has already begun, but this is the first time this particular subgraph is running.
339                            (0, false)
340                        };
341
342                    if new_loop_execution {
343                        // Run state hooks.
344                        self.context.run_state_hooks_loop(loop_id);
345                    }
346                    tracing::debug!("Loop iteration count {}", curr_iter_count);
347
348                    *loop_iter_count = Some(curr_iter_count);
349                    self.context.loop_iter_count = curr_iter_count;
350                    sg_data.last_loop_nonce =
351                        (curr_loop_nonce.unwrap_or_default(), Some(curr_iter_count));
352                }
353
354                // Run subgraph state hooks.
355                self.context.run_state_hooks_subgraph(sg_id);
356
357                tracing::info!("Running subgraph.");
358                sg_data.last_tick_run_in = Some(self.context.current_tick);
359                Box::into_pin(sg_data.subgraph.run(&mut self.context, &mut self.handoffs)).await;
360            };
361
362            let sg_data = &self.subgraphs[sg_id];
363            for &handoff_id in sg_data.succs.iter() {
364                let handoff = &self.handoffs[handoff_id];
365                if !handoff.handoff.is_bottom() {
366                    for &succ_id in handoff.succs.iter() {
367                        let succ_sg_data = &self.subgraphs[succ_id];
368                        // If we have sent data to the next tick, then we can start the next tick.
369                        if succ_sg_data.stratum < self.context.current_stratum && !sg_data.is_lazy {
370                            self.context.can_start_tick = true;
371                        }
372                        // Add subgraph to stratum queue if it is not already scheduled.
373                        if !succ_sg_data.is_scheduled.replace(true) {
374                            self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
375                        }
376                        // Add stratum to stratum stack if it is within a loop.
377                        if 0 < succ_sg_data.loop_depth {
378                            // TODO(mingwei): handle duplicates
379                            self.context
380                                .stratum_stack
381                                .push(succ_sg_data.loop_depth, succ_sg_data.stratum);
382                        }
383                    }
384                }
385            }
386
387            let reschedule = self.context.reschedule_loop_block.take();
388            let allow_another = self.context.allow_another_iteration.take();
389
390            if reschedule {
391                // Re-enqueue the subgraph.
392                self.context.schedule_deferred.push(sg_id);
393                self.context
394                    .stratum_stack
395                    .push(sg_data.loop_depth, sg_data.stratum);
396            }
397            if (reschedule || allow_another)
398                && let Some(loop_id) = sg_data.loop_id
399            {
400                self.loop_data
401                    .get_mut(loop_id)
402                    .unwrap()
403                    .allow_another_iteration = true;
404            }
405
406            work_done = true;
407        }
408        work_done
409    }
410
411    /// Go to the next stratum which has work available, possibly the current stratum.
412    /// Return true if more work is available, otherwise false if no work is immediately
413    /// available on any strata.
414    ///
415    /// This will receive external events when at the start of a tick.
416    ///
417    /// If `current_tick_only` is set to `true`, will only return `true` if work is immediately
418    /// available on the *current tick*.
419    ///
420    /// If this returns false then the graph will be at the start of a tick (at stratum 0, can
421    /// receive more external events).
422    #[tracing::instrument(level = "trace", skip(self), ret)]
423    pub fn next_stratum(&mut self, current_tick_only: bool) -> bool {
424        tracing::trace!(
425            events_received_tick = self.context.events_received_tick,
426            can_start_tick = self.context.can_start_tick,
427            "Starting `next_stratum` call.",
428        );
429
430        // The stratum we will stop searching at, i.e. made a full loop around.
431        let mut end_stratum = self.context.current_stratum;
432        let mut new_tick_started = false;
433
434        if 0 == self.context.current_stratum {
435            new_tick_started = true;
436
437            // Starting the tick, reset this to `false`.
438            tracing::trace!("Starting tick, setting `can_start_tick = false`.");
439            self.context.can_start_tick = false;
440            self.context.current_tick_start = SystemTime::now();
441
442            // Ensure external events are received before running the tick.
443            if !self.context.events_received_tick {
444                // Add any external jobs to ready queue.
445                self.try_recv_events();
446            }
447        }
448
449        loop {
450            tracing::trace!(
451                tick = u64::from(self.context.current_tick),
452                stratum = self.context.current_stratum,
453                "Looking for work on stratum."
454            );
455            // If current stratum has work, return true.
456            if !self.context.stratum_queues[self.context.current_stratum].is_empty() {
457                tracing::trace!(
458                    tick = u64::from(self.context.current_tick),
459                    stratum = self.context.current_stratum,
460                    "Work found on stratum."
461                );
462                return true;
463            }
464
465            if let Some(next_stratum) = self.context.stratum_stack.pop() {
466                self.context.current_stratum = next_stratum;
467
468                // Now schedule deferred subgraphs.
469                {
470                    for sg_id in self.context.schedule_deferred.drain(..) {
471                        let sg_data = &self.subgraphs[sg_id];
472                        tracing::info!(
473                            tick = u64::from(self.context.current_tick),
474                            stratum = self.context.current_stratum,
475                            sg_id = sg_id.to_string(),
476                            sg_name = &*sg_data.name,
477                            is_scheduled = sg_data.is_scheduled.get(),
478                            "Rescheduling deferred subgraph."
479                        );
480                        if !sg_data.is_scheduled.replace(true) {
481                            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
482                        }
483                    }
484                }
485            } else {
486                // Increment stratum counter.
487                self.context.current_stratum += 1;
488
489                if self.context.current_stratum >= self.context.stratum_queues.len() {
490                    new_tick_started = true;
491
492                    tracing::trace!(
493                        can_start_tick = self.context.can_start_tick,
494                        "End of tick {}, starting tick {}.",
495                        self.context.current_tick,
496                        self.context.current_tick + TickDuration::SINGLE_TICK,
497                    );
498                    self.context.run_state_hooks_tick();
499
500                    self.context.current_stratum = 0;
501                    self.context.current_tick += TickDuration::SINGLE_TICK;
502                    self.context.events_received_tick = false;
503
504                    if current_tick_only {
505                        tracing::trace!(
506                            "`current_tick_only` is `true`, returning `false` before receiving events."
507                        );
508                        return false;
509                    } else {
510                        self.try_recv_events();
511                        if std::mem::replace(&mut self.context.can_start_tick, false) {
512                            tracing::trace!(
513                                tick = u64::from(self.context.current_tick),
514                                "`can_start_tick` is `true`, continuing."
515                            );
516                            // Do a full loop more to find where events have been added.
517                            end_stratum = 0;
518                            continue;
519                        } else {
520                            tracing::trace!(
521                                "`can_start_tick` is `false`, re-setting `events_received_tick = false`, returning `false`."
522                            );
523                            self.context.events_received_tick = false;
524                            return false;
525                        }
526                    }
527                }
528            }
529
530            // After incrementing, exit if we made a full loop around the strata.
531            if new_tick_started && end_stratum == self.context.current_stratum {
532                tracing::trace!(
533                    "Made full loop around stratum, re-setting `current_stratum = 0`, returning `false`."
534                );
535                // Note: if current stratum had work, the very first loop iteration would've
536                // returned true. Therefore we can return false without checking.
537                // Also means nothing was done so we can reset the stratum to zero and wait for
538                // events.
539                self.context.events_received_tick = false;
540                self.context.current_stratum = 0;
541                return false;
542            }
543        }
544    }
545
546    /// Runs the dataflow graph forever.
547    ///
548    /// TODO(mingwei): Currently blocks forever, no notion of "completion."
549    #[tracing::instrument(level = "trace", skip(self), ret)]
550    pub async fn run(&mut self) -> Option<Never> {
551        loop {
552            // Run any work which is immediately available.
553            self.run_available().await;
554            // When no work is available yield until more events occur.
555            self.recv_events_async().await;
556        }
557    }
558
559    /// Enqueues subgraphs triggered by events without blocking.
560    ///
561    /// Returns the number of subgraphs enqueued, and if any were external.
562    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
563    pub fn try_recv_events(&mut self) -> usize {
564        let mut enqueued_count = 0;
565        while let Ok((sg_id, is_external)) = self.context.event_queue_recv.try_recv() {
566            let sg_data = &self.subgraphs[sg_id];
567            tracing::trace!(
568                sg_id = sg_id.to_string(),
569                is_external = is_external,
570                sg_stratum = sg_data.stratum,
571                "Event received."
572            );
573            if !sg_data.is_scheduled.replace(true) {
574                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
575                enqueued_count += 1;
576            }
577            if is_external {
578                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
579                // Or if the stratum is in the next tick.
580                if !self.context.events_received_tick
581                    || sg_data.stratum < self.context.current_stratum
582                {
583                    tracing::trace!(
584                        current_stratum = self.context.current_stratum,
585                        sg_stratum = sg_data.stratum,
586                        "External event, setting `can_start_tick = true`."
587                    );
588                    self.context.can_start_tick = true;
589                }
590            }
591        }
592        self.context.events_received_tick = true;
593
594        enqueued_count
595    }
596
597    /// Enqueues subgraphs triggered by external events, blocking until at
598    /// least one subgraph is scheduled **from an external event**.
599    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
600    pub fn recv_events(&mut self) -> Option<usize> {
601        let mut count = 0;
602        loop {
603            let (sg_id, is_external) = self.context.event_queue_recv.blocking_recv()?;
604            let sg_data = &self.subgraphs[sg_id];
605            tracing::trace!(
606                sg_id = sg_id.to_string(),
607                is_external = is_external,
608                sg_stratum = sg_data.stratum,
609                "Event received."
610            );
611            if !sg_data.is_scheduled.replace(true) {
612                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
613                count += 1;
614            }
615            if is_external {
616                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
617                // Or if the stratum is in the next tick.
618                if !self.context.events_received_tick
619                    || sg_data.stratum < self.context.current_stratum
620                {
621                    tracing::trace!(
622                        current_stratum = self.context.current_stratum,
623                        sg_stratum = sg_data.stratum,
624                        "External event, setting `can_start_tick = true`."
625                    );
626                    self.context.can_start_tick = true;
627                }
628                break;
629            }
630        }
631        self.context.events_received_tick = true;
632
633        // Enqueue any other immediate events.
634        let extra_count = self.try_recv_events();
635        Some(count + extra_count)
636    }
637
638    /// Enqueues subgraphs triggered by external events asynchronously, waiting until at least one
639    /// subgraph is scheduled **from an external event**. Returns the number of subgraphs enqueued,
640    /// which may be zero if an external event scheduled an already-scheduled subgraph.
641    ///
642    /// Returns `None` if the event queue is closed, but that should not happen normally.
643    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
644    pub async fn recv_events_async(&mut self) -> Option<usize> {
645        let mut count = 0;
646        loop {
647            tracing::trace!("Awaiting events (`event_queue_recv`).");
648            let (sg_id, is_external) = self.context.event_queue_recv.recv().await?;
649            let sg_data = &self.subgraphs[sg_id];
650            tracing::trace!(
651                sg_id = sg_id.to_string(),
652                is_external = is_external,
653                sg_stratum = sg_data.stratum,
654                "Event received."
655            );
656            if !sg_data.is_scheduled.replace(true) {
657                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
658                count += 1;
659            }
660            if is_external {
661                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
662                // Or if the stratum is in the next tick.
663                if !self.context.events_received_tick
664                    || sg_data.stratum < self.context.current_stratum
665                {
666                    tracing::trace!(
667                        current_stratum = self.context.current_stratum,
668                        sg_stratum = sg_data.stratum,
669                        "External event, setting `can_start_tick = true`."
670                    );
671                    self.context.can_start_tick = true;
672                }
673                break;
674            }
675        }
676        self.context.events_received_tick = true;
677
678        // Enqueue any other immediate events.
679        let extra_count = self.try_recv_events();
680        Some(count + extra_count)
681    }
682
683    /// Schedules a subgraph to be run. See also: [`Context::schedule_subgraph`].
684    pub fn schedule_subgraph(&mut self, sg_id: SubgraphId) -> bool {
685        let sg_data = &self.subgraphs[sg_id];
686        let already_scheduled = sg_data.is_scheduled.replace(true);
687        if !already_scheduled {
688            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
689            true
690        } else {
691            false
692        }
693    }
694
695    /// Adds a new compiled subgraph with the specified inputs and outputs in stratum 0.
696    pub fn add_subgraph<Name, R, W, Func, Fut>(
697        &mut self,
698        name: Name,
699        recv_ports: R,
700        send_ports: W,
701        subgraph: Func,
702    ) -> SubgraphId
703    where
704        Name: Into<Cow<'static, str>>,
705        R: 'static + PortList<RECV>,
706        W: 'static + PortList<SEND>,
707        Func: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>) -> Fut,
708        Fut: 'a + Future<Output = ()>,
709    {
710        self.add_subgraph_stratified(name, 0, recv_ports, send_ports, false, subgraph)
711    }
712
713    /// Adds a new compiled subgraph with the specified inputs, outputs, and stratum number.
714    ///
715    /// TODO(mingwei): add example in doc.
716    pub fn add_subgraph_stratified<Name, R, W, Func, Fut>(
717        &mut self,
718        name: Name,
719        stratum: usize,
720        recv_ports: R,
721        send_ports: W,
722        laziness: bool,
723        subgraph: Func,
724    ) -> SubgraphId
725    where
726        Name: Into<Cow<'static, str>>,
727        R: 'static + PortList<RECV>,
728        W: 'static + PortList<SEND>,
729        Func: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>) -> Fut,
730        Fut: 'a + Future<Output = ()>,
731    {
732        self.add_subgraph_full(
733            name, stratum, recv_ports, send_ports, laziness, None, subgraph,
734        )
735    }
736
737    /// Adds a new compiled subgraph with all options.
738    #[expect(clippy::too_many_arguments, reason = "Mainly for internal use.")]
739    pub fn add_subgraph_full<Name, R, W, Func, Fut>(
740        &mut self,
741        name: Name,
742        stratum: usize,
743        recv_ports: R,
744        send_ports: W,
745        laziness: bool,
746        loop_id: Option<LoopId>,
747        mut subgraph: Func,
748    ) -> SubgraphId
749    where
750        Name: Into<Cow<'static, str>>,
751        R: 'static + PortList<RECV>,
752        W: 'static + PortList<SEND>,
753        Func: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>) -> Fut,
754        Fut: 'a + Future<Output = ()>,
755    {
756        // SAFETY: Check that the send and recv ports are from `self.handoffs`.
757        recv_ports.assert_is_from(&self.handoffs);
758        send_ports.assert_is_from(&self.handoffs);
759
760        let loop_depth = loop_id
761            .and_then(|loop_id| self.context.loop_depth.get(loop_id))
762            .copied()
763            .unwrap_or(0);
764
765        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
766            let (mut subgraph_preds, mut subgraph_succs) = Default::default();
767            recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
768            send_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_succs, sg_id, false);
769
770            let subgraph =
771                move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
772                    let (recv, send) = unsafe {
773                        // SAFETY:
774                        // 1. We checked `assert_is_from` at assembly time, above.
775                        // 2. `SlotVec` is insert-only so no handoffs could have changed since then.
776                        (
777                            recv_ports.make_ctx(&*handoffs),
778                            send_ports.make_ctx(&*handoffs),
779                        )
780                    };
781                    (subgraph)(context, recv, send)
782                };
783            SubgraphData::new(
784                name.into(),
785                stratum,
786                subgraph,
787                subgraph_preds,
788                subgraph_succs,
789                true,
790                laziness,
791                loop_id,
792                loop_depth,
793            )
794        });
795        self.context.init_stratum(stratum);
796        self.context.stratum_queues[stratum].push_back(sg_id);
797
798        sg_id
799    }
800
801    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
802    pub fn add_subgraph_n_m<Name, R, W, Func, Fut>(
803        &mut self,
804        name: Name,
805        recv_ports: Vec<RecvPort<R>>,
806        send_ports: Vec<SendPort<W>>,
807        subgraph: Func,
808    ) -> SubgraphId
809    where
810        Name: Into<Cow<'static, str>>,
811        R: 'static + Handoff,
812        W: 'static + Handoff,
813        Func: 'a
814            + for<'ctx> FnMut(
815                &'ctx mut Context,
816                &'ctx [&'ctx RecvCtx<R>],
817                &'ctx [&'ctx SendCtx<W>],
818            ) -> Fut,
819        Fut: 'a + Future<Output = ()>,
820    {
821        self.add_subgraph_stratified_n_m(name, 0, recv_ports, send_ports, subgraph)
822    }
823
824    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
825    pub fn add_subgraph_stratified_n_m<Name, R, W, Func, Fut>(
826        &mut self,
827        name: Name,
828        stratum: usize,
829        recv_ports: Vec<RecvPort<R>>,
830        send_ports: Vec<SendPort<W>>,
831        mut subgraph: Func,
832    ) -> SubgraphId
833    where
834        Name: Into<Cow<'static, str>>,
835        R: 'static + Handoff,
836        W: 'static + Handoff,
837        Func: 'a
838            + for<'ctx> FnMut(
839                &'ctx mut Context,
840                &'ctx [&'ctx RecvCtx<R>],
841                &'ctx [&'ctx SendCtx<W>],
842            ) -> Fut,
843        Fut: 'a + Future<Output = ()>,
844    {
845        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
846            let subgraph_preds = recv_ports.iter().map(|port| port.handoff_id).collect();
847            let subgraph_succs = send_ports.iter().map(|port| port.handoff_id).collect();
848
849            for recv_port in recv_ports.iter() {
850                self.handoffs[recv_port.handoff_id].succs.push(sg_id);
851            }
852            for send_port in send_ports.iter() {
853                self.handoffs[send_port.handoff_id].preds.push(sg_id);
854            }
855
856            let subgraph =
857                move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
858                    let recvs: Vec<&RecvCtx<R>> = recv_ports
859                        .iter()
860                        .map(|hid| hid.handoff_id)
861                        .map(|hid| handoffs.get(hid).unwrap())
862                        .map(|h_data| {
863                            <dyn Any>::downcast_ref(&*h_data.handoff)
864                                .expect("Attempted to cast handoff to wrong type.")
865                        })
866                        .map(RefCast::ref_cast)
867                        .collect();
868
869                    let sends: Vec<&SendCtx<W>> = send_ports
870                        .iter()
871                        .map(|hid| hid.handoff_id)
872                        .map(|hid| handoffs.get(hid).unwrap())
873                        .map(|h_data| {
874                            <dyn Any>::downcast_ref(&*h_data.handoff)
875                                .expect("Attempted to cast handoff to wrong type.")
876                        })
877                        .map(RefCast::ref_cast)
878                        .collect();
879
880                    (subgraph)(context, &recvs, &sends)
881                };
882            SubgraphData::new(
883                name.into(),
884                stratum,
885                subgraph,
886                subgraph_preds,
887                subgraph_succs,
888                true,
889                false,
890                None,
891                0,
892            )
893        });
894
895        self.context.init_stratum(stratum);
896        self.context.stratum_queues[stratum].push_back(sg_id);
897
898        sg_id
899    }
900
901    /// Creates a handoff edge and returns the corresponding send and receive ports.
902    pub fn make_edge<Name, H>(&mut self, name: Name) -> (SendPort<H>, RecvPort<H>)
903    where
904        Name: Into<Cow<'static, str>>,
905        H: 'static + Handoff,
906    {
907        // Create and insert handoff.
908        let handoff = H::default();
909        let handoff_id = self
910            .handoffs
911            .insert_with_key(|hoff_id| HandoffData::new(name.into(), handoff, hoff_id));
912
913        // Make ports.
914        let input_port = SendPort {
915            handoff_id,
916            _marker: PhantomData,
917        };
918        let output_port = RecvPort {
919            handoff_id,
920            _marker: PhantomData,
921        };
922        (input_port, output_port)
923    }
924
925    /// Adds referenceable state into this instance. Returns a state handle which can be
926    /// used externally or by operators to access the state.
927    ///
928    /// This is part of the "state API".
929    pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
930    where
931        T: Any,
932    {
933        self.context.add_state(state)
934    }
935
936    /// Sets a hook to modify the state at the end of each tick, using the supplied closure.
937    ///
938    /// This is part of the "state API".
939    pub fn set_state_lifespan_hook<T>(
940        &mut self,
941        handle: StateHandle<T>,
942        lifespan: StateLifespan,
943        hook_fn: impl 'static + FnMut(&mut T),
944    ) where
945        T: Any,
946    {
947        self.context
948            .set_state_lifespan_hook(handle, lifespan, hook_fn)
949    }
950
951    /// Gets a exclusive (mut) ref to the internal context, setting the subgraph ID.
952    pub fn context_mut(&mut self, sg_id: SubgraphId) -> &mut Context {
953        self.context.subgraph_id = sg_id;
954        &mut self.context
955    }
956
957    /// Adds a new loop with the given parent (or `None` for top-level). Returns a loop ID which
958    /// is used in [`Self::add_subgraph_stratified`] or for nested loops.
959    ///
960    /// TODO(mingwei): add loop names to ensure traceability while debugging?
961    pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
962        let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
963        let loop_id = self.context.loop_depth.insert(depth);
964        self.loop_data.insert(
965            loop_id,
966            LoopData {
967                iter_count: None,
968                allow_another_iteration: true,
969            },
970        );
971        loop_id
972    }
973}
974
975impl Dfir<'_> {
976    /// Alias for [`Context::request_task`].
977    pub fn request_task<Fut>(&mut self, future: Fut)
978    where
979        Fut: Future<Output = ()> + 'static,
980    {
981        self.context.request_task(future);
982    }
983
984    /// Alias for [`Context::abort_tasks`].
985    pub fn abort_tasks(&mut self) {
986        self.context.abort_tasks()
987    }
988
989    /// Alias for [`Context::join_tasks`].
990    pub fn join_tasks(&mut self) -> impl use<'_> + Future {
991        self.context.join_tasks()
992    }
993}
994
995macro_rules! make_sync {
996    ($synch:ident, $asynch:ident, $ret:ty) => {
997        #[doc = concat!("A synchronous wrapper around [`Self::", stringify!($asynch), "`].")]
998        ///
999        /// This method will panic if the graph contains any async subgraph which returns `Poll::Pending`.
1000        pub fn $synch(&mut self) -> $ret {
1001            use std::sync::Arc;
1002
1003            #[derive(Default)]
1004            struct BoolWaker {
1005                woke: std::sync::atomic::AtomicBool,
1006            }
1007            impl BoolWaker {
1008                fn new() -> Arc<Self> {
1009                    Arc::new(Self::default())
1010                }
1011
1012                fn woke(&self) -> bool {
1013                    self.woke.load(std::sync::atomic::Ordering::Relaxed)
1014                }
1015            }
1016            impl futures::task::ArcWake for BoolWaker {
1017                fn wake_by_ref(arc_self: &Arc<Self>) {
1018                    arc_self.woke.store(true, std::sync::atomic::Ordering::Relaxed);
1019                }
1020            }
1021
1022            let mut fut = std::pin::pin!(self.$asynch());
1023            loop {
1024                let bool_waker = BoolWaker::new();
1025                let waker = futures::task::waker(Arc::clone(&bool_waker));
1026                let mut ctx = std::task::Context::from_waker(&waker);
1027                if let std::task::Poll::Ready(out) = fut.as_mut().poll(&mut ctx) {
1028                    return out;
1029                }
1030                // Spin until waker stops being woken.
1031                if !bool_waker.woke() {
1032                    panic!(
1033                        "Future has pending work: DFIR graph has an async subgraph which failed to run synchronously."
1034                    )
1035                }
1036            }
1037        }
1038    };
1039}
1040impl Dfir<'_> {
1041    make_sync!(run_available_sync, run_available, bool);
1042    make_sync!(run_tick_sync, run_tick, bool);
1043    make_sync!(run_sync, run, Option<Never>);
1044}
1045
1046impl Drop for Dfir<'_> {
1047    fn drop(&mut self) {
1048        self.abort_tasks();
1049    }
1050}
1051
1052/// A handoff and its input and output [SubgraphId]s.
1053///
1054/// Internal use: used to track the dfir graph structure.
1055///
1056/// TODO(mingwei): restructure `PortList` so this can be crate-private.
1057#[doc(hidden)]
1058pub struct HandoffData {
1059    /// A friendly name for diagnostics.
1060    pub(super) name: Cow<'static, str>,
1061    /// Crate-visible to crate for `handoff_list` internals.
1062    pub(super) handoff: Box<dyn HandoffMeta>,
1063    /// Preceeding subgraphs (including the send side of a teeing handoff).
1064    pub(super) preds: SmallVec<[SubgraphId; 1]>,
1065    /// Successor subgraphs (including recv sides of teeing handoffs).
1066    pub(super) succs: SmallVec<[SubgraphId; 1]>,
1067
1068    /// Predecessor handoffs, used by teeing handoffs.
1069    /// Should be `self` on any teeing send sides (input).
1070    /// Should be the send `HandoffId` if this is teeing recv side (output).
1071    /// Should be just `self`'s `HandoffId` on other handoffs.
1072    /// This field is only used in initialization.
1073    pub(super) pred_handoffs: Vec<HandoffId>,
1074    /// Successor handoffs, used by teeing handoffs.
1075    /// Should be a list of outputs on the teeing send side (input).
1076    /// Should be `self` on any teeing recv sides (outputs).
1077    /// Should be just `self`'s `HandoffId` on other handoffs.
1078    /// This field is only used in initialization.
1079    pub(super) succ_handoffs: Vec<HandoffId>,
1080}
1081impl std::fmt::Debug for HandoffData {
1082    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1083        f.debug_struct("HandoffData")
1084            .field("preds", &self.preds)
1085            .field("succs", &self.succs)
1086            .finish_non_exhaustive()
1087    }
1088}
1089impl HandoffData {
1090    /// New with `pred_handoffs` and `succ_handoffs` set to its own [`HandoffId`]: `vec![hoff_id]`.
1091    pub fn new(
1092        name: Cow<'static, str>,
1093        handoff: impl 'static + HandoffMeta,
1094        hoff_id: HandoffId,
1095    ) -> Self {
1096        let (preds, succs) = Default::default();
1097        Self {
1098            name,
1099            handoff: Box::new(handoff),
1100            preds,
1101            succs,
1102            pred_handoffs: vec![hoff_id],
1103            succ_handoffs: vec![hoff_id],
1104        }
1105    }
1106}
1107
1108/// A subgraph along with its predecessor and successor [SubgraphId]s.
1109///
1110/// Used internally by the [Dfir] struct to represent the dataflow graph
1111/// structure and scheduled state.
1112pub(super) struct SubgraphData<'a> {
1113    /// A friendly name for diagnostics.
1114    pub(super) name: Cow<'static, str>,
1115    /// This subgraph's stratum number.
1116    ///
1117    /// Within loop blocks, corresponds to the topological sort of the DAG created when `next_loop()/next_tick()` are removed.
1118    pub(super) stratum: usize,
1119    /// The actual execution code of the subgraph.
1120    subgraph: Box<dyn 'a + Subgraph<'a>>,
1121
1122    #[expect(dead_code, reason = "may be useful in the future")]
1123    preds: Vec<HandoffId>,
1124    succs: Vec<HandoffId>,
1125
1126    /// If this subgraph is scheduled in [`dfir_rs::stratum_queues`].
1127    /// [`Cell`] allows modifying this field when iterating `Self::preds` or
1128    /// `Self::succs`, as all `SubgraphData` are owned by the same vec
1129    /// `dfir_rs::subgraphs`.
1130    is_scheduled: Cell<bool>,
1131
1132    /// Keep track of the last tick that this subgraph was run in
1133    last_tick_run_in: Option<TickInstant>,
1134    /// A meaningless ID to track the last loop execution this subgraph was run in.
1135    /// `(loop_nonce, iter_count)` pair.
1136    last_loop_nonce: (usize, Option<usize>),
1137
1138    /// If this subgraph is marked as lazy, then sending data back to a lower stratum does not trigger a new tick to be run.
1139    is_lazy: bool,
1140
1141    /// The subgraph's loop ID, or `None` for the top level.
1142    loop_id: Option<LoopId>,
1143    /// The loop depth of the subgraph.
1144    loop_depth: usize,
1145}
1146impl<'a> SubgraphData<'a> {
1147    #[expect(clippy::too_many_arguments, reason = "internal use")]
1148    pub(crate) fn new(
1149        name: Cow<'static, str>,
1150        stratum: usize,
1151        subgraph: impl 'a + Subgraph<'a>,
1152        preds: Vec<HandoffId>,
1153        succs: Vec<HandoffId>,
1154        is_scheduled: bool,
1155        is_lazy: bool,
1156        loop_id: Option<LoopId>,
1157        loop_depth: usize,
1158    ) -> Self {
1159        Self {
1160            name,
1161            stratum,
1162            subgraph: Box::new(subgraph),
1163            preds,
1164            succs,
1165            is_scheduled: Cell::new(is_scheduled),
1166            last_tick_run_in: None,
1167            last_loop_nonce: (0, None),
1168            is_lazy,
1169            loop_id,
1170            loop_depth,
1171        }
1172    }
1173}
1174
1175pub(crate) struct LoopData {
1176    /// Count of iterations of this loop.
1177    iter_count: Option<usize>,
1178    /// If the loop has reason to do another iteration.
1179    allow_another_iteration: bool,
1180}
1181
1182/// Defines when state should be reset.
1183#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1184pub enum StateLifespan {
1185    /// Always reset, a ssociated with the subgraph.
1186    Subgraph(SubgraphId),
1187    /// Reset between loop executions.
1188    Loop(LoopId),
1189    /// Reset between ticks.
1190    Tick,
1191    /// Never reset.
1192    Static,
1193}