dfir_rs/scheduled/
graph.rs

1//! Module for the [`Dfir`] struct and helper items.
2
3use std::any::Any;
4use std::borrow::Cow;
5use std::cell::Cell;
6use std::cmp::Ordering;
7use std::future::Future;
8use std::marker::PhantomData;
9
10#[cfg(feature = "meta")]
11use dfir_lang::diagnostic::{Diagnostic, SerdeSpan};
12#[cfg(feature = "meta")]
13use dfir_lang::graph::DfirGraph;
14use ref_cast::RefCast;
15use smallvec::SmallVec;
16use web_time::SystemTime;
17
18use super::context::Context;
19use super::handoff::handoff_list::PortList;
20use super::handoff::{Handoff, HandoffMeta, TeeingHandoff};
21use super::port::{RECV, RecvCtx, RecvPort, SEND, SendCtx, SendPort};
22use super::reactor::Reactor;
23use super::state::StateHandle;
24use super::subgraph::Subgraph;
25use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
26use crate::Never;
27use crate::scheduled::ticks::{TickDuration, TickInstant};
28use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
29
30/// A DFIR graph. Owns, schedules, and runs the compiled subgraphs.
31#[derive(Default)]
32pub struct Dfir<'a> {
33    pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
34
35    pub(super) loop_data: SecondarySlotVec<LoopTag, LoopData>,
36
37    pub(super) context: Context,
38
39    handoffs: SlotVec<HandoffTag, HandoffData>,
40
41    #[cfg(feature = "meta")]
42    /// See [`Self::meta_graph()`].
43    meta_graph: Option<DfirGraph>,
44
45    #[cfg(feature = "meta")]
46    /// See [`Self::diagnostics()`].
47    diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
48}
49
50/// Methods for [`TeeingHandoff`] teeing and dropping.
51impl Dfir<'_> {
52    /// Tees a [`TeeingHandoff`].
53    pub fn teeing_handoff_tee<T>(
54        &mut self,
55        tee_parent_port: &RecvPort<TeeingHandoff<T>>,
56    ) -> RecvPort<TeeingHandoff<T>>
57    where
58        T: Clone,
59    {
60        // If we're teeing from a child make sure to find root.
61        let tee_root = self.handoffs[tee_parent_port.handoff_id].pred_handoffs[0];
62
63        // Set up teeing metadata.
64        let tee_root_data = &mut self.handoffs[tee_root];
65        let tee_root_data_name = tee_root_data.name.clone();
66
67        // Insert new handoff output.
68        let teeing_handoff =
69            <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*tee_root_data.handoff).unwrap();
70        let new_handoff = teeing_handoff.tee();
71
72        // Handoff ID of new tee output.
73        let new_hoff_id = self.handoffs.insert_with_key(|new_hoff_id| {
74            let new_name = Cow::Owned(format!("{} tee {:?}", tee_root_data_name, new_hoff_id));
75            let mut new_handoff_data = HandoffData::new(new_name, new_handoff, new_hoff_id);
76            // Set self's predecessor as `tee_root`.
77            new_handoff_data.pred_handoffs = vec![tee_root];
78            new_handoff_data
79        });
80
81        // Go to `tee_root`'s successors and insert self (the new tee output).
82        let tee_root_data = &mut self.handoffs[tee_root];
83        tee_root_data.succ_handoffs.push(new_hoff_id);
84
85        // Add our new handoff id into the subgraph data if the send `tee_root` has already been
86        // used to add a subgraph.
87        assert!(
88            tee_root_data.preds.len() <= 1,
89            "Tee send side should only have one sender (or none set yet)."
90        );
91        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
92            self.subgraphs[pred_sg_id].succs.push(new_hoff_id);
93        }
94
95        let output_port = RecvPort {
96            handoff_id: new_hoff_id,
97            _marker: PhantomData,
98        };
99        output_port
100    }
101
102    /// Marks an output of a [`TeeingHandoff`] as dropped so that no more data will be sent to it.
103    ///
104    /// It is recommended to not not use this method and instead simply avoid teeing a
105    /// [`TeeingHandoff`] when it is not needed.
106    pub fn teeing_handoff_drop<T>(&mut self, tee_port: RecvPort<TeeingHandoff<T>>)
107    where
108        T: Clone,
109    {
110        let data = &self.handoffs[tee_port.handoff_id];
111        let teeing_handoff = <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*data.handoff).unwrap();
112        teeing_handoff.drop();
113
114        let tee_root = data.pred_handoffs[0];
115        let tee_root_data = &mut self.handoffs[tee_root];
116        // Remove this output from the send succ handoff list.
117        tee_root_data
118            .succ_handoffs
119            .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
120        // Remove from subgraph successors if send port was already connected.
121        assert!(
122            tee_root_data.preds.len() <= 1,
123            "Tee send side should only have one sender (or none set yet)."
124        );
125        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
126            self.subgraphs[pred_sg_id]
127                .succs
128                .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
129        }
130    }
131}
132
133impl<'a> Dfir<'a> {
134    /// Create a new empty graph.
135    pub fn new() -> Self {
136        Default::default()
137    }
138
139    /// Assign the meta graph via JSON string. Used internally by the [`crate::dfir_syntax`] and other macros.
140    #[doc(hidden)]
141    pub fn __assign_meta_graph(&mut self, _meta_graph_json: &str) {
142        #[cfg(feature = "meta")]
143        {
144            let mut meta_graph: DfirGraph =
145                serde_json::from_str(_meta_graph_json).expect("Failed to deserialize graph.");
146
147            let mut op_inst_diagnostics = Vec::new();
148            meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
149            assert!(
150                op_inst_diagnostics.is_empty(),
151                "Expected no diagnostics, got: {:#?}",
152                op_inst_diagnostics
153            );
154
155            assert!(self.meta_graph.replace(meta_graph).is_none());
156        }
157    }
158    /// Assign the diagnostics via JSON string.
159    #[doc(hidden)]
160    pub fn __assign_diagnostics(&mut self, _diagnostics_json: &'static str) {
161        #[cfg(feature = "meta")]
162        {
163            let diagnostics: Vec<Diagnostic<SerdeSpan>> = serde_json::from_str(_diagnostics_json)
164                .expect("Failed to deserialize diagnostics.");
165
166            assert!(self.diagnostics.replace(diagnostics).is_none());
167        }
168    }
169
170    /// Return a handle to the meta graph, if set. The meta graph is a
171    /// representation of all the operators, subgraphs, and handoffs in this instance.
172    /// Will only be set if this graph was constructed using a surface syntax macro.
173    #[cfg(feature = "meta")]
174    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
175    pub fn meta_graph(&self) -> Option<&DfirGraph> {
176        self.meta_graph.as_ref()
177    }
178
179    /// Returns any diagnostics generated by the surface syntax macro. Each diagnostic is a pair of
180    /// (1) a `Diagnostic` with span info reset and (2) the `ToString` version of the diagnostic
181    /// with original span info.
182    /// Will only be set if this graph was constructed using a surface syntax macro.
183    #[cfg(feature = "meta")]
184    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
185    pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
186        self.diagnostics.as_deref()
187    }
188
189    /// Returns a reactor for externally scheduling subgraphs, possibly from another thread.
190    /// Reactor events are considered to be external events.
191    pub fn reactor(&self) -> Reactor {
192        Reactor::new(self.context.event_queue_send.clone())
193    }
194
195    /// Gets the current tick (local time) count.
196    pub fn current_tick(&self) -> TickInstant {
197        self.context.current_tick
198    }
199
200    /// Gets the current stratum nubmer.
201    pub fn current_stratum(&self) -> usize {
202        self.context.current_stratum
203    }
204
205    /// Runs the dataflow until the next tick begins.
206    ///
207    /// Returns `true` if any work was done.
208    ///
209    /// Will receive events if it is the start of a tick when called.
210    #[tracing::instrument(level = "trace", skip(self), ret)]
211    pub async fn run_tick(&mut self) -> bool {
212        let mut work_done = false;
213        // While work is immediately available *on the current tick*.
214        while self.next_stratum(true) {
215            work_done = true;
216            // Do any work.
217            self.run_stratum().await;
218        }
219        work_done
220    }
221
222    /// [`Self::run_tick`] but panics if a subgraph yields asynchronously.
223    #[tracing::instrument(level = "trace", skip(self), ret)]
224    pub fn run_tick_sync(&mut self) -> bool {
225        let mut work_done = false;
226        // While work is immediately available *on the current tick*.
227        while self.next_stratum(true) {
228            work_done = true;
229            // Do any work.
230            run_sync(self.run_stratum());
231        }
232        work_done
233    }
234
235    /// Runs the dataflow until no more (externally-triggered) work is immediately available.
236    /// Runs at least one tick of dataflow, even if no external events have been received.
237    /// If the dataflow contains loops this method may run forever.
238    ///
239    /// Returns `true` if any work was done.
240    ///
241    /// Yields repeatedly to allow external events to happen.
242    #[tracing::instrument(level = "trace", skip(self), ret)]
243    pub async fn run_available(&mut self) -> bool {
244        let mut work_done = false;
245        // While work is immediately available.
246        while self.next_stratum(false) {
247            work_done = true;
248            // Do any work.
249            self.run_stratum().await;
250
251            // Yield between each stratum to receive more events.
252            // TODO(mingwei): really only need to yield at start of ticks though.
253            tokio::task::yield_now().await;
254        }
255        work_done
256    }
257
258    /// [`Self::run_available`] but panics if a subgraph yields asynchronously.
259    #[tracing::instrument(level = "trace", skip(self), ret)]
260    pub fn run_available_sync(&mut self) -> bool {
261        let mut work_done = false;
262        // While work is immediately available.
263        while self.next_stratum(false) {
264            work_done = true;
265            // Do any work.
266            run_sync(self.run_stratum());
267        }
268        work_done
269    }
270
271    /// Runs the current stratum of the dataflow until no more local work is available (does not receive events).
272    ///
273    /// Returns `true` if any work was done.
274    #[tracing::instrument(level = "trace", skip(self), fields(tick = u64::from(self.context.current_tick), stratum = self.context.current_stratum), ret)]
275    pub async fn run_stratum(&mut self) -> bool {
276        // Make sure to spawn tasks once dfir is running!
277        // This drains the task buffer, so becomes a no-op after first call.
278        self.context.spawn_tasks();
279
280        let mut work_done = false;
281
282        'pop: while let Some(sg_id) =
283            self.context.stratum_queues[self.context.current_stratum].pop_front()
284        {
285            {
286                let sg_data = &mut self.subgraphs[sg_id];
287                // This must be true for the subgraph to be enqueued.
288                assert!(sg_data.is_scheduled.take());
289
290                let _enter = tracing::info_span!(
291                    "run-subgraph",
292                    sg_id = sg_id.to_string(),
293                    sg_name = &*sg_data.name,
294                    sg_depth = sg_data.loop_depth,
295                    sg_loop_nonce = sg_data.last_loop_nonce.0,
296                    sg_iter_count = sg_data.last_loop_nonce.1,
297                )
298                .entered();
299
300                match sg_data.loop_depth.cmp(&self.context.loop_nonce_stack.len()) {
301                    Ordering::Greater => {
302                        // We have entered a loop.
303                        self.context.loop_nonce += 1;
304                        self.context.loop_nonce_stack.push(self.context.loop_nonce);
305                        tracing::trace!(loop_nonce = self.context.loop_nonce, "Entered loop.");
306                    }
307                    Ordering::Less => {
308                        // We have exited a loop.
309                        self.context.loop_nonce_stack.pop();
310                        tracing::trace!("Exited loop.");
311                    }
312                    Ordering::Equal => {}
313                }
314
315                self.context.subgraph_id = sg_id;
316                self.context.is_first_run_this_tick = sg_data
317                    .last_tick_run_in
318                    .is_none_or(|last_tick| last_tick < self.context.current_tick);
319
320                if let Some(loop_id) = sg_data.loop_id {
321                    // Loop execution - running loop block, from start to finish, containing
322                    // multiple iterations.
323                    // Loop iteration - a single iteration of a loop block, all subgraphs within
324                    // the loop should run (at most) once.
325
326                    // If the previous run of this subgraph had the same loop execution and
327                    // iteration count, then we need to increment the iteration count.
328                    let curr_loop_nonce = self.context.loop_nonce_stack.last().copied();
329
330                    let LoopData {
331                        iter_count: loop_iter_count,
332                        allow_another_iteration,
333                    } = &mut self.loop_data[loop_id];
334
335                    let (prev_loop_nonce, prev_iter_count) = sg_data.last_loop_nonce;
336
337                    // If the loop nonce is the same as the previous execution, then we are in
338                    // the same loop execution.
339                    // `curr_loop_nonce` is `None` for top-level loops, and top-level loops are
340                    // always in the same (singular) loop execution.
341                    let (curr_iter_count, new_loop_execution) =
342                        if curr_loop_nonce.is_none_or(|nonce| nonce == prev_loop_nonce) {
343                            // If the iteration count is the same as the previous execution, then
344                            // we are on the next iteration.
345                            if *loop_iter_count == prev_iter_count {
346                                // If not true, then we shall not run the next iteration.
347                                if !std::mem::take(allow_another_iteration) {
348                                    tracing::debug!(
349                                        "Loop will not continue to next iteration, skipping."
350                                    );
351                                    continue 'pop;
352                                }
353                                // Increment `loop_iter_count` or set it to 0.
354                                loop_iter_count.map_or((0, true), |n| (n + 1, false))
355                            } else {
356                                // Otherwise update the local iteration count to match the loop.
357                                debug_assert!(
358                                    prev_iter_count < *loop_iter_count,
359                                    "Expect loop iteration count to be increasing."
360                                );
361                                (loop_iter_count.unwrap(), false)
362                            }
363                        } else {
364                            // The loop execution has already begun, but this is the first time this particular subgraph is running.
365                            (0, false)
366                        };
367
368                    if new_loop_execution {
369                        // Run state hooks.
370                        self.context.run_state_hooks_loop(loop_id);
371                    }
372                    tracing::debug!("Loop iteration count {}", curr_iter_count);
373
374                    *loop_iter_count = Some(curr_iter_count);
375                    self.context.loop_iter_count = curr_iter_count;
376                    sg_data.last_loop_nonce =
377                        (curr_loop_nonce.unwrap_or_default(), Some(curr_iter_count));
378                }
379
380                // Run subgraph state hooks.
381                self.context.run_state_hooks_subgraph(sg_id);
382
383                tracing::info!("Running subgraph.");
384                sg_data.last_tick_run_in = Some(self.context.current_tick);
385                Box::into_pin(sg_data.subgraph.run(&mut self.context, &mut self.handoffs)).await;
386            };
387
388            let sg_data = &self.subgraphs[sg_id];
389            for &handoff_id in sg_data.succs.iter() {
390                let handoff = &self.handoffs[handoff_id];
391                if !handoff.handoff.is_bottom() {
392                    for &succ_id in handoff.succs.iter() {
393                        let succ_sg_data = &self.subgraphs[succ_id];
394                        // If we have sent data to the next tick, then we can start the next tick.
395                        if succ_sg_data.stratum < self.context.current_stratum && !sg_data.is_lazy {
396                            self.context.can_start_tick = true;
397                        }
398                        // Add subgraph to stratum queue if it is not already scheduled.
399                        if !succ_sg_data.is_scheduled.replace(true) {
400                            self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
401                        }
402                        // Add stratum to stratum stack if it is within a loop.
403                        if 0 < succ_sg_data.loop_depth {
404                            // TODO(mingwei): handle duplicates
405                            self.context
406                                .stratum_stack
407                                .push(succ_sg_data.loop_depth, succ_sg_data.stratum);
408                        }
409                    }
410                }
411            }
412
413            let reschedule = self.context.reschedule_loop_block.take();
414            let allow_another = self.context.allow_another_iteration.take();
415
416            if reschedule {
417                // Re-enqueue the subgraph.
418                self.context.schedule_deferred.push(sg_id);
419                self.context
420                    .stratum_stack
421                    .push(sg_data.loop_depth, sg_data.stratum);
422            }
423            if (reschedule || allow_another)
424                && let Some(loop_id) = sg_data.loop_id
425            {
426                self.loop_data
427                    .get_mut(loop_id)
428                    .unwrap()
429                    .allow_another_iteration = true;
430            }
431
432            work_done = true;
433        }
434        work_done
435    }
436
437    /// Go to the next stratum which has work available, possibly the current stratum.
438    /// Return true if more work is available, otherwise false if no work is immediately
439    /// available on any strata.
440    ///
441    /// This will receive external events when at the start of a tick.
442    ///
443    /// If `current_tick_only` is set to `true`, will only return `true` if work is immediately
444    /// available on the *current tick*.
445    ///
446    /// If this returns false then the graph will be at the start of a tick (at stratum 0, can
447    /// receive more external events).
448    #[tracing::instrument(level = "trace", skip(self), ret)]
449    pub fn next_stratum(&mut self, current_tick_only: bool) -> bool {
450        tracing::trace!(
451            events_received_tick = self.context.events_received_tick,
452            can_start_tick = self.context.can_start_tick,
453            "Starting `next_stratum` call.",
454        );
455
456        // The stratum we will stop searching at, i.e. made a full loop around.
457        let mut end_stratum = self.context.current_stratum;
458        let mut new_tick_started = false;
459
460        if 0 == self.context.current_stratum {
461            new_tick_started = true;
462
463            // Starting the tick, reset this to `false`.
464            tracing::trace!("Starting tick, setting `can_start_tick = false`.");
465            self.context.can_start_tick = false;
466            self.context.current_tick_start = SystemTime::now();
467
468            // Ensure external events are received before running the tick.
469            if !self.context.events_received_tick {
470                // Add any external jobs to ready queue.
471                self.try_recv_events();
472            }
473        }
474
475        loop {
476            tracing::trace!(
477                tick = u64::from(self.context.current_tick),
478                stratum = self.context.current_stratum,
479                "Looking for work on stratum."
480            );
481            // If current stratum has work, return true.
482            if !self.context.stratum_queues[self.context.current_stratum].is_empty() {
483                tracing::trace!(
484                    tick = u64::from(self.context.current_tick),
485                    stratum = self.context.current_stratum,
486                    "Work found on stratum."
487                );
488                return true;
489            }
490
491            if let Some(next_stratum) = self.context.stratum_stack.pop() {
492                self.context.current_stratum = next_stratum;
493
494                // Now schedule deferred subgraphs.
495                {
496                    for sg_id in self.context.schedule_deferred.drain(..) {
497                        let sg_data = &self.subgraphs[sg_id];
498                        tracing::info!(
499                            tick = u64::from(self.context.current_tick),
500                            stratum = self.context.current_stratum,
501                            sg_id = sg_id.to_string(),
502                            sg_name = &*sg_data.name,
503                            is_scheduled = sg_data.is_scheduled.get(),
504                            "Rescheduling deferred subgraph."
505                        );
506                        if !sg_data.is_scheduled.replace(true) {
507                            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
508                        }
509                    }
510                }
511            } else {
512                // Increment stratum counter.
513                self.context.current_stratum += 1;
514
515                if self.context.current_stratum >= self.context.stratum_queues.len() {
516                    new_tick_started = true;
517
518                    tracing::trace!(
519                        can_start_tick = self.context.can_start_tick,
520                        "End of tick {}, starting tick {}.",
521                        self.context.current_tick,
522                        self.context.current_tick + TickDuration::SINGLE_TICK,
523                    );
524                    self.context.run_state_hooks_tick();
525
526                    self.context.current_stratum = 0;
527                    self.context.current_tick += TickDuration::SINGLE_TICK;
528                    self.context.events_received_tick = false;
529
530                    if current_tick_only {
531                        tracing::trace!(
532                            "`current_tick_only` is `true`, returning `false` before receiving events."
533                        );
534                        return false;
535                    } else {
536                        self.try_recv_events();
537                        if std::mem::replace(&mut self.context.can_start_tick, false) {
538                            tracing::trace!(
539                                tick = u64::from(self.context.current_tick),
540                                "`can_start_tick` is `true`, continuing."
541                            );
542                            // Do a full loop more to find where events have been added.
543                            end_stratum = 0;
544                            continue;
545                        } else {
546                            tracing::trace!(
547                                "`can_start_tick` is `false`, re-setting `events_received_tick = false`, returning `false`."
548                            );
549                            self.context.events_received_tick = false;
550                            return false;
551                        }
552                    }
553                }
554            }
555
556            // After incrementing, exit if we made a full loop around the strata.
557            if new_tick_started && end_stratum == self.context.current_stratum {
558                tracing::trace!(
559                    "Made full loop around stratum, re-setting `current_stratum = 0`, returning `false`."
560                );
561                // Note: if current stratum had work, the very first loop iteration would've
562                // returned true. Therefore we can return false without checking.
563                // Also means nothing was done so we can reset the stratum to zero and wait for
564                // events.
565                self.context.events_received_tick = false;
566                self.context.current_stratum = 0;
567                return false;
568            }
569        }
570    }
571
572    /// Runs the dataflow graph forever.
573    ///
574    /// TODO(mingwei): Currently blocks forever, no notion of "completion."
575    #[tracing::instrument(level = "trace", skip(self), ret)]
576    pub async fn run(&mut self) -> Option<Never> {
577        loop {
578            // Run any work which is immediately available.
579            self.run_available().await;
580            // When no work is available yield until more events occur.
581            self.recv_events_async().await;
582        }
583    }
584
585    /// [`Self::run`] but panics if a subgraph yields asynchronously.
586    #[tracing::instrument(level = "trace", skip(self), ret)]
587    pub fn run_sync(&mut self) -> Option<Never> {
588        loop {
589            // Run any work which is immediately available.
590            self.run_available_sync();
591            // When no work is available block until more events occur.
592            self.recv_events();
593        }
594    }
595
596    /// Enqueues subgraphs triggered by events without blocking.
597    ///
598    /// Returns the number of subgraphs enqueued, and if any were external.
599    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
600    pub fn try_recv_events(&mut self) -> usize {
601        let mut enqueued_count = 0;
602        while let Ok((sg_id, is_external)) = self.context.event_queue_recv.try_recv() {
603            let sg_data = &self.subgraphs[sg_id];
604            tracing::trace!(
605                sg_id = sg_id.to_string(),
606                is_external = is_external,
607                sg_stratum = sg_data.stratum,
608                "Event received."
609            );
610            if !sg_data.is_scheduled.replace(true) {
611                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
612                enqueued_count += 1;
613            }
614            if is_external {
615                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
616                // Or if the stratum is in the next tick.
617                if !self.context.events_received_tick
618                    || sg_data.stratum < self.context.current_stratum
619                {
620                    tracing::trace!(
621                        current_stratum = self.context.current_stratum,
622                        sg_stratum = sg_data.stratum,
623                        "External event, setting `can_start_tick = true`."
624                    );
625                    self.context.can_start_tick = true;
626                }
627            }
628        }
629        self.context.events_received_tick = true;
630
631        enqueued_count
632    }
633
634    /// Enqueues subgraphs triggered by external events, blocking until at
635    /// least one subgraph is scheduled **from an external event**.
636    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
637    pub fn recv_events(&mut self) -> Option<usize> {
638        let mut count = 0;
639        loop {
640            let (sg_id, is_external) = self.context.event_queue_recv.blocking_recv()?;
641            let sg_data = &self.subgraphs[sg_id];
642            tracing::trace!(
643                sg_id = sg_id.to_string(),
644                is_external = is_external,
645                sg_stratum = sg_data.stratum,
646                "Event received."
647            );
648            if !sg_data.is_scheduled.replace(true) {
649                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
650                count += 1;
651            }
652            if is_external {
653                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
654                // Or if the stratum is in the next tick.
655                if !self.context.events_received_tick
656                    || sg_data.stratum < self.context.current_stratum
657                {
658                    tracing::trace!(
659                        current_stratum = self.context.current_stratum,
660                        sg_stratum = sg_data.stratum,
661                        "External event, setting `can_start_tick = true`."
662                    );
663                    self.context.can_start_tick = true;
664                }
665                break;
666            }
667        }
668        self.context.events_received_tick = true;
669
670        // Enqueue any other immediate events.
671        let extra_count = self.try_recv_events();
672        Some(count + extra_count)
673    }
674
675    /// Enqueues subgraphs triggered by external events asynchronously, waiting until at least one
676    /// subgraph is scheduled **from an external event**. Returns the number of subgraphs enqueued,
677    /// which may be zero if an external event scheduled an already-scheduled subgraph.
678    ///
679    /// Returns `None` if the event queue is closed, but that should not happen normally.
680    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
681    pub async fn recv_events_async(&mut self) -> Option<usize> {
682        let mut count = 0;
683        loop {
684            tracing::trace!("Awaiting events (`event_queue_recv`).");
685            let (sg_id, is_external) = self.context.event_queue_recv.recv().await?;
686            let sg_data = &self.subgraphs[sg_id];
687            tracing::trace!(
688                sg_id = sg_id.to_string(),
689                is_external = is_external,
690                sg_stratum = sg_data.stratum,
691                "Event received."
692            );
693            if !sg_data.is_scheduled.replace(true) {
694                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
695                count += 1;
696            }
697            if is_external {
698                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
699                // Or if the stratum is in the next tick.
700                if !self.context.events_received_tick
701                    || sg_data.stratum < self.context.current_stratum
702                {
703                    tracing::trace!(
704                        current_stratum = self.context.current_stratum,
705                        sg_stratum = sg_data.stratum,
706                        "External event, setting `can_start_tick = true`."
707                    );
708                    self.context.can_start_tick = true;
709                }
710                break;
711            }
712        }
713        self.context.events_received_tick = true;
714
715        // Enqueue any other immediate events.
716        let extra_count = self.try_recv_events();
717        Some(count + extra_count)
718    }
719
720    /// Schedules a subgraph to be run. See also: [`Context::schedule_subgraph`].
721    pub fn schedule_subgraph(&mut self, sg_id: SubgraphId) -> bool {
722        let sg_data = &self.subgraphs[sg_id];
723        let already_scheduled = sg_data.is_scheduled.replace(true);
724        if !already_scheduled {
725            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
726            true
727        } else {
728            false
729        }
730    }
731
732    /// Adds a new compiled subgraph with the specified inputs and outputs in stratum 0.
733    pub fn add_subgraph<Name, R, W, Func>(
734        &mut self,
735        name: Name,
736        recv_ports: R,
737        send_ports: W,
738        subgraph: Func,
739    ) -> SubgraphId
740    where
741        Name: Into<Cow<'static, str>>,
742        R: 'static + PortList<RECV>,
743        W: 'static + PortList<SEND>,
744        Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
745    {
746        self.add_subgraph_stratified(name, 0, recv_ports, send_ports, false, subgraph)
747    }
748
749    /// Adds a new compiled subgraph with the specified inputs, outputs, and stratum number.
750    ///
751    /// TODO(mingwei): add example in doc.
752    pub fn add_subgraph_stratified<Name, R, W, Func>(
753        &mut self,
754        name: Name,
755        stratum: usize,
756        recv_ports: R,
757        send_ports: W,
758        laziness: bool,
759        subgraph: Func,
760    ) -> SubgraphId
761    where
762        Name: Into<Cow<'static, str>>,
763        R: 'static + PortList<RECV>,
764        W: 'static + PortList<SEND>,
765        Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
766    {
767        self.add_subgraph_full(
768            name, stratum, recv_ports, send_ports, laziness, None, subgraph,
769        )
770    }
771
772    /// Adds a new compiled subgraph with all options.
773    #[expect(clippy::too_many_arguments, reason = "Mainly for internal use.")]
774    pub fn add_subgraph_full<Name, R, W, Func>(
775        &mut self,
776        name: Name,
777        stratum: usize,
778        recv_ports: R,
779        send_ports: W,
780        laziness: bool,
781        loop_id: Option<LoopId>,
782        mut subgraph: Func,
783    ) -> SubgraphId
784    where
785        Name: Into<Cow<'static, str>>,
786        R: 'static + PortList<RECV>,
787        W: 'static + PortList<SEND>,
788        Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
789    {
790        // SAFETY: Check that the send and recv ports are from `self.handoffs`.
791        recv_ports.assert_is_from(&self.handoffs);
792        send_ports.assert_is_from(&self.handoffs);
793
794        let loop_depth = loop_id
795            .and_then(|loop_id| self.context.loop_depth.get(loop_id))
796            .copied()
797            .unwrap_or(0);
798
799        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
800            let (mut subgraph_preds, mut subgraph_succs) = Default::default();
801            recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
802            send_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_succs, sg_id, false);
803
804            let subgraph =
805                async move |context: &mut Context,
806                            handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
807                    let (recv, send) = unsafe {
808                        // SAFETY:
809                        // 1. We checked `assert_is_from` at assembly time, above.
810                        // 2. `SlotVec` is insert-only so no handoffs could have changed since then.
811                        (
812                            recv_ports.make_ctx(&*handoffs),
813                            send_ports.make_ctx(&*handoffs),
814                        )
815                    };
816                    (subgraph)(context, recv, send).await;
817                };
818            SubgraphData::new(
819                name.into(),
820                stratum,
821                subgraph,
822                subgraph_preds,
823                subgraph_succs,
824                true,
825                laziness,
826                loop_id,
827                loop_depth,
828            )
829        });
830        self.context.init_stratum(stratum);
831        self.context.stratum_queues[stratum].push_back(sg_id);
832
833        sg_id
834    }
835
836    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
837    pub fn add_subgraph_n_m<Name, R, W, Func>(
838        &mut self,
839        name: Name,
840        recv_ports: Vec<RecvPort<R>>,
841        send_ports: Vec<SendPort<W>>,
842        subgraph: Func,
843    ) -> SubgraphId
844    where
845        Name: Into<Cow<'static, str>>,
846        R: 'static + Handoff,
847        W: 'static + Handoff,
848        Func: 'a
849            + for<'ctx> AsyncFnMut(
850                &'ctx mut Context,
851                &'ctx [&'ctx RecvCtx<R>],
852                &'ctx [&'ctx SendCtx<W>],
853            ),
854    {
855        self.add_subgraph_stratified_n_m(name, 0, recv_ports, send_ports, subgraph)
856    }
857
858    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
859    pub fn add_subgraph_stratified_n_m<Name, R, W, Func>(
860        &mut self,
861        name: Name,
862        stratum: usize,
863        recv_ports: Vec<RecvPort<R>>,
864        send_ports: Vec<SendPort<W>>,
865        mut subgraph: Func,
866    ) -> SubgraphId
867    where
868        Name: Into<Cow<'static, str>>,
869        R: 'static + Handoff,
870        W: 'static + Handoff,
871        Func: 'a
872            + for<'ctx> AsyncFnMut(
873                &'ctx mut Context,
874                &'ctx [&'ctx RecvCtx<R>],
875                &'ctx [&'ctx SendCtx<W>],
876            ),
877    {
878        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
879            let subgraph_preds = recv_ports.iter().map(|port| port.handoff_id).collect();
880            let subgraph_succs = send_ports.iter().map(|port| port.handoff_id).collect();
881
882            for recv_port in recv_ports.iter() {
883                self.handoffs[recv_port.handoff_id].succs.push(sg_id);
884            }
885            for send_port in send_ports.iter() {
886                self.handoffs[send_port.handoff_id].preds.push(sg_id);
887            }
888
889            let subgraph =
890                async move |context: &mut Context,
891                            handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
892                    let recvs: Vec<&RecvCtx<R>> = recv_ports
893                        .iter()
894                        .map(|hid| hid.handoff_id)
895                        .map(|hid| handoffs.get(hid).unwrap())
896                        .map(|h_data| {
897                            <dyn Any>::downcast_ref(&*h_data.handoff)
898                                .expect("Attempted to cast handoff to wrong type.")
899                        })
900                        .map(RefCast::ref_cast)
901                        .collect();
902
903                    let sends: Vec<&SendCtx<W>> = send_ports
904                        .iter()
905                        .map(|hid| hid.handoff_id)
906                        .map(|hid| handoffs.get(hid).unwrap())
907                        .map(|h_data| {
908                            <dyn Any>::downcast_ref(&*h_data.handoff)
909                                .expect("Attempted to cast handoff to wrong type.")
910                        })
911                        .map(RefCast::ref_cast)
912                        .collect();
913
914                    (subgraph)(context, &recvs, &sends).await;
915                };
916            SubgraphData::new(
917                name.into(),
918                stratum,
919                subgraph,
920                subgraph_preds,
921                subgraph_succs,
922                true,
923                false,
924                None,
925                0,
926            )
927        });
928
929        self.context.init_stratum(stratum);
930        self.context.stratum_queues[stratum].push_back(sg_id);
931
932        sg_id
933    }
934
935    /// Creates a handoff edge and returns the corresponding send and receive ports.
936    pub fn make_edge<Name, H>(&mut self, name: Name) -> (SendPort<H>, RecvPort<H>)
937    where
938        Name: Into<Cow<'static, str>>,
939        H: 'static + Handoff,
940    {
941        // Create and insert handoff.
942        let handoff = H::default();
943        let handoff_id = self
944            .handoffs
945            .insert_with_key(|hoff_id| HandoffData::new(name.into(), handoff, hoff_id));
946
947        // Make ports.
948        let input_port = SendPort {
949            handoff_id,
950            _marker: PhantomData,
951        };
952        let output_port = RecvPort {
953            handoff_id,
954            _marker: PhantomData,
955        };
956        (input_port, output_port)
957    }
958
959    /// Adds referenceable state into this instance. Returns a state handle which can be
960    /// used externally or by operators to access the state.
961    ///
962    /// This is part of the "state API".
963    pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
964    where
965        T: Any,
966    {
967        self.context.add_state(state)
968    }
969
970    /// Sets a hook to modify the state at the end of each tick, using the supplied closure.
971    ///
972    /// This is part of the "state API".
973    pub fn set_state_lifespan_hook<T>(
974        &mut self,
975        handle: StateHandle<T>,
976        lifespan: StateLifespan,
977        hook_fn: impl 'static + FnMut(&mut T),
978    ) where
979        T: Any,
980    {
981        self.context
982            .set_state_lifespan_hook(handle, lifespan, hook_fn)
983    }
984
985    /// Gets a exclusive (mut) ref to the internal context, setting the subgraph ID.
986    pub fn context_mut(&mut self, sg_id: SubgraphId) -> &mut Context {
987        self.context.subgraph_id = sg_id;
988        &mut self.context
989    }
990
991    /// Adds a new loop with the given parent (or `None` for top-level). Returns a loop ID which
992    /// is used in [`Self::add_subgraph_stratified`] or for nested loops.
993    ///
994    /// TODO(mingwei): add loop names to ensure traceability while debugging?
995    pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
996        let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
997        let loop_id = self.context.loop_depth.insert(depth);
998        self.loop_data.insert(
999            loop_id,
1000            LoopData {
1001                iter_count: None,
1002                allow_another_iteration: true,
1003            },
1004        );
1005        loop_id
1006    }
1007}
1008
1009impl Dfir<'_> {
1010    /// Alias for [`Context::request_task`].
1011    pub fn request_task<Fut>(&mut self, future: Fut)
1012    where
1013        Fut: Future<Output = ()> + 'static,
1014    {
1015        self.context.request_task(future);
1016    }
1017
1018    /// Alias for [`Context::abort_tasks`].
1019    pub fn abort_tasks(&mut self) {
1020        self.context.abort_tasks()
1021    }
1022
1023    /// Alias for [`Context::join_tasks`].
1024    pub fn join_tasks(&mut self) -> impl use<'_> + Future {
1025        self.context.join_tasks()
1026    }
1027}
1028
1029fn run_sync<Fut>(fut: Fut) -> Fut::Output
1030where
1031    Fut: Future,
1032{
1033    let mut fut = std::pin::pin!(fut);
1034    let mut ctx = std::task::Context::from_waker(std::task::Waker::noop());
1035    match fut.as_mut().poll(&mut ctx) {
1036        std::task::Poll::Ready(out) => out,
1037        std::task::Poll::Pending => panic!("Future did not resolve immediately."),
1038    }
1039}
1040
1041impl Drop for Dfir<'_> {
1042    fn drop(&mut self) {
1043        self.abort_tasks();
1044    }
1045}
1046
1047/// A handoff and its input and output [SubgraphId]s.
1048///
1049/// Internal use: used to track the dfir graph structure.
1050///
1051/// TODO(mingwei): restructure `PortList` so this can be crate-private.
1052#[doc(hidden)]
1053pub struct HandoffData {
1054    /// A friendly name for diagnostics.
1055    pub(super) name: Cow<'static, str>,
1056    /// Crate-visible to crate for `handoff_list` internals.
1057    pub(super) handoff: Box<dyn HandoffMeta>,
1058    /// Preceeding subgraphs (including the send side of a teeing handoff).
1059    pub(super) preds: SmallVec<[SubgraphId; 1]>,
1060    /// Successor subgraphs (including recv sides of teeing handoffs).
1061    pub(super) succs: SmallVec<[SubgraphId; 1]>,
1062
1063    /// Predecessor handoffs, used by teeing handoffs.
1064    /// Should be `self` on any teeing send sides (input).
1065    /// Should be the send `HandoffId` if this is teeing recv side (output).
1066    /// Should be just `self`'s `HandoffId` on other handoffs.
1067    /// This field is only used in initialization.
1068    pub(super) pred_handoffs: Vec<HandoffId>,
1069    /// Successor handoffs, used by teeing handoffs.
1070    /// Should be a list of outputs on the teeing send side (input).
1071    /// Should be `self` on any teeing recv sides (outputs).
1072    /// Should be just `self`'s `HandoffId` on other handoffs.
1073    /// This field is only used in initialization.
1074    pub(super) succ_handoffs: Vec<HandoffId>,
1075}
1076impl std::fmt::Debug for HandoffData {
1077    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1078        f.debug_struct("HandoffData")
1079            .field("preds", &self.preds)
1080            .field("succs", &self.succs)
1081            .finish_non_exhaustive()
1082    }
1083}
1084impl HandoffData {
1085    /// New with `pred_handoffs` and `succ_handoffs` set to its own [`HandoffId`]: `vec![hoff_id]`.
1086    pub fn new(
1087        name: Cow<'static, str>,
1088        handoff: impl 'static + HandoffMeta,
1089        hoff_id: HandoffId,
1090    ) -> Self {
1091        let (preds, succs) = Default::default();
1092        Self {
1093            name,
1094            handoff: Box::new(handoff),
1095            preds,
1096            succs,
1097            pred_handoffs: vec![hoff_id],
1098            succ_handoffs: vec![hoff_id],
1099        }
1100    }
1101}
1102
1103/// A subgraph along with its predecessor and successor [SubgraphId]s.
1104///
1105/// Used internally by the [Dfir] struct to represent the dataflow graph
1106/// structure and scheduled state.
1107pub(super) struct SubgraphData<'a> {
1108    /// A friendly name for diagnostics.
1109    pub(super) name: Cow<'static, str>,
1110    /// This subgraph's stratum number.
1111    ///
1112    /// Within loop blocks, corresponds to the topological sort of the DAG created when `next_loop()/next_tick()` are removed.
1113    pub(super) stratum: usize,
1114    /// The actual execution code of the subgraph.
1115    subgraph: Box<dyn 'a + Subgraph>,
1116
1117    #[expect(dead_code, reason = "may be useful in the future")]
1118    preds: Vec<HandoffId>,
1119    succs: Vec<HandoffId>,
1120
1121    /// If this subgraph is scheduled in [`dfir_rs::stratum_queues`].
1122    /// [`Cell`] allows modifying this field when iterating `Self::preds` or
1123    /// `Self::succs`, as all `SubgraphData` are owned by the same vec
1124    /// `dfir_rs::subgraphs`.
1125    is_scheduled: Cell<bool>,
1126
1127    /// Keep track of the last tick that this subgraph was run in
1128    last_tick_run_in: Option<TickInstant>,
1129    /// A meaningless ID to track the last loop execution this subgraph was run in.
1130    /// `(loop_nonce, iter_count)` pair.
1131    last_loop_nonce: (usize, Option<usize>),
1132
1133    /// If this subgraph is marked as lazy, then sending data back to a lower stratum does not trigger a new tick to be run.
1134    is_lazy: bool,
1135
1136    /// The subgraph's loop ID, or `None` for the top level.
1137    loop_id: Option<LoopId>,
1138    /// The loop depth of the subgraph.
1139    loop_depth: usize,
1140}
1141impl<'a> SubgraphData<'a> {
1142    #[expect(clippy::too_many_arguments, reason = "internal use")]
1143    pub(crate) fn new(
1144        name: Cow<'static, str>,
1145        stratum: usize,
1146        subgraph: impl 'a + Subgraph,
1147        preds: Vec<HandoffId>,
1148        succs: Vec<HandoffId>,
1149        is_scheduled: bool,
1150        is_lazy: bool,
1151        loop_id: Option<LoopId>,
1152        loop_depth: usize,
1153    ) -> Self {
1154        Self {
1155            name,
1156            stratum,
1157            subgraph: Box::new(subgraph),
1158            preds,
1159            succs,
1160            is_scheduled: Cell::new(is_scheduled),
1161            last_tick_run_in: None,
1162            last_loop_nonce: (0, None),
1163            is_lazy,
1164            loop_id,
1165            loop_depth,
1166        }
1167    }
1168}
1169
1170pub(crate) struct LoopData {
1171    /// Count of iterations of this loop.
1172    iter_count: Option<usize>,
1173    /// If the loop has reason to do another iteration.
1174    allow_another_iteration: bool,
1175}
1176
1177/// Defines when state should be reset.
1178#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1179pub enum StateLifespan {
1180    /// Always reset, a ssociated with the subgraph.
1181    Subgraph(SubgraphId),
1182    /// Reset between loop executions.
1183    Loop(LoopId),
1184    /// Reset between ticks.
1185    Tick,
1186    /// Never reset.
1187    Static,
1188}