1use std::any::Any;
4use std::borrow::Cow;
5use std::cell::Cell;
6use std::cmp::Ordering;
7use std::future::Future;
8use std::marker::PhantomData;
9use std::rc::Rc;
10
11#[cfg(feature = "meta")]
12use dfir_lang::diagnostic::{Diagnostic, SerdeSpan};
13#[cfg(feature = "meta")]
14use dfir_lang::graph::DfirGraph;
15use ref_cast::RefCast;
16use smallvec::SmallVec;
17use tracing::Instrument;
18use web_time::SystemTime;
19
20use super::context::Context;
21use super::handoff::handoff_list::PortList;
22use super::handoff::{Handoff, HandoffMeta, TeeingHandoff};
23use super::metrics::{DfirMetrics, DfirMetricsState, InstrumentSubgraph};
24use super::port::{RECV, RecvCtx, RecvPort, SEND, SendCtx, SendPort};
25use super::reactor::Reactor;
26use super::state::StateHandle;
27use super::subgraph::Subgraph;
28use super::ticks::{TickDuration, TickInstant};
29use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
30use crate::Never;
31use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
32
33#[derive(Default)]
35pub struct Dfir<'a> {
36 pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
37
38 pub(super) loop_data: SecondarySlotVec<LoopTag, LoopData>,
39
40 pub(super) context: Context,
41
42 pub(super) handoffs: SlotVec<HandoffTag, HandoffData>,
43
44 metrics: Rc<DfirMetricsState>,
45
46 #[cfg(feature = "meta")]
47 meta_graph: Option<DfirGraph>,
49
50 #[cfg(feature = "meta")]
51 diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
53}
54
55impl Dfir<'_> {
57 pub fn teeing_handoff_tee<T>(
59 &mut self,
60 tee_parent_port: &RecvPort<TeeingHandoff<T>>,
61 ) -> RecvPort<TeeingHandoff<T>>
62 where
63 T: Clone,
64 {
65 let tee_root = self.handoffs[tee_parent_port.handoff_id].pred_handoffs[0];
67
68 let tee_root_data = &mut self.handoffs[tee_root];
70 let tee_root_data_name = tee_root_data.name.clone();
71
72 let teeing_handoff =
74 <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*tee_root_data.handoff).unwrap();
75 let new_handoff = teeing_handoff.tee();
76
77 let new_hoff_id = self.handoffs.insert_with_key(|new_hoff_id| {
79 let new_name = Cow::Owned(format!("{} tee {:?}", tee_root_data_name, new_hoff_id));
80 let mut new_handoff_data = HandoffData::new(new_name, new_handoff, new_hoff_id);
81 new_handoff_data.pred_handoffs = vec![tee_root];
83 new_handoff_data
84 });
85
86 let tee_root_data = &mut self.handoffs[tee_root];
88 tee_root_data.succ_handoffs.push(new_hoff_id);
89
90 assert!(
93 tee_root_data.preds.len() <= 1,
94 "Tee send side should only have one sender (or none set yet)."
95 );
96 if let Some(&pred_sg_id) = tee_root_data.preds.first() {
97 self.subgraphs[pred_sg_id].succs.push(new_hoff_id);
98 }
99
100 Rc::make_mut(&mut self.metrics)
102 .handoff_metrics
103 .insert(new_hoff_id, Default::default());
104
105 let output_port = RecvPort {
106 handoff_id: new_hoff_id,
107 _marker: PhantomData,
108 };
109 output_port
110 }
111
112 pub fn teeing_handoff_drop<T>(&mut self, tee_port: RecvPort<TeeingHandoff<T>>)
117 where
118 T: Clone,
119 {
120 let data = &self.handoffs[tee_port.handoff_id];
121 let teeing_handoff = <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*data.handoff).unwrap();
122 teeing_handoff.drop();
123
124 let tee_root = data.pred_handoffs[0];
125 let tee_root_data = &mut self.handoffs[tee_root];
126 tee_root_data
128 .succ_handoffs
129 .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
130 assert!(
132 tee_root_data.preds.len() <= 1,
133 "Tee send side should only have one sender (or none set yet)."
134 );
135 if let Some(&pred_sg_id) = tee_root_data.preds.first() {
136 self.subgraphs[pred_sg_id]
137 .succs
138 .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
139 }
140 }
141}
142
143impl<'a> Dfir<'a> {
144 pub fn new() -> Self {
146 Default::default()
147 }
148
149 #[doc(hidden)]
151 pub fn __assign_meta_graph(&mut self, _meta_graph_json: &str) {
152 #[cfg(feature = "meta")]
153 {
154 let mut meta_graph: DfirGraph =
155 serde_json::from_str(_meta_graph_json).expect("Failed to deserialize graph.");
156
157 let mut op_inst_diagnostics = Vec::new();
158 meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
159 assert!(
160 op_inst_diagnostics.is_empty(),
161 "Expected no diagnostics, got: {:#?}",
162 op_inst_diagnostics
163 );
164
165 assert!(self.meta_graph.replace(meta_graph).is_none());
166 }
167 }
168 #[doc(hidden)]
170 pub fn __assign_diagnostics(&mut self, _diagnostics_json: &'static str) {
171 #[cfg(feature = "meta")]
172 {
173 let diagnostics: Vec<Diagnostic<SerdeSpan>> = serde_json::from_str(_diagnostics_json)
174 .expect("Failed to deserialize diagnostics.");
175
176 assert!(self.diagnostics.replace(diagnostics).is_none());
177 }
178 }
179
180 #[cfg(feature = "meta")]
184 #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
185 pub fn meta_graph(&self) -> Option<&DfirGraph> {
186 self.meta_graph.as_ref()
187 }
188
189 #[cfg(feature = "meta")]
194 #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
195 pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
196 self.diagnostics.as_deref()
197 }
198
199 pub fn reactor(&self) -> Reactor {
202 Reactor::new(self.context.event_queue_send.clone())
203 }
204
205 pub fn current_tick(&self) -> TickInstant {
207 self.context.current_tick
208 }
209
210 pub fn current_stratum(&self) -> usize {
212 self.context.current_stratum
213 }
214
215 #[tracing::instrument(level = "trace", skip(self), ret)]
221 pub async fn run_tick(&mut self) -> bool {
222 let mut work_done = false;
223 while self.next_stratum(true) {
225 work_done = true;
226 self.run_stratum().await;
228 }
229 work_done
230 }
231
232 #[tracing::instrument(level = "trace", skip(self), ret)]
234 pub fn run_tick_sync(&mut self) -> bool {
235 let mut work_done = false;
236 while self.next_stratum(true) {
238 work_done = true;
239 run_sync(self.run_stratum());
241 }
242 work_done
243 }
244
245 #[tracing::instrument(level = "trace", skip(self), ret)]
253 pub async fn run_available(&mut self) -> bool {
254 let mut work_done = false;
255 while self.next_stratum(false) {
257 work_done = true;
258 self.run_stratum().await;
260
261 tokio::task::yield_now().await;
264 }
265 work_done
266 }
267
268 #[tracing::instrument(level = "trace", skip(self), ret)]
270 pub fn run_available_sync(&mut self) -> bool {
271 let mut work_done = false;
272 while self.next_stratum(false) {
274 work_done = true;
275 run_sync(self.run_stratum());
277 }
278 work_done
279 }
280
281 #[tracing::instrument(level = "trace", skip(self), fields(tick = u64::from(self.context.current_tick), stratum = self.context.current_stratum), ret)]
285 pub async fn run_stratum(&mut self) -> bool {
286 self.context.spawn_tasks();
289
290 let mut work_done = false;
291
292 'pop: while let Some(sg_id) =
293 self.context.stratum_queues[self.context.current_stratum].pop_front()
294 {
295 let sg_data = &mut self.subgraphs[sg_id];
296
297 for &handoff_id in sg_data.preds.iter() {
306 let handoff_metrics = &self.metrics.handoff_metrics[handoff_id];
307 let handoff_data = &mut self.handoffs[handoff_id];
308 let handoff_len = handoff_data.handoff.len();
309 handoff_metrics
310 .total_items_count
311 .update(|x| x + handoff_len);
312 handoff_metrics.curr_items_count.set(handoff_len);
313 }
314
315 {
317 assert!(sg_data.is_scheduled.take());
319
320 let run_subgraph_span_guard = tracing::info_span!(
321 "run-subgraph",
322 sg_id = sg_id.to_string(),
323 sg_name = &*sg_data.name,
324 sg_depth = sg_data.loop_depth,
325 sg_loop_nonce = sg_data.last_loop_nonce.0,
326 sg_iter_count = sg_data.last_loop_nonce.1,
327 )
328 .entered();
329
330 match sg_data.loop_depth.cmp(&self.context.loop_nonce_stack.len()) {
331 Ordering::Greater => {
332 self.context.loop_nonce += 1;
334 self.context.loop_nonce_stack.push(self.context.loop_nonce);
335 tracing::trace!(loop_nonce = self.context.loop_nonce, "Entered loop.");
336 }
337 Ordering::Less => {
338 self.context.loop_nonce_stack.pop();
340 tracing::trace!("Exited loop.");
341 }
342 Ordering::Equal => {}
343 }
344
345 self.context.subgraph_id = sg_id;
346 self.context.is_first_run_this_tick = sg_data
347 .last_tick_run_in
348 .is_none_or(|last_tick| last_tick < self.context.current_tick);
349
350 if let Some(loop_id) = sg_data.loop_id {
351 let curr_loop_nonce = self.context.loop_nonce_stack.last().copied();
359
360 let LoopData {
361 iter_count: loop_iter_count,
362 allow_another_iteration,
363 } = &mut self.loop_data[loop_id];
364
365 let (prev_loop_nonce, prev_iter_count) = sg_data.last_loop_nonce;
366
367 let (curr_iter_count, new_loop_execution) =
372 if curr_loop_nonce.is_none_or(|nonce| nonce == prev_loop_nonce) {
373 if *loop_iter_count == prev_iter_count {
376 if !std::mem::take(allow_another_iteration) {
378 tracing::debug!(
379 "Loop will not continue to next iteration, skipping."
380 );
381 continue 'pop;
382 }
383 loop_iter_count.map_or((0, true), |n| (n + 1, false))
385 } else {
386 debug_assert!(
388 prev_iter_count < *loop_iter_count,
389 "Expect loop iteration count to be increasing."
390 );
391 (loop_iter_count.unwrap(), false)
392 }
393 } else {
394 (0, false)
396 };
397
398 if new_loop_execution {
399 self.context.run_state_hooks_loop(loop_id);
401 }
402 tracing::debug!("Loop iteration count {}", curr_iter_count);
403
404 *loop_iter_count = Some(curr_iter_count);
405 self.context.loop_iter_count = curr_iter_count;
406 sg_data.last_loop_nonce =
407 (curr_loop_nonce.unwrap_or_default(), Some(curr_iter_count));
408 }
409
410 self.context.run_state_hooks_subgraph(sg_id);
412
413 tracing::info!("Running subgraph.");
414 sg_data.last_tick_run_in = Some(self.context.current_tick);
415
416 let sg_metrics = &self.metrics.subgraph_metrics[sg_id];
417 let sg_fut =
418 Box::into_pin(sg_data.subgraph.run(&mut self.context, &mut self.handoffs));
419 let sg_fut = InstrumentSubgraph::new(sg_fut, sg_metrics);
421 let sg_fut = sg_fut.instrument(run_subgraph_span_guard.exit());
423 let () = sg_fut.await;
424
425 sg_metrics.total_run_count.update(|x| x + 1);
426 };
427
428 let sg_data = &self.subgraphs[sg_id];
430 for &handoff_id in sg_data.succs.iter() {
431 let handoff_data = &self.handoffs[handoff_id];
432 let handoff_len = handoff_data.handoff.len();
433 if 0 < handoff_len {
434 for &succ_id in handoff_data.succs.iter() {
435 let succ_sg_data = &self.subgraphs[succ_id];
436 if succ_sg_data.stratum < self.context.current_stratum && !sg_data.is_lazy {
438 self.context.can_start_tick = true;
439 }
440 if !succ_sg_data.is_scheduled.replace(true) {
442 self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
443 }
444 if 0 < succ_sg_data.loop_depth {
446 self.context
448 .stratum_stack
449 .push(succ_sg_data.loop_depth, succ_sg_data.stratum);
450 }
451 }
452 }
453 let handoff_metrics = &self.metrics.handoff_metrics[handoff_id];
454 handoff_metrics.curr_items_count.set(handoff_len);
455 }
456
457 let reschedule = self.context.reschedule_loop_block.take();
458 let allow_another = self.context.allow_another_iteration.take();
459
460 if reschedule {
461 self.context.schedule_deferred.push(sg_id);
463 self.context
464 .stratum_stack
465 .push(sg_data.loop_depth, sg_data.stratum);
466 }
467 if (reschedule || allow_another)
468 && let Some(loop_id) = sg_data.loop_id
469 {
470 self.loop_data
471 .get_mut(loop_id)
472 .unwrap()
473 .allow_another_iteration = true;
474 }
475
476 work_done = true;
477 }
478 work_done
479 }
480
481 #[tracing::instrument(level = "trace", skip(self), ret)]
493 pub fn next_stratum(&mut self, current_tick_only: bool) -> bool {
494 tracing::trace!(
495 events_received_tick = self.context.events_received_tick,
496 can_start_tick = self.context.can_start_tick,
497 "Starting `next_stratum` call.",
498 );
499
500 let mut end_stratum = self.context.current_stratum;
502 let mut new_tick_started = false;
503
504 if 0 == self.context.current_stratum {
505 new_tick_started = true;
506
507 tracing::trace!("Starting tick, setting `can_start_tick = false`.");
509 self.context.can_start_tick = false;
510 self.context.current_tick_start = SystemTime::now();
511
512 if !self.context.events_received_tick {
514 self.try_recv_events();
516 }
517 }
518
519 loop {
520 tracing::trace!(
521 tick = u64::from(self.context.current_tick),
522 stratum = self.context.current_stratum,
523 "Looking for work on stratum."
524 );
525 if !self.context.stratum_queues[self.context.current_stratum].is_empty() {
527 tracing::trace!(
528 tick = u64::from(self.context.current_tick),
529 stratum = self.context.current_stratum,
530 "Work found on stratum."
531 );
532 return true;
533 }
534
535 if let Some(next_stratum) = self.context.stratum_stack.pop() {
536 self.context.current_stratum = next_stratum;
537
538 {
540 for sg_id in self.context.schedule_deferred.drain(..) {
541 let sg_data = &self.subgraphs[sg_id];
542 tracing::info!(
543 tick = u64::from(self.context.current_tick),
544 stratum = self.context.current_stratum,
545 sg_id = sg_id.to_string(),
546 sg_name = &*sg_data.name,
547 is_scheduled = sg_data.is_scheduled.get(),
548 "Rescheduling deferred subgraph."
549 );
550 if !sg_data.is_scheduled.replace(true) {
551 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
552 }
553 }
554 }
555 } else {
556 self.context.current_stratum += 1;
558
559 if self.context.current_stratum >= self.context.stratum_queues.len() {
560 new_tick_started = true;
561
562 tracing::trace!(
563 can_start_tick = self.context.can_start_tick,
564 "End of tick {}, starting tick {}.",
565 self.context.current_tick,
566 self.context.current_tick + TickDuration::SINGLE_TICK,
567 );
568 self.context.run_state_hooks_tick();
569
570 self.context.current_stratum = 0;
571 self.context.current_tick += TickDuration::SINGLE_TICK;
572 self.context.events_received_tick = false;
573
574 if current_tick_only {
575 tracing::trace!(
576 "`current_tick_only` is `true`, returning `false` before receiving events."
577 );
578 return false;
579 } else {
580 self.try_recv_events();
581 if std::mem::replace(&mut self.context.can_start_tick, false) {
582 tracing::trace!(
583 tick = u64::from(self.context.current_tick),
584 "`can_start_tick` is `true`, continuing."
585 );
586 end_stratum = 0;
588 continue;
589 } else {
590 tracing::trace!(
591 "`can_start_tick` is `false`, re-setting `events_received_tick = false`, returning `false`."
592 );
593 self.context.events_received_tick = false;
594 return false;
595 }
596 }
597 }
598 }
599
600 if new_tick_started && end_stratum == self.context.current_stratum {
602 tracing::trace!(
603 "Made full loop around stratum, re-setting `current_stratum = 0`, returning `false`."
604 );
605 self.context.events_received_tick = false;
610 self.context.current_stratum = 0;
611 return false;
612 }
613 }
614 }
615
616 #[tracing::instrument(level = "trace", skip(self), ret)]
620 pub async fn run(&mut self) -> Option<Never> {
621 loop {
622 self.run_available().await;
624 self.recv_events_async().await;
626 }
627 }
628
629 #[tracing::instrument(level = "trace", skip(self), ret)]
631 pub fn run_sync(&mut self) -> Option<Never> {
632 loop {
633 self.run_available_sync();
635 self.recv_events();
637 }
638 }
639
640 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
644 pub fn try_recv_events(&mut self) -> usize {
645 let mut enqueued_count = 0;
646 while let Ok((sg_id, is_external)) = self.context.event_queue_recv.try_recv() {
647 let sg_data = &self.subgraphs[sg_id];
648 tracing::trace!(
649 sg_id = sg_id.to_string(),
650 is_external = is_external,
651 sg_stratum = sg_data.stratum,
652 "Event received."
653 );
654 if !sg_data.is_scheduled.replace(true) {
655 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
656 enqueued_count += 1;
657 }
658 if is_external {
659 if !self.context.events_received_tick
662 || sg_data.stratum < self.context.current_stratum
663 {
664 tracing::trace!(
665 current_stratum = self.context.current_stratum,
666 sg_stratum = sg_data.stratum,
667 "External event, setting `can_start_tick = true`."
668 );
669 self.context.can_start_tick = true;
670 }
671 }
672 }
673 self.context.events_received_tick = true;
674
675 enqueued_count
676 }
677
678 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
681 pub fn recv_events(&mut self) -> Option<usize> {
682 let mut count = 0;
683 loop {
684 let (sg_id, is_external) = self.context.event_queue_recv.blocking_recv()?;
685 let sg_data = &self.subgraphs[sg_id];
686 tracing::trace!(
687 sg_id = sg_id.to_string(),
688 is_external = is_external,
689 sg_stratum = sg_data.stratum,
690 "Event received."
691 );
692 if !sg_data.is_scheduled.replace(true) {
693 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
694 count += 1;
695 }
696 if is_external {
697 if !self.context.events_received_tick
700 || sg_data.stratum < self.context.current_stratum
701 {
702 tracing::trace!(
703 current_stratum = self.context.current_stratum,
704 sg_stratum = sg_data.stratum,
705 "External event, setting `can_start_tick = true`."
706 );
707 self.context.can_start_tick = true;
708 }
709 break;
710 }
711 }
712 self.context.events_received_tick = true;
713
714 let extra_count = self.try_recv_events();
716 Some(count + extra_count)
717 }
718
719 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
725 pub async fn recv_events_async(&mut self) -> Option<usize> {
726 let mut count = 0;
727 loop {
728 tracing::trace!("Awaiting events (`event_queue_recv`).");
729 let (sg_id, is_external) = self.context.event_queue_recv.recv().await?;
730 let sg_data = &self.subgraphs[sg_id];
731 tracing::trace!(
732 sg_id = sg_id.to_string(),
733 is_external = is_external,
734 sg_stratum = sg_data.stratum,
735 "Event received."
736 );
737 if !sg_data.is_scheduled.replace(true) {
738 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
739 count += 1;
740 }
741 if is_external {
742 if !self.context.events_received_tick
745 || sg_data.stratum < self.context.current_stratum
746 {
747 tracing::trace!(
748 current_stratum = self.context.current_stratum,
749 sg_stratum = sg_data.stratum,
750 "External event, setting `can_start_tick = true`."
751 );
752 self.context.can_start_tick = true;
753 }
754 break;
755 }
756 }
757 self.context.events_received_tick = true;
758
759 let extra_count = self.try_recv_events();
761 Some(count + extra_count)
762 }
763
764 pub fn schedule_subgraph(&mut self, sg_id: SubgraphId) -> bool {
766 let sg_data = &self.subgraphs[sg_id];
767 let already_scheduled = sg_data.is_scheduled.replace(true);
768 if !already_scheduled {
769 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
770 true
771 } else {
772 false
773 }
774 }
775
776 pub fn add_subgraph<Name, R, W, Func>(
778 &mut self,
779 name: Name,
780 recv_ports: R,
781 send_ports: W,
782 subgraph: Func,
783 ) -> SubgraphId
784 where
785 Name: Into<Cow<'static, str>>,
786 R: 'static + PortList<RECV>,
787 W: 'static + PortList<SEND>,
788 Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
789 {
790 self.add_subgraph_stratified(name, 0, recv_ports, send_ports, false, subgraph)
791 }
792
793 pub fn add_subgraph_stratified<Name, R, W, Func>(
797 &mut self,
798 name: Name,
799 stratum: usize,
800 recv_ports: R,
801 send_ports: W,
802 laziness: bool,
803 subgraph: Func,
804 ) -> SubgraphId
805 where
806 Name: Into<Cow<'static, str>>,
807 R: 'static + PortList<RECV>,
808 W: 'static + PortList<SEND>,
809 Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
810 {
811 self.add_subgraph_full(
812 name, stratum, recv_ports, send_ports, laziness, None, subgraph,
813 )
814 }
815
816 #[expect(clippy::too_many_arguments, reason = "Mainly for internal use.")]
818 pub fn add_subgraph_full<Name, R, W, Func>(
819 &mut self,
820 name: Name,
821 stratum: usize,
822 recv_ports: R,
823 send_ports: W,
824 laziness: bool,
825 loop_id: Option<LoopId>,
826 mut subgraph: Func,
827 ) -> SubgraphId
828 where
829 Name: Into<Cow<'static, str>>,
830 R: 'static + PortList<RECV>,
831 W: 'static + PortList<SEND>,
832 Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
833 {
834 recv_ports.assert_is_from(&self.handoffs);
836 send_ports.assert_is_from(&self.handoffs);
837
838 let loop_depth = loop_id
839 .and_then(|loop_id| self.context.loop_depth.get(loop_id))
840 .copied()
841 .unwrap_or(0);
842
843 let sg_id = self.subgraphs.insert_with_key(|sg_id| {
844 let (mut subgraph_preds, mut subgraph_succs) = Default::default();
845 recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
846 send_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_succs, sg_id, false);
847
848 let subgraph =
849 async move |context: &mut Context,
850 handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
851 let (recv, send) = unsafe {
852 (
856 recv_ports.make_ctx(&*handoffs),
857 send_ports.make_ctx(&*handoffs),
858 )
859 };
860 (subgraph)(context, recv, send).await;
861 };
862 SubgraphData::new(
863 name.into(),
864 stratum,
865 subgraph,
866 subgraph_preds,
867 subgraph_succs,
868 true,
869 laziness,
870 loop_id,
871 loop_depth,
872 )
873 });
874 self.context.init_stratum(stratum);
875 self.context.stratum_queues[stratum].push_back(sg_id);
876
877 Rc::make_mut(&mut self.metrics)
879 .subgraph_metrics
880 .insert(sg_id, Default::default());
881
882 sg_id
883 }
884
885 pub fn add_subgraph_n_m<Name, R, W, Func>(
887 &mut self,
888 name: Name,
889 recv_ports: Vec<RecvPort<R>>,
890 send_ports: Vec<SendPort<W>>,
891 subgraph: Func,
892 ) -> SubgraphId
893 where
894 Name: Into<Cow<'static, str>>,
895 R: 'static + Handoff,
896 W: 'static + Handoff,
897 Func: 'a
898 + for<'ctx> AsyncFnMut(
899 &'ctx mut Context,
900 &'ctx [&'ctx RecvCtx<R>],
901 &'ctx [&'ctx SendCtx<W>],
902 ),
903 {
904 self.add_subgraph_stratified_n_m(name, 0, recv_ports, send_ports, subgraph)
905 }
906
907 pub fn add_subgraph_stratified_n_m<Name, R, W, Func>(
909 &mut self,
910 name: Name,
911 stratum: usize,
912 recv_ports: Vec<RecvPort<R>>,
913 send_ports: Vec<SendPort<W>>,
914 mut subgraph: Func,
915 ) -> SubgraphId
916 where
917 Name: Into<Cow<'static, str>>,
918 R: 'static + Handoff,
919 W: 'static + Handoff,
920 Func: 'a
921 + for<'ctx> AsyncFnMut(
922 &'ctx mut Context,
923 &'ctx [&'ctx RecvCtx<R>],
924 &'ctx [&'ctx SendCtx<W>],
925 ),
926 {
927 let sg_id = self.subgraphs.insert_with_key(|sg_id| {
928 let subgraph_preds = recv_ports.iter().map(|port| port.handoff_id).collect();
929 let subgraph_succs = send_ports.iter().map(|port| port.handoff_id).collect();
930
931 for recv_port in recv_ports.iter() {
932 self.handoffs[recv_port.handoff_id].succs.push(sg_id);
933 }
934 for send_port in send_ports.iter() {
935 self.handoffs[send_port.handoff_id].preds.push(sg_id);
936 }
937
938 let subgraph =
939 async move |context: &mut Context,
940 handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
941 let recvs: Vec<&RecvCtx<R>> = recv_ports
942 .iter()
943 .map(|hid| hid.handoff_id)
944 .map(|hid| handoffs.get(hid).unwrap())
945 .map(|h_data| {
946 <dyn Any>::downcast_ref(&*h_data.handoff)
947 .expect("Attempted to cast handoff to wrong type.")
948 })
949 .map(RefCast::ref_cast)
950 .collect();
951
952 let sends: Vec<&SendCtx<W>> = send_ports
953 .iter()
954 .map(|hid| hid.handoff_id)
955 .map(|hid| handoffs.get(hid).unwrap())
956 .map(|h_data| {
957 <dyn Any>::downcast_ref(&*h_data.handoff)
958 .expect("Attempted to cast handoff to wrong type.")
959 })
960 .map(RefCast::ref_cast)
961 .collect();
962
963 (subgraph)(context, &recvs, &sends).await;
964 };
965 SubgraphData::new(
966 name.into(),
967 stratum,
968 subgraph,
969 subgraph_preds,
970 subgraph_succs,
971 true,
972 false,
973 None,
974 0,
975 )
976 });
977 self.context.init_stratum(stratum);
978 self.context.stratum_queues[stratum].push_back(sg_id);
979
980 Rc::make_mut(&mut self.metrics)
982 .subgraph_metrics
983 .insert(sg_id, Default::default());
984
985 sg_id
986 }
987
988 pub fn make_edge<Name, H>(&mut self, name: Name) -> (SendPort<H>, RecvPort<H>)
990 where
991 Name: Into<Cow<'static, str>>,
992 H: 'static + Handoff,
993 {
994 let handoff = H::default();
996 let handoff_id = self
997 .handoffs
998 .insert_with_key(|hoff_id| HandoffData::new(name.into(), handoff, hoff_id));
999
1000 Rc::make_mut(&mut self.metrics)
1002 .handoff_metrics
1003 .insert(handoff_id, Default::default());
1004
1005 let input_port = SendPort {
1007 handoff_id,
1008 _marker: PhantomData,
1009 };
1010 let output_port = RecvPort {
1011 handoff_id,
1012 _marker: PhantomData,
1013 };
1014 (input_port, output_port)
1015 }
1016
1017 pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
1022 where
1023 T: Any,
1024 {
1025 self.context.add_state(state)
1026 }
1027
1028 pub fn set_state_lifespan_hook<T>(
1032 &mut self,
1033 handle: StateHandle<T>,
1034 lifespan: StateLifespan,
1035 hook_fn: impl 'static + FnMut(&mut T),
1036 ) where
1037 T: Any,
1038 {
1039 self.context
1040 .set_state_lifespan_hook(handle, lifespan, hook_fn)
1041 }
1042
1043 pub fn context_mut(&mut self, sg_id: SubgraphId) -> &mut Context {
1045 self.context.subgraph_id = sg_id;
1046 &mut self.context
1047 }
1048
1049 pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
1054 let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
1055 let loop_id = self.context.loop_depth.insert(depth);
1056 self.loop_data.insert(
1057 loop_id,
1058 LoopData {
1059 iter_count: None,
1060 allow_another_iteration: true,
1061 },
1062 );
1063 loop_id
1064 }
1065
1066 pub fn metrics(&self) -> DfirMetrics {
1068 DfirMetrics {
1069 curr: Rc::clone(&self.metrics),
1070 prev: None,
1071 }
1072 }
1073}
1074
1075impl Dfir<'_> {
1076 pub fn request_task<Fut>(&mut self, future: Fut)
1078 where
1079 Fut: Future<Output = ()> + 'static,
1080 {
1081 self.context.request_task(future);
1082 }
1083
1084 pub fn abort_tasks(&mut self) {
1086 self.context.abort_tasks()
1087 }
1088
1089 pub fn join_tasks(&mut self) -> impl use<'_> + Future {
1091 self.context.join_tasks()
1092 }
1093}
1094
1095fn run_sync<Fut>(fut: Fut) -> Fut::Output
1096where
1097 Fut: Future,
1098{
1099 let mut fut = std::pin::pin!(fut);
1100 let mut ctx = std::task::Context::from_waker(std::task::Waker::noop());
1101 match fut.as_mut().poll(&mut ctx) {
1102 std::task::Poll::Ready(out) => out,
1103 std::task::Poll::Pending => panic!("Future did not resolve immediately."),
1104 }
1105}
1106
1107impl Drop for Dfir<'_> {
1108 fn drop(&mut self) {
1109 self.abort_tasks();
1110 }
1111}
1112
1113#[doc(hidden)]
1119pub struct HandoffData {
1120 pub(super) name: Cow<'static, str>,
1122 pub(super) handoff: Box<dyn HandoffMeta>,
1124 pub(super) preds: SmallVec<[SubgraphId; 1]>,
1126 pub(super) succs: SmallVec<[SubgraphId; 1]>,
1128
1129 pub(super) pred_handoffs: Vec<HandoffId>,
1135 pub(super) succ_handoffs: Vec<HandoffId>,
1141}
1142
1143impl std::fmt::Debug for HandoffData {
1144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1145 f.debug_struct("HandoffData")
1146 .field("preds", &self.preds)
1147 .field("succs", &self.succs)
1148 .finish_non_exhaustive()
1149 }
1150}
1151
1152impl HandoffData {
1153 pub fn new(
1155 name: Cow<'static, str>,
1156 handoff: impl 'static + HandoffMeta,
1157 hoff_id: HandoffId,
1158 ) -> Self {
1159 let (preds, succs) = Default::default();
1160 Self {
1161 name,
1162 handoff: Box::new(handoff),
1163 preds,
1164 succs,
1165 pred_handoffs: vec![hoff_id],
1166 succ_handoffs: vec![hoff_id],
1167 }
1168 }
1169}
1170
1171pub(super) struct SubgraphData<'a> {
1176 pub(super) name: Cow<'static, str>,
1178 pub(super) stratum: usize,
1182 subgraph: Box<dyn 'a + Subgraph>,
1184
1185 preds: Vec<HandoffId>,
1186 succs: Vec<HandoffId>,
1187
1188 is_scheduled: Cell<bool>,
1193
1194 last_tick_run_in: Option<TickInstant>,
1196 last_loop_nonce: (usize, Option<usize>),
1199
1200 is_lazy: bool,
1202
1203 loop_id: Option<LoopId>,
1205 loop_depth: usize,
1207}
1208
1209impl<'a> SubgraphData<'a> {
1210 #[expect(clippy::too_many_arguments, reason = "internal use")]
1211 pub(crate) fn new(
1212 name: Cow<'static, str>,
1213 stratum: usize,
1214 subgraph: impl 'a + Subgraph,
1215 preds: Vec<HandoffId>,
1216 succs: Vec<HandoffId>,
1217 is_scheduled: bool,
1218 is_lazy: bool,
1219 loop_id: Option<LoopId>,
1220 loop_depth: usize,
1221 ) -> Self {
1222 Self {
1223 name,
1224 stratum,
1225 subgraph: Box::new(subgraph),
1226 preds,
1227 succs,
1228 is_scheduled: Cell::new(is_scheduled),
1229 last_tick_run_in: None,
1230 last_loop_nonce: (0, None),
1231 is_lazy,
1232 loop_id,
1233 loop_depth,
1234 }
1235 }
1236}
1237
1238pub(crate) struct LoopData {
1239 iter_count: Option<usize>,
1241 allow_another_iteration: bool,
1243}
1244
1245#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1247pub enum StateLifespan {
1248 Subgraph(SubgraphId),
1250 Loop(LoopId),
1252 Tick,
1254 Static,
1256}