dfir_rs/scheduled/
graph.rs

1//! Module for the [`Dfir`] struct and helper items.
2
3use std::any::Any;
4use std::borrow::Cow;
5use std::cell::Cell;
6use std::cmp::Ordering;
7use std::future::Future;
8use std::marker::PhantomData;
9
10#[cfg(feature = "meta")]
11use dfir_lang::diagnostic::{Diagnostic, SerdeSpan};
12#[cfg(feature = "meta")]
13use dfir_lang::graph::DfirGraph;
14use ref_cast::RefCast;
15use smallvec::SmallVec;
16use web_time::SystemTime;
17
18use super::context::Context;
19use super::handoff::handoff_list::PortList;
20use super::handoff::{Handoff, HandoffMeta, TeeingHandoff};
21use super::port::{RECV, RecvCtx, RecvPort, SEND, SendCtx, SendPort};
22use super::reactor::Reactor;
23use super::state::StateHandle;
24use super::subgraph::Subgraph;
25use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
26use crate::Never;
27use crate::scheduled::ticks::{TickDuration, TickInstant};
28use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
29
30/// A DFIR graph. Owns, schedules, and runs the compiled subgraphs.
31#[derive(Default)]
32pub struct Dfir<'a> {
33    pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
34
35    pub(super) loop_data: SecondarySlotVec<LoopTag, LoopData>,
36
37    pub(super) context: Context,
38
39    handoffs: SlotVec<HandoffTag, HandoffData>,
40
41    #[cfg(feature = "meta")]
42    /// See [`Self::meta_graph()`].
43    meta_graph: Option<DfirGraph>,
44
45    #[cfg(feature = "meta")]
46    /// See [`Self::diagnostics()`].
47    diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
48}
49
50/// Methods for [`TeeingHandoff`] teeing and dropping.
51impl Dfir<'_> {
52    /// Tees a [`TeeingHandoff`].
53    pub fn teeing_handoff_tee<T>(
54        &mut self,
55        tee_parent_port: &RecvPort<TeeingHandoff<T>>,
56    ) -> RecvPort<TeeingHandoff<T>>
57    where
58        T: Clone,
59    {
60        // If we're teeing from a child make sure to find root.
61        let tee_root = self.handoffs[tee_parent_port.handoff_id].pred_handoffs[0];
62
63        // Set up teeing metadata.
64        let tee_root_data = &mut self.handoffs[tee_root];
65        let tee_root_data_name = tee_root_data.name.clone();
66
67        // Insert new handoff output.
68        let teeing_handoff =
69            <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*tee_root_data.handoff).unwrap();
70        let new_handoff = teeing_handoff.tee();
71
72        // Handoff ID of new tee output.
73        let new_hoff_id = self.handoffs.insert_with_key(|new_hoff_id| {
74            let new_name = Cow::Owned(format!("{} tee {:?}", tee_root_data_name, new_hoff_id));
75            let mut new_handoff_data = HandoffData::new(new_name, new_handoff, new_hoff_id);
76            // Set self's predecessor as `tee_root`.
77            new_handoff_data.pred_handoffs = vec![tee_root];
78            new_handoff_data
79        });
80
81        // Go to `tee_root`'s successors and insert self (the new tee output).
82        let tee_root_data = &mut self.handoffs[tee_root];
83        tee_root_data.succ_handoffs.push(new_hoff_id);
84
85        // Add our new handoff id into the subgraph data if the send `tee_root` has already been
86        // used to add a subgraph.
87        assert!(
88            tee_root_data.preds.len() <= 1,
89            "Tee send side should only have one sender (or none set yet)."
90        );
91        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
92            self.subgraphs[pred_sg_id].succs.push(new_hoff_id);
93        }
94
95        let output_port = RecvPort {
96            handoff_id: new_hoff_id,
97            _marker: PhantomData,
98        };
99        output_port
100    }
101
102    /// Marks an output of a [`TeeingHandoff`] as dropped so that no more data will be sent to it.
103    ///
104    /// It is recommended to not not use this method and instead simply avoid teeing a
105    /// [`TeeingHandoff`] when it is not needed.
106    pub fn teeing_handoff_drop<T>(&mut self, tee_port: RecvPort<TeeingHandoff<T>>)
107    where
108        T: Clone,
109    {
110        let data = &self.handoffs[tee_port.handoff_id];
111        let teeing_handoff = <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*data.handoff).unwrap();
112        teeing_handoff.drop();
113
114        let tee_root = data.pred_handoffs[0];
115        let tee_root_data = &mut self.handoffs[tee_root];
116        // Remove this output from the send succ handoff list.
117        tee_root_data
118            .succ_handoffs
119            .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
120        // Remove from subgraph successors if send port was already connected.
121        assert!(
122            tee_root_data.preds.len() <= 1,
123            "Tee send side should only have one sender (or none set yet)."
124        );
125        if let Some(&pred_sg_id) = tee_root_data.preds.first() {
126            self.subgraphs[pred_sg_id]
127                .succs
128                .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
129        }
130    }
131}
132
133impl<'a> Dfir<'a> {
134    /// Create a new empty graph.
135    pub fn new() -> Self {
136        Default::default()
137    }
138
139    /// Assign the meta graph via JSON string. Used internally by the [`dfir_syntax`] and other macros.
140    #[doc(hidden)]
141    pub fn __assign_meta_graph(&mut self, _meta_graph_json: &str) {
142        #[cfg(feature = "meta")]
143        {
144            let mut meta_graph: DfirGraph =
145                serde_json::from_str(_meta_graph_json).expect("Failed to deserialize graph.");
146
147            let mut op_inst_diagnostics = Vec::new();
148            meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
149            assert!(
150                op_inst_diagnostics.is_empty(),
151                "Expected no diagnostics, got: {:#?}",
152                op_inst_diagnostics
153            );
154
155            assert!(self.meta_graph.replace(meta_graph).is_none());
156        }
157    }
158    /// Assign the diagnostics via JSON string.
159    #[doc(hidden)]
160    pub fn __assign_diagnostics(&mut self, _diagnostics_json: &'static str) {
161        #[cfg(feature = "meta")]
162        {
163            let diagnostics: Vec<Diagnostic<SerdeSpan>> = serde_json::from_str(_diagnostics_json)
164                .expect("Failed to deserialize diagnostics.");
165
166            assert!(self.diagnostics.replace(diagnostics).is_none());
167        }
168    }
169
170    /// Return a handle to the meta graph, if set. The meta graph is a
171    /// representation of all the operators, subgraphs, and handoffs in this instance.
172    /// Will only be set if this graph was constructed using a surface syntax macro.
173    #[cfg(feature = "meta")]
174    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
175    pub fn meta_graph(&self) -> Option<&DfirGraph> {
176        self.meta_graph.as_ref()
177    }
178
179    /// Returns any diagnostics generated by the surface syntax macro. Each diagnostic is a pair of
180    /// (1) a `Diagnostic` with span info reset and (2) the `ToString` version of the diagnostic
181    /// with original span info.
182    /// Will only be set if this graph was constructed using a surface syntax macro.
183    #[cfg(feature = "meta")]
184    #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
185    pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
186        self.diagnostics.as_deref()
187    }
188
189    /// Returns a reactor for externally scheduling subgraphs, possibly from another thread.
190    /// Reactor events are considered to be external events.
191    pub fn reactor(&self) -> Reactor {
192        Reactor::new(self.context.event_queue_send.clone())
193    }
194
195    /// Gets the current tick (local time) count.
196    pub fn current_tick(&self) -> TickInstant {
197        self.context.current_tick
198    }
199
200    /// Gets the current stratum nubmer.
201    pub fn current_stratum(&self) -> usize {
202        self.context.current_stratum
203    }
204
205    /// Runs the dataflow until the next tick begins.
206    /// Returns true if any work was done.
207    #[tracing::instrument(level = "trace", skip(self), ret)]
208    pub fn run_tick(&mut self) -> bool {
209        let mut work_done = false;
210        // While work is immediately available *on the current tick*.
211        while self.next_stratum(true) {
212            work_done = true;
213            // Do any work.
214            self.run_stratum();
215        }
216        work_done
217    }
218
219    /// Runs the dataflow until no more (externally-triggered) work is immediately available.
220    /// Runs at least one tick of dataflow, even if no external events have been received.
221    /// If the dataflow contains loops this method may run forever.
222    /// Returns true if any work was done.
223    #[tracing::instrument(level = "trace", skip(self), ret)]
224    pub fn run_available(&mut self) -> bool {
225        let mut work_done = false;
226        // While work is immediately available.
227        while self.next_stratum(false) {
228            work_done = true;
229            // Do any work.
230            self.run_stratum();
231        }
232        work_done
233    }
234
235    /// Runs the dataflow until no more (externally-triggered) work is immediately available.
236    /// Runs at least one tick of dataflow, even if no external events have been received.
237    /// If the dataflow contains loops this method may run forever.
238    /// Returns true if any work was done.
239    /// Yields repeatedly to allow external events to happen.
240    #[tracing::instrument(level = "trace", skip(self), ret)]
241    pub async fn run_available_async(&mut self) -> bool {
242        let mut work_done = false;
243        // While work is immediately available.
244        while self.next_stratum(false) {
245            work_done = true;
246            // Do any work.
247            self.run_stratum();
248
249            // Yield between each stratum to receive more events.
250            // TODO(mingwei): really only need to yield at start of ticks though.
251            tokio::task::yield_now().await;
252        }
253        work_done
254    }
255
256    /// Runs the current stratum of the dataflow until no more local work is available (does not receive events).
257    /// Returns true if any work was done.
258    #[tracing::instrument(level = "trace", skip(self), fields(tick = u64::from(self.context.current_tick), stratum = self.context.current_stratum), ret)]
259    pub fn run_stratum(&mut self) -> bool {
260        // Make sure to spawn tasks once dfir is running!
261        // This drains the task buffer, so becomes a no-op after first call.
262        self.context.spawn_tasks();
263
264        let mut work_done = false;
265
266        'pop: while let Some(sg_id) =
267            self.context.stratum_queues[self.context.current_stratum].pop_front()
268        {
269            {
270                let sg_data = &mut self.subgraphs[sg_id];
271                // This must be true for the subgraph to be enqueued.
272                assert!(sg_data.is_scheduled.take());
273
274                let _enter = tracing::info_span!(
275                    "run-subgraph",
276                    sg_id = sg_id.to_string(),
277                    sg_name = &*sg_data.name,
278                    sg_depth = sg_data.loop_depth,
279                    sg_loop_nonce = sg_data.last_loop_nonce.0,
280                    sg_iter_count = sg_data.last_loop_nonce.1,
281                )
282                .entered();
283
284                match sg_data.loop_depth.cmp(&self.context.loop_nonce_stack.len()) {
285                    Ordering::Greater => {
286                        // We have entered a loop.
287                        self.context.loop_nonce += 1;
288                        self.context.loop_nonce_stack.push(self.context.loop_nonce);
289                        tracing::trace!(loop_nonce = self.context.loop_nonce, "Entered loop.");
290                    }
291                    Ordering::Less => {
292                        // We have exited a loop.
293                        self.context.loop_nonce_stack.pop();
294                        tracing::trace!("Exited loop.");
295                    }
296                    Ordering::Equal => {}
297                }
298
299                self.context.subgraph_id = sg_id;
300                self.context.is_first_run_this_tick = sg_data
301                    .last_tick_run_in
302                    .is_none_or(|last_tick| last_tick < self.context.current_tick);
303
304                if let Some(loop_id) = sg_data.loop_id {
305                    // Loop execution - running loop block, from start to finish, containing
306                    // multiple iterations.
307                    // Loop iteration - a single iteration of a loop block, all subgraphs within
308                    // the loop should run (at most) once.
309
310                    // If the previous run of this subgraph had the same loop execution and
311                    // iteration count, then we need to increment the iteration count.
312                    let curr_loop_nonce = self.context.loop_nonce_stack.last().copied();
313
314                    let LoopData {
315                        iter_count: loop_iter_count,
316                        allow_another_iteration,
317                    } = &mut self.loop_data[loop_id];
318
319                    let (prev_loop_nonce, prev_iter_count) = sg_data.last_loop_nonce;
320
321                    // If the loop nonce is the same as the previous execution, then we are in
322                    // the same loop execution.
323                    // `curr_loop_nonce` is `None` for top-level loops, and top-level loops are
324                    // always in the same (singular) loop execution.
325                    let (curr_iter_count, new_loop_execution) =
326                        if curr_loop_nonce.is_none_or(|nonce| nonce == prev_loop_nonce) {
327                            // If the iteration count is the same as the previous execution, then
328                            // we are on the next iteration.
329                            if *loop_iter_count == prev_iter_count {
330                                // If not true, then we shall not run the next iteration.
331                                if !std::mem::take(allow_another_iteration) {
332                                    tracing::debug!(
333                                        "Loop will not continue to next iteration, skipping."
334                                    );
335                                    continue 'pop;
336                                }
337                                // Increment `loop_iter_count` or set it to 0.
338                                loop_iter_count.map_or((0, true), |n| (n + 1, false))
339                            } else {
340                                // Otherwise update the local iteration count to match the loop.
341                                debug_assert!(
342                                    prev_iter_count < *loop_iter_count,
343                                    "Expect loop iteration count to be increasing."
344                                );
345                                (loop_iter_count.unwrap(), false)
346                            }
347                        } else {
348                            // The loop execution has already begun, but this is the first time this particular subgraph is running.
349                            (0, false)
350                        };
351
352                    if new_loop_execution {
353                        // Run state hooks.
354                        self.context.run_state_hooks_loop(loop_id);
355                    }
356                    tracing::debug!("Loop iteration count {}", curr_iter_count);
357
358                    *loop_iter_count = Some(curr_iter_count);
359                    self.context.loop_iter_count = curr_iter_count;
360                    sg_data.last_loop_nonce =
361                        (curr_loop_nonce.unwrap_or_default(), Some(curr_iter_count));
362                }
363
364                // Run subgraph state hooks.
365                self.context.run_state_hooks_subgraph(sg_id);
366
367                tracing::info!("Running subgraph.");
368                sg_data.subgraph.run(&mut self.context, &mut self.handoffs);
369
370                sg_data.last_tick_run_in = Some(self.context.current_tick);
371            }
372
373            let sg_data = &self.subgraphs[sg_id];
374            for &handoff_id in sg_data.succs.iter() {
375                let handoff = &self.handoffs[handoff_id];
376                if !handoff.handoff.is_bottom() {
377                    for &succ_id in handoff.succs.iter() {
378                        let succ_sg_data = &self.subgraphs[succ_id];
379                        // If we have sent data to the next tick, then we can start the next tick.
380                        if succ_sg_data.stratum < self.context.current_stratum && !sg_data.is_lazy {
381                            self.context.can_start_tick = true;
382                        }
383                        // Add subgraph to stratum queue if it is not already scheduled.
384                        if !succ_sg_data.is_scheduled.replace(true) {
385                            self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
386                        }
387                        // Add stratum to stratum stack if it is within a loop.
388                        if 0 < succ_sg_data.loop_depth {
389                            // TODO(mingwei): handle duplicates
390                            self.context
391                                .stratum_stack
392                                .push(succ_sg_data.loop_depth, succ_sg_data.stratum);
393                        }
394                    }
395                }
396            }
397
398            let reschedule = self.context.reschedule_loop_block.take();
399            let allow_another = self.context.allow_another_iteration.take();
400
401            if reschedule {
402                // Re-enqueue the subgraph.
403                self.context.schedule_deferred.push(sg_id);
404                self.context
405                    .stratum_stack
406                    .push(sg_data.loop_depth, sg_data.stratum);
407            }
408            if (reschedule || allow_another)
409                && let Some(loop_id) = sg_data.loop_id
410            {
411                self.loop_data
412                    .get_mut(loop_id)
413                    .unwrap()
414                    .allow_another_iteration = true;
415            }
416
417            work_done = true;
418        }
419        work_done
420    }
421
422    /// Go to the next stratum which has work available, possibly the current stratum.
423    /// Return true if more work is available, otherwise false if no work is immediately
424    /// available on any strata.
425    ///
426    /// This will receive external events when at the start of a tick.
427    ///
428    /// If `current_tick_only` is set to `true`, will only return `true` if work is immediately
429    /// available on the *current tick*.
430    ///
431    /// If this returns false then the graph will be at the start of a tick (at stratum 0, can
432    /// receive more external events).
433    #[tracing::instrument(level = "trace", skip(self), ret)]
434    pub fn next_stratum(&mut self, current_tick_only: bool) -> bool {
435        tracing::trace!(
436            events_received_tick = self.context.events_received_tick,
437            can_start_tick = self.context.can_start_tick,
438            "Starting `next_stratum` call.",
439        );
440
441        // The stratum we will stop searching at, i.e. made a full loop around.
442        let mut end_stratum = self.context.current_stratum;
443        let mut new_tick_started = false;
444
445        if 0 == self.context.current_stratum {
446            new_tick_started = true;
447
448            // Starting the tick, reset this to `false`.
449            tracing::trace!("Starting tick, setting `can_start_tick = false`.");
450            self.context.can_start_tick = false;
451            self.context.current_tick_start = SystemTime::now();
452
453            // Ensure external events are received before running the tick.
454            if !self.context.events_received_tick {
455                // Add any external jobs to ready queue.
456                self.try_recv_events();
457            }
458        }
459
460        loop {
461            tracing::trace!(
462                tick = u64::from(self.context.current_tick),
463                stratum = self.context.current_stratum,
464                "Looking for work on stratum."
465            );
466            // If current stratum has work, return true.
467            if !self.context.stratum_queues[self.context.current_stratum].is_empty() {
468                tracing::trace!(
469                    tick = u64::from(self.context.current_tick),
470                    stratum = self.context.current_stratum,
471                    "Work found on stratum."
472                );
473                return true;
474            }
475
476            if let Some(next_stratum) = self.context.stratum_stack.pop() {
477                self.context.current_stratum = next_stratum;
478
479                // Now schedule deferred subgraphs.
480                {
481                    for sg_id in self.context.schedule_deferred.drain(..) {
482                        let sg_data = &self.subgraphs[sg_id];
483                        tracing::info!(
484                            tick = u64::from(self.context.current_tick),
485                            stratum = self.context.current_stratum,
486                            sg_id = sg_id.to_string(),
487                            sg_name = &*sg_data.name,
488                            is_scheduled = sg_data.is_scheduled.get(),
489                            "Rescheduling deferred subgraph."
490                        );
491                        if !sg_data.is_scheduled.replace(true) {
492                            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
493                        }
494                    }
495                }
496            } else {
497                // Increment stratum counter.
498                self.context.current_stratum += 1;
499
500                if self.context.current_stratum >= self.context.stratum_queues.len() {
501                    new_tick_started = true;
502
503                    tracing::trace!(
504                        can_start_tick = self.context.can_start_tick,
505                        "End of tick {}, starting tick {}.",
506                        self.context.current_tick,
507                        self.context.current_tick + TickDuration::SINGLE_TICK,
508                    );
509                    self.context.run_state_hooks_tick();
510
511                    self.context.current_stratum = 0;
512                    self.context.current_tick += TickDuration::SINGLE_TICK;
513                    self.context.events_received_tick = false;
514
515                    if current_tick_only {
516                        tracing::trace!(
517                            "`current_tick_only` is `true`, returning `false` before receiving events."
518                        );
519                        return false;
520                    } else {
521                        self.try_recv_events();
522                        if std::mem::replace(&mut self.context.can_start_tick, false) {
523                            tracing::trace!(
524                                tick = u64::from(self.context.current_tick),
525                                "`can_start_tick` is `true`, continuing."
526                            );
527                            // Do a full loop more to find where events have been added.
528                            end_stratum = 0;
529                            continue;
530                        } else {
531                            tracing::trace!(
532                                "`can_start_tick` is `false`, re-setting `events_received_tick = false`, returning `false`."
533                            );
534                            self.context.events_received_tick = false;
535                            return false;
536                        }
537                    }
538                }
539            }
540
541            // After incrementing, exit if we made a full loop around the strata.
542            if new_tick_started && end_stratum == self.context.current_stratum {
543                tracing::trace!(
544                    "Made full loop around stratum, re-setting `current_stratum = 0`, returning `false`."
545                );
546                // Note: if current stratum had work, the very first loop iteration would've
547                // returned true. Therefore we can return false without checking.
548                // Also means nothing was done so we can reset the stratum to zero and wait for
549                // events.
550                self.context.events_received_tick = false;
551                self.context.current_stratum = 0;
552                return false;
553            }
554        }
555    }
556
557    /// Runs the dataflow graph forever.
558    ///
559    /// TODO(mingwei): Currently blocks forever, no notion of "completion."
560    #[tracing::instrument(level = "trace", skip(self), ret)]
561    pub fn run(&mut self) -> Option<Never> {
562        loop {
563            self.run_tick();
564        }
565    }
566
567    /// Runs the dataflow graph forever.
568    ///
569    /// TODO(mingwei): Currently blocks forever, no notion of "completion."
570    #[tracing::instrument(level = "trace", skip(self), ret)]
571    pub async fn run_async(&mut self) -> Option<Never> {
572        loop {
573            // Run any work which is immediately available.
574            self.run_available_async().await;
575            // When no work is available yield until more events occur.
576            self.recv_events_async().await;
577        }
578    }
579
580    /// Enqueues subgraphs triggered by events without blocking.
581    ///
582    /// Returns the number of subgraphs enqueued, and if any were external.
583    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
584    pub fn try_recv_events(&mut self) -> usize {
585        let mut enqueued_count = 0;
586        while let Ok((sg_id, is_external)) = self.context.event_queue_recv.try_recv() {
587            let sg_data = &self.subgraphs[sg_id];
588            tracing::trace!(
589                sg_id = sg_id.to_string(),
590                is_external = is_external,
591                sg_stratum = sg_data.stratum,
592                "Event received."
593            );
594            if !sg_data.is_scheduled.replace(true) {
595                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
596                enqueued_count += 1;
597            }
598            if is_external {
599                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
600                // Or if the stratum is in the next tick.
601                if !self.context.events_received_tick
602                    || sg_data.stratum < self.context.current_stratum
603                {
604                    tracing::trace!(
605                        current_stratum = self.context.current_stratum,
606                        sg_stratum = sg_data.stratum,
607                        "External event, setting `can_start_tick = true`."
608                    );
609                    self.context.can_start_tick = true;
610                }
611            }
612        }
613        self.context.events_received_tick = true;
614
615        enqueued_count
616    }
617
618    /// Enqueues subgraphs triggered by external events, blocking until at
619    /// least one subgraph is scheduled **from an external event**.
620    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
621    pub fn recv_events(&mut self) -> Option<usize> {
622        let mut count = 0;
623        loop {
624            let (sg_id, is_external) = self.context.event_queue_recv.blocking_recv()?;
625            let sg_data = &self.subgraphs[sg_id];
626            tracing::trace!(
627                sg_id = sg_id.to_string(),
628                is_external = is_external,
629                sg_stratum = sg_data.stratum,
630                "Event received."
631            );
632            if !sg_data.is_scheduled.replace(true) {
633                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
634                count += 1;
635            }
636            if is_external {
637                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
638                // Or if the stratum is in the next tick.
639                if !self.context.events_received_tick
640                    || sg_data.stratum < self.context.current_stratum
641                {
642                    tracing::trace!(
643                        current_stratum = self.context.current_stratum,
644                        sg_stratum = sg_data.stratum,
645                        "External event, setting `can_start_tick = true`."
646                    );
647                    self.context.can_start_tick = true;
648                }
649                break;
650            }
651        }
652        self.context.events_received_tick = true;
653
654        // Enqueue any other immediate events.
655        let extra_count = self.try_recv_events();
656        Some(count + extra_count)
657    }
658
659    /// Enqueues subgraphs triggered by external events asynchronously, waiting until at least one
660    /// subgraph is scheduled **from an external event**. Returns the number of subgraphs enqueued,
661    /// which may be zero if an external event scheduled an already-scheduled subgraph.
662    ///
663    /// Returns `None` if the event queue is closed, but that should not happen normally.
664    #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
665    pub async fn recv_events_async(&mut self) -> Option<usize> {
666        let mut count = 0;
667        loop {
668            tracing::trace!("Awaiting events (`event_queue_recv`).");
669            let (sg_id, is_external) = self.context.event_queue_recv.recv().await?;
670            let sg_data = &self.subgraphs[sg_id];
671            tracing::trace!(
672                sg_id = sg_id.to_string(),
673                is_external = is_external,
674                sg_stratum = sg_data.stratum,
675                "Event received."
676            );
677            if !sg_data.is_scheduled.replace(true) {
678                self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
679                count += 1;
680            }
681            if is_external {
682                // Next tick is triggered if we are at the start of the next tick (`!self.events_receved_tick`).
683                // Or if the stratum is in the next tick.
684                if !self.context.events_received_tick
685                    || sg_data.stratum < self.context.current_stratum
686                {
687                    tracing::trace!(
688                        current_stratum = self.context.current_stratum,
689                        sg_stratum = sg_data.stratum,
690                        "External event, setting `can_start_tick = true`."
691                    );
692                    self.context.can_start_tick = true;
693                }
694                break;
695            }
696        }
697        self.context.events_received_tick = true;
698
699        // Enqueue any other immediate events.
700        let extra_count = self.try_recv_events();
701        Some(count + extra_count)
702    }
703
704    /// Schedules a subgraph to be run. See also: [`Context::schedule_subgraph`].
705    pub fn schedule_subgraph(&mut self, sg_id: SubgraphId) -> bool {
706        let sg_data = &self.subgraphs[sg_id];
707        let already_scheduled = sg_data.is_scheduled.replace(true);
708        if !already_scheduled {
709            self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
710            true
711        } else {
712            false
713        }
714    }
715
716    /// Adds a new compiled subgraph with the specified inputs and outputs in stratum 0.
717    pub fn add_subgraph<Name, R, W, F>(
718        &mut self,
719        name: Name,
720        recv_ports: R,
721        send_ports: W,
722        subgraph: F,
723    ) -> SubgraphId
724    where
725        Name: Into<Cow<'static, str>>,
726        R: 'static + PortList<RECV>,
727        W: 'static + PortList<SEND>,
728        F: 'static + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
729    {
730        self.add_subgraph_stratified(name, 0, recv_ports, send_ports, false, subgraph)
731    }
732
733    /// Adds a new compiled subgraph with the specified inputs, outputs, and stratum number.
734    ///
735    /// TODO(mingwei): add example in doc.
736    pub fn add_subgraph_stratified<Name, R, W, F>(
737        &mut self,
738        name: Name,
739        stratum: usize,
740        recv_ports: R,
741        send_ports: W,
742        laziness: bool,
743        subgraph: F,
744    ) -> SubgraphId
745    where
746        Name: Into<Cow<'static, str>>,
747        R: 'static + PortList<RECV>,
748        W: 'static + PortList<SEND>,
749        F: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
750    {
751        self.add_subgraph_full(
752            name, stratum, recv_ports, send_ports, laziness, None, subgraph,
753        )
754    }
755
756    /// Adds a new compiled subgraph with all options.
757    #[expect(clippy::too_many_arguments, reason = "Mainly for internal use.")]
758    pub fn add_subgraph_full<Name, R, W, F>(
759        &mut self,
760        name: Name,
761        stratum: usize,
762        recv_ports: R,
763        send_ports: W,
764        laziness: bool,
765        loop_id: Option<LoopId>,
766        mut subgraph: F,
767    ) -> SubgraphId
768    where
769        Name: Into<Cow<'static, str>>,
770        R: 'static + PortList<RECV>,
771        W: 'static + PortList<SEND>,
772        F: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
773    {
774        // SAFETY: Check that the send and recv ports are from `self.handoffs`.
775        recv_ports.assert_is_from(&self.handoffs);
776        send_ports.assert_is_from(&self.handoffs);
777
778        let loop_depth = loop_id
779            .and_then(|loop_id| self.context.loop_depth.get(loop_id))
780            .copied()
781            .unwrap_or(0);
782
783        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
784            let (mut subgraph_preds, mut subgraph_succs) = Default::default();
785            recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
786            send_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_succs, sg_id, false);
787
788            let subgraph =
789                move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
790                    let (recv, send) = unsafe {
791                        // SAFETY:
792                        // 1. We checked `assert_is_from` at assembly time, above.
793                        // 2. `SlotVec` is insert-only so no handoffs could have changed since then.
794                        (
795                            recv_ports.make_ctx(&*handoffs),
796                            send_ports.make_ctx(&*handoffs),
797                        )
798                    };
799                    (subgraph)(context, recv, send);
800                };
801            SubgraphData::new(
802                name.into(),
803                stratum,
804                subgraph,
805                subgraph_preds,
806                subgraph_succs,
807                true,
808                laziness,
809                loop_id,
810                loop_depth,
811            )
812        });
813        self.context.init_stratum(stratum);
814        self.context.stratum_queues[stratum].push_back(sg_id);
815
816        sg_id
817    }
818
819    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
820    pub fn add_subgraph_n_m<Name, R, W, F>(
821        &mut self,
822        name: Name,
823        recv_ports: Vec<RecvPort<R>>,
824        send_ports: Vec<SendPort<W>>,
825        subgraph: F,
826    ) -> SubgraphId
827    where
828        Name: Into<Cow<'static, str>>,
829        R: 'static + Handoff,
830        W: 'static + Handoff,
831        F: 'static
832            + for<'ctx> FnMut(&'ctx mut Context, &'ctx [&'ctx RecvCtx<R>], &'ctx [&'ctx SendCtx<W>]),
833    {
834        self.add_subgraph_stratified_n_m(name, 0, recv_ports, send_ports, subgraph)
835    }
836
837    /// Adds a new compiled subgraph with a variable number of inputs and outputs of the same respective handoff types.
838    pub fn add_subgraph_stratified_n_m<Name, R, W, F>(
839        &mut self,
840        name: Name,
841        stratum: usize,
842        recv_ports: Vec<RecvPort<R>>,
843        send_ports: Vec<SendPort<W>>,
844        mut subgraph: F,
845    ) -> SubgraphId
846    where
847        Name: Into<Cow<'static, str>>,
848        R: 'static + Handoff,
849        W: 'static + Handoff,
850        F: 'static
851            + for<'ctx> FnMut(&'ctx mut Context, &'ctx [&'ctx RecvCtx<R>], &'ctx [&'ctx SendCtx<W>]),
852    {
853        let sg_id = self.subgraphs.insert_with_key(|sg_id| {
854            let subgraph_preds = recv_ports.iter().map(|port| port.handoff_id).collect();
855            let subgraph_succs = send_ports.iter().map(|port| port.handoff_id).collect();
856
857            for recv_port in recv_ports.iter() {
858                self.handoffs[recv_port.handoff_id].succs.push(sg_id);
859            }
860            for send_port in send_ports.iter() {
861                self.handoffs[send_port.handoff_id].preds.push(sg_id);
862            }
863
864            let subgraph =
865                move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
866                    let recvs: Vec<&RecvCtx<R>> = recv_ports
867                        .iter()
868                        .map(|hid| hid.handoff_id)
869                        .map(|hid| handoffs.get(hid).unwrap())
870                        .map(|h_data| {
871                            <dyn Any>::downcast_ref(&*h_data.handoff)
872                                .expect("Attempted to cast handoff to wrong type.")
873                        })
874                        .map(RefCast::ref_cast)
875                        .collect();
876
877                    let sends: Vec<&SendCtx<W>> = send_ports
878                        .iter()
879                        .map(|hid| hid.handoff_id)
880                        .map(|hid| handoffs.get(hid).unwrap())
881                        .map(|h_data| {
882                            <dyn Any>::downcast_ref(&*h_data.handoff)
883                                .expect("Attempted to cast handoff to wrong type.")
884                        })
885                        .map(RefCast::ref_cast)
886                        .collect();
887
888                    (subgraph)(context, &recvs, &sends)
889                };
890            SubgraphData::new(
891                name.into(),
892                stratum,
893                subgraph,
894                subgraph_preds,
895                subgraph_succs,
896                true,
897                false,
898                None,
899                0,
900            )
901        });
902
903        self.context.init_stratum(stratum);
904        self.context.stratum_queues[stratum].push_back(sg_id);
905
906        sg_id
907    }
908
909    /// Creates a handoff edge and returns the corresponding send and receive ports.
910    pub fn make_edge<Name, H>(&mut self, name: Name) -> (SendPort<H>, RecvPort<H>)
911    where
912        Name: Into<Cow<'static, str>>,
913        H: 'static + Handoff,
914    {
915        // Create and insert handoff.
916        let handoff = H::default();
917        let handoff_id = self
918            .handoffs
919            .insert_with_key(|hoff_id| HandoffData::new(name.into(), handoff, hoff_id));
920
921        // Make ports.
922        let input_port = SendPort {
923            handoff_id,
924            _marker: PhantomData,
925        };
926        let output_port = RecvPort {
927            handoff_id,
928            _marker: PhantomData,
929        };
930        (input_port, output_port)
931    }
932
933    /// Adds referenceable state into this instance. Returns a state handle which can be
934    /// used externally or by operators to access the state.
935    ///
936    /// This is part of the "state API".
937    pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
938    where
939        T: Any,
940    {
941        self.context.add_state(state)
942    }
943
944    /// Sets a hook to modify the state at the end of each tick, using the supplied closure.
945    ///
946    /// This is part of the "state API".
947    pub fn set_state_lifespan_hook<T>(
948        &mut self,
949        handle: StateHandle<T>,
950        lifespan: StateLifespan,
951        hook_fn: impl 'static + FnMut(&mut T),
952    ) where
953        T: Any,
954    {
955        self.context
956            .set_state_lifespan_hook(handle, lifespan, hook_fn)
957    }
958
959    /// Gets a exclusive (mut) ref to the internal context, setting the subgraph ID.
960    pub fn context_mut(&mut self, sg_id: SubgraphId) -> &mut Context {
961        self.context.subgraph_id = sg_id;
962        &mut self.context
963    }
964
965    /// Adds a new loop with the given parent (or `None` for top-level). Returns a loop ID which
966    /// is used in [`Self::add_subgraph_stratified`] or for nested loops.
967    ///
968    /// TODO(mingwei): add loop names to ensure traceability while debugging?
969    pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
970        let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
971        let loop_id = self.context.loop_depth.insert(depth);
972        self.loop_data.insert(
973            loop_id,
974            LoopData {
975                iter_count: None,
976                allow_another_iteration: true,
977            },
978        );
979        loop_id
980    }
981}
982
983impl Dfir<'_> {
984    /// Alias for [`Context::request_task`].
985    pub fn request_task<Fut>(&mut self, future: Fut)
986    where
987        Fut: Future<Output = ()> + 'static,
988    {
989        self.context.request_task(future);
990    }
991
992    /// Alias for [`Context::abort_tasks`].
993    pub fn abort_tasks(&mut self) {
994        self.context.abort_tasks()
995    }
996
997    /// Alias for [`Context::join_tasks`].
998    pub fn join_tasks(&mut self) -> impl use<'_> + Future {
999        self.context.join_tasks()
1000    }
1001}
1002
1003impl Drop for Dfir<'_> {
1004    fn drop(&mut self) {
1005        self.abort_tasks();
1006    }
1007}
1008
1009/// A handoff and its input and output [SubgraphId]s.
1010///
1011/// Internal use: used to track the dfir graph structure.
1012///
1013/// TODO(mingwei): restructure `PortList` so this can be crate-private.
1014#[doc(hidden)]
1015pub struct HandoffData {
1016    /// A friendly name for diagnostics.
1017    pub(super) name: Cow<'static, str>,
1018    /// Crate-visible to crate for `handoff_list` internals.
1019    pub(super) handoff: Box<dyn HandoffMeta>,
1020    /// Preceeding subgraphs (including the send side of a teeing handoff).
1021    pub(super) preds: SmallVec<[SubgraphId; 1]>,
1022    /// Successor subgraphs (including recv sides of teeing handoffs).
1023    pub(super) succs: SmallVec<[SubgraphId; 1]>,
1024
1025    /// Predecessor handoffs, used by teeing handoffs.
1026    /// Should be `self` on any teeing send sides (input).
1027    /// Should be the send `HandoffId` if this is teeing recv side (output).
1028    /// Should be just `self`'s `HandoffId` on other handoffs.
1029    /// This field is only used in initialization.
1030    pub(super) pred_handoffs: Vec<HandoffId>,
1031    /// Successor handoffs, used by teeing handoffs.
1032    /// Should be a list of outputs on the teeing send side (input).
1033    /// Should be `self` on any teeing recv sides (outputs).
1034    /// Should be just `self`'s `HandoffId` on other handoffs.
1035    /// This field is only used in initialization.
1036    pub(super) succ_handoffs: Vec<HandoffId>,
1037}
1038impl std::fmt::Debug for HandoffData {
1039    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1040        f.debug_struct("HandoffData")
1041            .field("preds", &self.preds)
1042            .field("succs", &self.succs)
1043            .finish_non_exhaustive()
1044    }
1045}
1046impl HandoffData {
1047    /// New with `pred_handoffs` and `succ_handoffs` set to its own [`HandoffId`]: `vec![hoff_id]`.
1048    pub fn new(
1049        name: Cow<'static, str>,
1050        handoff: impl 'static + HandoffMeta,
1051        hoff_id: HandoffId,
1052    ) -> Self {
1053        let (preds, succs) = Default::default();
1054        Self {
1055            name,
1056            handoff: Box::new(handoff),
1057            preds,
1058            succs,
1059            pred_handoffs: vec![hoff_id],
1060            succ_handoffs: vec![hoff_id],
1061        }
1062    }
1063}
1064
1065/// A subgraph along with its predecessor and successor [SubgraphId]s.
1066///
1067/// Used internally by the [Dfir] struct to represent the dataflow graph
1068/// structure and scheduled state.
1069pub(super) struct SubgraphData<'a> {
1070    /// A friendly name for diagnostics.
1071    pub(super) name: Cow<'static, str>,
1072    /// This subgraph's stratum number.
1073    ///
1074    /// Within loop blocks, corresponds to the topological sort of the DAG created when `next_loop()/next_tick()` are removed.
1075    pub(super) stratum: usize,
1076    /// The actual execution code of the subgraph.
1077    subgraph: Box<dyn Subgraph + 'a>,
1078
1079    #[expect(dead_code, reason = "may be useful in the future")]
1080    preds: Vec<HandoffId>,
1081    succs: Vec<HandoffId>,
1082
1083    /// If this subgraph is scheduled in [`dfir_rs::stratum_queues`].
1084    /// [`Cell`] allows modifying this field when iterating `Self::preds` or
1085    /// `Self::succs`, as all `SubgraphData` are owned by the same vec
1086    /// `dfir_rs::subgraphs`.
1087    is_scheduled: Cell<bool>,
1088
1089    /// Keep track of the last tick that this subgraph was run in
1090    last_tick_run_in: Option<TickInstant>,
1091    /// A meaningless ID to track the last loop execution this subgraph was run in.
1092    /// `(loop_nonce, iter_count)` pair.
1093    last_loop_nonce: (usize, Option<usize>),
1094
1095    /// If this subgraph is marked as lazy, then sending data back to a lower stratum does not trigger a new tick to be run.
1096    is_lazy: bool,
1097
1098    /// The subgraph's loop ID, or `None` for the top level.
1099    loop_id: Option<LoopId>,
1100    /// The loop depth of the subgraph.
1101    loop_depth: usize,
1102}
1103impl<'a> SubgraphData<'a> {
1104    #[expect(clippy::too_many_arguments, reason = "internal use")]
1105    pub(crate) fn new(
1106        name: Cow<'static, str>,
1107        stratum: usize,
1108        subgraph: impl Subgraph + 'a,
1109        preds: Vec<HandoffId>,
1110        succs: Vec<HandoffId>,
1111        is_scheduled: bool,
1112        is_lazy: bool,
1113        loop_id: Option<LoopId>,
1114        loop_depth: usize,
1115    ) -> Self {
1116        Self {
1117            name,
1118            stratum,
1119            subgraph: Box::new(subgraph),
1120            preds,
1121            succs,
1122            is_scheduled: Cell::new(is_scheduled),
1123            last_tick_run_in: None,
1124            last_loop_nonce: (0, None),
1125            is_lazy,
1126            loop_id,
1127            loop_depth,
1128        }
1129    }
1130}
1131
1132pub(crate) struct LoopData {
1133    /// Count of iterations of this loop.
1134    iter_count: Option<usize>,
1135    /// If the loop has reason to do another iteration.
1136    allow_another_iteration: bool,
1137}
1138
1139/// Defines when state should be reset.
1140#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1141pub enum StateLifespan {
1142    /// Always reset, a ssociated with the subgraph.
1143    Subgraph(SubgraphId),
1144    /// Reset between loop executions.
1145    Loop(LoopId),
1146    /// Reset between ticks.
1147    Tick,
1148    /// Never reset.
1149    Static,
1150}