dfir_rs/scheduled/
graph.rs

1//! Module for the [`Dfir`] struct and helper items.
2
3use std::any::Any;
4use std::borrow::Cow;
5use std::cell::Cell;
6use std::cmp::Ordering;
7use std::future::Future;
8use std::marker::PhantomData;
9use std::rc::Rc;
10
11#[cfg(feature = "meta")]
12use dfir_lang::diagnostic::{Diagnostic, SerdeSpan};
13#[cfg(feature = "meta")]
14use dfir_lang::graph::DfirGraph;
15use ref_cast::RefCast;
16use smallvec::SmallVec;
17use tracing::Instrument;
18use web_time::SystemTime;
19
20use super::context::Context;
21use super::handoff::handoff_list::PortList;
22use super::handoff::{Handoff, HandoffMeta, TeeingHandoff};
23use super::metrics::{DfirMetrics, DfirMetricsState, InstrumentSubgraph};
24use super::port::{RECV, RecvCtx, RecvPort, SEND, SendCtx, SendPort};
25use super::reactor::Reactor;
26use super::state::StateHandle;
27use super::subgraph::Subgraph;
28use super::ticks::{TickDuration, TickInstant};
29use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
30use crate::Never;
31use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
32
33/// A DFIR graph. Owns, schedules, and runs the compiled subgraphs.
34#[derive(Default)]
35pub struct Dfir<'a> {
36    pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
37
38    pub(super) loop_data: SecondarySlotVec<LoopTag, LoopData>,
39
40    pub(super) context: Context,
41
42    pub(super) handoffs: SlotVec<HandoffTag, HandoffData>,
43
44    metrics: Rc<DfirMetricsState>,
45
46    #[cfg(feature = "meta")]
47    /// See [`Self::meta_graph()`].
48    meta_graph: Option<DfirGraph>,
49
50    #[cfg(feature = "meta")]
51    /// See [`Self::diagnostics()`].
52    diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
53}
54
55/// Methods for [`TeeingHandoff`] teeing and dropping.
56impl Dfir<'_> {
57    /// Tees a [`TeeingHandoff`].
58    pub fn teeing_handoff_tee<T>(
59        &mut self,
60        tee_parent_port: &RecvPort<TeeingHandoff<T>>,
61    ) -> RecvPort<TeeingHandoff<T>>
62    where
63        T: Clone,
64    {
65        // If we're teeing from a child make sure to find root.
66        let tee_root = self.handoffs[tee_parent_port.handoff_id].pred_handoffs[0];
67
68        // Set up teeing metadata.
69        let tee_root_data = &mut self.handoffs[tee_root];
70        let tee_root_data_name = tee_root_data.name.clone();
71
72        // Insert new handoff output.
73        let teeing_handoff =
74            <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*tee_root_data.handoff).unwrap();
75        let new_handoff = teeing_handoff.tee();
76
77        // Handoff ID of new tee output.
78        let new_hoff_id = self.handoffs.insert_with_key(|new_hoff_id| {
79            let new_name = Cow::Owned(format!("{} tee {:?}", tee_root_data_name, new_hoff_id));
80            let mut new_handoff_data = HandoffData::new(new_name, new_handoff, new_hoff_id);
81            // Set self's predecessor as `tee_root`.
82            new_handoff_data.pred_handoffs = vec![tee_root];
83            new_handoff_data
84        });
85
86        // Go to `tee_root`'s successors and insert self (the new tee output).
87        let tee_root_data = &mut self.handoffs[tee_root];
88        tee_root_data.succ_handoffs.push(new_hoff_id);
89
90        // Add our new handoff id into the subgraph data if the send `tee_root` has already been
91        // used to add a subgraph.
92        assert!(
93            tee_root_data.preds.len() <= 1,
94            "Tee send side should only have one sender (or none set yet)."
95        );
96        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
97            self.subgraphs[pred_sg_id].succs.push(new_hoff_id);
98        }
99
100        // Initialize handoff metrics struct.
101        Rc::make_mut(&mut self.metrics)
102            .handoff_metrics
103            .insert(new_hoff_id, Default::default());
104
105        let output_port = RecvPort {
106            handoff_id: new_hoff_id,
107            _marker: PhantomData,
108        };
109        output_port
110    }
111
112    /// Marks an output of a [`TeeingHandoff`] as dropped so that no more data will be sent to it.
113    ///
114    /// It is recommended to not not use this method and instead simply avoid teeing a
115    /// [`TeeingHandoff`] when it is not needed.
116    pub fn teeing_handoff_drop<T>(&mut self, tee_port: RecvPort<TeeingHandoff<T>>)
117    where
118        T: Clone,
119    {
120        let data = &self.handoffs[tee_port.handoff_id];
121        let teeing_handoff = <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*data.handoff).unwrap();
122        teeing_handoff.drop();
123
124        let tee_root = data.pred_handoffs[0];
125        let tee_root_data = &mut self.handoffs[tee_root];
126        // Remove this output from the send succ handoff list.
127        tee_root_data
128            .succ_handoffs
129            .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
130        // Remove from subgraph successors if send port was already connected.
131        assert!(
132            tee_root_data.preds.len() <= 1,
133            "Tee send side should only have one sender (or none set yet)."
134        );
135        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
136            self.subgraphs[pred_sg_id]
137                .succs
138                .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
139        }
140    }
141}
142
143impl<'a> Dfir<'a> {
144    /// Create a new empty graph.
145    pub fn new() -> Self {
146        Default::default()
147    }
148
149    /// Assign the meta graph via JSON string. Used internally by the [`crate::dfir_syntax`] and other macros.
150    #[doc(hidden)]
151    pub fn __assign_meta_graph(&mut self, _meta_graph_json: &str) {
152        #[cfg(feature = "meta")]
153        {
154            let mut meta_graph: DfirGraph =
155                serde_json::from_str(_meta_graph_json).expect("Failed to deserialize graph.");
156
157            let mut op_inst_diagnostics = Vec::new();
158            meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
159            assert!(
160                op_inst_diagnostics.is_empty(),
161                "Expected no diagnostics, got: {:#?}",
162                op_inst_diagnostics
163            );
164
165            assert!(self.meta_graph.replace(meta_graph).is_none());
166        }
167    }
168    /// Assign the diagnostics via JSON string.
169    #[doc(hidden)]
170    pub fn __assign_diagnostics(&mut self, _diagnostics_json: &'static str) {
171        #[cfg(feature = "meta")]
172        {
173            let diagnostics: Vec<Diagnostic<SerdeSpan>> = serde_json::from_str(_diagnostics_json)
174                .expect("Failed to deserialize diagnostics.");
175
176            assert!(self.diagnostics.replace(diagnostics).is_none());
177        }
178    }
179
180    /// Return a handle to the meta graph, if set. The meta graph is a
181    /// representation of all the operators, subgraphs, and handoffs in this instance.
182    /// Will only be set if this graph was constructed using a surface syntax macro.
183    #[cfg(feature = "meta")]
184    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
185    pub fn meta_graph(&self) -> Option<&DfirGraph> {
186        self.meta_graph.as_ref()
187    }
188
189    /// Returns any diagnostics generated by the surface syntax macro. Each diagnostic is a pair of
190    /// (1) a `Diagnostic` with span info reset and (2) the `ToString` version of the diagnostic
191    /// with original span info.
192    /// Will only be set if this graph was constructed using a surface syntax macro.
193    #[cfg(feature = "meta")]
194    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
195    pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
196        self.diagnostics.as_deref()
197    }
198
199    /// Returns a reactor for externally scheduling subgraphs, possibly from another thread.
200    /// Reactor events are considered to be external events.
201    pub fn reactor(&self) -> Reactor {
202        Reactor::new(self.context.event_queue_send.clone())
203    }
204
205    /// Gets the current tick (local time) count.
206    pub fn current_tick(&self) -> TickInstant {
207        self.context.current_tick
208    }
209
210    /// Gets the current stratum nubmer.
211    pub fn current_stratum(&self) -> usize {
212        self.context.current_stratum
213    }
214
215    /// Runs the dataflow until the next tick begins.
216    ///
217    /// Returns `true` if any work was done.
218    ///
219    /// Will receive events if it is the start of a tick when called.
220    #[tracing::instrument(level = "trace", skip(self), ret)]
221    pub async fn run_tick(&mut self) -> bool {
222        let mut work_done = false;
223        // While work is immediately available *on the current tick*.
224        while self.next_stratum(true) {
225            work_done = true;
226            // Do any work.
227            self.run_stratum().await;
228        }
229        work_done
230    }
231
232    /// [`Self::run_tick`] but panics if a subgraph yields asynchronously.
233    #[tracing::instrument(level = "trace", skip(self), ret)]
234    pub fn run_tick_sync(&mut self) -> bool {
235        let mut work_done = false;
236        // While work is immediately available *on the current tick*.
237        while self.next_stratum(true) {
238            work_done = true;
239            // Do any work.
240            run_sync(self.run_stratum());
241        }
242        work_done
243    }
244
245    /// Runs the dataflow until no more (externally-triggered) work is immediately available.
246    /// Runs at least one tick of dataflow, even if no external events have been received.
247    /// If the dataflow contains loops this method may run forever.
248    ///
249    /// Returns `true` if any work was done.
250    ///
251    /// Yields repeatedly to allow external events to happen.
252    #[tracing::instrument(level = "trace", skip(self), ret)]
253    pub async fn run_available(&mut self) -> bool {
254        let mut work_done = false;
255        // While work is immediately available.
256        while self.next_stratum(false) {
257            work_done = true;
258            // Do any work.
259            self.run_stratum().await;
260
261            // Yield between each stratum to receive more events.
262            // TODO(mingwei): really only need to yield at start of ticks though.
263            tokio::task::yield_now().await;
264        }
265        work_done
266    }
267
268    /// [`Self::run_available`] but panics if a subgraph yields asynchronously.
269    #[tracing::instrument(level = "trace", skip(self), ret)]
270    pub fn run_available_sync(&mut self) -> bool {
271        let mut work_done = false;
272        // While work is immediately available.
273        while self.next_stratum(false) {
274            work_done = true;
275            // Do any work.
276            run_sync(self.run_stratum());
277        }
278        work_done
279    }
280
281    /// Runs the current stratum of the dataflow until no more local work is available (does not receive events).
282    ///
283    /// Returns `true` if any work was done.
284    #[tracing::instrument(level = "trace", skip(self), fields(tick = u64::from(self.context.current_tick), stratum = self.context.current_stratum), ret)]
285    pub async fn run_stratum(&mut self) -> bool {
286        // Make sure to spawn tasks once dfir is running!
287        // This drains the task buffer, so becomes a no-op after first call.
288        self.context.spawn_tasks();
289
290        let mut work_done = false;
291
292        'pop: while let Some(sg_id) =
293            self.context.stratum_queues[self.context.current_stratum].pop_front()
294        {
295            let sg_data = &mut self.subgraphs[sg_id];
296
297            // Update handoff metrics.
298            // NOTE(mingwei):
299            // We measure handoff metrics at `recv`, not `send`.
300            // This is done because we (a) know that all `recv`
301            // will be consumed(*) by the subgraph when it is run, and in contrast (b) do not know
302            // that all `send` will be consumed after a run. I.e. this subgraph may run multiple
303            // times before the next subgraph consumes the output; don't double count those.
304            // (*) - usually... always true for Hydro-generated DFIR at least.
305            for &handoff_id in sg_data.preds.iter() {
306                let handoff_metrics = &self.metrics.handoff_metrics[handoff_id];
307                let handoff_data = &mut self.handoffs[handoff_id];
308                let handoff_len = handoff_data.handoff.len();
309                handoff_metrics
310                    .total_items_count
311                    .update(|x| x + handoff_len);
312                handoff_metrics.curr_items_count.set(handoff_len);
313            }
314
315            // Run the subgraph (and do bookkeeping).
316            {
317                // This must be true for the subgraph to be enqueued.
318                assert!(sg_data.is_scheduled.take());
319
320                let run_subgraph_span_guard = tracing::info_span!(
321                    "run-subgraph",
322                    sg_id = sg_id.to_string(),
323                    sg_name = &*sg_data.name,
324                    sg_depth = sg_data.loop_depth,
325                    sg_loop_nonce = sg_data.last_loop_nonce.0,
326                    sg_iter_count = sg_data.last_loop_nonce.1,
327                )
328                .entered();
329
330                match sg_data.loop_depth.cmp(&self.context.loop_nonce_stack.len()) {
331                    Ordering::Greater => {
332                        // We have entered a loop.
333                        self.context.loop_nonce += 1;
334                        self.context.loop_nonce_stack.push(self.context.loop_nonce);
335                        tracing::trace!(loop_nonce = self.context.loop_nonce, "Entered loop.");
336                    }
337                    Ordering::Less => {
338                        // We have exited a loop.
339                        self.context.loop_nonce_stack.pop();
340                        tracing::trace!("Exited loop.");
341                    }
342                    Ordering::Equal => {}
343                }
344
345                self.context.subgraph_id = sg_id;
346                self.context.is_first_run_this_tick = sg_data
347                    .last_tick_run_in
348                    .is_none_or(|last_tick| last_tick < self.context.current_tick);
349
350                if let Some(loop_id) = sg_data.loop_id {
351                    // Loop execution - running loop block, from start to finish, containing
352                    // multiple iterations.
353                    // Loop iteration - a single iteration of a loop block, all subgraphs within
354                    // the loop should run (at most) once.
355
356                    // If the previous run of this subgraph had the same loop execution and
357                    // iteration count, then we need to increment the iteration count.
358                    let curr_loop_nonce = self.context.loop_nonce_stack.last().copied();
359
360                    let LoopData {
361                        iter_count: loop_iter_count,
362                        allow_another_iteration,
363                    } = &mut self.loop_data[loop_id];
364
365                    let (prev_loop_nonce, prev_iter_count) = sg_data.last_loop_nonce;
366
367                    // If the loop nonce is the same as the previous execution, then we are in
368                    // the same loop execution.
369                    // `curr_loop_nonce` is `None` for top-level loops, and top-level loops are
370                    // always in the same (singular) loop execution.
371                    let (curr_iter_count, new_loop_execution) =
372                        if curr_loop_nonce.is_none_or(|nonce| nonce == prev_loop_nonce) {
373                            // If the iteration count is the same as the previous execution, then
374                            // we are on the next iteration.
375                            if *loop_iter_count == prev_iter_count {
376                                // If not true, then we shall not run the next iteration.
377                                if !std::mem::take(allow_another_iteration) {
378                                    tracing::debug!(
379                                        "Loop will not continue to next iteration, skipping."
380                                    );
381                                    continue 'pop;
382                                }
383                                // Increment `loop_iter_count` or set it to 0.
384                                loop_iter_count.map_or((0, true), |n| (n + 1, false))
385                            } else {
386                                // Otherwise update the local iteration count to match the loop.
387                                debug_assert!(
388                                    prev_iter_count < *loop_iter_count,
389                                    "Expect loop iteration count to be increasing."
390                                );
391                                (loop_iter_count.unwrap(), false)
392                            }
393                        } else {
394                            // The loop execution has already begun, but this is the first time this particular subgraph is running.
395                            (0, false)
396                        };
397
398                    if new_loop_execution {
399                        // Run state hooks.
400                        self.context.run_state_hooks_loop(loop_id);
401                    }
402                    tracing::debug!("Loop iteration count {}", curr_iter_count);
403
404                    *loop_iter_count = Some(curr_iter_count);
405                    self.context.loop_iter_count = curr_iter_count;
406                    sg_data.last_loop_nonce =
407                        (curr_loop_nonce.unwrap_or_default(), Some(curr_iter_count));
408                }
409
410                // Run subgraph state hooks.
411                self.context.run_state_hooks_subgraph(sg_id);
412
413                tracing::info!("Running subgraph.");
414                sg_data.last_tick_run_in = Some(self.context.current_tick);
415
416                let sg_metrics = &self.metrics.subgraph_metrics[sg_id];
417                let sg_fut =
418                    Box::into_pin(sg_data.subgraph.run(&mut self.context, &mut self.handoffs));
419                // Update subgraph metrics.
420                let sg_fut = InstrumentSubgraph::new(sg_fut, sg_metrics);
421                // Pass along `run_subgraph_span` to the run subgraph future.
422                let sg_fut = sg_fut.instrument(run_subgraph_span_guard.exit());
423                let () = sg_fut.await;
424
425                sg_metrics.total_run_count.update(|x| x + 1);
426            };
427
428            // Schedule the following subgraphs if data was pushed to the corresponding handoff.
429            let sg_data = &self.subgraphs[sg_id];
430            for &handoff_id in sg_data.succs.iter() {
431                let handoff_data = &self.handoffs[handoff_id];
432                let handoff_len = handoff_data.handoff.len();
433                if 0 < handoff_len {
434                    for &succ_id in handoff_data.succs.iter() {
435                        let succ_sg_data = &self.subgraphs[succ_id];
436                        // If we have sent data to the next tick, then we can start the next tick.
437                        if succ_sg_data.stratum < self.context.current_stratum && !sg_data.is_lazy {
438                            self.context.can_start_tick = true;
439                        }
440                        // Add subgraph to stratum queue if it is not already scheduled.
441                        if !succ_sg_data.is_scheduled.replace(true) {
442                            self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
443                        }
444                        // Add stratum to stratum stack if it is within a loop.
445                        if 0 < succ_sg_data.loop_depth {
446                            // TODO(mingwei): handle duplicates
447                            self.context
448                                .stratum_stack
449                                .push(succ_sg_data.loop_depth, succ_sg_data.stratum);
450                        }
451                    }
452                }
453                let handoff_metrics = &self.metrics.handoff_metrics[handoff_id];
454                handoff_metrics.curr_items_count.set(handoff_len);
455            }
456
457            let reschedule = self.context.reschedule_loop_block.take();
458            let allow_another = self.context.allow_another_iteration.take();
459
460            if reschedule {
461                // Re-enqueue the subgraph.
462                self.context.schedule_deferred.push(sg_id);
463                self.context
464                    .stratum_stack
465                    .push(sg_data.loop_depth, sg_data.stratum);
466            }
467            if (reschedule || allow_another)
468                && let Some(loop_id) = sg_data.loop_id
469            {
470                self.loop_data
471                    .get_mut(loop_id)
472                    .unwrap()
473                    .allow_another_iteration = true;
474            }
475
476            work_done = true;
477        }
478        work_done
479    }
480
481    /// Go to the next stratum which has work available, possibly the current stratum.
482    /// Return true if more work is available, otherwise false if no work is immediately
483    /// available on any strata.
484    ///
485    /// This will receive external events when at the start of a tick.
486    ///
487    /// If `current_tick_only` is set to `true`, will only return `true` if work is immediately
488    /// available on the *current tick*.
489    ///
490    /// If this returns false then the graph will be at the start of a tick (at stratum 0, can
491    /// receive more external events).
492    #[tracing::instrument(level = "trace", skip(self), ret)]
493    pub fn next_stratum(&mut self, current_tick_only: bool) -> bool {
494        tracing::trace!(
495            events_received_tick = self.context.events_received_tick,
496            can_start_tick = self.context.can_start_tick,
497            "Starting `next_stratum` call.",
498        );
499
500        // The stratum we will stop searching at, i.e. made a full loop around.
501        let mut end_stratum = self.context.current_stratum;
502        let mut new_tick_started = false;
503
504        if 0 == self.context.current_stratum {
505            new_tick_started = true;
506
507            // Starting the tick, reset this to `false`.
508            tracing::trace!("Starting tick, setting `can_start_tick = false`.");
509            self.context.can_start_tick = false;
510            self.context.current_tick_start = SystemTime::now();
511
512            // Ensure external events are received before running the tick.
513            if !self.context.events_received_tick {
514                // Add any external jobs to ready queue.
515                self.try_recv_events();
516            }
517        }
518
519        loop {
520            tracing::trace!(
521                tick = u64::from(self.context.current_tick),
522                stratum = self.context.current_stratum,
523                "Looking for work on stratum."
524            );
525            // If current stratum has work, return true.
526            if !self.context.stratum_queues[self.context.current_stratum].is_empty() {
527                tracing::trace!(
528                    tick = u64::from(self.context.current_tick),
529                    stratum = self.context.current_stratum,
530                    "Work found on stratum."
531                );
532                return true;
533            }
534
535            if let Some(next_stratum) = self.context.stratum_stack.pop() {
536                self.context.current_stratum = next_stratum;
537
538                // Now schedule deferred subgraphs.
539                {
540                    for sg_id in self.context.schedule_deferred.drain(..) {
541                        let sg_data = &self.subgraphs[sg_id];
542                        tracing::info!(
543                            tick = u64::from(self.context.current_tick),
544                            stratum = self.context.current_stratum,
545                            sg_id = sg_id.to_string(),
546                            sg_name = &*sg_data.name,
547                            is_scheduled = sg_data.is_scheduled.get(),
548                            "Rescheduling deferred subgraph."
549                        );
550                        if !sg_data.is_scheduled.replace(true) {
551                            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
552                        }
553                    }
554                }
555            } else {
556                // Increment stratum counter.
557                self.context.current_stratum += 1;
558
559                if self.context.current_stratum >= self.context.stratum_queues.len() {
560                    new_tick_started = true;
561
562                    tracing::trace!(
563                        can_start_tick = self.context.can_start_tick,
564                        "End of tick {}, starting tick {}.",
565                        self.context.current_tick,
566                        self.context.current_tick + TickDuration::SINGLE_TICK,
567                    );
568                    self.context.run_state_hooks_tick();
569
570                    self.context.current_stratum = 0;
571                    self.context.current_tick += TickDuration::SINGLE_TICK;
572                    self.context.events_received_tick = false;
573
574                    if current_tick_only {
575                        tracing::trace!(
576                            "`current_tick_only` is `true`, returning `false` before receiving events."
577                        );
578                        return false;
579                    } else {
580                        self.try_recv_events();
581                        if std::mem::replace(&mut self.context.can_start_tick, false) {
582                            tracing::trace!(
583                                tick = u64::from(self.context.current_tick),
584                                "`can_start_tick` is `true`, continuing."
585                            );
586                            // Do a full loop more to find where events have been added.
587                            end_stratum = 0;
588                            continue;
589                        } else {
590                            tracing::trace!(
591                                "`can_start_tick` is `false`, re-setting `events_received_tick = false`, returning `false`."
592                            );
593                            self.context.events_received_tick = false;
594                            return false;
595                        }
596                    }
597                }
598            }
599
600            // After incrementing, exit if we made a full loop around the strata.
601            if new_tick_started && end_stratum == self.context.current_stratum {
602                tracing::trace!(
603                    "Made full loop around stratum, re-setting `current_stratum = 0`, returning `false`."
604                );
605                // Note: if current stratum had work, the very first loop iteration would've
606                // returned true. Therefore we can return false without checking.
607                // Also means nothing was done so we can reset the stratum to zero and wait for
608                // events.
609                self.context.events_received_tick = false;
610                self.context.current_stratum = 0;
611                return false;
612            }
613        }
614    }
615
616    /// Runs the dataflow graph forever.
617    ///
618    /// TODO(mingwei): Currently blocks forever, no notion of "completion."
619    #[tracing::instrument(level = "trace", skip(self), ret)]
620    pub async fn run(&mut self) -> Option<Never> {
621        loop {
622            // Run any work which is immediately available.
623            self.run_available().await;
624            // When no work is available yield until more events occur.
625            self.recv_events_async().await;
626        }
627    }
628
629    /// [`Self::run`] but panics if a subgraph yields asynchronously.
630    #[tracing::instrument(level = "trace", skip(self), ret)]
631    pub fn run_sync(&mut self) -> Option<Never> {
632        loop {
633            // Run any work which is immediately available.
634            self.run_available_sync();
635            // When no work is available block until more events occur.
636            self.recv_events();
637        }
638    }
639
640    /// Enqueues subgraphs triggered by events without blocking.
641    ///
642    /// Returns the number of subgraphs enqueued, and if any were external.
643    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
644    pub fn try_recv_events(&mut self) -> usize {
645        let mut enqueued_count = 0;
646        while let Ok((sg_id, is_external)) = self.context.event_queue_recv.try_recv() {
647            let sg_data = &self.subgraphs[sg_id];
648            tracing::trace!(
649                sg_id = sg_id.to_string(),
650                is_external = is_external,
651                sg_stratum = sg_data.stratum,
652                "Event received."
653            );
654            if !sg_data.is_scheduled.replace(true) {
655                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
656                enqueued_count += 1;
657            }
658            if is_external {
659                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
660                // Or if the stratum is in the next tick.
661                if !self.context.events_received_tick
662                    || sg_data.stratum < self.context.current_stratum
663                {
664                    tracing::trace!(
665                        current_stratum = self.context.current_stratum,
666                        sg_stratum = sg_data.stratum,
667                        "External event, setting `can_start_tick = true`."
668                    );
669                    self.context.can_start_tick = true;
670                }
671            }
672        }
673        self.context.events_received_tick = true;
674
675        enqueued_count
676    }
677
678    /// Enqueues subgraphs triggered by external events, blocking until at
679    /// least one subgraph is scheduled **from an external event**.
680    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
681    pub fn recv_events(&mut self) -> Option<usize> {
682        let mut count = 0;
683        loop {
684            let (sg_id, is_external) = self.context.event_queue_recv.blocking_recv()?;
685            let sg_data = &self.subgraphs[sg_id];
686            tracing::trace!(
687                sg_id = sg_id.to_string(),
688                is_external = is_external,
689                sg_stratum = sg_data.stratum,
690                "Event received."
691            );
692            if !sg_data.is_scheduled.replace(true) {
693                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
694                count += 1;
695            }
696            if is_external {
697                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
698                // Or if the stratum is in the next tick.
699                if !self.context.events_received_tick
700                    || sg_data.stratum < self.context.current_stratum
701                {
702                    tracing::trace!(
703                        current_stratum = self.context.current_stratum,
704                        sg_stratum = sg_data.stratum,
705                        "External event, setting `can_start_tick = true`."
706                    );
707                    self.context.can_start_tick = true;
708                }
709                break;
710            }
711        }
712        self.context.events_received_tick = true;
713
714        // Enqueue any other immediate events.
715        let extra_count = self.try_recv_events();
716        Some(count + extra_count)
717    }
718
719    /// Enqueues subgraphs triggered by external events asynchronously, waiting until at least one
720    /// subgraph is scheduled **from an external event**. Returns the number of subgraphs enqueued,
721    /// which may be zero if an external event scheduled an already-scheduled subgraph.
722    ///
723    /// Returns `None` if the event queue is closed, but that should not happen normally.
724    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
725    pub async fn recv_events_async(&mut self) -> Option<usize> {
726        let mut count = 0;
727        loop {
728            tracing::trace!("Awaiting events (`event_queue_recv`).");
729            let (sg_id, is_external) = self.context.event_queue_recv.recv().await?;
730            let sg_data = &self.subgraphs[sg_id];
731            tracing::trace!(
732                sg_id = sg_id.to_string(),
733                is_external = is_external,
734                sg_stratum = sg_data.stratum,
735                "Event received."
736            );
737            if !sg_data.is_scheduled.replace(true) {
738                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
739                count += 1;
740            }
741            if is_external {
742                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
743                // Or if the stratum is in the next tick.
744                if !self.context.events_received_tick
745                    || sg_data.stratum < self.context.current_stratum
746                {
747                    tracing::trace!(
748                        current_stratum = self.context.current_stratum,
749                        sg_stratum = sg_data.stratum,
750                        "External event, setting `can_start_tick = true`."
751                    );
752                    self.context.can_start_tick = true;
753                }
754                break;
755            }
756        }
757        self.context.events_received_tick = true;
758
759        // Enqueue any other immediate events.
760        let extra_count = self.try_recv_events();
761        Some(count + extra_count)
762    }
763
764    /// Schedules a subgraph to be run. See also: [`Context::schedule_subgraph`].
765    pub fn schedule_subgraph(&mut self, sg_id: SubgraphId) -> bool {
766        let sg_data = &self.subgraphs[sg_id];
767        let already_scheduled = sg_data.is_scheduled.replace(true);
768        if !already_scheduled {
769            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
770            true
771        } else {
772            false
773        }
774    }
775
776    /// Adds a new compiled subgraph with the specified inputs and outputs in stratum 0.
777    pub fn add_subgraph<Name, R, W, Func>(
778        &mut self,
779        name: Name,
780        recv_ports: R,
781        send_ports: W,
782        subgraph: Func,
783    ) -> SubgraphId
784    where
785        Name: Into<Cow<'static, str>>,
786        R: 'static + PortList<RECV>,
787        W: 'static + PortList<SEND>,
788        Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
789    {
790        self.add_subgraph_stratified(name, 0, recv_ports, send_ports, false, subgraph)
791    }
792
793    /// Adds a new compiled subgraph with the specified inputs, outputs, and stratum number.
794    ///
795    /// TODO(mingwei): add example in doc.
796    pub fn add_subgraph_stratified<Name, R, W, Func>(
797        &mut self,
798        name: Name,
799        stratum: usize,
800        recv_ports: R,
801        send_ports: W,
802        laziness: bool,
803        subgraph: Func,
804    ) -> SubgraphId
805    where
806        Name: Into<Cow<'static, str>>,
807        R: 'static + PortList<RECV>,
808        W: 'static + PortList<SEND>,
809        Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
810    {
811        self.add_subgraph_full(
812            name, stratum, recv_ports, send_ports, laziness, None, subgraph,
813        )
814    }
815
816    /// Adds a new compiled subgraph with all options.
817    #[expect(clippy::too_many_arguments, reason = "Mainly for internal use.")]
818    pub fn add_subgraph_full<Name, R, W, Func>(
819        &mut self,
820        name: Name,
821        stratum: usize,
822        recv_ports: R,
823        send_ports: W,
824        laziness: bool,
825        loop_id: Option<LoopId>,
826        mut subgraph: Func,
827    ) -> SubgraphId
828    where
829        Name: Into<Cow<'static, str>>,
830        R: 'static + PortList<RECV>,
831        W: 'static + PortList<SEND>,
832        Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
833    {
834        // SAFETY: Check that the send and recv ports are from `self.handoffs`.
835        recv_ports.assert_is_from(&self.handoffs);
836        send_ports.assert_is_from(&self.handoffs);
837
838        let loop_depth = loop_id
839            .and_then(|loop_id| self.context.loop_depth.get(loop_id))
840            .copied()
841            .unwrap_or(0);
842
843        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
844            let (mut subgraph_preds, mut subgraph_succs) = Default::default();
845            recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
846            send_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_succs, sg_id, false);
847
848            let subgraph =
849                async move |context: &mut Context,
850                            handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
851                    let (recv, send) = unsafe {
852                        // SAFETY:
853                        // 1. We checked `assert_is_from` at assembly time, above.
854                        // 2. `SlotVec` is insert-only so no handoffs could have changed since then.
855                        (
856                            recv_ports.make_ctx(&*handoffs),
857                            send_ports.make_ctx(&*handoffs),
858                        )
859                    };
860                    (subgraph)(context, recv, send).await;
861                };
862            SubgraphData::new(
863                name.into(),
864                stratum,
865                subgraph,
866                subgraph_preds,
867                subgraph_succs,
868                true,
869                laziness,
870                loop_id,
871                loop_depth,
872            )
873        });
874        self.context.init_stratum(stratum);
875        self.context.stratum_queues[stratum].push_back(sg_id);
876
877        // Initialize subgraph metrics struct.
878        Rc::make_mut(&mut self.metrics)
879            .subgraph_metrics
880            .insert(sg_id, Default::default());
881
882        sg_id
883    }
884
885    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
886    pub fn add_subgraph_n_m<Name, R, W, Func>(
887        &mut self,
888        name: Name,
889        recv_ports: Vec<RecvPort<R>>,
890        send_ports: Vec<SendPort<W>>,
891        subgraph: Func,
892    ) -> SubgraphId
893    where
894        Name: Into<Cow<'static, str>>,
895        R: 'static + Handoff,
896        W: 'static + Handoff,
897        Func: 'a
898            + for<'ctx> AsyncFnMut(
899                &'ctx mut Context,
900                &'ctx [&'ctx RecvCtx<R>],
901                &'ctx [&'ctx SendCtx<W>],
902            ),
903    {
904        self.add_subgraph_stratified_n_m(name, 0, recv_ports, send_ports, subgraph)
905    }
906
907    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
908    pub fn add_subgraph_stratified_n_m<Name, R, W, Func>(
909        &mut self,
910        name: Name,
911        stratum: usize,
912        recv_ports: Vec<RecvPort<R>>,
913        send_ports: Vec<SendPort<W>>,
914        mut subgraph: Func,
915    ) -> SubgraphId
916    where
917        Name: Into<Cow<'static, str>>,
918        R: 'static + Handoff,
919        W: 'static + Handoff,
920        Func: 'a
921            + for<'ctx> AsyncFnMut(
922                &'ctx mut Context,
923                &'ctx [&'ctx RecvCtx<R>],
924                &'ctx [&'ctx SendCtx<W>],
925            ),
926    {
927        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
928            let subgraph_preds = recv_ports.iter().map(|port| port.handoff_id).collect();
929            let subgraph_succs = send_ports.iter().map(|port| port.handoff_id).collect();
930
931            for recv_port in recv_ports.iter() {
932                self.handoffs[recv_port.handoff_id].succs.push(sg_id);
933            }
934            for send_port in send_ports.iter() {
935                self.handoffs[send_port.handoff_id].preds.push(sg_id);
936            }
937
938            let subgraph =
939                async move |context: &mut Context,
940                            handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
941                    let recvs: Vec<&RecvCtx<R>> = recv_ports
942                        .iter()
943                        .map(|hid| hid.handoff_id)
944                        .map(|hid| handoffs.get(hid).unwrap())
945                        .map(|h_data| {
946                            <dyn Any>::downcast_ref(&*h_data.handoff)
947                                .expect("Attempted to cast handoff to wrong type.")
948                        })
949                        .map(RefCast::ref_cast)
950                        .collect();
951
952                    let sends: Vec<&SendCtx<W>> = send_ports
953                        .iter()
954                        .map(|hid| hid.handoff_id)
955                        .map(|hid| handoffs.get(hid).unwrap())
956                        .map(|h_data| {
957                            <dyn Any>::downcast_ref(&*h_data.handoff)
958                                .expect("Attempted to cast handoff to wrong type.")
959                        })
960                        .map(RefCast::ref_cast)
961                        .collect();
962
963                    (subgraph)(context, &recvs, &sends).await;
964                };
965            SubgraphData::new(
966                name.into(),
967                stratum,
968                subgraph,
969                subgraph_preds,
970                subgraph_succs,
971                true,
972                false,
973                None,
974                0,
975            )
976        });
977        self.context.init_stratum(stratum);
978        self.context.stratum_queues[stratum].push_back(sg_id);
979
980        // Initialize subgraph metrics struct.
981        Rc::make_mut(&mut self.metrics)
982            .subgraph_metrics
983            .insert(sg_id, Default::default());
984
985        sg_id
986    }
987
988    /// Creates a handoff edge and returns the corresponding send and receive ports.
989    pub fn make_edge<Name, H>(&mut self, name: Name) -> (SendPort<H>, RecvPort<H>)
990    where
991        Name: Into<Cow<'static, str>>,
992        H: 'static + Handoff,
993    {
994        // Create and insert handoff.
995        let handoff = H::default();
996        let handoff_id = self
997            .handoffs
998            .insert_with_key(|hoff_id| HandoffData::new(name.into(), handoff, hoff_id));
999
1000        // Initialize handoff metrics struct.
1001        Rc::make_mut(&mut self.metrics)
1002            .handoff_metrics
1003            .insert(handoff_id, Default::default());
1004
1005        // Make ports.
1006        let input_port = SendPort {
1007            handoff_id,
1008            _marker: PhantomData,
1009        };
1010        let output_port = RecvPort {
1011            handoff_id,
1012            _marker: PhantomData,
1013        };
1014        (input_port, output_port)
1015    }
1016
1017    /// Adds referenceable state into this instance. Returns a state handle which can be
1018    /// used externally or by operators to access the state.
1019    ///
1020    /// This is part of the "state API".
1021    pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
1022    where
1023        T: Any,
1024    {
1025        self.context.add_state(state)
1026    }
1027
1028    /// Sets a hook to modify the state at the end of each tick, using the supplied closure.
1029    ///
1030    /// This is part of the "state API".
1031    pub fn set_state_lifespan_hook<T>(
1032        &mut self,
1033        handle: StateHandle<T>,
1034        lifespan: StateLifespan,
1035        hook_fn: impl 'static + FnMut(&mut T),
1036    ) where
1037        T: Any,
1038    {
1039        self.context
1040            .set_state_lifespan_hook(handle, lifespan, hook_fn)
1041    }
1042
1043    /// Gets a exclusive (mut) ref to the internal context, setting the subgraph ID.
1044    pub fn context_mut(&mut self, sg_id: SubgraphId) -> &mut Context {
1045        self.context.subgraph_id = sg_id;
1046        &mut self.context
1047    }
1048
1049    /// Adds a new loop with the given parent (or `None` for top-level). Returns a loop ID which
1050    /// is used in [`Self::add_subgraph_stratified`] or for nested loops.
1051    ///
1052    /// TODO(mingwei): add loop names to ensure traceability while debugging?
1053    pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
1054        let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
1055        let loop_id = self.context.loop_depth.insert(depth);
1056        self.loop_data.insert(
1057            loop_id,
1058            LoopData {
1059                iter_count: None,
1060                allow_another_iteration: true,
1061            },
1062        );
1063        loop_id
1064    }
1065
1066    /// Returns DFIR runtime metrics accumulated since runtime creation.
1067    pub fn metrics(&self) -> DfirMetrics {
1068        DfirMetrics {
1069            curr: Rc::clone(&self.metrics),
1070            prev: None,
1071        }
1072    }
1073}
1074
1075impl Dfir<'_> {
1076    /// Alias for [`Context::request_task`].
1077    pub fn request_task<Fut>(&mut self, future: Fut)
1078    where
1079        Fut: Future<Output = ()> + 'static,
1080    {
1081        self.context.request_task(future);
1082    }
1083
1084    /// Alias for [`Context::abort_tasks`].
1085    pub fn abort_tasks(&mut self) {
1086        self.context.abort_tasks()
1087    }
1088
1089    /// Alias for [`Context::join_tasks`].
1090    pub fn join_tasks(&mut self) -> impl use<'_> + Future {
1091        self.context.join_tasks()
1092    }
1093}
1094
1095fn run_sync<Fut>(fut: Fut) -> Fut::Output
1096where
1097    Fut: Future,
1098{
1099    let mut fut = std::pin::pin!(fut);
1100    let mut ctx = std::task::Context::from_waker(std::task::Waker::noop());
1101    match fut.as_mut().poll(&mut ctx) {
1102        std::task::Poll::Ready(out) => out,
1103        std::task::Poll::Pending => panic!("Future did not resolve immediately."),
1104    }
1105}
1106
1107impl Drop for Dfir<'_> {
1108    fn drop(&mut self) {
1109        self.abort_tasks();
1110    }
1111}
1112
1113/// A handoff and its input and output [SubgraphId]s.
1114///
1115/// Internal use: used to track the dfir graph structure.
1116///
1117/// TODO(mingwei): restructure `PortList` so this can be crate-private.
1118#[doc(hidden)]
1119pub struct HandoffData {
1120    /// A friendly name for diagnostics.
1121    pub(super) name: Cow<'static, str>,
1122    /// Crate-visible to crate for `handoff_list` internals.
1123    pub(super) handoff: Box<dyn HandoffMeta>,
1124    /// Preceeding subgraphs (including the send side of a teeing handoff).
1125    pub(super) preds: SmallVec<[SubgraphId; 1]>,
1126    /// Successor subgraphs (including recv sides of teeing handoffs).
1127    pub(super) succs: SmallVec<[SubgraphId; 1]>,
1128
1129    /// Predecessor handoffs, used by teeing handoffs.
1130    /// Should be `self` on any teeing send sides (input).
1131    /// Should be the send `HandoffId` if this is teeing recv side (output).
1132    /// Should be just `self`'s `HandoffId` on other handoffs.
1133    /// This field is only used in initialization.
1134    pub(super) pred_handoffs: Vec<HandoffId>,
1135    /// Successor handoffs, used by teeing handoffs.
1136    /// Should be a list of outputs on the teeing send side (input).
1137    /// Should be `self` on any teeing recv sides (outputs).
1138    /// Should be just `self`'s `HandoffId` on other handoffs.
1139    /// This field is only used in initialization.
1140    pub(super) succ_handoffs: Vec<HandoffId>,
1141}
1142
1143impl std::fmt::Debug for HandoffData {
1144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1145        f.debug_struct("HandoffData")
1146            .field("preds", &self.preds)
1147            .field("succs", &self.succs)
1148            .finish_non_exhaustive()
1149    }
1150}
1151
1152impl HandoffData {
1153    /// New with `pred_handoffs` and `succ_handoffs` set to its own [`HandoffId`]: `vec![hoff_id]`.
1154    pub fn new(
1155        name: Cow<'static, str>,
1156        handoff: impl 'static + HandoffMeta,
1157        hoff_id: HandoffId,
1158    ) -> Self {
1159        let (preds, succs) = Default::default();
1160        Self {
1161            name,
1162            handoff: Box::new(handoff),
1163            preds,
1164            succs,
1165            pred_handoffs: vec![hoff_id],
1166            succ_handoffs: vec![hoff_id],
1167        }
1168    }
1169}
1170
1171/// A subgraph along with its predecessor and successor [SubgraphId]s.
1172///
1173/// Used internally by the [Dfir] struct to represent the dataflow graph
1174/// structure and scheduled state.
1175pub(super) struct SubgraphData<'a> {
1176    /// A friendly name for diagnostics.
1177    pub(super) name: Cow<'static, str>,
1178    /// This subgraph's stratum number.
1179    ///
1180    /// Within loop blocks, corresponds to the topological sort of the DAG created when `next_loop()/next_tick()` are removed.
1181    pub(super) stratum: usize,
1182    /// The actual execution code of the subgraph.
1183    subgraph: Box<dyn 'a + Subgraph>,
1184
1185    preds: Vec<HandoffId>,
1186    succs: Vec<HandoffId>,
1187
1188    /// If this subgraph is scheduled in [`dfir_rs::stratum_queues`].
1189    /// [`Cell`] allows modifying this field when iterating `Self::preds` or
1190    /// `Self::succs`, as all `SubgraphData` are owned by the same vec
1191    /// `dfir_rs::subgraphs`.
1192    is_scheduled: Cell<bool>,
1193
1194    /// Keep track of the last tick that this subgraph was run in
1195    last_tick_run_in: Option<TickInstant>,
1196    /// A meaningless ID to track the last loop execution this subgraph was run in.
1197    /// `(loop_nonce, iter_count)` pair.
1198    last_loop_nonce: (usize, Option<usize>),
1199
1200    /// If this subgraph is marked as lazy, then sending data back to a lower stratum does not trigger a new tick to be run.
1201    is_lazy: bool,
1202
1203    /// The subgraph's loop ID, or `None` for the top level.
1204    loop_id: Option<LoopId>,
1205    /// The loop depth of the subgraph.
1206    loop_depth: usize,
1207}
1208
1209impl<'a> SubgraphData<'a> {
1210    #[expect(clippy::too_many_arguments, reason = "internal use")]
1211    pub(crate) fn new(
1212        name: Cow<'static, str>,
1213        stratum: usize,
1214        subgraph: impl 'a + Subgraph,
1215        preds: Vec<HandoffId>,
1216        succs: Vec<HandoffId>,
1217        is_scheduled: bool,
1218        is_lazy: bool,
1219        loop_id: Option<LoopId>,
1220        loop_depth: usize,
1221    ) -> Self {
1222        Self {
1223            name,
1224            stratum,
1225            subgraph: Box::new(subgraph),
1226            preds,
1227            succs,
1228            is_scheduled: Cell::new(is_scheduled),
1229            last_tick_run_in: None,
1230            last_loop_nonce: (0, None),
1231            is_lazy,
1232            loop_id,
1233            loop_depth,
1234        }
1235    }
1236}
1237
1238pub(crate) struct LoopData {
1239    /// Count of iterations of this loop.
1240    iter_count: Option<usize>,
1241    /// If the loop has reason to do another iteration.
1242    allow_another_iteration: bool,
1243}
1244
1245/// Defines when state should be reset.
1246#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1247pub enum StateLifespan {
1248    /// Always reset, associated with the subgraph.
1249    Subgraph(SubgraphId),
1250    /// Reset between loop executions.
1251    Loop(LoopId),
1252    /// Reset between ticks.
1253    Tick,
1254    /// Never reset.
1255    Static,
1256}