dfir_rs/scheduled/
graph.rs

1//! Module for the [`Dfir`] struct and helper items.
2
3use std::any::Any;
4use std::borrow::Cow;
5use std::cell::Cell;
6use std::cmp::Ordering;
7use std::future::Future;
8use std::marker::PhantomData;
9
10#[cfg(feature = "meta")]
11use dfir_lang::diagnostic::{Diagnostic, SerdeSpan};
12#[cfg(feature = "meta")]
13use dfir_lang::graph::DfirGraph;
14use ref_cast::RefCast;
15use smallvec::SmallVec;
16use web_time::SystemTime;
17
18use super::context::Context;
19use super::handoff::handoff_list::PortList;
20use super::handoff::{Handoff, HandoffMeta, TeeingHandoff};
21use super::port::{RECV, RecvCtx, RecvPort, SEND, SendCtx, SendPort};
22use super::reactor::Reactor;
23use super::state::StateHandle;
24use super::subgraph::Subgraph;
25use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
26use crate::Never;
27use crate::scheduled::ticks::{TickDuration, TickInstant};
28use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
29
30/// A DFIR graph. Owns, schedules, and runs the compiled subgraphs.
31#[derive(Default)]
32pub struct Dfir<'a> {
33    pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
34
35    pub(super) loop_data: SecondarySlotVec<LoopTag, LoopData>,
36
37    pub(super) context: Context,
38
39    handoffs: SlotVec<HandoffTag, HandoffData>,
40
41    #[cfg(feature = "meta")]
42    /// See [`Self::meta_graph()`].
43    meta_graph: Option<DfirGraph>,
44
45    #[cfg(feature = "meta")]
46    /// See [`Self::diagnostics()`].
47    diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
48}
49
50/// Methods for [`TeeingHandoff`] teeing and dropping.
51impl Dfir<'_> {
52    /// Tees a [`TeeingHandoff`].
53    pub fn teeing_handoff_tee<T>(
54        &mut self,
55        tee_parent_port: &RecvPort<TeeingHandoff<T>>,
56    ) -> RecvPort<TeeingHandoff<T>>
57    where
58        T: Clone,
59    {
60        // If we're teeing from a child make sure to find root.
61        let tee_root = self.handoffs[tee_parent_port.handoff_id].pred_handoffs[0];
62
63        // Set up teeing metadata.
64        let tee_root_data = &mut self.handoffs[tee_root];
65        let tee_root_data_name = tee_root_data.name.clone();
66
67        // Insert new handoff output.
68        let teeing_handoff = tee_root_data
69            .handoff
70            .any_ref()
71            .downcast_ref::<TeeingHandoff<T>>()
72            .unwrap();
73        let new_handoff = teeing_handoff.tee();
74
75        // Handoff ID of new tee output.
76        let new_hoff_id = self.handoffs.insert_with_key(|new_hoff_id| {
77            let new_name = Cow::Owned(format!("{} tee {:?}", tee_root_data_name, new_hoff_id));
78            let mut new_handoff_data = HandoffData::new(new_name, new_handoff, new_hoff_id);
79            // Set self's predecessor as `tee_root`.
80            new_handoff_data.pred_handoffs = vec![tee_root];
81            new_handoff_data
82        });
83
84        // Go to `tee_root`'s successors and insert self (the new tee output).
85        let tee_root_data = &mut self.handoffs[tee_root];
86        tee_root_data.succ_handoffs.push(new_hoff_id);
87
88        // Add our new handoff id into the subgraph data if the send `tee_root` has already been
89        // used to add a subgraph.
90        assert!(
91            tee_root_data.preds.len() <= 1,
92            "Tee send side should only have one sender (or none set yet)."
93        );
94        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
95            self.subgraphs[pred_sg_id].succs.push(new_hoff_id);
96        }
97
98        let output_port = RecvPort {
99            handoff_id: new_hoff_id,
100            _marker: PhantomData,
101        };
102        output_port
103    }
104
105    /// Marks an output of a [`TeeingHandoff`] as dropped so that no more data will be sent to it.
106    ///
107    /// It is recommended to not not use this method and instead simply avoid teeing a
108    /// [`TeeingHandoff`] when it is not needed.
109    pub fn teeing_handoff_drop<T>(&mut self, tee_port: RecvPort<TeeingHandoff<T>>)
110    where
111        T: Clone,
112    {
113        let data = &self.handoffs[tee_port.handoff_id];
114        let teeing_handoff = data
115            .handoff
116            .any_ref()
117            .downcast_ref::<TeeingHandoff<T>>()
118            .unwrap();
119        teeing_handoff.drop();
120
121        let tee_root = data.pred_handoffs[0];
122        let tee_root_data = &mut self.handoffs[tee_root];
123        // Remove this output from the send succ handoff list.
124        tee_root_data
125            .succ_handoffs
126            .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
127        // Remove from subgraph successors if send port was already connected.
128        assert!(
129            tee_root_data.preds.len() <= 1,
130            "Tee send side should only have one sender (or none set yet)."
131        );
132        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
133            self.subgraphs[pred_sg_id]
134                .succs
135                .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
136        }
137    }
138}
139
140impl<'a> Dfir<'a> {
141    /// Create a new empty graph.
142    pub fn new() -> Self {
143        Default::default()
144    }
145
146    /// Assign the meta graph via JSON string. Used internally by the [`dfir_syntax`] and other macros.
147    #[doc(hidden)]
148    pub fn __assign_meta_graph(&mut self, _meta_graph_json: &str) {
149        #[cfg(feature = "meta")]
150        {
151            let mut meta_graph: DfirGraph =
152                serde_json::from_str(_meta_graph_json).expect("Failed to deserialize graph.");
153
154            let mut op_inst_diagnostics = Vec::new();
155            meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
156            assert!(
157                op_inst_diagnostics.is_empty(),
158                "Expected no diagnostics, got: {:#?}",
159                op_inst_diagnostics
160            );
161
162            assert!(self.meta_graph.replace(meta_graph).is_none());
163        }
164    }
165    /// Assign the diagnostics via JSON string.
166    #[doc(hidden)]
167    pub fn __assign_diagnostics(&mut self, _diagnostics_json: &'static str) {
168        #[cfg(feature = "meta")]
169        {
170            let diagnostics: Vec<Diagnostic<SerdeSpan>> = serde_json::from_str(_diagnostics_json)
171                .expect("Failed to deserialize diagnostics.");
172
173            assert!(self.diagnostics.replace(diagnostics).is_none());
174        }
175    }
176
177    /// Return a handle to the meta graph, if set. The meta graph is a
178    /// representation of all the operators, subgraphs, and handoffs in this instance.
179    /// Will only be set if this graph was constructed using a surface syntax macro.
180    #[cfg(feature = "meta")]
181    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
182    pub fn meta_graph(&self) -> Option<&DfirGraph> {
183        self.meta_graph.as_ref()
184    }
185
186    /// Returns any diagnostics generated by the surface syntax macro. Each diagnostic is a pair of
187    /// (1) a `Diagnostic` with span info reset and (2) the `ToString` version of the diagnostic
188    /// with original span info.
189    /// Will only be set if this graph was constructed using a surface syntax macro.
190    #[cfg(feature = "meta")]
191    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
192    pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
193        self.diagnostics.as_deref()
194    }
195
196    /// Returns a reactor for externally scheduling subgraphs, possibly from another thread.
197    /// Reactor events are considered to be external events.
198    pub fn reactor(&self) -> Reactor {
199        Reactor::new(self.context.event_queue_send.clone())
200    }
201
202    /// Gets the current tick (local time) count.
203    pub fn current_tick(&self) -> TickInstant {
204        self.context.current_tick
205    }
206
207    /// Gets the current stratum nubmer.
208    pub fn current_stratum(&self) -> usize {
209        self.context.current_stratum
210    }
211
212    /// Runs the dataflow until the next tick begins.
213    /// Returns true if any work was done.
214    #[tracing::instrument(level = "trace", skip(self), ret)]
215    pub fn run_tick(&mut self) -> bool {
216        let mut work_done = false;
217        // While work is immediately available *on the current tick*.
218        while self.next_stratum(true) {
219            work_done = true;
220            // Do any work.
221            self.run_stratum();
222        }
223        work_done
224    }
225
226    /// Runs the dataflow until no more (externally-triggered) work is immediately available.
227    /// Runs at least one tick of dataflow, even if no external events have been received.
228    /// If the dataflow contains loops this method may run forever.
229    /// Returns true if any work was done.
230    #[tracing::instrument(level = "trace", skip(self), ret)]
231    pub fn run_available(&mut self) -> bool {
232        let mut work_done = false;
233        // While work is immediately available.
234        while self.next_stratum(false) {
235            work_done = true;
236            // Do any work.
237            self.run_stratum();
238        }
239        work_done
240    }
241
242    /// Runs the dataflow until no more (externally-triggered) work is immediately available.
243    /// Runs at least one tick of dataflow, even if no external events have been received.
244    /// If the dataflow contains loops this method may run forever.
245    /// Returns true if any work was done.
246    /// Yields repeatedly to allow external events to happen.
247    #[tracing::instrument(level = "trace", skip(self), ret)]
248    pub async fn run_available_async(&mut self) -> bool {
249        let mut work_done = false;
250        // While work is immediately available.
251        while self.next_stratum(false) {
252            work_done = true;
253            // Do any work.
254            self.run_stratum();
255
256            // Yield between each stratum to receive more events.
257            // TODO(mingwei): really only need to yield at start of ticks though.
258            tokio::task::yield_now().await;
259        }
260        work_done
261    }
262
263    /// Runs the current stratum of the dataflow until no more local work is available (does not receive events).
264    /// Returns true if any work was done.
265    #[tracing::instrument(level = "trace", skip(self), fields(tick = u64::from(self.context.current_tick), stratum = self.context.current_stratum), ret)]
266    pub fn run_stratum(&mut self) -> bool {
267        // Make sure to spawn tasks once dfir is running!
268        // This drains the task buffer, so becomes a no-op after first call.
269        self.context.spawn_tasks();
270
271        let mut work_done = false;
272
273        'pop: while let Some(sg_id) =
274            self.context.stratum_queues[self.context.current_stratum].pop_front()
275        {
276            {
277                let sg_data = &mut self.subgraphs[sg_id];
278                // This must be true for the subgraph to be enqueued.
279                assert!(sg_data.is_scheduled.take());
280
281                match sg_data.loop_depth.cmp(&self.context.loop_nonce_stack.len()) {
282                    Ordering::Greater => {
283                        // We have entered a loop.
284                        self.context.loop_nonce += 1;
285                        self.context.loop_nonce_stack.push(self.context.loop_nonce);
286                        tracing::trace!(loop_nonce = self.context.loop_nonce, "Entered loop.");
287                    }
288                    Ordering::Less => {
289                        // We have exited a loop.
290                        self.context.loop_nonce_stack.pop();
291                        tracing::trace!("Exited loop.");
292                    }
293                    Ordering::Equal => {}
294                }
295
296                self.context.subgraph_id = sg_id;
297                self.context.is_first_run_this_tick = sg_data
298                    .last_tick_run_in
299                    .is_none_or(|last_tick| last_tick < self.context.current_tick);
300
301                if let Some(loop_id) = sg_data.loop_id {
302                    // Loop execution - running loop block, from start to finish, containing
303                    // multiple iterations.
304                    // Loop iteration - a single iteration of a loop block, all subgraphs within
305                    // the loop should run (at most) once.
306
307                    // If the previous run of this subgraph had the same loop execution and
308                    // iteration count, then we need to increment the iteration count.
309                    let curr_loop_nonce = self.context.loop_nonce_stack.last().copied();
310
311                    let LoopData {
312                        iter_count: loop_iter_count,
313                        allow_another_iteration,
314                    } = &mut self.loop_data[loop_id];
315
316                    let (prev_loop_nonce, prev_iter_count) = sg_data.last_loop_nonce;
317
318                    // If the loop nonce is the same as the previous execution, then we are in
319                    // the same loop execution.
320                    // `curr_loop_nonce` is `None` for top-level loops, and top-level loops are
321                    // always in the same (singular) loop execution.
322                    let curr_iter_count =
323                        if curr_loop_nonce.is_none_or(|nonce| nonce == prev_loop_nonce) {
324                            // If the iteration count is the same as the previous execution, then
325                            // we are on the next iteration.
326                            if loop_iter_count.is_none_or(|n| n == prev_iter_count) {
327                                // If not true, then we shall not run the next iteration.
328                                if !std::mem::take(allow_another_iteration) {
329                                    tracing::trace!(
330                                        "Loop will not continue to next iteration, skipping."
331                                    );
332                                    continue 'pop;
333                                }
334                                // Increment `loop_iter_count` or set it to 0.
335                                loop_iter_count.map_or(0, |n| n + 1)
336                            } else {
337                                // Otherwise update the local iteration count to match the loop.
338                                debug_assert!(loop_iter_count.is_some_and(|n| prev_iter_count < n));
339                                loop_iter_count.unwrap()
340                            }
341                        } else {
342                            // We are in a new loop execution.
343                            0
344                        };
345                    *loop_iter_count = Some(curr_iter_count);
346                    self.context.loop_iter_count = curr_iter_count;
347                    sg_data.last_loop_nonce =
348                        (curr_loop_nonce.unwrap_or_default(), curr_iter_count);
349                }
350
351                tracing::info!(
352                    sg_id = sg_id.to_string(),
353                    sg_name = &*sg_data.name,
354                    sg_depth = sg_data.loop_depth,
355                    sg_loop_nonce = sg_data.last_loop_nonce.0,
356                    sg_iter_count = sg_data.last_loop_nonce.1,
357                    "Running subgraph."
358                );
359                sg_data.subgraph.run(&mut self.context, &mut self.handoffs);
360
361                sg_data.last_tick_run_in = Some(self.context.current_tick);
362            }
363
364            let sg_data = &self.subgraphs[sg_id];
365            for &handoff_id in sg_data.succs.iter() {
366                let handoff = &self.handoffs[handoff_id];
367                if !handoff.handoff.is_bottom() {
368                    for &succ_id in handoff.succs.iter() {
369                        let succ_sg_data = &self.subgraphs[succ_id];
370                        // If we have sent data to the next tick, then we can start the next tick.
371                        if succ_sg_data.stratum < self.context.current_stratum && !sg_data.is_lazy {
372                            self.context.can_start_tick = true;
373                        }
374                        // Add subgraph to stratum queue if it is not already scheduled.
375                        if !succ_sg_data.is_scheduled.replace(true) {
376                            self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
377                        }
378                        // Add stratum to stratum stack if it is within a loop.
379                        if 0 < succ_sg_data.loop_depth {
380                            // TODO(mingwei): handle duplicates
381                            self.context
382                                .stratum_stack
383                                .push(succ_sg_data.loop_depth, succ_sg_data.stratum);
384                        }
385                    }
386                }
387            }
388
389            let reschedule = self.context.reschedule_loop_block.take();
390            let allow_another = self.context.allow_another_iteration.take();
391
392            if reschedule {
393                // Re-enqueue the subgraph.
394                self.context.schedule_deferred.push(sg_id);
395                self.context
396                    .stratum_stack
397                    .push(sg_data.loop_depth, sg_data.stratum);
398            }
399            if reschedule || allow_another {
400                if let Some(loop_id) = sg_data.loop_id {
401                    self.loop_data
402                        .get_mut(loop_id)
403                        .unwrap()
404                        .allow_another_iteration = true;
405                }
406            }
407
408            work_done = true;
409        }
410        work_done
411    }
412
413    /// Go to the next stratum which has work available, possibly the current stratum.
414    /// Return true if more work is available, otherwise false if no work is immediately
415    /// available on any strata.
416    ///
417    /// This will receive external events when at the start of a tick.
418    ///
419    /// If `current_tick_only` is set to `true`, will only return `true` if work is immediately
420    /// available on the *current tick*.
421    ///
422    /// If this returns false then the graph will be at the start of a tick (at stratum 0, can
423    /// receive more external events).
424    #[tracing::instrument(level = "trace", skip(self), ret)]
425    pub fn next_stratum(&mut self, current_tick_only: bool) -> bool {
426        tracing::trace!(
427            events_received_tick = self.context.events_received_tick,
428            can_start_tick = self.context.can_start_tick,
429            "Starting `next_stratum` call.",
430        );
431
432        // The stratum we will stop searching at, i.e. made a full loop around.
433        let mut end_stratum = self.context.current_stratum;
434        let mut new_tick_started = false;
435
436        if 0 == self.context.current_stratum {
437            new_tick_started = true;
438
439            // Starting the tick, reset this to `false`.
440            tracing::trace!("Starting tick, setting `can_start_tick = false`.");
441            self.context.can_start_tick = false;
442            self.context.current_tick_start = SystemTime::now();
443
444            // Ensure external events are received before running the tick.
445            if !self.context.events_received_tick {
446                // Add any external jobs to ready queue.
447                self.try_recv_events();
448            }
449        }
450
451        loop {
452            tracing::trace!(
453                tick = u64::from(self.context.current_tick),
454                stratum = self.context.current_stratum,
455                "Looking for work on stratum."
456            );
457            // If current stratum has work, return true.
458            if !self.context.stratum_queues[self.context.current_stratum].is_empty() {
459                tracing::trace!(
460                    tick = u64::from(self.context.current_tick),
461                    stratum = self.context.current_stratum,
462                    "Work found on stratum."
463                );
464                return true;
465            }
466
467            if let Some(next_stratum) = self.context.stratum_stack.pop() {
468                self.context.current_stratum = next_stratum;
469
470                // Now schedule deferred subgraphs.
471                {
472                    for sg_id in self.context.schedule_deferred.drain(..) {
473                        let sg_data = &self.subgraphs[sg_id];
474                        tracing::info!(
475                            tick = u64::from(self.context.current_tick),
476                            stratum = self.context.current_stratum,
477                            sg_id = sg_id.to_string(),
478                            sg_name = &*sg_data.name,
479                            is_scheduled = sg_data.is_scheduled.get(),
480                            "Rescheduling deferred subgraph."
481                        );
482                        if !sg_data.is_scheduled.replace(true) {
483                            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
484                        }
485                    }
486                }
487            } else {
488                // Increment stratum counter.
489                self.context.current_stratum += 1;
490
491                if self.context.current_stratum >= self.context.stratum_queues.len() {
492                    new_tick_started = true;
493
494                    tracing::trace!(
495                        can_start_tick = self.context.can_start_tick,
496                        "End of tick {}, starting tick {}.",
497                        self.context.current_tick,
498                        self.context.current_tick + TickDuration::SINGLE_TICK,
499                    );
500                    self.context.reset_state_at_end_of_tick();
501
502                    self.context.current_stratum = 0;
503                    self.context.current_tick += TickDuration::SINGLE_TICK;
504                    self.context.events_received_tick = false;
505
506                    if current_tick_only {
507                        tracing::trace!(
508                            "`current_tick_only` is `true`, returning `false` before receiving events."
509                        );
510                        return false;
511                    } else {
512                        self.try_recv_events();
513                        if std::mem::replace(&mut self.context.can_start_tick, false) {
514                            tracing::trace!(
515                                tick = u64::from(self.context.current_tick),
516                                "`can_start_tick` is `true`, continuing."
517                            );
518                            // Do a full loop more to find where events have been added.
519                            end_stratum = 0;
520                            continue;
521                        } else {
522                            tracing::trace!(
523                                "`can_start_tick` is `false`, re-setting `events_received_tick = false`, returning `false`."
524                            );
525                            self.context.events_received_tick = false;
526                            return false;
527                        }
528                    }
529                }
530            }
531
532            // After incrementing, exit if we made a full loop around the strata.
533            if new_tick_started && end_stratum == self.context.current_stratum {
534                tracing::trace!(
535                    "Made full loop around stratum, re-setting `current_stratum = 0`, returning `false`."
536                );
537                // Note: if current stratum had work, the very first loop iteration would've
538                // returned true. Therefore we can return false without checking.
539                // Also means nothing was done so we can reset the stratum to zero and wait for
540                // events.
541                self.context.events_received_tick = false;
542                self.context.current_stratum = 0;
543                return false;
544            }
545        }
546    }
547
548    /// Runs the dataflow graph forever.
549    ///
550    /// TODO(mingwei): Currently blocks forever, no notion of "completion."
551    #[tracing::instrument(level = "trace", skip(self), ret)]
552    pub fn run(&mut self) -> Option<Never> {
553        loop {
554            self.run_tick();
555        }
556    }
557
558    /// Runs the dataflow graph forever.
559    ///
560    /// TODO(mingwei): Currently blocks forever, no notion of "completion."
561    #[tracing::instrument(level = "trace", skip(self), ret)]
562    pub async fn run_async(&mut self) -> Option<Never> {
563        loop {
564            // Run any work which is immediately available.
565            self.run_available_async().await;
566            // When no work is available yield until more events occur.
567            self.recv_events_async().await;
568        }
569    }
570
571    /// Enqueues subgraphs triggered by events without blocking.
572    ///
573    /// Returns the number of subgraphs enqueued, and if any were external.
574    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
575    pub fn try_recv_events(&mut self) -> usize {
576        let mut enqueued_count = 0;
577        while let Ok((sg_id, is_external)) = self.context.event_queue_recv.try_recv() {
578            let sg_data = &self.subgraphs[sg_id];
579            tracing::trace!(
580                sg_id = sg_id.to_string(),
581                is_external = is_external,
582                sg_stratum = sg_data.stratum,
583                "Event received."
584            );
585            if !sg_data.is_scheduled.replace(true) {
586                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
587                enqueued_count += 1;
588            }
589            if is_external {
590                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
591                // Or if the stratum is in the next tick.
592                if !self.context.events_received_tick
593                    || sg_data.stratum < self.context.current_stratum
594                {
595                    tracing::trace!(
596                        current_stratum = self.context.current_stratum,
597                        sg_stratum = sg_data.stratum,
598                        "External event, setting `can_start_tick = true`."
599                    );
600                    self.context.can_start_tick = true;
601                }
602            }
603        }
604        self.context.events_received_tick = true;
605
606        enqueued_count
607    }
608
609    /// Enqueues subgraphs triggered by external events, blocking until at
610    /// least one subgraph is scheduled **from an external event**.
611    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
612    pub fn recv_events(&mut self) -> Option<usize> {
613        let mut count = 0;
614        loop {
615            let (sg_id, is_external) = self.context.event_queue_recv.blocking_recv()?;
616            let sg_data = &self.subgraphs[sg_id];
617            tracing::trace!(
618                sg_id = sg_id.to_string(),
619                is_external = is_external,
620                sg_stratum = sg_data.stratum,
621                "Event received."
622            );
623            if !sg_data.is_scheduled.replace(true) {
624                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
625                count += 1;
626            }
627            if is_external {
628                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
629                // Or if the stratum is in the next tick.
630                if !self.context.events_received_tick
631                    || sg_data.stratum < self.context.current_stratum
632                {
633                    tracing::trace!(
634                        current_stratum = self.context.current_stratum,
635                        sg_stratum = sg_data.stratum,
636                        "External event, setting `can_start_tick = true`."
637                    );
638                    self.context.can_start_tick = true;
639                }
640                break;
641            }
642        }
643        self.context.events_received_tick = true;
644
645        // Enqueue any other immediate events.
646        let extra_count = self.try_recv_events();
647        Some(count + extra_count)
648    }
649
650    /// Enqueues subgraphs triggered by external events asynchronously, waiting until at least one
651    /// subgraph is scheduled **from an external event**. Returns the number of subgraphs enqueued,
652    /// which may be zero if an external event scheduled an already-scheduled subgraph.
653    ///
654    /// Returns `None` if the event queue is closed, but that should not happen normally.
655    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
656    pub async fn recv_events_async(&mut self) -> Option<usize> {
657        let mut count = 0;
658        loop {
659            tracing::trace!("Awaiting events (`event_queue_recv`).");
660            let (sg_id, is_external) = self.context.event_queue_recv.recv().await?;
661            let sg_data = &self.subgraphs[sg_id];
662            tracing::trace!(
663                sg_id = sg_id.to_string(),
664                is_external = is_external,
665                sg_stratum = sg_data.stratum,
666                "Event received."
667            );
668            if !sg_data.is_scheduled.replace(true) {
669                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
670                count += 1;
671            }
672            if is_external {
673                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
674                // Or if the stratum is in the next tick.
675                if !self.context.events_received_tick
676                    || sg_data.stratum < self.context.current_stratum
677                {
678                    tracing::trace!(
679                        current_stratum = self.context.current_stratum,
680                        sg_stratum = sg_data.stratum,
681                        "External event, setting `can_start_tick = true`."
682                    );
683                    self.context.can_start_tick = true;
684                }
685                break;
686            }
687        }
688        self.context.events_received_tick = true;
689
690        // Enqueue any other immediate events.
691        let extra_count = self.try_recv_events();
692        Some(count + extra_count)
693    }
694
695    /// Schedules a subgraph to be run. See also: [`Context::schedule_subgraph`].
696    pub fn schedule_subgraph(&mut self, sg_id: SubgraphId) -> bool {
697        let sg_data = &self.subgraphs[sg_id];
698        let already_scheduled = sg_data.is_scheduled.replace(true);
699        if !already_scheduled {
700            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
701            true
702        } else {
703            false
704        }
705    }
706
707    /// Adds a new compiled subgraph with the specified inputs and outputs in stratum 0.
708    pub fn add_subgraph<Name, R, W, F>(
709        &mut self,
710        name: Name,
711        recv_ports: R,
712        send_ports: W,
713        subgraph: F,
714    ) -> SubgraphId
715    where
716        Name: Into<Cow<'static, str>>,
717        R: 'static + PortList<RECV>,
718        W: 'static + PortList<SEND>,
719        F: 'static + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
720    {
721        self.add_subgraph_stratified(name, 0, recv_ports, send_ports, false, subgraph)
722    }
723
724    /// Adds a new compiled subgraph with the specified inputs, outputs, and stratum number.
725    ///
726    /// TODO(mingwei): add example in doc.
727    pub fn add_subgraph_stratified<Name, R, W, F>(
728        &mut self,
729        name: Name,
730        stratum: usize,
731        recv_ports: R,
732        send_ports: W,
733        laziness: bool,
734        subgraph: F,
735    ) -> SubgraphId
736    where
737        Name: Into<Cow<'static, str>>,
738        R: 'static + PortList<RECV>,
739        W: 'static + PortList<SEND>,
740        F: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
741    {
742        self.add_subgraph_full(
743            name, stratum, recv_ports, send_ports, laziness, None, subgraph,
744        )
745    }
746
747    /// Adds a new compiled subgraph with all options.
748    #[expect(clippy::too_many_arguments, reason = "Mainly for internal use.")]
749    pub fn add_subgraph_full<Name, R, W, F>(
750        &mut self,
751        name: Name,
752        stratum: usize,
753        recv_ports: R,
754        send_ports: W,
755        laziness: bool,
756        loop_id: Option<LoopId>,
757        mut subgraph: F,
758    ) -> SubgraphId
759    where
760        Name: Into<Cow<'static, str>>,
761        R: 'static + PortList<RECV>,
762        W: 'static + PortList<SEND>,
763        F: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
764    {
765        // SAFETY: Check that the send and recv ports are from `self.handoffs`.
766        recv_ports.assert_is_from(&self.handoffs);
767        send_ports.assert_is_from(&self.handoffs);
768
769        let loop_depth = loop_id
770            .and_then(|loop_id| self.context.loop_depth.get(loop_id))
771            .copied()
772            .unwrap_or(0);
773
774        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
775            let (mut subgraph_preds, mut subgraph_succs) = Default::default();
776            recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
777            send_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_succs, sg_id, false);
778
779            let subgraph =
780                move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
781                    let (recv, send) = unsafe {
782                        // SAFETY:
783                        // 1. We checked `assert_is_from` at assembly time, above.
784                        // 2. `SlotVec` is insert-only so no handoffs could have changed since then.
785                        (
786                            recv_ports.make_ctx(&*handoffs),
787                            send_ports.make_ctx(&*handoffs),
788                        )
789                    };
790                    (subgraph)(context, recv, send);
791                };
792            SubgraphData::new(
793                name.into(),
794                stratum,
795                subgraph,
796                subgraph_preds,
797                subgraph_succs,
798                true,
799                laziness,
800                loop_id,
801                loop_depth,
802            )
803        });
804        self.context.init_stratum(stratum);
805        self.context.stratum_queues[stratum].push_back(sg_id);
806
807        sg_id
808    }
809
810    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
811    pub fn add_subgraph_n_m<Name, R, W, F>(
812        &mut self,
813        name: Name,
814        recv_ports: Vec<RecvPort<R>>,
815        send_ports: Vec<SendPort<W>>,
816        subgraph: F,
817    ) -> SubgraphId
818    where
819        Name: Into<Cow<'static, str>>,
820        R: 'static + Handoff,
821        W: 'static + Handoff,
822        F: 'static
823            + for<'ctx> FnMut(&'ctx mut Context, &'ctx [&'ctx RecvCtx<R>], &'ctx [&'ctx SendCtx<W>]),
824    {
825        self.add_subgraph_stratified_n_m(name, 0, recv_ports, send_ports, subgraph)
826    }
827
828    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
829    pub fn add_subgraph_stratified_n_m<Name, R, W, F>(
830        &mut self,
831        name: Name,
832        stratum: usize,
833        recv_ports: Vec<RecvPort<R>>,
834        send_ports: Vec<SendPort<W>>,
835        mut subgraph: F,
836    ) -> SubgraphId
837    where
838        Name: Into<Cow<'static, str>>,
839        R: 'static + Handoff,
840        W: 'static + Handoff,
841        F: 'static
842            + for<'ctx> FnMut(&'ctx mut Context, &'ctx [&'ctx RecvCtx<R>], &'ctx [&'ctx SendCtx<W>]),
843    {
844        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
845            let subgraph_preds = recv_ports.iter().map(|port| port.handoff_id).collect();
846            let subgraph_succs = send_ports.iter().map(|port| port.handoff_id).collect();
847
848            for recv_port in recv_ports.iter() {
849                self.handoffs[recv_port.handoff_id].succs.push(sg_id);
850            }
851            for send_port in send_ports.iter() {
852                self.handoffs[send_port.handoff_id].preds.push(sg_id);
853            }
854
855            let subgraph =
856                move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
857                    let recvs: Vec<&RecvCtx<R>> = recv_ports
858                        .iter()
859                        .map(|hid| hid.handoff_id)
860                        .map(|hid| handoffs.get(hid).unwrap())
861                        .map(|h_data| {
862                            h_data
863                                .handoff
864                                .any_ref()
865                                .downcast_ref()
866                                .expect("Attempted to cast handoff to wrong type.")
867                        })
868                        .map(RefCast::ref_cast)
869                        .collect();
870
871                    let sends: Vec<&SendCtx<W>> = send_ports
872                        .iter()
873                        .map(|hid| hid.handoff_id)
874                        .map(|hid| handoffs.get(hid).unwrap())
875                        .map(|h_data| {
876                            h_data
877                                .handoff
878                                .any_ref()
879                                .downcast_ref()
880                                .expect("Attempted to cast handoff to wrong type.")
881                        })
882                        .map(RefCast::ref_cast)
883                        .collect();
884
885                    (subgraph)(context, &recvs, &sends)
886                };
887            SubgraphData::new(
888                name.into(),
889                stratum,
890                subgraph,
891                subgraph_preds,
892                subgraph_succs,
893                true,
894                false,
895                None,
896                0,
897            )
898        });
899
900        self.context.init_stratum(stratum);
901        self.context.stratum_queues[stratum].push_back(sg_id);
902
903        sg_id
904    }
905
906    /// Creates a handoff edge and returns the corresponding send and receive ports.
907    pub fn make_edge<Name, H>(&mut self, name: Name) -> (SendPort<H>, RecvPort<H>)
908    where
909        Name: Into<Cow<'static, str>>,
910        H: 'static + Handoff,
911    {
912        // Create and insert handoff.
913        let handoff = H::default();
914        let handoff_id = self
915            .handoffs
916            .insert_with_key(|hoff_id| HandoffData::new(name.into(), handoff, hoff_id));
917
918        // Make ports.
919        let input_port = SendPort {
920            handoff_id,
921            _marker: PhantomData,
922        };
923        let output_port = RecvPort {
924            handoff_id,
925            _marker: PhantomData,
926        };
927        (input_port, output_port)
928    }
929
930    /// Adds referenceable state into this instance. Returns a state handle which can be
931    /// used externally or by operators to access the state.
932    ///
933    /// This is part of the "state API".
934    pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
935    where
936        T: Any,
937    {
938        self.context.add_state(state)
939    }
940
941    /// Sets a hook to modify the state at the end of each tick, using the supplied closure.
942    ///
943    /// This is part of the "state API".
944    pub fn set_state_tick_hook<T>(
945        &mut self,
946        handle: StateHandle<T>,
947        tick_hook_fn: impl 'static + FnMut(&mut T),
948    ) where
949        T: Any,
950    {
951        self.context.set_state_tick_hook(handle, tick_hook_fn)
952    }
953
954    /// Gets a exclusive (mut) ref to the internal context, setting the subgraph ID.
955    pub fn context_mut(&mut self, sg_id: SubgraphId) -> &mut Context {
956        self.context.subgraph_id = sg_id;
957        &mut self.context
958    }
959
960    /// Adds a new loop with the given parent (or `None` for top-level). Returns a loop ID which
961    /// is used in [`Self::add_subgraph_stratified`] or for nested loops.
962    ///
963    /// TODO(mingwei): add loop names to ensure traceability while debugging?
964    pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
965        let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
966        let loop_id = self.context.loop_depth.insert(depth);
967        self.loop_data.insert(
968            loop_id,
969            LoopData {
970                iter_count: None,
971                allow_another_iteration: true,
972            },
973        );
974        loop_id
975    }
976}
977
978impl Dfir<'_> {
979    /// Alias for [`Context::request_task`].
980    pub fn request_task<Fut>(&mut self, future: Fut)
981    where
982        Fut: Future<Output = ()> + 'static,
983    {
984        self.context.request_task(future);
985    }
986
987    /// Alias for [`Context::abort_tasks`].
988    pub fn abort_tasks(&mut self) {
989        self.context.abort_tasks()
990    }
991
992    /// Alias for [`Context::join_tasks`].
993    pub fn join_tasks(&mut self) -> impl use<'_> + Future {
994        self.context.join_tasks()
995    }
996}
997
998impl Drop for Dfir<'_> {
999    fn drop(&mut self) {
1000        self.abort_tasks();
1001    }
1002}
1003
1004/// A handoff and its input and output [SubgraphId]s.
1005///
1006/// Internal use: used to track the dfir graph structure.
1007///
1008/// TODO(mingwei): restructure `PortList` so this can be crate-private.
1009#[doc(hidden)]
1010pub struct HandoffData {
1011    /// A friendly name for diagnostics.
1012    pub(super) name: Cow<'static, str>,
1013    /// Crate-visible to crate for `handoff_list` internals.
1014    pub(super) handoff: Box<dyn HandoffMeta>,
1015    /// Preceeding subgraphs (including the send side of a teeing handoff).
1016    pub(super) preds: SmallVec<[SubgraphId; 1]>,
1017    /// Successor subgraphs (including recv sides of teeing handoffs).
1018    pub(super) succs: SmallVec<[SubgraphId; 1]>,
1019
1020    /// Predecessor handoffs, used by teeing handoffs.
1021    /// Should be `self` on any teeing send sides (input).
1022    /// Should be the send `HandoffId` if this is teeing recv side (output).
1023    /// Should be just `self`'s `HandoffId` on other handoffs.
1024    /// This field is only used in initialization.
1025    pub(super) pred_handoffs: Vec<HandoffId>,
1026    /// Successor handoffs, used by teeing handoffs.
1027    /// Should be a list of outputs on the teeing send side (input).
1028    /// Should be `self` on any teeing recv sides (outputs).
1029    /// Should be just `self`'s `HandoffId` on other handoffs.
1030    /// This field is only used in initialization.
1031    pub(super) succ_handoffs: Vec<HandoffId>,
1032}
1033impl std::fmt::Debug for HandoffData {
1034    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1035        f.debug_struct("HandoffData")
1036            .field("preds", &self.preds)
1037            .field("succs", &self.succs)
1038            .finish_non_exhaustive()
1039    }
1040}
1041impl HandoffData {
1042    /// New with `pred_handoffs` and `succ_handoffs` set to its own [`HandoffId`]: `vec![hoff_id]`.
1043    pub fn new(
1044        name: Cow<'static, str>,
1045        handoff: impl 'static + HandoffMeta,
1046        hoff_id: HandoffId,
1047    ) -> Self {
1048        let (preds, succs) = Default::default();
1049        Self {
1050            name,
1051            handoff: Box::new(handoff),
1052            preds,
1053            succs,
1054            pred_handoffs: vec![hoff_id],
1055            succ_handoffs: vec![hoff_id],
1056        }
1057    }
1058}
1059
1060/// A subgraph along with its predecessor and successor [SubgraphId]s.
1061///
1062/// Used internally by the [Dfir] struct to represent the dataflow graph
1063/// structure and scheduled state.
1064pub(super) struct SubgraphData<'a> {
1065    /// A friendly name for diagnostics.
1066    pub(super) name: Cow<'static, str>,
1067    /// This subgraph's stratum number.
1068    ///
1069    /// Within loop blocks, corresponds to the topological sort of the DAG created when `next_loop()/next_tick()` are removed.
1070    pub(super) stratum: usize,
1071    /// The actual execution code of the subgraph.
1072    subgraph: Box<dyn Subgraph + 'a>,
1073
1074    #[expect(dead_code, reason = "may be useful in the future")]
1075    preds: Vec<HandoffId>,
1076    succs: Vec<HandoffId>,
1077
1078    /// If this subgraph is scheduled in [`dfir_rs::stratum_queues`].
1079    /// [`Cell`] allows modifying this field when iterating `Self::preds` or
1080    /// `Self::succs`, as all `SubgraphData` are owned by the same vec
1081    /// `dfir_rs::subgraphs`.
1082    is_scheduled: Cell<bool>,
1083
1084    /// Keep track of the last tick that this subgraph was run in
1085    last_tick_run_in: Option<TickInstant>,
1086    /// A meaningless ID to track the last loop execution this subgraph was run in.
1087    /// `(loop_nonce, iter_count)` pair.
1088    last_loop_nonce: (usize, usize),
1089
1090    /// If this subgraph is marked as lazy, then sending data back to a lower stratum does not trigger a new tick to be run.
1091    is_lazy: bool,
1092
1093    /// The subgraph's loop ID, or `None` for the top level.
1094    loop_id: Option<LoopId>,
1095    /// The loop depth of the subgraph.
1096    loop_depth: usize,
1097}
1098impl<'a> SubgraphData<'a> {
1099    #[expect(clippy::too_many_arguments, reason = "internal use")]
1100    pub(crate) fn new(
1101        name: Cow<'static, str>,
1102        stratum: usize,
1103        subgraph: impl Subgraph + 'a,
1104        preds: Vec<HandoffId>,
1105        succs: Vec<HandoffId>,
1106        is_scheduled: bool,
1107        is_lazy: bool,
1108        loop_id: Option<LoopId>,
1109        loop_depth: usize,
1110    ) -> Self {
1111        Self {
1112            name,
1113            stratum,
1114            subgraph: Box::new(subgraph),
1115            preds,
1116            succs,
1117            is_scheduled: Cell::new(is_scheduled),
1118            last_tick_run_in: None,
1119            last_loop_nonce: (0, 0),
1120            is_lazy,
1121            loop_id,
1122            loop_depth,
1123        }
1124    }
1125}
1126
1127pub(crate) struct LoopData {
1128    /// Count of iterations of this loop.
1129    iter_count: Option<usize>,
1130    /// If the loop has reason to do another iteration.
1131    allow_another_iteration: bool,
1132}