dfir_rs/scheduled/
graph.rs

1//! Module for the [`Dfir`] struct and helper items.
2
3use std::any::Any;
4use std::borrow::Cow;
5use std::cell::Cell;
6use std::cmp::Ordering;
7use std::future::Future;
8use std::marker::PhantomData;
9use std::rc::Rc;
10
11#[cfg(feature = "meta")]
12use dfir_lang::diagnostic::{Diagnostic, SerdeSpan};
13#[cfg(feature = "meta")]
14use dfir_lang::graph::DfirGraph;
15use ref_cast::RefCast;
16use smallvec::SmallVec;
17use tracing::Instrument;
18use web_time::SystemTime;
19
20use super::context::Context;
21use super::handoff::handoff_list::PortList;
22use super::handoff::{Handoff, HandoffMeta, TeeingHandoff};
23use super::metrics::{DfirMetrics, DfirMetricsIntervals, InstrumentSubgraph};
24use super::port::{RECV, RecvCtx, RecvPort, SEND, SendCtx, SendPort};
25use super::reactor::Reactor;
26use super::state::StateHandle;
27use super::subgraph::Subgraph;
28use super::ticks::{TickDuration, TickInstant};
29use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
30use crate::Never;
31use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
32
33/// A DFIR graph. Owns, schedules, and runs the compiled subgraphs.
34#[derive(Default)]
35pub struct Dfir<'a> {
36    pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
37
38    pub(super) loop_data: SecondarySlotVec<LoopTag, LoopData>,
39
40    pub(super) context: Context,
41
42    pub(super) handoffs: SlotVec<HandoffTag, HandoffData>,
43
44    /// Live-updating DFIR runtime metrics via interior mutability.
45    metrics: Rc<DfirMetrics>,
46
47    #[cfg(feature = "meta")]
48    /// See [`Self::meta_graph()`].
49    meta_graph: Option<DfirGraph>,
50
51    #[cfg(feature = "meta")]
52    /// See [`Self::diagnostics()`].
53    diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
54}
55
56/// Methods for [`TeeingHandoff`] teeing and dropping.
57impl Dfir<'_> {
58    /// Tees a [`TeeingHandoff`].
59    pub fn teeing_handoff_tee<T>(
60        &mut self,
61        tee_parent_port: &RecvPort<TeeingHandoff<T>>,
62    ) -> RecvPort<TeeingHandoff<T>>
63    where
64        T: Clone,
65    {
66        // If we're teeing from a child make sure to find root.
67        let tee_root = self.handoffs[tee_parent_port.handoff_id].pred_handoffs[0];
68
69        // Set up teeing metadata.
70        let tee_root_data = &mut self.handoffs[tee_root];
71        let tee_root_data_name = tee_root_data.name.clone();
72
73        // Insert new handoff output.
74        let teeing_handoff =
75            <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*tee_root_data.handoff).unwrap();
76        let new_handoff = teeing_handoff.tee();
77
78        // Handoff ID of new tee output.
79        let new_hoff_id = self.handoffs.insert_with_key(|new_hoff_id| {
80            let new_name = Cow::Owned(format!("{} tee {:?}", tee_root_data_name, new_hoff_id));
81            let mut new_handoff_data = HandoffData::new(new_name, new_handoff, new_hoff_id);
82            // Set self's predecessor as `tee_root`.
83            new_handoff_data.pred_handoffs = vec![tee_root];
84            new_handoff_data
85        });
86
87        // Go to `tee_root`'s successors and insert self (the new tee output).
88        let tee_root_data = &mut self.handoffs[tee_root];
89        tee_root_data.succ_handoffs.push(new_hoff_id);
90
91        // Add our new handoff id into the subgraph data if the send `tee_root` has already been
92        // used to add a subgraph.
93        assert!(
94            tee_root_data.preds.len() <= 1,
95            "Tee send side should only have one sender (or none set yet)."
96        );
97        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
98            self.subgraphs[pred_sg_id].succs.push(new_hoff_id);
99        }
100
101        // Initialize handoff metrics struct.
102        Rc::make_mut(&mut self.metrics)
103            .handoffs
104            .insert(new_hoff_id, Default::default());
105
106        let output_port = RecvPort {
107            handoff_id: new_hoff_id,
108            _marker: PhantomData,
109        };
110        output_port
111    }
112
113    /// Marks an output of a [`TeeingHandoff`] as dropped so that no more data will be sent to it.
114    ///
115    /// It is recommended to not not use this method and instead simply avoid teeing a
116    /// [`TeeingHandoff`] when it is not needed.
117    pub fn teeing_handoff_drop<T>(&mut self, tee_port: RecvPort<TeeingHandoff<T>>)
118    where
119        T: Clone,
120    {
121        let data = &self.handoffs[tee_port.handoff_id];
122        let teeing_handoff = <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*data.handoff).unwrap();
123        teeing_handoff.drop();
124
125        let tee_root = data.pred_handoffs[0];
126        let tee_root_data = &mut self.handoffs[tee_root];
127        // Remove this output from the send succ handoff list.
128        tee_root_data
129            .succ_handoffs
130            .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
131        // Remove from subgraph successors if send port was already connected.
132        assert!(
133            tee_root_data.preds.len() <= 1,
134            "Tee send side should only have one sender (or none set yet)."
135        );
136        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
137            self.subgraphs[pred_sg_id]
138                .succs
139                .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
140        }
141    }
142}
143
144impl<'a> Dfir<'a> {
145    /// Create a new empty graph.
146    pub fn new() -> Self {
147        Default::default()
148    }
149
150    /// Assign the meta graph via JSON string. Used internally by the [`crate::dfir_syntax`] and other macros.
151    #[doc(hidden)]
152    pub fn __assign_meta_graph(&mut self, _meta_graph_json: &str) {
153        #[cfg(feature = "meta")]
154        {
155            let mut meta_graph: DfirGraph =
156                serde_json::from_str(_meta_graph_json).expect("Failed to deserialize graph.");
157
158            let mut op_inst_diagnostics = Vec::new();
159            meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
160            assert!(
161                op_inst_diagnostics.is_empty(),
162                "Expected no diagnostics, got: {:#?}",
163                op_inst_diagnostics
164            );
165
166            assert!(self.meta_graph.replace(meta_graph).is_none());
167        }
168    }
169    /// Assign the diagnostics via JSON string.
170    #[doc(hidden)]
171    pub fn __assign_diagnostics(&mut self, _diagnostics_json: &'static str) {
172        #[cfg(feature = "meta")]
173        {
174            let diagnostics: Vec<Diagnostic<SerdeSpan>> = serde_json::from_str(_diagnostics_json)
175                .expect("Failed to deserialize diagnostics.");
176
177            assert!(self.diagnostics.replace(diagnostics).is_none());
178        }
179    }
180
181    /// Return a handle to the meta graph, if set. The meta graph is a
182    /// representation of all the operators, subgraphs, and handoffs in this instance.
183    /// Will only be set if this graph was constructed using a surface syntax macro.
184    #[cfg(feature = "meta")]
185    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
186    pub fn meta_graph(&self) -> Option<&DfirGraph> {
187        self.meta_graph.as_ref()
188    }
189
190    /// Returns any diagnostics generated by the surface syntax macro. Each diagnostic is a pair of
191    /// (1) a `Diagnostic` with span info reset and (2) the `ToString` version of the diagnostic
192    /// with original span info.
193    /// Will only be set if this graph was constructed using a surface syntax macro.
194    #[cfg(feature = "meta")]
195    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
196    pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
197        self.diagnostics.as_deref()
198    }
199
200    /// Returns a reactor for externally scheduling subgraphs, possibly from another thread.
201    /// Reactor events are considered to be external events.
202    pub fn reactor(&self) -> Reactor {
203        Reactor::new(self.context.event_queue_send.clone())
204    }
205
206    /// Gets the current tick (local time) count.
207    pub fn current_tick(&self) -> TickInstant {
208        self.context.current_tick
209    }
210
211    /// Gets the current stratum nubmer.
212    pub fn current_stratum(&self) -> usize {
213        self.context.current_stratum
214    }
215
216    /// Runs the dataflow until the next tick begins.
217    ///
218    /// Returns `true` if any work was done.
219    ///
220    /// Will receive events if it is the start of a tick when called.
221    #[tracing::instrument(level = "trace", skip(self), ret)]
222    pub async fn run_tick(&mut self) -> bool {
223        let mut work_done = false;
224        // While work is immediately available *on the current tick*.
225        while self.next_stratum(true) {
226            work_done = true;
227            // Do any work.
228            self.run_stratum().await;
229        }
230        work_done
231    }
232
233    /// [`Self::run_tick`] but panics if a subgraph yields asynchronously.
234    #[tracing::instrument(level = "trace", skip(self), ret)]
235    pub fn run_tick_sync(&mut self) -> bool {
236        let mut work_done = false;
237        // While work is immediately available *on the current tick*.
238        while self.next_stratum(true) {
239            work_done = true;
240            // Do any work.
241            run_sync(self.run_stratum());
242        }
243        work_done
244    }
245
246    /// Runs the dataflow until no more (externally-triggered) work is immediately available.
247    /// Runs at least one tick of dataflow, even if no external events have been received.
248    /// If the dataflow contains loops this method may run forever.
249    ///
250    /// Returns `true` if any work was done.
251    ///
252    /// Yields repeatedly to allow external events to happen.
253    #[tracing::instrument(level = "trace", skip(self), ret)]
254    pub async fn run_available(&mut self) -> bool {
255        let mut work_done = false;
256        // While work is immediately available.
257        while self.next_stratum(false) {
258            work_done = true;
259            // Do any work.
260            self.run_stratum().await;
261
262            // Yield between each stratum to receive more events.
263            // TODO(mingwei): really only need to yield at start of ticks though.
264            tokio::task::yield_now().await;
265        }
266        work_done
267    }
268
269    /// [`Self::run_available`] but panics if a subgraph yields asynchronously.
270    #[tracing::instrument(level = "trace", skip(self), ret)]
271    pub fn run_available_sync(&mut self) -> bool {
272        let mut work_done = false;
273        // While work is immediately available.
274        while self.next_stratum(false) {
275            work_done = true;
276            // Do any work.
277            run_sync(self.run_stratum());
278        }
279        work_done
280    }
281
282    /// Runs the current stratum of the dataflow until no more local work is available (does not receive events).
283    ///
284    /// Returns `true` if any work was done.
285    #[tracing::instrument(level = "trace", skip(self), fields(tick = u64::from(self.context.current_tick), stratum = self.context.current_stratum), ret)]
286    pub async fn run_stratum(&mut self) -> bool {
287        // Make sure to spawn tasks once dfir is running!
288        // This drains the task buffer, so becomes a no-op after first call.
289        self.context.spawn_tasks();
290
291        let mut work_done = false;
292
293        'pop: while let Some(sg_id) =
294            self.context.stratum_queues[self.context.current_stratum].pop_front()
295        {
296            let sg_data = &mut self.subgraphs[sg_id];
297
298            // Update handoff metrics.
299            // NOTE(mingwei):
300            // We measure handoff metrics at `recv`, not `send`.
301            // This is done because we (a) know that all `recv`
302            // will be consumed(*) by the subgraph when it is run, and in contrast (b) do not know
303            // that all `send` will be consumed after a run. I.e. this subgraph may run multiple
304            // times before the next subgraph consumes the output; don't double count those.
305            // (*) - usually... always true for Hydro-generated DFIR at least.
306            for &handoff_id in sg_data.preds.iter() {
307                let handoff_metrics = &self.metrics.handoffs[handoff_id];
308                let handoff_data = &mut self.handoffs[handoff_id];
309                let handoff_len = handoff_data.handoff.len();
310                handoff_metrics
311                    .total_items_count
312                    .update(|x| x + handoff_len);
313                handoff_metrics.curr_items_count.set(handoff_len);
314            }
315
316            // Run the subgraph (and do bookkeeping).
317            {
318                // This must be true for the subgraph to be enqueued.
319                assert!(sg_data.is_scheduled.take());
320
321                let run_subgraph_span_guard = tracing::info_span!(
322                    "run-subgraph",
323                    sg_id = sg_id.to_string(),
324                    sg_name = &*sg_data.name,
325                    sg_depth = sg_data.loop_depth,
326                    sg_loop_nonce = sg_data.last_loop_nonce.0,
327                    sg_iter_count = sg_data.last_loop_nonce.1,
328                )
329                .entered();
330
331                match sg_data.loop_depth.cmp(&self.context.loop_nonce_stack.len()) {
332                    Ordering::Greater => {
333                        // We have entered a loop.
334                        self.context.loop_nonce += 1;
335                        self.context.loop_nonce_stack.push(self.context.loop_nonce);
336                        tracing::trace!(loop_nonce = self.context.loop_nonce, "Entered loop.");
337                    }
338                    Ordering::Less => {
339                        // We have exited a loop.
340                        self.context.loop_nonce_stack.pop();
341                        tracing::trace!("Exited loop.");
342                    }
343                    Ordering::Equal => {}
344                }
345
346                self.context.subgraph_id = sg_id;
347                self.context.is_first_run_this_tick = sg_data
348                    .last_tick_run_in
349                    .is_none_or(|last_tick| last_tick < self.context.current_tick);
350
351                if let Some(loop_id) = sg_data.loop_id {
352                    // Loop execution - running loop block, from start to finish, containing
353                    // multiple iterations.
354                    // Loop iteration - a single iteration of a loop block, all subgraphs within
355                    // the loop should run (at most) once.
356
357                    // If the previous run of this subgraph had the same loop execution and
358                    // iteration count, then we need to increment the iteration count.
359                    let curr_loop_nonce = self.context.loop_nonce_stack.last().copied();
360
361                    let LoopData {
362                        iter_count: loop_iter_count,
363                        allow_another_iteration,
364                    } = &mut self.loop_data[loop_id];
365
366                    let (prev_loop_nonce, prev_iter_count) = sg_data.last_loop_nonce;
367
368                    // If the loop nonce is the same as the previous execution, then we are in
369                    // the same loop execution.
370                    // `curr_loop_nonce` is `None` for top-level loops, and top-level loops are
371                    // always in the same (singular) loop execution.
372                    let (curr_iter_count, new_loop_execution) =
373                        if curr_loop_nonce.is_none_or(|nonce| nonce == prev_loop_nonce) {
374                            // If the iteration count is the same as the previous execution, then
375                            // we are on the next iteration.
376                            if *loop_iter_count == prev_iter_count {
377                                // If not true, then we shall not run the next iteration.
378                                if !std::mem::take(allow_another_iteration) {
379                                    tracing::debug!(
380                                        "Loop will not continue to next iteration, skipping."
381                                    );
382                                    continue 'pop;
383                                }
384                                // Increment `loop_iter_count` or set it to 0.
385                                loop_iter_count.map_or((0, true), |n| (n + 1, false))
386                            } else {
387                                // Otherwise update the local iteration count to match the loop.
388                                debug_assert!(
389                                    prev_iter_count < *loop_iter_count,
390                                    "Expect loop iteration count to be increasing."
391                                );
392                                (loop_iter_count.unwrap(), false)
393                            }
394                        } else {
395                            // The loop execution has already begun, but this is the first time this particular subgraph is running.
396                            (0, false)
397                        };
398
399                    if new_loop_execution {
400                        // Run state hooks.
401                        self.context.run_state_hooks_loop(loop_id);
402                    }
403                    tracing::debug!("Loop iteration count {}", curr_iter_count);
404
405                    *loop_iter_count = Some(curr_iter_count);
406                    self.context.loop_iter_count = curr_iter_count;
407                    sg_data.last_loop_nonce =
408                        (curr_loop_nonce.unwrap_or_default(), Some(curr_iter_count));
409                }
410
411                // Run subgraph state hooks.
412                self.context.run_state_hooks_subgraph(sg_id);
413
414                tracing::info!("Running subgraph.");
415                sg_data.last_tick_run_in = Some(self.context.current_tick);
416
417                let sg_metrics = &self.metrics.subgraphs[sg_id];
418                let sg_fut =
419                    Box::into_pin(sg_data.subgraph.run(&mut self.context, &mut self.handoffs));
420                // Update subgraph metrics.
421                let sg_fut = InstrumentSubgraph::new(sg_fut, sg_metrics);
422                // Pass along `run_subgraph_span` to the run subgraph future.
423                let sg_fut = sg_fut.instrument(run_subgraph_span_guard.exit());
424                let () = sg_fut.await;
425
426                sg_metrics.total_run_count.update(|x| x + 1);
427            };
428
429            // Schedule the following subgraphs if data was pushed to the corresponding handoff.
430            let sg_data = &self.subgraphs[sg_id];
431            for &handoff_id in sg_data.succs.iter() {
432                let handoff_data = &self.handoffs[handoff_id];
433                let handoff_len = handoff_data.handoff.len();
434                if 0 < handoff_len {
435                    for &succ_id in handoff_data.succs.iter() {
436                        let succ_sg_data = &self.subgraphs[succ_id];
437                        // If we have sent data to the next tick, then we can start the next tick.
438                        if succ_sg_data.stratum < self.context.current_stratum && !sg_data.is_lazy {
439                            self.context.can_start_tick = true;
440                        }
441                        // Add subgraph to stratum queue if it is not already scheduled.
442                        if !succ_sg_data.is_scheduled.replace(true) {
443                            self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
444                        }
445                        // Add stratum to stratum stack if it is within a loop.
446                        if 0 < succ_sg_data.loop_depth {
447                            // TODO(mingwei): handle duplicates
448                            self.context
449                                .stratum_stack
450                                .push(succ_sg_data.loop_depth, succ_sg_data.stratum);
451                        }
452                    }
453                }
454                let handoff_metrics = &self.metrics.handoffs[handoff_id];
455                handoff_metrics.curr_items_count.set(handoff_len);
456            }
457
458            let reschedule = self.context.reschedule_loop_block.take();
459            let allow_another = self.context.allow_another_iteration.take();
460
461            if reschedule {
462                // Re-enqueue the subgraph.
463                self.context.schedule_deferred.push(sg_id);
464                self.context
465                    .stratum_stack
466                    .push(sg_data.loop_depth, sg_data.stratum);
467            }
468            if (reschedule || allow_another)
469                && let Some(loop_id) = sg_data.loop_id
470            {
471                self.loop_data
472                    .get_mut(loop_id)
473                    .unwrap()
474                    .allow_another_iteration = true;
475            }
476
477            work_done = true;
478        }
479        work_done
480    }
481
482    /// Go to the next stratum which has work available, possibly the current stratum.
483    /// Return true if more work is available, otherwise false if no work is immediately
484    /// available on any strata.
485    ///
486    /// This will receive external events when at the start of a tick.
487    ///
488    /// If `current_tick_only` is set to `true`, will only return `true` if work is immediately
489    /// available on the *current tick*.
490    ///
491    /// If this returns false then the graph will be at the start of a tick (at stratum 0, can
492    /// receive more external events).
493    #[tracing::instrument(level = "trace", skip(self), ret)]
494    pub fn next_stratum(&mut self, current_tick_only: bool) -> bool {
495        tracing::trace!(
496            events_received_tick = self.context.events_received_tick,
497            can_start_tick = self.context.can_start_tick,
498            "Starting `next_stratum` call.",
499        );
500
501        // The stratum we will stop searching at, i.e. made a full loop around.
502        let mut end_stratum = self.context.current_stratum;
503        let mut new_tick_started = false;
504
505        if 0 == self.context.current_stratum {
506            new_tick_started = true;
507
508            // Starting the tick, reset this to `false`.
509            tracing::trace!("Starting tick, setting `can_start_tick = false`.");
510            self.context.can_start_tick = false;
511            self.context.current_tick_start = SystemTime::now();
512
513            // Ensure external events are received before running the tick.
514            if !self.context.events_received_tick {
515                // Add any external jobs to ready queue.
516                self.try_recv_events();
517            }
518        }
519
520        loop {
521            tracing::trace!(
522                tick = u64::from(self.context.current_tick),
523                stratum = self.context.current_stratum,
524                "Looking for work on stratum."
525            );
526            // If current stratum has work, return true.
527            if !self.context.stratum_queues[self.context.current_stratum].is_empty() {
528                tracing::trace!(
529                    tick = u64::from(self.context.current_tick),
530                    stratum = self.context.current_stratum,
531                    "Work found on stratum."
532                );
533                return true;
534            }
535
536            if let Some(next_stratum) = self.context.stratum_stack.pop() {
537                self.context.current_stratum = next_stratum;
538
539                // Now schedule deferred subgraphs.
540                {
541                    for sg_id in self.context.schedule_deferred.drain(..) {
542                        let sg_data = &self.subgraphs[sg_id];
543                        tracing::info!(
544                            tick = u64::from(self.context.current_tick),
545                            stratum = self.context.current_stratum,
546                            sg_id = sg_id.to_string(),
547                            sg_name = &*sg_data.name,
548                            is_scheduled = sg_data.is_scheduled.get(),
549                            "Rescheduling deferred subgraph."
550                        );
551                        if !sg_data.is_scheduled.replace(true) {
552                            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
553                        }
554                    }
555                }
556            } else {
557                // Increment stratum counter.
558                self.context.current_stratum += 1;
559
560                if self.context.current_stratum >= self.context.stratum_queues.len() {
561                    new_tick_started = true;
562
563                    tracing::trace!(
564                        can_start_tick = self.context.can_start_tick,
565                        "End of tick {}, starting tick {}.",
566                        self.context.current_tick,
567                        self.context.current_tick + TickDuration::SINGLE_TICK,
568                    );
569                    self.context.run_state_hooks_tick();
570
571                    self.context.current_stratum = 0;
572                    self.context.current_tick += TickDuration::SINGLE_TICK;
573                    self.context.events_received_tick = false;
574
575                    if current_tick_only {
576                        tracing::trace!(
577                            "`current_tick_only` is `true`, returning `false` before receiving events."
578                        );
579                        return false;
580                    } else {
581                        self.try_recv_events();
582                        if std::mem::replace(&mut self.context.can_start_tick, false) {
583                            tracing::trace!(
584                                tick = u64::from(self.context.current_tick),
585                                "`can_start_tick` is `true`, continuing."
586                            );
587                            // Do a full loop more to find where events have been added.
588                            end_stratum = 0;
589                            continue;
590                        } else {
591                            tracing::trace!(
592                                "`can_start_tick` is `false`, re-setting `events_received_tick = false`, returning `false`."
593                            );
594                            self.context.events_received_tick = false;
595                            return false;
596                        }
597                    }
598                }
599            }
600
601            // After incrementing, exit if we made a full loop around the strata.
602            if new_tick_started && end_stratum == self.context.current_stratum {
603                tracing::trace!(
604                    "Made full loop around stratum, re-setting `current_stratum = 0`, returning `false`."
605                );
606                // Note: if current stratum had work, the very first loop iteration would've
607                // returned true. Therefore we can return false without checking.
608                // Also means nothing was done so we can reset the stratum to zero and wait for
609                // events.
610                self.context.events_received_tick = false;
611                self.context.current_stratum = 0;
612                return false;
613            }
614        }
615    }
616
617    /// Runs the dataflow graph forever.
618    ///
619    /// TODO(mingwei): Currently blocks forever, no notion of "completion."
620    #[tracing::instrument(level = "trace", skip(self), ret)]
621    pub async fn run(&mut self) -> Option<Never> {
622        loop {
623            // Run any work which is immediately available.
624            self.run_available().await;
625            // When no work is available yield until more events occur.
626            self.recv_events_async().await;
627        }
628    }
629
630    /// [`Self::run`] but panics if a subgraph yields asynchronously.
631    #[tracing::instrument(level = "trace", skip(self), ret)]
632    pub fn run_sync(&mut self) -> Option<Never> {
633        loop {
634            // Run any work which is immediately available.
635            self.run_available_sync();
636            // When no work is available block until more events occur.
637            self.recv_events();
638        }
639    }
640
641    /// Enqueues subgraphs triggered by events without blocking.
642    ///
643    /// Returns the number of subgraphs enqueued, and if any were external.
644    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
645    pub fn try_recv_events(&mut self) -> usize {
646        let mut enqueued_count = 0;
647        while let Ok((sg_id, is_external)) = self.context.event_queue_recv.try_recv() {
648            let sg_data = &self.subgraphs[sg_id];
649            tracing::trace!(
650                sg_id = sg_id.to_string(),
651                is_external = is_external,
652                sg_stratum = sg_data.stratum,
653                "Event received."
654            );
655            if !sg_data.is_scheduled.replace(true) {
656                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
657                enqueued_count += 1;
658            }
659            if is_external {
660                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
661                // Or if the stratum is in the next tick.
662                if !self.context.events_received_tick
663                    || sg_data.stratum < self.context.current_stratum
664                {
665                    tracing::trace!(
666                        current_stratum = self.context.current_stratum,
667                        sg_stratum = sg_data.stratum,
668                        "External event, setting `can_start_tick = true`."
669                    );
670                    self.context.can_start_tick = true;
671                }
672            }
673        }
674        self.context.events_received_tick = true;
675
676        enqueued_count
677    }
678
679    /// Enqueues subgraphs triggered by external events, blocking until at
680    /// least one subgraph is scheduled **from an external event**.
681    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
682    pub fn recv_events(&mut self) -> Option<usize> {
683        let mut count = 0;
684        loop {
685            let (sg_id, is_external) = self.context.event_queue_recv.blocking_recv()?;
686            let sg_data = &self.subgraphs[sg_id];
687            tracing::trace!(
688                sg_id = sg_id.to_string(),
689                is_external = is_external,
690                sg_stratum = sg_data.stratum,
691                "Event received."
692            );
693            if !sg_data.is_scheduled.replace(true) {
694                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
695                count += 1;
696            }
697            if is_external {
698                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
699                // Or if the stratum is in the next tick.
700                if !self.context.events_received_tick
701                    || sg_data.stratum < self.context.current_stratum
702                {
703                    tracing::trace!(
704                        current_stratum = self.context.current_stratum,
705                        sg_stratum = sg_data.stratum,
706                        "External event, setting `can_start_tick = true`."
707                    );
708                    self.context.can_start_tick = true;
709                }
710                break;
711            }
712        }
713        self.context.events_received_tick = true;
714
715        // Enqueue any other immediate events.
716        let extra_count = self.try_recv_events();
717        Some(count + extra_count)
718    }
719
720    /// Enqueues subgraphs triggered by external events asynchronously, waiting until at least one
721    /// subgraph is scheduled **from an external event**. Returns the number of subgraphs enqueued,
722    /// which may be zero if an external event scheduled an already-scheduled subgraph.
723    ///
724    /// Returns `None` if the event queue is closed, but that should not happen normally.
725    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
726    pub async fn recv_events_async(&mut self) -> Option<usize> {
727        let mut count = 0;
728        loop {
729            tracing::trace!("Awaiting events (`event_queue_recv`).");
730            let (sg_id, is_external) = self.context.event_queue_recv.recv().await?;
731            let sg_data = &self.subgraphs[sg_id];
732            tracing::trace!(
733                sg_id = sg_id.to_string(),
734                is_external = is_external,
735                sg_stratum = sg_data.stratum,
736                "Event received."
737            );
738            if !sg_data.is_scheduled.replace(true) {
739                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
740                count += 1;
741            }
742            if is_external {
743                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
744                // Or if the stratum is in the next tick.
745                if !self.context.events_received_tick
746                    || sg_data.stratum < self.context.current_stratum
747                {
748                    tracing::trace!(
749                        current_stratum = self.context.current_stratum,
750                        sg_stratum = sg_data.stratum,
751                        "External event, setting `can_start_tick = true`."
752                    );
753                    self.context.can_start_tick = true;
754                }
755                break;
756            }
757        }
758        self.context.events_received_tick = true;
759
760        // Enqueue any other immediate events.
761        let extra_count = self.try_recv_events();
762        Some(count + extra_count)
763    }
764
765    /// Schedules a subgraph to be run. See also: [`Context::schedule_subgraph`].
766    pub fn schedule_subgraph(&mut self, sg_id: SubgraphId) -> bool {
767        let sg_data = &self.subgraphs[sg_id];
768        let already_scheduled = sg_data.is_scheduled.replace(true);
769        if !already_scheduled {
770            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
771            true
772        } else {
773            false
774        }
775    }
776
777    /// Adds a new compiled subgraph with the specified inputs and outputs in stratum 0.
778    pub fn add_subgraph<Name, R, W, Func>(
779        &mut self,
780        name: Name,
781        recv_ports: R,
782        send_ports: W,
783        subgraph: Func,
784    ) -> SubgraphId
785    where
786        Name: Into<Cow<'static, str>>,
787        R: 'static + PortList<RECV>,
788        W: 'static + PortList<SEND>,
789        Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
790    {
791        self.add_subgraph_stratified(name, 0, recv_ports, send_ports, false, subgraph)
792    }
793
794    /// Adds a new compiled subgraph with the specified inputs, outputs, and stratum number.
795    ///
796    /// TODO(mingwei): add example in doc.
797    pub fn add_subgraph_stratified<Name, R, W, Func>(
798        &mut self,
799        name: Name,
800        stratum: usize,
801        recv_ports: R,
802        send_ports: W,
803        laziness: bool,
804        subgraph: Func,
805    ) -> SubgraphId
806    where
807        Name: Into<Cow<'static, str>>,
808        R: 'static + PortList<RECV>,
809        W: 'static + PortList<SEND>,
810        Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
811    {
812        self.add_subgraph_full(
813            name, stratum, recv_ports, send_ports, laziness, None, subgraph,
814        )
815    }
816
817    /// Adds a new compiled subgraph with all options.
818    #[expect(clippy::too_many_arguments, reason = "Mainly for internal use.")]
819    pub fn add_subgraph_full<Name, R, W, Func>(
820        &mut self,
821        name: Name,
822        stratum: usize,
823        recv_ports: R,
824        send_ports: W,
825        laziness: bool,
826        loop_id: Option<LoopId>,
827        mut subgraph: Func,
828    ) -> SubgraphId
829    where
830        Name: Into<Cow<'static, str>>,
831        R: 'static + PortList<RECV>,
832        W: 'static + PortList<SEND>,
833        Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
834    {
835        // SAFETY: Check that the send and recv ports are from `self.handoffs`.
836        recv_ports.assert_is_from(&self.handoffs);
837        send_ports.assert_is_from(&self.handoffs);
838
839        let loop_depth = loop_id
840            .and_then(|loop_id| self.context.loop_depth.get(loop_id))
841            .copied()
842            .unwrap_or(0);
843
844        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
845            let (mut subgraph_preds, mut subgraph_succs) = Default::default();
846            recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
847            send_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_succs, sg_id, false);
848
849            let subgraph =
850                async move |context: &mut Context,
851                            handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
852                    let (recv, send) = unsafe {
853                        // SAFETY:
854                        // 1. We checked `assert_is_from` at assembly time, above.
855                        // 2. `SlotVec` is insert-only so no handoffs could have changed since then.
856                        (
857                            recv_ports.make_ctx(&*handoffs),
858                            send_ports.make_ctx(&*handoffs),
859                        )
860                    };
861                    (subgraph)(context, recv, send).await;
862                };
863            SubgraphData::new(
864                name.into(),
865                stratum,
866                subgraph,
867                subgraph_preds,
868                subgraph_succs,
869                true,
870                laziness,
871                loop_id,
872                loop_depth,
873            )
874        });
875        self.context.init_stratum(stratum);
876        self.context.stratum_queues[stratum].push_back(sg_id);
877
878        // Initialize subgraph metrics struct.
879        Rc::make_mut(&mut self.metrics)
880            .subgraphs
881            .insert(sg_id, Default::default());
882
883        sg_id
884    }
885
886    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
887    pub fn add_subgraph_n_m<Name, R, W, Func>(
888        &mut self,
889        name: Name,
890        recv_ports: Vec<RecvPort<R>>,
891        send_ports: Vec<SendPort<W>>,
892        subgraph: Func,
893    ) -> SubgraphId
894    where
895        Name: Into<Cow<'static, str>>,
896        R: 'static + Handoff,
897        W: 'static + Handoff,
898        Func: 'a
899            + for<'ctx> AsyncFnMut(
900                &'ctx mut Context,
901                &'ctx [&'ctx RecvCtx<R>],
902                &'ctx [&'ctx SendCtx<W>],
903            ),
904    {
905        self.add_subgraph_stratified_n_m(name, 0, recv_ports, send_ports, subgraph)
906    }
907
908    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
909    pub fn add_subgraph_stratified_n_m<Name, R, W, Func>(
910        &mut self,
911        name: Name,
912        stratum: usize,
913        recv_ports: Vec<RecvPort<R>>,
914        send_ports: Vec<SendPort<W>>,
915        mut subgraph: Func,
916    ) -> SubgraphId
917    where
918        Name: Into<Cow<'static, str>>,
919        R: 'static + Handoff,
920        W: 'static + Handoff,
921        Func: 'a
922            + for<'ctx> AsyncFnMut(
923                &'ctx mut Context,
924                &'ctx [&'ctx RecvCtx<R>],
925                &'ctx [&'ctx SendCtx<W>],
926            ),
927    {
928        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
929            let subgraph_preds = recv_ports.iter().map(|port| port.handoff_id).collect();
930            let subgraph_succs = send_ports.iter().map(|port| port.handoff_id).collect();
931
932            for recv_port in recv_ports.iter() {
933                self.handoffs[recv_port.handoff_id].succs.push(sg_id);
934            }
935            for send_port in send_ports.iter() {
936                self.handoffs[send_port.handoff_id].preds.push(sg_id);
937            }
938
939            let subgraph =
940                async move |context: &mut Context,
941                            handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
942                    let recvs: Vec<&RecvCtx<R>> = recv_ports
943                        .iter()
944                        .map(|hid| hid.handoff_id)
945                        .map(|hid| handoffs.get(hid).unwrap())
946                        .map(|h_data| {
947                            <dyn Any>::downcast_ref(&*h_data.handoff)
948                                .expect("Attempted to cast handoff to wrong type.")
949                        })
950                        .map(RefCast::ref_cast)
951                        .collect();
952
953                    let sends: Vec<&SendCtx<W>> = send_ports
954                        .iter()
955                        .map(|hid| hid.handoff_id)
956                        .map(|hid| handoffs.get(hid).unwrap())
957                        .map(|h_data| {
958                            <dyn Any>::downcast_ref(&*h_data.handoff)
959                                .expect("Attempted to cast handoff to wrong type.")
960                        })
961                        .map(RefCast::ref_cast)
962                        .collect();
963
964                    (subgraph)(context, &recvs, &sends).await;
965                };
966            SubgraphData::new(
967                name.into(),
968                stratum,
969                subgraph,
970                subgraph_preds,
971                subgraph_succs,
972                true,
973                false,
974                None,
975                0,
976            )
977        });
978        self.context.init_stratum(stratum);
979        self.context.stratum_queues[stratum].push_back(sg_id);
980
981        // Initialize subgraph metrics struct.
982        Rc::make_mut(&mut self.metrics)
983            .subgraphs
984            .insert(sg_id, Default::default());
985
986        sg_id
987    }
988
989    /// Creates a handoff edge and returns the corresponding send and receive ports.
990    pub fn make_edge<Name, H>(&mut self, name: Name) -> (SendPort<H>, RecvPort<H>)
991    where
992        Name: Into<Cow<'static, str>>,
993        H: 'static + Handoff,
994    {
995        // Create and insert handoff.
996        let handoff = H::default();
997        let handoff_id = self
998            .handoffs
999            .insert_with_key(|hoff_id| HandoffData::new(name.into(), handoff, hoff_id));
1000
1001        // Initialize handoff metrics struct.
1002        Rc::make_mut(&mut self.metrics)
1003            .handoffs
1004            .insert(handoff_id, Default::default());
1005
1006        // Make ports.
1007        let input_port = SendPort {
1008            handoff_id,
1009            _marker: PhantomData,
1010        };
1011        let output_port = RecvPort {
1012            handoff_id,
1013            _marker: PhantomData,
1014        };
1015        (input_port, output_port)
1016    }
1017
1018    /// Adds referenceable state into this instance. Returns a state handle which can be
1019    /// used externally or by operators to access the state.
1020    ///
1021    /// This is part of the "state API".
1022    pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
1023    where
1024        T: Any,
1025    {
1026        self.context.add_state(state)
1027    }
1028
1029    /// Sets a hook to modify the state at the end of each tick, using the supplied closure.
1030    ///
1031    /// This is part of the "state API".
1032    pub fn set_state_lifespan_hook<T>(
1033        &mut self,
1034        handle: StateHandle<T>,
1035        lifespan: StateLifespan,
1036        hook_fn: impl 'static + FnMut(&mut T),
1037    ) where
1038        T: Any,
1039    {
1040        self.context
1041            .set_state_lifespan_hook(handle, lifespan, hook_fn)
1042    }
1043
1044    /// Gets a exclusive (mut) ref to the internal context, setting the subgraph ID.
1045    pub fn context_mut(&mut self, sg_id: SubgraphId) -> &mut Context {
1046        self.context.subgraph_id = sg_id;
1047        &mut self.context
1048    }
1049
1050    /// Adds a new loop with the given parent (or `None` for top-level). Returns a loop ID which
1051    /// is used in [`Self::add_subgraph_stratified`] or for nested loops.
1052    ///
1053    /// TODO(mingwei): add loop names to ensure traceability while debugging?
1054    pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
1055        let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
1056        let loop_id = self.context.loop_depth.insert(depth);
1057        self.loop_data.insert(
1058            loop_id,
1059            LoopData {
1060                iter_count: None,
1061                allow_another_iteration: true,
1062            },
1063        );
1064        loop_id
1065    }
1066
1067    /// Returns a referenced-counted handle to the continually-updated runtime metrics for this DFIR instance.
1068    pub fn metrics(&self) -> Rc<DfirMetrics> {
1069        Rc::clone(&self.metrics)
1070    }
1071
1072    /// Returns an infinite iterator of [`DfirMetrics`], where each call to `next()` ends an interval.
1073    ///
1074    /// The first call to [`DfirMetricsIntervals::next`] returns metrics since this DFIR instance was created. Each
1075    /// subsequent call to [`DfirMetricsIntervals::next`] returns metrics since the previous call.
1076    ///
1077    /// Cloning the iterator "forks" it from the original, as afterwards each iterator will return different metrics
1078    /// based on when [`DfirMetricsIntervals::next`] is called.
1079    pub fn metrics_intervals(&self) -> DfirMetricsIntervals {
1080        DfirMetricsIntervals {
1081            curr: self.metrics(),
1082            prev: None,
1083        }
1084    }
1085}
1086
1087impl Dfir<'_> {
1088    /// Alias for [`Context::request_task`].
1089    pub fn request_task<Fut>(&mut self, future: Fut)
1090    where
1091        Fut: Future<Output = ()> + 'static,
1092    {
1093        self.context.request_task(future);
1094    }
1095
1096    /// Alias for [`Context::abort_tasks`].
1097    pub fn abort_tasks(&mut self) {
1098        self.context.abort_tasks()
1099    }
1100
1101    /// Alias for [`Context::join_tasks`].
1102    pub fn join_tasks(&mut self) -> impl use<'_> + Future {
1103        self.context.join_tasks()
1104    }
1105}
1106
1107fn run_sync<Fut>(fut: Fut) -> Fut::Output
1108where
1109    Fut: Future,
1110{
1111    let mut fut = std::pin::pin!(fut);
1112    let mut ctx = std::task::Context::from_waker(std::task::Waker::noop());
1113    match fut.as_mut().poll(&mut ctx) {
1114        std::task::Poll::Ready(out) => out,
1115        std::task::Poll::Pending => panic!("Future did not resolve immediately."),
1116    }
1117}
1118
1119impl Drop for Dfir<'_> {
1120    fn drop(&mut self) {
1121        self.abort_tasks();
1122    }
1123}
1124
1125/// A handoff and its input and output [SubgraphId]s.
1126///
1127/// Internal use: used to track the dfir graph structure.
1128///
1129/// TODO(mingwei): restructure `PortList` so this can be crate-private.
1130#[doc(hidden)]
1131pub struct HandoffData {
1132    /// A friendly name for diagnostics.
1133    pub(super) name: Cow<'static, str>,
1134    /// Crate-visible to crate for `handoff_list` internals.
1135    pub(super) handoff: Box<dyn HandoffMeta>,
1136    /// Preceeding subgraphs (including the send side of a teeing handoff).
1137    pub(super) preds: SmallVec<[SubgraphId; 1]>,
1138    /// Successor subgraphs (including recv sides of teeing handoffs).
1139    pub(super) succs: SmallVec<[SubgraphId; 1]>,
1140
1141    /// Predecessor handoffs, used by teeing handoffs.
1142    /// Should be `self` on any teeing send sides (input).
1143    /// Should be the send `HandoffId` if this is teeing recv side (output).
1144    /// Should be just `self`'s `HandoffId` on other handoffs.
1145    /// This field is only used in initialization.
1146    pub(super) pred_handoffs: Vec<HandoffId>,
1147    /// Successor handoffs, used by teeing handoffs.
1148    /// Should be a list of outputs on the teeing send side (input).
1149    /// Should be `self` on any teeing recv sides (outputs).
1150    /// Should be just `self`'s `HandoffId` on other handoffs.
1151    /// This field is only used in initialization.
1152    pub(super) succ_handoffs: Vec<HandoffId>,
1153}
1154
1155impl std::fmt::Debug for HandoffData {
1156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1157        f.debug_struct("HandoffData")
1158            .field("preds", &self.preds)
1159            .field("succs", &self.succs)
1160            .finish_non_exhaustive()
1161    }
1162}
1163
1164impl HandoffData {
1165    /// New with `pred_handoffs` and `succ_handoffs` set to its own [`HandoffId`]: `vec![hoff_id]`.
1166    pub fn new(
1167        name: Cow<'static, str>,
1168        handoff: impl 'static + HandoffMeta,
1169        hoff_id: HandoffId,
1170    ) -> Self {
1171        let (preds, succs) = Default::default();
1172        Self {
1173            name,
1174            handoff: Box::new(handoff),
1175            preds,
1176            succs,
1177            pred_handoffs: vec![hoff_id],
1178            succ_handoffs: vec![hoff_id],
1179        }
1180    }
1181}
1182
1183/// A subgraph along with its predecessor and successor [SubgraphId]s.
1184///
1185/// Used internally by the [Dfir] struct to represent the dataflow graph
1186/// structure and scheduled state.
1187pub(super) struct SubgraphData<'a> {
1188    /// A friendly name for diagnostics.
1189    pub(super) name: Cow<'static, str>,
1190    /// This subgraph's stratum number.
1191    ///
1192    /// Within loop blocks, corresponds to the topological sort of the DAG created when `next_loop()/next_tick()` are removed.
1193    pub(super) stratum: usize,
1194    /// The actual execution code of the subgraph.
1195    subgraph: Box<dyn 'a + Subgraph>,
1196
1197    preds: Vec<HandoffId>,
1198    succs: Vec<HandoffId>,
1199
1200    /// If this subgraph is scheduled in [`dfir_rs::stratum_queues`].
1201    /// [`Cell`] allows modifying this field when iterating `Self::preds` or
1202    /// `Self::succs`, as all `SubgraphData` are owned by the same vec
1203    /// `dfir_rs::subgraphs`.
1204    is_scheduled: Cell<bool>,
1205
1206    /// Keep track of the last tick that this subgraph was run in
1207    last_tick_run_in: Option<TickInstant>,
1208    /// A meaningless ID to track the last loop execution this subgraph was run in.
1209    /// `(loop_nonce, iter_count)` pair.
1210    last_loop_nonce: (usize, Option<usize>),
1211
1212    /// If this subgraph is marked as lazy, then sending data back to a lower stratum does not trigger a new tick to be run.
1213    is_lazy: bool,
1214
1215    /// The subgraph's loop ID, or `None` for the top level.
1216    loop_id: Option<LoopId>,
1217    /// The loop depth of the subgraph.
1218    loop_depth: usize,
1219}
1220
1221impl<'a> SubgraphData<'a> {
1222    #[expect(clippy::too_many_arguments, reason = "internal use")]
1223    pub(crate) fn new(
1224        name: Cow<'static, str>,
1225        stratum: usize,
1226        subgraph: impl 'a + Subgraph,
1227        preds: Vec<HandoffId>,
1228        succs: Vec<HandoffId>,
1229        is_scheduled: bool,
1230        is_lazy: bool,
1231        loop_id: Option<LoopId>,
1232        loop_depth: usize,
1233    ) -> Self {
1234        Self {
1235            name,
1236            stratum,
1237            subgraph: Box::new(subgraph),
1238            preds,
1239            succs,
1240            is_scheduled: Cell::new(is_scheduled),
1241            last_tick_run_in: None,
1242            last_loop_nonce: (0, None),
1243            is_lazy,
1244            loop_id,
1245            loop_depth,
1246        }
1247    }
1248}
1249
1250pub(crate) struct LoopData {
1251    /// Count of iterations of this loop.
1252    iter_count: Option<usize>,
1253    /// If the loop has reason to do another iteration.
1254    allow_another_iteration: bool,
1255}
1256
1257/// Defines when state should be reset.
1258#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1259pub enum StateLifespan {
1260    /// Always reset, associated with the subgraph.
1261    Subgraph(SubgraphId),
1262    /// Reset between loop executions.
1263    Loop(LoopId),
1264    /// Reset between ticks.
1265    Tick,
1266    /// Never reset.
1267    Static,
1268}