dfir_rs/scheduled/
graph.rs

1//! Module for the [`Dfir`] struct and helper items.
2
3use std::any::Any;
4use std::borrow::Cow;
5use std::cell::Cell;
6use std::cmp::Ordering;
7use std::future::Future;
8use std::marker::PhantomData;
9
10#[cfg(feature = "meta")]
11use dfir_lang::diagnostic::{Diagnostic, SerdeSpan};
12#[cfg(feature = "meta")]
13use dfir_lang::graph::DfirGraph;
14use ref_cast::RefCast;
15use smallvec::SmallVec;
16use web_time::SystemTime;
17
18use super::context::Context;
19use super::handoff::handoff_list::PortList;
20use super::handoff::{Handoff, HandoffMeta, TeeingHandoff};
21use super::port::{RECV, RecvCtx, RecvPort, SEND, SendCtx, SendPort};
22use super::reactor::Reactor;
23use super::state::StateHandle;
24use super::subgraph::Subgraph;
25use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
26use crate::Never;
27use crate::scheduled::ticks::{TickDuration, TickInstant};
28use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
29
30/// A DFIR graph. Owns, schedules, and runs the compiled subgraphs.
31#[derive(Default)]
32pub struct Dfir<'a> {
33    pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
34
35    pub(super) loop_data: SecondarySlotVec<LoopTag, LoopData>,
36
37    pub(super) context: Context,
38
39    handoffs: SlotVec<HandoffTag, HandoffData>,
40
41    #[cfg(feature = "meta")]
42    /// See [`Self::meta_graph()`].
43    meta_graph: Option<DfirGraph>,
44
45    #[cfg(feature = "meta")]
46    /// See [`Self::diagnostics()`].
47    diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
48}
49
50/// Methods for [`TeeingHandoff`] teeing and dropping.
51impl Dfir<'_> {
52    /// Tees a [`TeeingHandoff`].
53    pub fn teeing_handoff_tee<T>(
54        &mut self,
55        tee_parent_port: &RecvPort<TeeingHandoff<T>>,
56    ) -> RecvPort<TeeingHandoff<T>>
57    where
58        T: Clone,
59    {
60        // If we're teeing from a child make sure to find root.
61        let tee_root = self.handoffs[tee_parent_port.handoff_id].pred_handoffs[0];
62
63        // Set up teeing metadata.
64        let tee_root_data = &mut self.handoffs[tee_root];
65        let tee_root_data_name = tee_root_data.name.clone();
66
67        // Insert new handoff output.
68        let teeing_handoff = tee_root_data
69            .handoff
70            .any_ref()
71            .downcast_ref::<TeeingHandoff<T>>()
72            .unwrap();
73        let new_handoff = teeing_handoff.tee();
74
75        // Handoff ID of new tee output.
76        let new_hoff_id = self.handoffs.insert_with_key(|new_hoff_id| {
77            let new_name = Cow::Owned(format!("{} tee {:?}", tee_root_data_name, new_hoff_id));
78            let mut new_handoff_data = HandoffData::new(new_name, new_handoff, new_hoff_id);
79            // Set self's predecessor as `tee_root`.
80            new_handoff_data.pred_handoffs = vec![tee_root];
81            new_handoff_data
82        });
83
84        // Go to `tee_root`'s successors and insert self (the new tee output).
85        let tee_root_data = &mut self.handoffs[tee_root];
86        tee_root_data.succ_handoffs.push(new_hoff_id);
87
88        // Add our new handoff id into the subgraph data if the send `tee_root` has already been
89        // used to add a subgraph.
90        assert!(
91            tee_root_data.preds.len() <= 1,
92            "Tee send side should only have one sender (or none set yet)."
93        );
94        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
95            self.subgraphs[pred_sg_id].succs.push(new_hoff_id);
96        }
97
98        let output_port = RecvPort {
99            handoff_id: new_hoff_id,
100            _marker: PhantomData,
101        };
102        output_port
103    }
104
105    /// Marks an output of a [`TeeingHandoff`] as dropped so that no more data will be sent to it.
106    ///
107    /// It is recommended to not not use this method and instead simply avoid teeing a
108    /// [`TeeingHandoff`] when it is not needed.
109    pub fn teeing_handoff_drop<T>(&mut self, tee_port: RecvPort<TeeingHandoff<T>>)
110    where
111        T: Clone,
112    {
113        let data = &self.handoffs[tee_port.handoff_id];
114        let teeing_handoff = data
115            .handoff
116            .any_ref()
117            .downcast_ref::<TeeingHandoff<T>>()
118            .unwrap();
119        teeing_handoff.drop();
120
121        let tee_root = data.pred_handoffs[0];
122        let tee_root_data = &mut self.handoffs[tee_root];
123        // Remove this output from the send succ handoff list.
124        tee_root_data
125            .succ_handoffs
126            .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
127        // Remove from subgraph successors if send port was already connected.
128        assert!(
129            tee_root_data.preds.len() <= 1,
130            "Tee send side should only have one sender (or none set yet)."
131        );
132        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
133            self.subgraphs[pred_sg_id]
134                .succs
135                .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
136        }
137    }
138}
139
140impl<'a> Dfir<'a> {
141    /// Create a new empty graph.
142    pub fn new() -> Self {
143        Default::default()
144    }
145
146    /// Assign the meta graph via JSON string. Used internally by the [`dfir_syntax`] and other macros.
147    #[doc(hidden)]
148    pub fn __assign_meta_graph(&mut self, _meta_graph_json: &str) {
149        #[cfg(feature = "meta")]
150        {
151            let mut meta_graph: DfirGraph =
152                serde_json::from_str(_meta_graph_json).expect("Failed to deserialize graph.");
153
154            let mut op_inst_diagnostics = Vec::new();
155            meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
156            assert!(
157                op_inst_diagnostics.is_empty(),
158                "Expected no diagnostics, got: {:#?}",
159                op_inst_diagnostics
160            );
161
162            assert!(self.meta_graph.replace(meta_graph).is_none());
163        }
164    }
165    /// Assign the diagnostics via JSON string.
166    #[doc(hidden)]
167    pub fn __assign_diagnostics(&mut self, _diagnostics_json: &'static str) {
168        #[cfg(feature = "meta")]
169        {
170            let diagnostics: Vec<Diagnostic<SerdeSpan>> = serde_json::from_str(_diagnostics_json)
171                .expect("Failed to deserialize diagnostics.");
172
173            assert!(self.diagnostics.replace(diagnostics).is_none());
174        }
175    }
176
177    /// Return a handle to the meta graph, if set. The meta graph is a
178    /// representation of all the operators, subgraphs, and handoffs in this instance.
179    /// Will only be set if this graph was constructed using a surface syntax macro.
180    #[cfg(feature = "meta")]
181    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
182    pub fn meta_graph(&self) -> Option<&DfirGraph> {
183        self.meta_graph.as_ref()
184    }
185
186    /// Returns any diagnostics generated by the surface syntax macro. Each diagnostic is a pair of
187    /// (1) a `Diagnostic` with span info reset and (2) the `ToString` version of the diagnostic
188    /// with original span info.
189    /// Will only be set if this graph was constructed using a surface syntax macro.
190    #[cfg(feature = "meta")]
191    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
192    pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
193        self.diagnostics.as_deref()
194    }
195
196    /// Returns a reactor for externally scheduling subgraphs, possibly from another thread.
197    /// Reactor events are considered to be external events.
198    pub fn reactor(&self) -> Reactor {
199        Reactor::new(self.context.event_queue_send.clone())
200    }
201
202    /// Gets the current tick (local time) count.
203    pub fn current_tick(&self) -> TickInstant {
204        self.context.current_tick
205    }
206
207    /// Gets the current stratum nubmer.
208    pub fn current_stratum(&self) -> usize {
209        self.context.current_stratum
210    }
211
212    /// Runs the dataflow until the next tick begins.
213    /// Returns true if any work was done.
214    #[tracing::instrument(level = "trace", skip(self), ret)]
215    pub fn run_tick(&mut self) -> bool {
216        let mut work_done = false;
217        // While work is immediately available *on the current tick*.
218        while self.next_stratum(true) {
219            work_done = true;
220            // Do any work.
221            self.run_stratum();
222        }
223        work_done
224    }
225
226    /// Runs the dataflow until no more (externally-triggered) work is immediately available.
227    /// Runs at least one tick of dataflow, even if no external events have been received.
228    /// If the dataflow contains loops this method may run forever.
229    /// Returns true if any work was done.
230    #[tracing::instrument(level = "trace", skip(self), ret)]
231    pub fn run_available(&mut self) -> bool {
232        let mut work_done = false;
233        // While work is immediately available.
234        while self.next_stratum(false) {
235            work_done = true;
236            // Do any work.
237            self.run_stratum();
238        }
239        work_done
240    }
241
242    /// Runs the dataflow until no more (externally-triggered) work is immediately available.
243    /// Runs at least one tick of dataflow, even if no external events have been received.
244    /// If the dataflow contains loops this method may run forever.
245    /// Returns true if any work was done.
246    /// Yields repeatedly to allow external events to happen.
247    #[tracing::instrument(level = "trace", skip(self), ret)]
248    pub async fn run_available_async(&mut self) -> bool {
249        let mut work_done = false;
250        // While work is immediately available.
251        while self.next_stratum(false) {
252            work_done = true;
253            // Do any work.
254            self.run_stratum();
255
256            // Yield between each stratum to receive more events.
257            // TODO(mingwei): really only need to yield at start of ticks though.
258            tokio::task::yield_now().await;
259        }
260        work_done
261    }
262
263    /// Runs the current stratum of the dataflow until no more local work is available (does not receive events).
264    /// Returns true if any work was done.
265    #[tracing::instrument(level = "trace", skip(self), fields(tick = u64::from(self.context.current_tick), stratum = self.context.current_stratum), ret)]
266    pub fn run_stratum(&mut self) -> bool {
267        // Make sure to spawn tasks once dfir is running!
268        // This drains the task buffer, so becomes a no-op after first call.
269        self.context.spawn_tasks();
270
271        let mut work_done = false;
272
273        'pop: while let Some(sg_id) =
274            self.context.stratum_queues[self.context.current_stratum].pop_front()
275        {
276            {
277                let sg_data = &mut self.subgraphs[sg_id];
278                // This must be true for the subgraph to be enqueued.
279                assert!(sg_data.is_scheduled.take());
280
281                let _enter = tracing::info_span!(
282                    "run-subgraph",
283                    sg_id = sg_id.to_string(),
284                    sg_name = &*sg_data.name,
285                    sg_depth = sg_data.loop_depth,
286                    sg_loop_nonce = sg_data.last_loop_nonce.0,
287                    sg_iter_count = sg_data.last_loop_nonce.1,
288                )
289                .entered();
290
291                match sg_data.loop_depth.cmp(&self.context.loop_nonce_stack.len()) {
292                    Ordering::Greater => {
293                        // We have entered a loop.
294                        self.context.loop_nonce += 1;
295                        self.context.loop_nonce_stack.push(self.context.loop_nonce);
296                        tracing::trace!(loop_nonce = self.context.loop_nonce, "Entered loop.");
297                    }
298                    Ordering::Less => {
299                        // We have exited a loop.
300                        self.context.loop_nonce_stack.pop();
301                        tracing::trace!("Exited loop.");
302                    }
303                    Ordering::Equal => {}
304                }
305
306                self.context.subgraph_id = sg_id;
307                self.context.is_first_run_this_tick = sg_data
308                    .last_tick_run_in
309                    .is_none_or(|last_tick| last_tick < self.context.current_tick);
310
311                if let Some(loop_id) = sg_data.loop_id {
312                    // Loop execution - running loop block, from start to finish, containing
313                    // multiple iterations.
314                    // Loop iteration - a single iteration of a loop block, all subgraphs within
315                    // the loop should run (at most) once.
316
317                    // If the previous run of this subgraph had the same loop execution and
318                    // iteration count, then we need to increment the iteration count.
319                    let curr_loop_nonce = self.context.loop_nonce_stack.last().copied();
320
321                    let LoopData {
322                        iter_count: loop_iter_count,
323                        allow_another_iteration,
324                    } = &mut self.loop_data[loop_id];
325
326                    let (prev_loop_nonce, prev_iter_count) = sg_data.last_loop_nonce;
327
328                    // If the loop nonce is the same as the previous execution, then we are in
329                    // the same loop execution.
330                    // `curr_loop_nonce` is `None` for top-level loops, and top-level loops are
331                    // always in the same (singular) loop execution.
332                    let (curr_iter_count, new_loop_execution) =
333                        if curr_loop_nonce.is_none_or(|nonce| nonce == prev_loop_nonce) {
334                            // If the iteration count is the same as the previous execution, then
335                            // we are on the next iteration.
336                            if *loop_iter_count == prev_iter_count {
337                                // If not true, then we shall not run the next iteration.
338                                if !std::mem::take(allow_another_iteration) {
339                                    tracing::debug!(
340                                        "Loop will not continue to next iteration, skipping."
341                                    );
342                                    continue 'pop;
343                                }
344                                // Increment `loop_iter_count` or set it to 0.
345                                loop_iter_count.map_or((0, true), |n| (n + 1, false))
346                            } else {
347                                // Otherwise update the local iteration count to match the loop.
348                                debug_assert!(
349                                    prev_iter_count < *loop_iter_count,
350                                    "Expect loop iteration count to be increasing."
351                                );
352                                (loop_iter_count.unwrap(), false)
353                            }
354                        } else {
355                            // The loop execution has already begun, but this is the first time this particular subgraph is running.
356                            (0, false)
357                        };
358
359                    if new_loop_execution {
360                        // Run state hooks.
361                        self.context.run_state_hooks_loop(loop_id);
362                    }
363                    tracing::debug!("Loop iteration count {}", curr_iter_count);
364
365                    *loop_iter_count = Some(curr_iter_count);
366                    self.context.loop_iter_count = curr_iter_count;
367                    sg_data.last_loop_nonce =
368                        (curr_loop_nonce.unwrap_or_default(), Some(curr_iter_count));
369                }
370
371                // Run subgraph state hooks.
372                self.context.run_state_hooks_subgraph(sg_id);
373
374                tracing::info!("Running subgraph.");
375                sg_data.subgraph.run(&mut self.context, &mut self.handoffs);
376
377                sg_data.last_tick_run_in = Some(self.context.current_tick);
378            }
379
380            let sg_data = &self.subgraphs[sg_id];
381            for &handoff_id in sg_data.succs.iter() {
382                let handoff = &self.handoffs[handoff_id];
383                if !handoff.handoff.is_bottom() {
384                    for &succ_id in handoff.succs.iter() {
385                        let succ_sg_data = &self.subgraphs[succ_id];
386                        // If we have sent data to the next tick, then we can start the next tick.
387                        if succ_sg_data.stratum < self.context.current_stratum && !sg_data.is_lazy {
388                            self.context.can_start_tick = true;
389                        }
390                        // Add subgraph to stratum queue if it is not already scheduled.
391                        if !succ_sg_data.is_scheduled.replace(true) {
392                            self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
393                        }
394                        // Add stratum to stratum stack if it is within a loop.
395                        if 0 < succ_sg_data.loop_depth {
396                            // TODO(mingwei): handle duplicates
397                            self.context
398                                .stratum_stack
399                                .push(succ_sg_data.loop_depth, succ_sg_data.stratum);
400                        }
401                    }
402                }
403            }
404
405            let reschedule = self.context.reschedule_loop_block.take();
406            let allow_another = self.context.allow_another_iteration.take();
407
408            if reschedule {
409                // Re-enqueue the subgraph.
410                self.context.schedule_deferred.push(sg_id);
411                self.context
412                    .stratum_stack
413                    .push(sg_data.loop_depth, sg_data.stratum);
414            }
415            if reschedule || allow_another {
416                if let Some(loop_id) = sg_data.loop_id {
417                    self.loop_data
418                        .get_mut(loop_id)
419                        .unwrap()
420                        .allow_another_iteration = true;
421                }
422            }
423
424            work_done = true;
425        }
426        work_done
427    }
428
429    /// Go to the next stratum which has work available, possibly the current stratum.
430    /// Return true if more work is available, otherwise false if no work is immediately
431    /// available on any strata.
432    ///
433    /// This will receive external events when at the start of a tick.
434    ///
435    /// If `current_tick_only` is set to `true`, will only return `true` if work is immediately
436    /// available on the *current tick*.
437    ///
438    /// If this returns false then the graph will be at the start of a tick (at stratum 0, can
439    /// receive more external events).
440    #[tracing::instrument(level = "trace", skip(self), ret)]
441    pub fn next_stratum(&mut self, current_tick_only: bool) -> bool {
442        tracing::trace!(
443            events_received_tick = self.context.events_received_tick,
444            can_start_tick = self.context.can_start_tick,
445            "Starting `next_stratum` call.",
446        );
447
448        // The stratum we will stop searching at, i.e. made a full loop around.
449        let mut end_stratum = self.context.current_stratum;
450        let mut new_tick_started = false;
451
452        if 0 == self.context.current_stratum {
453            new_tick_started = true;
454
455            // Starting the tick, reset this to `false`.
456            tracing::trace!("Starting tick, setting `can_start_tick = false`.");
457            self.context.can_start_tick = false;
458            self.context.current_tick_start = SystemTime::now();
459
460            // Ensure external events are received before running the tick.
461            if !self.context.events_received_tick {
462                // Add any external jobs to ready queue.
463                self.try_recv_events();
464            }
465        }
466
467        loop {
468            tracing::trace!(
469                tick = u64::from(self.context.current_tick),
470                stratum = self.context.current_stratum,
471                "Looking for work on stratum."
472            );
473            // If current stratum has work, return true.
474            if !self.context.stratum_queues[self.context.current_stratum].is_empty() {
475                tracing::trace!(
476                    tick = u64::from(self.context.current_tick),
477                    stratum = self.context.current_stratum,
478                    "Work found on stratum."
479                );
480                return true;
481            }
482
483            if let Some(next_stratum) = self.context.stratum_stack.pop() {
484                self.context.current_stratum = next_stratum;
485
486                // Now schedule deferred subgraphs.
487                {
488                    for sg_id in self.context.schedule_deferred.drain(..) {
489                        let sg_data = &self.subgraphs[sg_id];
490                        tracing::info!(
491                            tick = u64::from(self.context.current_tick),
492                            stratum = self.context.current_stratum,
493                            sg_id = sg_id.to_string(),
494                            sg_name = &*sg_data.name,
495                            is_scheduled = sg_data.is_scheduled.get(),
496                            "Rescheduling deferred subgraph."
497                        );
498                        if !sg_data.is_scheduled.replace(true) {
499                            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
500                        }
501                    }
502                }
503            } else {
504                // Increment stratum counter.
505                self.context.current_stratum += 1;
506
507                if self.context.current_stratum >= self.context.stratum_queues.len() {
508                    new_tick_started = true;
509
510                    tracing::trace!(
511                        can_start_tick = self.context.can_start_tick,
512                        "End of tick {}, starting tick {}.",
513                        self.context.current_tick,
514                        self.context.current_tick + TickDuration::SINGLE_TICK,
515                    );
516                    self.context.run_state_hooks_tick();
517
518                    self.context.current_stratum = 0;
519                    self.context.current_tick += TickDuration::SINGLE_TICK;
520                    self.context.events_received_tick = false;
521
522                    if current_tick_only {
523                        tracing::trace!(
524                            "`current_tick_only` is `true`, returning `false` before receiving events."
525                        );
526                        return false;
527                    } else {
528                        self.try_recv_events();
529                        if std::mem::replace(&mut self.context.can_start_tick, false) {
530                            tracing::trace!(
531                                tick = u64::from(self.context.current_tick),
532                                "`can_start_tick` is `true`, continuing."
533                            );
534                            // Do a full loop more to find where events have been added.
535                            end_stratum = 0;
536                            continue;
537                        } else {
538                            tracing::trace!(
539                                "`can_start_tick` is `false`, re-setting `events_received_tick = false`, returning `false`."
540                            );
541                            self.context.events_received_tick = false;
542                            return false;
543                        }
544                    }
545                }
546            }
547
548            // After incrementing, exit if we made a full loop around the strata.
549            if new_tick_started && end_stratum == self.context.current_stratum {
550                tracing::trace!(
551                    "Made full loop around stratum, re-setting `current_stratum = 0`, returning `false`."
552                );
553                // Note: if current stratum had work, the very first loop iteration would've
554                // returned true. Therefore we can return false without checking.
555                // Also means nothing was done so we can reset the stratum to zero and wait for
556                // events.
557                self.context.events_received_tick = false;
558                self.context.current_stratum = 0;
559                return false;
560            }
561        }
562    }
563
564    /// Runs the dataflow graph forever.
565    ///
566    /// TODO(mingwei): Currently blocks forever, no notion of "completion."
567    #[tracing::instrument(level = "trace", skip(self), ret)]
568    pub fn run(&mut self) -> Option<Never> {
569        loop {
570            self.run_tick();
571        }
572    }
573
574    /// Runs the dataflow graph forever.
575    ///
576    /// TODO(mingwei): Currently blocks forever, no notion of "completion."
577    #[tracing::instrument(level = "trace", skip(self), ret)]
578    pub async fn run_async(&mut self) -> Option<Never> {
579        loop {
580            // Run any work which is immediately available.
581            self.run_available_async().await;
582            // When no work is available yield until more events occur.
583            self.recv_events_async().await;
584        }
585    }
586
587    /// Enqueues subgraphs triggered by events without blocking.
588    ///
589    /// Returns the number of subgraphs enqueued, and if any were external.
590    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
591    pub fn try_recv_events(&mut self) -> usize {
592        let mut enqueued_count = 0;
593        while let Ok((sg_id, is_external)) = self.context.event_queue_recv.try_recv() {
594            let sg_data = &self.subgraphs[sg_id];
595            tracing::trace!(
596                sg_id = sg_id.to_string(),
597                is_external = is_external,
598                sg_stratum = sg_data.stratum,
599                "Event received."
600            );
601            if !sg_data.is_scheduled.replace(true) {
602                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
603                enqueued_count += 1;
604            }
605            if is_external {
606                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
607                // Or if the stratum is in the next tick.
608                if !self.context.events_received_tick
609                    || sg_data.stratum < self.context.current_stratum
610                {
611                    tracing::trace!(
612                        current_stratum = self.context.current_stratum,
613                        sg_stratum = sg_data.stratum,
614                        "External event, setting `can_start_tick = true`."
615                    );
616                    self.context.can_start_tick = true;
617                }
618            }
619        }
620        self.context.events_received_tick = true;
621
622        enqueued_count
623    }
624
625    /// Enqueues subgraphs triggered by external events, blocking until at
626    /// least one subgraph is scheduled **from an external event**.
627    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
628    pub fn recv_events(&mut self) -> Option<usize> {
629        let mut count = 0;
630        loop {
631            let (sg_id, is_external) = self.context.event_queue_recv.blocking_recv()?;
632            let sg_data = &self.subgraphs[sg_id];
633            tracing::trace!(
634                sg_id = sg_id.to_string(),
635                is_external = is_external,
636                sg_stratum = sg_data.stratum,
637                "Event received."
638            );
639            if !sg_data.is_scheduled.replace(true) {
640                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
641                count += 1;
642            }
643            if is_external {
644                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
645                // Or if the stratum is in the next tick.
646                if !self.context.events_received_tick
647                    || sg_data.stratum < self.context.current_stratum
648                {
649                    tracing::trace!(
650                        current_stratum = self.context.current_stratum,
651                        sg_stratum = sg_data.stratum,
652                        "External event, setting `can_start_tick = true`."
653                    );
654                    self.context.can_start_tick = true;
655                }
656                break;
657            }
658        }
659        self.context.events_received_tick = true;
660
661        // Enqueue any other immediate events.
662        let extra_count = self.try_recv_events();
663        Some(count + extra_count)
664    }
665
666    /// Enqueues subgraphs triggered by external events asynchronously, waiting until at least one
667    /// subgraph is scheduled **from an external event**. Returns the number of subgraphs enqueued,
668    /// which may be zero if an external event scheduled an already-scheduled subgraph.
669    ///
670    /// Returns `None` if the event queue is closed, but that should not happen normally.
671    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
672    pub async fn recv_events_async(&mut self) -> Option<usize> {
673        let mut count = 0;
674        loop {
675            tracing::trace!("Awaiting events (`event_queue_recv`).");
676            let (sg_id, is_external) = self.context.event_queue_recv.recv().await?;
677            let sg_data = &self.subgraphs[sg_id];
678            tracing::trace!(
679                sg_id = sg_id.to_string(),
680                is_external = is_external,
681                sg_stratum = sg_data.stratum,
682                "Event received."
683            );
684            if !sg_data.is_scheduled.replace(true) {
685                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
686                count += 1;
687            }
688            if is_external {
689                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
690                // Or if the stratum is in the next tick.
691                if !self.context.events_received_tick
692                    || sg_data.stratum < self.context.current_stratum
693                {
694                    tracing::trace!(
695                        current_stratum = self.context.current_stratum,
696                        sg_stratum = sg_data.stratum,
697                        "External event, setting `can_start_tick = true`."
698                    );
699                    self.context.can_start_tick = true;
700                }
701                break;
702            }
703        }
704        self.context.events_received_tick = true;
705
706        // Enqueue any other immediate events.
707        let extra_count = self.try_recv_events();
708        Some(count + extra_count)
709    }
710
711    /// Schedules a subgraph to be run. See also: [`Context::schedule_subgraph`].
712    pub fn schedule_subgraph(&mut self, sg_id: SubgraphId) -> bool {
713        let sg_data = &self.subgraphs[sg_id];
714        let already_scheduled = sg_data.is_scheduled.replace(true);
715        if !already_scheduled {
716            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
717            true
718        } else {
719            false
720        }
721    }
722
723    /// Adds a new compiled subgraph with the specified inputs and outputs in stratum 0.
724    pub fn add_subgraph<Name, R, W, F>(
725        &mut self,
726        name: Name,
727        recv_ports: R,
728        send_ports: W,
729        subgraph: F,
730    ) -> SubgraphId
731    where
732        Name: Into<Cow<'static, str>>,
733        R: 'static + PortList<RECV>,
734        W: 'static + PortList<SEND>,
735        F: 'static + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
736    {
737        self.add_subgraph_stratified(name, 0, recv_ports, send_ports, false, subgraph)
738    }
739
740    /// Adds a new compiled subgraph with the specified inputs, outputs, and stratum number.
741    ///
742    /// TODO(mingwei): add example in doc.
743    pub fn add_subgraph_stratified<Name, R, W, F>(
744        &mut self,
745        name: Name,
746        stratum: usize,
747        recv_ports: R,
748        send_ports: W,
749        laziness: bool,
750        subgraph: F,
751    ) -> SubgraphId
752    where
753        Name: Into<Cow<'static, str>>,
754        R: 'static + PortList<RECV>,
755        W: 'static + PortList<SEND>,
756        F: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
757    {
758        self.add_subgraph_full(
759            name, stratum, recv_ports, send_ports, laziness, None, subgraph,
760        )
761    }
762
763    /// Adds a new compiled subgraph with all options.
764    #[expect(clippy::too_many_arguments, reason = "Mainly for internal use.")]
765    pub fn add_subgraph_full<Name, R, W, F>(
766        &mut self,
767        name: Name,
768        stratum: usize,
769        recv_ports: R,
770        send_ports: W,
771        laziness: bool,
772        loop_id: Option<LoopId>,
773        mut subgraph: F,
774    ) -> SubgraphId
775    where
776        Name: Into<Cow<'static, str>>,
777        R: 'static + PortList<RECV>,
778        W: 'static + PortList<SEND>,
779        F: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
780    {
781        // SAFETY: Check that the send and recv ports are from `self.handoffs`.
782        recv_ports.assert_is_from(&self.handoffs);
783        send_ports.assert_is_from(&self.handoffs);
784
785        let loop_depth = loop_id
786            .and_then(|loop_id| self.context.loop_depth.get(loop_id))
787            .copied()
788            .unwrap_or(0);
789
790        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
791            let (mut subgraph_preds, mut subgraph_succs) = Default::default();
792            recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
793            send_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_succs, sg_id, false);
794
795            let subgraph =
796                move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
797                    let (recv, send) = unsafe {
798                        // SAFETY:
799                        // 1. We checked `assert_is_from` at assembly time, above.
800                        // 2. `SlotVec` is insert-only so no handoffs could have changed since then.
801                        (
802                            recv_ports.make_ctx(&*handoffs),
803                            send_ports.make_ctx(&*handoffs),
804                        )
805                    };
806                    (subgraph)(context, recv, send);
807                };
808            SubgraphData::new(
809                name.into(),
810                stratum,
811                subgraph,
812                subgraph_preds,
813                subgraph_succs,
814                true,
815                laziness,
816                loop_id,
817                loop_depth,
818            )
819        });
820        self.context.init_stratum(stratum);
821        self.context.stratum_queues[stratum].push_back(sg_id);
822
823        sg_id
824    }
825
826    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
827    pub fn add_subgraph_n_m<Name, R, W, F>(
828        &mut self,
829        name: Name,
830        recv_ports: Vec<RecvPort<R>>,
831        send_ports: Vec<SendPort<W>>,
832        subgraph: F,
833    ) -> SubgraphId
834    where
835        Name: Into<Cow<'static, str>>,
836        R: 'static + Handoff,
837        W: 'static + Handoff,
838        F: 'static
839            + for<'ctx> FnMut(&'ctx mut Context, &'ctx [&'ctx RecvCtx<R>], &'ctx [&'ctx SendCtx<W>]),
840    {
841        self.add_subgraph_stratified_n_m(name, 0, recv_ports, send_ports, subgraph)
842    }
843
844    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
845    pub fn add_subgraph_stratified_n_m<Name, R, W, F>(
846        &mut self,
847        name: Name,
848        stratum: usize,
849        recv_ports: Vec<RecvPort<R>>,
850        send_ports: Vec<SendPort<W>>,
851        mut subgraph: F,
852    ) -> SubgraphId
853    where
854        Name: Into<Cow<'static, str>>,
855        R: 'static + Handoff,
856        W: 'static + Handoff,
857        F: 'static
858            + for<'ctx> FnMut(&'ctx mut Context, &'ctx [&'ctx RecvCtx<R>], &'ctx [&'ctx SendCtx<W>]),
859    {
860        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
861            let subgraph_preds = recv_ports.iter().map(|port| port.handoff_id).collect();
862            let subgraph_succs = send_ports.iter().map(|port| port.handoff_id).collect();
863
864            for recv_port in recv_ports.iter() {
865                self.handoffs[recv_port.handoff_id].succs.push(sg_id);
866            }
867            for send_port in send_ports.iter() {
868                self.handoffs[send_port.handoff_id].preds.push(sg_id);
869            }
870
871            let subgraph =
872                move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
873                    let recvs: Vec<&RecvCtx<R>> = recv_ports
874                        .iter()
875                        .map(|hid| hid.handoff_id)
876                        .map(|hid| handoffs.get(hid).unwrap())
877                        .map(|h_data| {
878                            h_data
879                                .handoff
880                                .any_ref()
881                                .downcast_ref()
882                                .expect("Attempted to cast handoff to wrong type.")
883                        })
884                        .map(RefCast::ref_cast)
885                        .collect();
886
887                    let sends: Vec<&SendCtx<W>> = send_ports
888                        .iter()
889                        .map(|hid| hid.handoff_id)
890                        .map(|hid| handoffs.get(hid).unwrap())
891                        .map(|h_data| {
892                            h_data
893                                .handoff
894                                .any_ref()
895                                .downcast_ref()
896                                .expect("Attempted to cast handoff to wrong type.")
897                        })
898                        .map(RefCast::ref_cast)
899                        .collect();
900
901                    (subgraph)(context, &recvs, &sends)
902                };
903            SubgraphData::new(
904                name.into(),
905                stratum,
906                subgraph,
907                subgraph_preds,
908                subgraph_succs,
909                true,
910                false,
911                None,
912                0,
913            )
914        });
915
916        self.context.init_stratum(stratum);
917        self.context.stratum_queues[stratum].push_back(sg_id);
918
919        sg_id
920    }
921
922    /// Creates a handoff edge and returns the corresponding send and receive ports.
923    pub fn make_edge<Name, H>(&mut self, name: Name) -> (SendPort<H>, RecvPort<H>)
924    where
925        Name: Into<Cow<'static, str>>,
926        H: 'static + Handoff,
927    {
928        // Create and insert handoff.
929        let handoff = H::default();
930        let handoff_id = self
931            .handoffs
932            .insert_with_key(|hoff_id| HandoffData::new(name.into(), handoff, hoff_id));
933
934        // Make ports.
935        let input_port = SendPort {
936            handoff_id,
937            _marker: PhantomData,
938        };
939        let output_port = RecvPort {
940            handoff_id,
941            _marker: PhantomData,
942        };
943        (input_port, output_port)
944    }
945
946    /// Adds referenceable state into this instance. Returns a state handle which can be
947    /// used externally or by operators to access the state.
948    ///
949    /// This is part of the "state API".
950    pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
951    where
952        T: Any,
953    {
954        self.context.add_state(state)
955    }
956
957    /// Sets a hook to modify the state at the end of each tick, using the supplied closure.
958    ///
959    /// This is part of the "state API".
960    pub fn set_state_lifespan_hook<T>(
961        &mut self,
962        handle: StateHandle<T>,
963        lifespan: StateLifespan,
964        hook_fn: impl 'static + FnMut(&mut T),
965    ) where
966        T: Any,
967    {
968        self.context
969            .set_state_lifespan_hook(handle, lifespan, hook_fn)
970    }
971
972    /// Gets a exclusive (mut) ref to the internal context, setting the subgraph ID.
973    pub fn context_mut(&mut self, sg_id: SubgraphId) -> &mut Context {
974        self.context.subgraph_id = sg_id;
975        &mut self.context
976    }
977
978    /// Adds a new loop with the given parent (or `None` for top-level). Returns a loop ID which
979    /// is used in [`Self::add_subgraph_stratified`] or for nested loops.
980    ///
981    /// TODO(mingwei): add loop names to ensure traceability while debugging?
982    pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
983        let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
984        let loop_id = self.context.loop_depth.insert(depth);
985        self.loop_data.insert(
986            loop_id,
987            LoopData {
988                iter_count: None,
989                allow_another_iteration: true,
990            },
991        );
992        loop_id
993    }
994}
995
996impl Dfir<'_> {
997    /// Alias for [`Context::request_task`].
998    pub fn request_task<Fut>(&mut self, future: Fut)
999    where
1000        Fut: Future<Output = ()> + 'static,
1001    {
1002        self.context.request_task(future);
1003    }
1004
1005    /// Alias for [`Context::abort_tasks`].
1006    pub fn abort_tasks(&mut self) {
1007        self.context.abort_tasks()
1008    }
1009
1010    /// Alias for [`Context::join_tasks`].
1011    pub fn join_tasks(&mut self) -> impl use<'_> + Future {
1012        self.context.join_tasks()
1013    }
1014}
1015
1016impl Drop for Dfir<'_> {
1017    fn drop(&mut self) {
1018        self.abort_tasks();
1019    }
1020}
1021
1022/// A handoff and its input and output [SubgraphId]s.
1023///
1024/// Internal use: used to track the dfir graph structure.
1025///
1026/// TODO(mingwei): restructure `PortList` so this can be crate-private.
1027#[doc(hidden)]
1028pub struct HandoffData {
1029    /// A friendly name for diagnostics.
1030    pub(super) name: Cow<'static, str>,
1031    /// Crate-visible to crate for `handoff_list` internals.
1032    pub(super) handoff: Box<dyn HandoffMeta>,
1033    /// Preceeding subgraphs (including the send side of a teeing handoff).
1034    pub(super) preds: SmallVec<[SubgraphId; 1]>,
1035    /// Successor subgraphs (including recv sides of teeing handoffs).
1036    pub(super) succs: SmallVec<[SubgraphId; 1]>,
1037
1038    /// Predecessor handoffs, used by teeing handoffs.
1039    /// Should be `self` on any teeing send sides (input).
1040    /// Should be the send `HandoffId` if this is teeing recv side (output).
1041    /// Should be just `self`'s `HandoffId` on other handoffs.
1042    /// This field is only used in initialization.
1043    pub(super) pred_handoffs: Vec<HandoffId>,
1044    /// Successor handoffs, used by teeing handoffs.
1045    /// Should be a list of outputs on the teeing send side (input).
1046    /// Should be `self` on any teeing recv sides (outputs).
1047    /// Should be just `self`'s `HandoffId` on other handoffs.
1048    /// This field is only used in initialization.
1049    pub(super) succ_handoffs: Vec<HandoffId>,
1050}
1051impl std::fmt::Debug for HandoffData {
1052    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1053        f.debug_struct("HandoffData")
1054            .field("preds", &self.preds)
1055            .field("succs", &self.succs)
1056            .finish_non_exhaustive()
1057    }
1058}
1059impl HandoffData {
1060    /// New with `pred_handoffs` and `succ_handoffs` set to its own [`HandoffId`]: `vec![hoff_id]`.
1061    pub fn new(
1062        name: Cow<'static, str>,
1063        handoff: impl 'static + HandoffMeta,
1064        hoff_id: HandoffId,
1065    ) -> Self {
1066        let (preds, succs) = Default::default();
1067        Self {
1068            name,
1069            handoff: Box::new(handoff),
1070            preds,
1071            succs,
1072            pred_handoffs: vec![hoff_id],
1073            succ_handoffs: vec![hoff_id],
1074        }
1075    }
1076}
1077
1078/// A subgraph along with its predecessor and successor [SubgraphId]s.
1079///
1080/// Used internally by the [Dfir] struct to represent the dataflow graph
1081/// structure and scheduled state.
1082pub(super) struct SubgraphData<'a> {
1083    /// A friendly name for diagnostics.
1084    pub(super) name: Cow<'static, str>,
1085    /// This subgraph's stratum number.
1086    ///
1087    /// Within loop blocks, corresponds to the topological sort of the DAG created when `next_loop()/next_tick()` are removed.
1088    pub(super) stratum: usize,
1089    /// The actual execution code of the subgraph.
1090    subgraph: Box<dyn Subgraph + 'a>,
1091
1092    #[expect(dead_code, reason = "may be useful in the future")]
1093    preds: Vec<HandoffId>,
1094    succs: Vec<HandoffId>,
1095
1096    /// If this subgraph is scheduled in [`dfir_rs::stratum_queues`].
1097    /// [`Cell`] allows modifying this field when iterating `Self::preds` or
1098    /// `Self::succs`, as all `SubgraphData` are owned by the same vec
1099    /// `dfir_rs::subgraphs`.
1100    is_scheduled: Cell<bool>,
1101
1102    /// Keep track of the last tick that this subgraph was run in
1103    last_tick_run_in: Option<TickInstant>,
1104    /// A meaningless ID to track the last loop execution this subgraph was run in.
1105    /// `(loop_nonce, iter_count)` pair.
1106    last_loop_nonce: (usize, Option<usize>),
1107
1108    /// If this subgraph is marked as lazy, then sending data back to a lower stratum does not trigger a new tick to be run.
1109    is_lazy: bool,
1110
1111    /// The subgraph's loop ID, or `None` for the top level.
1112    loop_id: Option<LoopId>,
1113    /// The loop depth of the subgraph.
1114    loop_depth: usize,
1115}
1116impl<'a> SubgraphData<'a> {
1117    #[expect(clippy::too_many_arguments, reason = "internal use")]
1118    pub(crate) fn new(
1119        name: Cow<'static, str>,
1120        stratum: usize,
1121        subgraph: impl Subgraph + 'a,
1122        preds: Vec<HandoffId>,
1123        succs: Vec<HandoffId>,
1124        is_scheduled: bool,
1125        is_lazy: bool,
1126        loop_id: Option<LoopId>,
1127        loop_depth: usize,
1128    ) -> Self {
1129        Self {
1130            name,
1131            stratum,
1132            subgraph: Box::new(subgraph),
1133            preds,
1134            succs,
1135            is_scheduled: Cell::new(is_scheduled),
1136            last_tick_run_in: None,
1137            last_loop_nonce: (0, None),
1138            is_lazy,
1139            loop_id,
1140            loop_depth,
1141        }
1142    }
1143}
1144
1145pub(crate) struct LoopData {
1146    /// Count of iterations of this loop.
1147    iter_count: Option<usize>,
1148    /// If the loop has reason to do another iteration.
1149    allow_another_iteration: bool,
1150}
1151
1152/// Defines when state should be reset.
1153#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1154pub enum StateLifespan {
1155    /// Always reset, a ssociated with the subgraph.
1156    Subgraph(SubgraphId),
1157    /// Reset between loop executions.
1158    Loop(LoopId),
1159    /// Reset between ticks.
1160    Tick,
1161    /// Never reset.
1162    Static,
1163}