1use std::any::Any;
4use std::borrow::Cow;
5use std::cell::Cell;
6use std::cmp::Ordering;
7use std::future::Future;
8use std::marker::PhantomData;
9use std::rc::Rc;
10
11#[cfg(feature = "meta")]
12use dfir_lang::diagnostic::{Diagnostic, SerdeSpan};
13#[cfg(feature = "meta")]
14use dfir_lang::graph::DfirGraph;
15use ref_cast::RefCast;
16use smallvec::SmallVec;
17use tracing::Instrument;
18use web_time::SystemTime;
19
20use super::context::Context;
21use super::handoff::handoff_list::PortList;
22use super::handoff::{Handoff, HandoffMeta, TeeingHandoff};
23use super::metrics::{DfirMetrics, DfirMetricsIntervals, InstrumentSubgraph};
24use super::port::{RECV, RecvCtx, RecvPort, SEND, SendCtx, SendPort};
25use super::reactor::Reactor;
26use super::state::StateHandle;
27use super::subgraph::Subgraph;
28use super::ticks::{TickDuration, TickInstant};
29use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
30use crate::Never;
31use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
32
33#[derive(Default)]
35pub struct Dfir<'a> {
36 pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
37
38 pub(super) loop_data: SecondarySlotVec<LoopTag, LoopData>,
39
40 pub(super) context: Context,
41
42 pub(super) handoffs: SlotVec<HandoffTag, HandoffData>,
43
44 metrics: Rc<DfirMetrics>,
46
47 #[cfg(feature = "meta")]
48 meta_graph: Option<DfirGraph>,
50
51 #[cfg(feature = "meta")]
52 diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
54}
55
56impl Dfir<'_> {
58 pub fn teeing_handoff_tee<T>(
60 &mut self,
61 tee_parent_port: &RecvPort<TeeingHandoff<T>>,
62 ) -> RecvPort<TeeingHandoff<T>>
63 where
64 T: Clone,
65 {
66 let tee_root = self.handoffs[tee_parent_port.handoff_id].pred_handoffs[0];
68
69 let tee_root_data = &mut self.handoffs[tee_root];
71 let tee_root_data_name = tee_root_data.name.clone();
72
73 let teeing_handoff =
75 <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*tee_root_data.handoff).unwrap();
76 let new_handoff = teeing_handoff.tee();
77
78 let new_hoff_id = self.handoffs.insert_with_key(|new_hoff_id| {
80 let new_name = Cow::Owned(format!("{} tee {:?}", tee_root_data_name, new_hoff_id));
81 let mut new_handoff_data = HandoffData::new(new_name, new_handoff, new_hoff_id);
82 new_handoff_data.pred_handoffs = vec![tee_root];
84 new_handoff_data
85 });
86
87 let tee_root_data = &mut self.handoffs[tee_root];
89 tee_root_data.succ_handoffs.push(new_hoff_id);
90
91 assert!(
94 tee_root_data.preds.len() <= 1,
95 "Tee send side should only have one sender (or none set yet)."
96 );
97 if let Some(&pred_sg_id) = tee_root_data.preds.first() {
98 self.subgraphs[pred_sg_id].succs.push(new_hoff_id);
99 }
100
101 Rc::make_mut(&mut self.metrics)
103 .handoffs
104 .insert(new_hoff_id, Default::default());
105
106 let output_port = RecvPort {
107 handoff_id: new_hoff_id,
108 _marker: PhantomData,
109 };
110 output_port
111 }
112
113 pub fn teeing_handoff_drop<T>(&mut self, tee_port: RecvPort<TeeingHandoff<T>>)
118 where
119 T: Clone,
120 {
121 let data = &self.handoffs[tee_port.handoff_id];
122 let teeing_handoff = <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*data.handoff).unwrap();
123 teeing_handoff.drop();
124
125 let tee_root = data.pred_handoffs[0];
126 let tee_root_data = &mut self.handoffs[tee_root];
127 tee_root_data
129 .succ_handoffs
130 .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
131 assert!(
133 tee_root_data.preds.len() <= 1,
134 "Tee send side should only have one sender (or none set yet)."
135 );
136 if let Some(&pred_sg_id) = tee_root_data.preds.first() {
137 self.subgraphs[pred_sg_id]
138 .succs
139 .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
140 }
141 }
142}
143
144impl<'a> Dfir<'a> {
145 pub fn new() -> Self {
147 Default::default()
148 }
149
150 #[doc(hidden)]
152 pub fn __assign_meta_graph(&mut self, _meta_graph_json: &str) {
153 #[cfg(feature = "meta")]
154 {
155 let mut meta_graph: DfirGraph =
156 serde_json::from_str(_meta_graph_json).expect("Failed to deserialize graph.");
157
158 let mut op_inst_diagnostics = Vec::new();
159 meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
160 assert!(
161 op_inst_diagnostics.is_empty(),
162 "Expected no diagnostics, got: {:#?}",
163 op_inst_diagnostics
164 );
165
166 assert!(self.meta_graph.replace(meta_graph).is_none());
167 }
168 }
169 #[doc(hidden)]
171 pub fn __assign_diagnostics(&mut self, _diagnostics_json: &'static str) {
172 #[cfg(feature = "meta")]
173 {
174 let diagnostics: Vec<Diagnostic<SerdeSpan>> = serde_json::from_str(_diagnostics_json)
175 .expect("Failed to deserialize diagnostics.");
176
177 assert!(self.diagnostics.replace(diagnostics).is_none());
178 }
179 }
180
181 #[cfg(feature = "meta")]
185 #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
186 pub fn meta_graph(&self) -> Option<&DfirGraph> {
187 self.meta_graph.as_ref()
188 }
189
190 #[cfg(feature = "meta")]
195 #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
196 pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
197 self.diagnostics.as_deref()
198 }
199
200 pub fn reactor(&self) -> Reactor {
203 Reactor::new(self.context.event_queue_send.clone())
204 }
205
206 pub fn current_tick(&self) -> TickInstant {
208 self.context.current_tick
209 }
210
211 pub fn current_stratum(&self) -> usize {
213 self.context.current_stratum
214 }
215
216 #[tracing::instrument(level = "trace", skip(self), ret)]
222 pub async fn run_tick(&mut self) -> bool {
223 let mut work_done = false;
224 while self.next_stratum(true) {
226 work_done = true;
227 self.run_stratum().await;
229 }
230 work_done
231 }
232
233 #[tracing::instrument(level = "trace", skip(self), ret)]
235 pub fn run_tick_sync(&mut self) -> bool {
236 let mut work_done = false;
237 while self.next_stratum(true) {
239 work_done = true;
240 run_sync(self.run_stratum());
242 }
243 work_done
244 }
245
246 #[tracing::instrument(level = "trace", skip(self), ret)]
254 pub async fn run_available(&mut self) -> bool {
255 let mut work_done = false;
256 while self.next_stratum(false) {
258 work_done = true;
259 self.run_stratum().await;
261
262 tokio::task::yield_now().await;
265 }
266 work_done
267 }
268
269 #[tracing::instrument(level = "trace", skip(self), ret)]
271 pub fn run_available_sync(&mut self) -> bool {
272 let mut work_done = false;
273 while self.next_stratum(false) {
275 work_done = true;
276 run_sync(self.run_stratum());
278 }
279 work_done
280 }
281
282 #[tracing::instrument(level = "trace", skip(self), fields(tick = u64::from(self.context.current_tick), stratum = self.context.current_stratum), ret)]
286 pub async fn run_stratum(&mut self) -> bool {
287 self.context.spawn_tasks();
290
291 let mut work_done = false;
292
293 'pop: while let Some(sg_id) =
294 self.context.stratum_queues[self.context.current_stratum].pop_front()
295 {
296 let sg_data = &mut self.subgraphs[sg_id];
297
298 for &handoff_id in sg_data.preds.iter() {
307 let handoff_metrics = &self.metrics.handoffs[handoff_id];
308 let handoff_data = &mut self.handoffs[handoff_id];
309 let handoff_len = handoff_data.handoff.len();
310 handoff_metrics
311 .total_items_count
312 .update(|x| x + handoff_len);
313 handoff_metrics.curr_items_count.set(handoff_len);
314 }
315
316 {
318 assert!(sg_data.is_scheduled.take());
320
321 let run_subgraph_span_guard = tracing::info_span!(
322 "run-subgraph",
323 sg_id = sg_id.to_string(),
324 sg_name = &*sg_data.name,
325 sg_depth = sg_data.loop_depth,
326 sg_loop_nonce = sg_data.last_loop_nonce.0,
327 sg_iter_count = sg_data.last_loop_nonce.1,
328 )
329 .entered();
330
331 match sg_data.loop_depth.cmp(&self.context.loop_nonce_stack.len()) {
332 Ordering::Greater => {
333 self.context.loop_nonce += 1;
335 self.context.loop_nonce_stack.push(self.context.loop_nonce);
336 tracing::trace!(loop_nonce = self.context.loop_nonce, "Entered loop.");
337 }
338 Ordering::Less => {
339 self.context.loop_nonce_stack.pop();
341 tracing::trace!("Exited loop.");
342 }
343 Ordering::Equal => {}
344 }
345
346 self.context.subgraph_id = sg_id;
347 self.context.is_first_run_this_tick = sg_data
348 .last_tick_run_in
349 .is_none_or(|last_tick| last_tick < self.context.current_tick);
350
351 if let Some(loop_id) = sg_data.loop_id {
352 let curr_loop_nonce = self.context.loop_nonce_stack.last().copied();
360
361 let LoopData {
362 iter_count: loop_iter_count,
363 allow_another_iteration,
364 } = &mut self.loop_data[loop_id];
365
366 let (prev_loop_nonce, prev_iter_count) = sg_data.last_loop_nonce;
367
368 let (curr_iter_count, new_loop_execution) =
373 if curr_loop_nonce.is_none_or(|nonce| nonce == prev_loop_nonce) {
374 if *loop_iter_count == prev_iter_count {
377 if !std::mem::take(allow_another_iteration) {
379 tracing::debug!(
380 "Loop will not continue to next iteration, skipping."
381 );
382 continue 'pop;
383 }
384 loop_iter_count.map_or((0, true), |n| (n + 1, false))
386 } else {
387 debug_assert!(
389 prev_iter_count < *loop_iter_count,
390 "Expect loop iteration count to be increasing."
391 );
392 (loop_iter_count.unwrap(), false)
393 }
394 } else {
395 (0, false)
397 };
398
399 if new_loop_execution {
400 self.context.run_state_hooks_loop(loop_id);
402 }
403 tracing::debug!("Loop iteration count {}", curr_iter_count);
404
405 *loop_iter_count = Some(curr_iter_count);
406 self.context.loop_iter_count = curr_iter_count;
407 sg_data.last_loop_nonce =
408 (curr_loop_nonce.unwrap_or_default(), Some(curr_iter_count));
409 }
410
411 self.context.run_state_hooks_subgraph(sg_id);
413
414 tracing::info!("Running subgraph.");
415 sg_data.last_tick_run_in = Some(self.context.current_tick);
416
417 let sg_metrics = &self.metrics.subgraphs[sg_id];
418 let sg_fut =
419 Box::into_pin(sg_data.subgraph.run(&mut self.context, &mut self.handoffs));
420 let sg_fut = InstrumentSubgraph::new(sg_fut, sg_metrics);
422 let sg_fut = sg_fut.instrument(run_subgraph_span_guard.exit());
424 let () = sg_fut.await;
425
426 sg_metrics.total_run_count.update(|x| x + 1);
427 };
428
429 let sg_data = &self.subgraphs[sg_id];
431 for &handoff_id in sg_data.succs.iter() {
432 let handoff_data = &self.handoffs[handoff_id];
433 let handoff_len = handoff_data.handoff.len();
434 if 0 < handoff_len {
435 for &succ_id in handoff_data.succs.iter() {
436 let succ_sg_data = &self.subgraphs[succ_id];
437 if succ_sg_data.stratum < self.context.current_stratum && !sg_data.is_lazy {
439 self.context.can_start_tick = true;
440 }
441 if !succ_sg_data.is_scheduled.replace(true) {
443 self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
444 }
445 if 0 < succ_sg_data.loop_depth {
447 self.context
449 .stratum_stack
450 .push(succ_sg_data.loop_depth, succ_sg_data.stratum);
451 }
452 }
453 }
454 let handoff_metrics = &self.metrics.handoffs[handoff_id];
455 handoff_metrics.curr_items_count.set(handoff_len);
456 }
457
458 let reschedule = self.context.reschedule_loop_block.take();
459 let allow_another = self.context.allow_another_iteration.take();
460
461 if reschedule {
462 self.context.schedule_deferred.push(sg_id);
464 self.context
465 .stratum_stack
466 .push(sg_data.loop_depth, sg_data.stratum);
467 }
468 if (reschedule || allow_another)
469 && let Some(loop_id) = sg_data.loop_id
470 {
471 self.loop_data
472 .get_mut(loop_id)
473 .unwrap()
474 .allow_another_iteration = true;
475 }
476
477 work_done = true;
478 }
479 work_done
480 }
481
482 #[tracing::instrument(level = "trace", skip(self), ret)]
494 pub fn next_stratum(&mut self, current_tick_only: bool) -> bool {
495 tracing::trace!(
496 events_received_tick = self.context.events_received_tick,
497 can_start_tick = self.context.can_start_tick,
498 "Starting `next_stratum` call.",
499 );
500
501 let mut end_stratum = self.context.current_stratum;
503 let mut new_tick_started = false;
504
505 if 0 == self.context.current_stratum {
506 new_tick_started = true;
507
508 tracing::trace!("Starting tick, setting `can_start_tick = false`.");
510 self.context.can_start_tick = false;
511 self.context.current_tick_start = SystemTime::now();
512
513 if !self.context.events_received_tick {
515 self.try_recv_events();
517 }
518 }
519
520 loop {
521 tracing::trace!(
522 tick = u64::from(self.context.current_tick),
523 stratum = self.context.current_stratum,
524 "Looking for work on stratum."
525 );
526 if !self.context.stratum_queues[self.context.current_stratum].is_empty() {
528 tracing::trace!(
529 tick = u64::from(self.context.current_tick),
530 stratum = self.context.current_stratum,
531 "Work found on stratum."
532 );
533 return true;
534 }
535
536 if let Some(next_stratum) = self.context.stratum_stack.pop() {
537 self.context.current_stratum = next_stratum;
538
539 {
541 for sg_id in self.context.schedule_deferred.drain(..) {
542 let sg_data = &self.subgraphs[sg_id];
543 tracing::info!(
544 tick = u64::from(self.context.current_tick),
545 stratum = self.context.current_stratum,
546 sg_id = sg_id.to_string(),
547 sg_name = &*sg_data.name,
548 is_scheduled = sg_data.is_scheduled.get(),
549 "Rescheduling deferred subgraph."
550 );
551 if !sg_data.is_scheduled.replace(true) {
552 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
553 }
554 }
555 }
556 } else {
557 self.context.current_stratum += 1;
559
560 if self.context.current_stratum >= self.context.stratum_queues.len() {
561 new_tick_started = true;
562
563 tracing::trace!(
564 can_start_tick = self.context.can_start_tick,
565 "End of tick {}, starting tick {}.",
566 self.context.current_tick,
567 self.context.current_tick + TickDuration::SINGLE_TICK,
568 );
569 self.context.run_state_hooks_tick();
570
571 self.context.current_stratum = 0;
572 self.context.current_tick += TickDuration::SINGLE_TICK;
573 self.context.events_received_tick = false;
574
575 if current_tick_only {
576 tracing::trace!(
577 "`current_tick_only` is `true`, returning `false` before receiving events."
578 );
579 return false;
580 } else {
581 self.try_recv_events();
582 if std::mem::replace(&mut self.context.can_start_tick, false) {
583 tracing::trace!(
584 tick = u64::from(self.context.current_tick),
585 "`can_start_tick` is `true`, continuing."
586 );
587 end_stratum = 0;
589 continue;
590 } else {
591 tracing::trace!(
592 "`can_start_tick` is `false`, re-setting `events_received_tick = false`, returning `false`."
593 );
594 self.context.events_received_tick = false;
595 return false;
596 }
597 }
598 }
599 }
600
601 if new_tick_started && end_stratum == self.context.current_stratum {
603 tracing::trace!(
604 "Made full loop around stratum, re-setting `current_stratum = 0`, returning `false`."
605 );
606 self.context.events_received_tick = false;
611 self.context.current_stratum = 0;
612 return false;
613 }
614 }
615 }
616
617 #[tracing::instrument(level = "trace", skip(self), ret)]
621 pub async fn run(&mut self) -> Option<Never> {
622 loop {
623 self.run_available().await;
625 self.recv_events_async().await;
627 }
628 }
629
630 #[tracing::instrument(level = "trace", skip(self), ret)]
632 pub fn run_sync(&mut self) -> Option<Never> {
633 loop {
634 self.run_available_sync();
636 self.recv_events();
638 }
639 }
640
641 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
645 pub fn try_recv_events(&mut self) -> usize {
646 let mut enqueued_count = 0;
647 while let Ok((sg_id, is_external)) = self.context.event_queue_recv.try_recv() {
648 let sg_data = &self.subgraphs[sg_id];
649 tracing::trace!(
650 sg_id = sg_id.to_string(),
651 is_external = is_external,
652 sg_stratum = sg_data.stratum,
653 "Event received."
654 );
655 if !sg_data.is_scheduled.replace(true) {
656 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
657 enqueued_count += 1;
658 }
659 if is_external {
660 if !self.context.events_received_tick
663 || sg_data.stratum < self.context.current_stratum
664 {
665 tracing::trace!(
666 current_stratum = self.context.current_stratum,
667 sg_stratum = sg_data.stratum,
668 "External event, setting `can_start_tick = true`."
669 );
670 self.context.can_start_tick = true;
671 }
672 }
673 }
674 self.context.events_received_tick = true;
675
676 enqueued_count
677 }
678
679 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
682 pub fn recv_events(&mut self) -> Option<usize> {
683 let mut count = 0;
684 loop {
685 let (sg_id, is_external) = self.context.event_queue_recv.blocking_recv()?;
686 let sg_data = &self.subgraphs[sg_id];
687 tracing::trace!(
688 sg_id = sg_id.to_string(),
689 is_external = is_external,
690 sg_stratum = sg_data.stratum,
691 "Event received."
692 );
693 if !sg_data.is_scheduled.replace(true) {
694 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
695 count += 1;
696 }
697 if is_external {
698 if !self.context.events_received_tick
701 || sg_data.stratum < self.context.current_stratum
702 {
703 tracing::trace!(
704 current_stratum = self.context.current_stratum,
705 sg_stratum = sg_data.stratum,
706 "External event, setting `can_start_tick = true`."
707 );
708 self.context.can_start_tick = true;
709 }
710 break;
711 }
712 }
713 self.context.events_received_tick = true;
714
715 let extra_count = self.try_recv_events();
717 Some(count + extra_count)
718 }
719
720 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
726 pub async fn recv_events_async(&mut self) -> Option<usize> {
727 let mut count = 0;
728 loop {
729 tracing::trace!("Awaiting events (`event_queue_recv`).");
730 let (sg_id, is_external) = self.context.event_queue_recv.recv().await?;
731 let sg_data = &self.subgraphs[sg_id];
732 tracing::trace!(
733 sg_id = sg_id.to_string(),
734 is_external = is_external,
735 sg_stratum = sg_data.stratum,
736 "Event received."
737 );
738 if !sg_data.is_scheduled.replace(true) {
739 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
740 count += 1;
741 }
742 if is_external {
743 if !self.context.events_received_tick
746 || sg_data.stratum < self.context.current_stratum
747 {
748 tracing::trace!(
749 current_stratum = self.context.current_stratum,
750 sg_stratum = sg_data.stratum,
751 "External event, setting `can_start_tick = true`."
752 );
753 self.context.can_start_tick = true;
754 }
755 break;
756 }
757 }
758 self.context.events_received_tick = true;
759
760 let extra_count = self.try_recv_events();
762 Some(count + extra_count)
763 }
764
765 pub fn schedule_subgraph(&mut self, sg_id: SubgraphId) -> bool {
767 let sg_data = &self.subgraphs[sg_id];
768 let already_scheduled = sg_data.is_scheduled.replace(true);
769 if !already_scheduled {
770 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
771 true
772 } else {
773 false
774 }
775 }
776
777 pub fn add_subgraph<Name, R, W, Func>(
779 &mut self,
780 name: Name,
781 recv_ports: R,
782 send_ports: W,
783 subgraph: Func,
784 ) -> SubgraphId
785 where
786 Name: Into<Cow<'static, str>>,
787 R: 'static + PortList<RECV>,
788 W: 'static + PortList<SEND>,
789 Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
790 {
791 self.add_subgraph_stratified(name, 0, recv_ports, send_ports, false, subgraph)
792 }
793
794 pub fn add_subgraph_stratified<Name, R, W, Func>(
798 &mut self,
799 name: Name,
800 stratum: usize,
801 recv_ports: R,
802 send_ports: W,
803 laziness: bool,
804 subgraph: Func,
805 ) -> SubgraphId
806 where
807 Name: Into<Cow<'static, str>>,
808 R: 'static + PortList<RECV>,
809 W: 'static + PortList<SEND>,
810 Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
811 {
812 self.add_subgraph_full(
813 name, stratum, recv_ports, send_ports, laziness, None, subgraph,
814 )
815 }
816
817 #[expect(clippy::too_many_arguments, reason = "Mainly for internal use.")]
819 pub fn add_subgraph_full<Name, R, W, Func>(
820 &mut self,
821 name: Name,
822 stratum: usize,
823 recv_ports: R,
824 send_ports: W,
825 laziness: bool,
826 loop_id: Option<LoopId>,
827 mut subgraph: Func,
828 ) -> SubgraphId
829 where
830 Name: Into<Cow<'static, str>>,
831 R: 'static + PortList<RECV>,
832 W: 'static + PortList<SEND>,
833 Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
834 {
835 recv_ports.assert_is_from(&self.handoffs);
837 send_ports.assert_is_from(&self.handoffs);
838
839 let loop_depth = loop_id
840 .and_then(|loop_id| self.context.loop_depth.get(loop_id))
841 .copied()
842 .unwrap_or(0);
843
844 let sg_id = self.subgraphs.insert_with_key(|sg_id| {
845 let (mut subgraph_preds, mut subgraph_succs) = Default::default();
846 recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
847 send_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_succs, sg_id, false);
848
849 let subgraph =
850 async move |context: &mut Context,
851 handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
852 let (recv, send) = unsafe {
853 (
857 recv_ports.make_ctx(&*handoffs),
858 send_ports.make_ctx(&*handoffs),
859 )
860 };
861 (subgraph)(context, recv, send).await;
862 };
863 SubgraphData::new(
864 name.into(),
865 stratum,
866 subgraph,
867 subgraph_preds,
868 subgraph_succs,
869 true,
870 laziness,
871 loop_id,
872 loop_depth,
873 )
874 });
875 self.context.init_stratum(stratum);
876 self.context.stratum_queues[stratum].push_back(sg_id);
877
878 Rc::make_mut(&mut self.metrics)
880 .subgraphs
881 .insert(sg_id, Default::default());
882
883 sg_id
884 }
885
886 pub fn add_subgraph_n_m<Name, R, W, Func>(
888 &mut self,
889 name: Name,
890 recv_ports: Vec<RecvPort<R>>,
891 send_ports: Vec<SendPort<W>>,
892 subgraph: Func,
893 ) -> SubgraphId
894 where
895 Name: Into<Cow<'static, str>>,
896 R: 'static + Handoff,
897 W: 'static + Handoff,
898 Func: 'a
899 + for<'ctx> AsyncFnMut(
900 &'ctx mut Context,
901 &'ctx [&'ctx RecvCtx<R>],
902 &'ctx [&'ctx SendCtx<W>],
903 ),
904 {
905 self.add_subgraph_stratified_n_m(name, 0, recv_ports, send_ports, subgraph)
906 }
907
908 pub fn add_subgraph_stratified_n_m<Name, R, W, Func>(
910 &mut self,
911 name: Name,
912 stratum: usize,
913 recv_ports: Vec<RecvPort<R>>,
914 send_ports: Vec<SendPort<W>>,
915 mut subgraph: Func,
916 ) -> SubgraphId
917 where
918 Name: Into<Cow<'static, str>>,
919 R: 'static + Handoff,
920 W: 'static + Handoff,
921 Func: 'a
922 + for<'ctx> AsyncFnMut(
923 &'ctx mut Context,
924 &'ctx [&'ctx RecvCtx<R>],
925 &'ctx [&'ctx SendCtx<W>],
926 ),
927 {
928 let sg_id = self.subgraphs.insert_with_key(|sg_id| {
929 let subgraph_preds = recv_ports.iter().map(|port| port.handoff_id).collect();
930 let subgraph_succs = send_ports.iter().map(|port| port.handoff_id).collect();
931
932 for recv_port in recv_ports.iter() {
933 self.handoffs[recv_port.handoff_id].succs.push(sg_id);
934 }
935 for send_port in send_ports.iter() {
936 self.handoffs[send_port.handoff_id].preds.push(sg_id);
937 }
938
939 let subgraph =
940 async move |context: &mut Context,
941 handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
942 let recvs: Vec<&RecvCtx<R>> = recv_ports
943 .iter()
944 .map(|hid| hid.handoff_id)
945 .map(|hid| handoffs.get(hid).unwrap())
946 .map(|h_data| {
947 <dyn Any>::downcast_ref(&*h_data.handoff)
948 .expect("Attempted to cast handoff to wrong type.")
949 })
950 .map(RefCast::ref_cast)
951 .collect();
952
953 let sends: Vec<&SendCtx<W>> = send_ports
954 .iter()
955 .map(|hid| hid.handoff_id)
956 .map(|hid| handoffs.get(hid).unwrap())
957 .map(|h_data| {
958 <dyn Any>::downcast_ref(&*h_data.handoff)
959 .expect("Attempted to cast handoff to wrong type.")
960 })
961 .map(RefCast::ref_cast)
962 .collect();
963
964 (subgraph)(context, &recvs, &sends).await;
965 };
966 SubgraphData::new(
967 name.into(),
968 stratum,
969 subgraph,
970 subgraph_preds,
971 subgraph_succs,
972 true,
973 false,
974 None,
975 0,
976 )
977 });
978 self.context.init_stratum(stratum);
979 self.context.stratum_queues[stratum].push_back(sg_id);
980
981 Rc::make_mut(&mut self.metrics)
983 .subgraphs
984 .insert(sg_id, Default::default());
985
986 sg_id
987 }
988
989 pub fn make_edge<Name, H>(&mut self, name: Name) -> (SendPort<H>, RecvPort<H>)
991 where
992 Name: Into<Cow<'static, str>>,
993 H: 'static + Handoff,
994 {
995 let handoff = H::default();
997 let handoff_id = self
998 .handoffs
999 .insert_with_key(|hoff_id| HandoffData::new(name.into(), handoff, hoff_id));
1000
1001 Rc::make_mut(&mut self.metrics)
1003 .handoffs
1004 .insert(handoff_id, Default::default());
1005
1006 let input_port = SendPort {
1008 handoff_id,
1009 _marker: PhantomData,
1010 };
1011 let output_port = RecvPort {
1012 handoff_id,
1013 _marker: PhantomData,
1014 };
1015 (input_port, output_port)
1016 }
1017
1018 pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
1023 where
1024 T: Any,
1025 {
1026 self.context.add_state(state)
1027 }
1028
1029 pub fn set_state_lifespan_hook<T>(
1033 &mut self,
1034 handle: StateHandle<T>,
1035 lifespan: StateLifespan,
1036 hook_fn: impl 'static + FnMut(&mut T),
1037 ) where
1038 T: Any,
1039 {
1040 self.context
1041 .set_state_lifespan_hook(handle, lifespan, hook_fn)
1042 }
1043
1044 pub fn context_mut(&mut self, sg_id: SubgraphId) -> &mut Context {
1046 self.context.subgraph_id = sg_id;
1047 &mut self.context
1048 }
1049
1050 pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
1055 let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
1056 let loop_id = self.context.loop_depth.insert(depth);
1057 self.loop_data.insert(
1058 loop_id,
1059 LoopData {
1060 iter_count: None,
1061 allow_another_iteration: true,
1062 },
1063 );
1064 loop_id
1065 }
1066
1067 pub fn metrics(&self) -> Rc<DfirMetrics> {
1069 Rc::clone(&self.metrics)
1070 }
1071
1072 pub fn metrics_intervals(&self) -> DfirMetricsIntervals {
1080 DfirMetricsIntervals {
1081 curr: self.metrics(),
1082 prev: None,
1083 }
1084 }
1085}
1086
1087impl Dfir<'_> {
1088 pub fn request_task<Fut>(&mut self, future: Fut)
1090 where
1091 Fut: Future<Output = ()> + 'static,
1092 {
1093 self.context.request_task(future);
1094 }
1095
1096 pub fn abort_tasks(&mut self) {
1098 self.context.abort_tasks()
1099 }
1100
1101 pub fn join_tasks(&mut self) -> impl use<'_> + Future {
1103 self.context.join_tasks()
1104 }
1105}
1106
1107fn run_sync<Fut>(fut: Fut) -> Fut::Output
1108where
1109 Fut: Future,
1110{
1111 let mut fut = std::pin::pin!(fut);
1112 let mut ctx = std::task::Context::from_waker(std::task::Waker::noop());
1113 match fut.as_mut().poll(&mut ctx) {
1114 std::task::Poll::Ready(out) => out,
1115 std::task::Poll::Pending => panic!("Future did not resolve immediately."),
1116 }
1117}
1118
1119impl Drop for Dfir<'_> {
1120 fn drop(&mut self) {
1121 self.abort_tasks();
1122 }
1123}
1124
1125#[doc(hidden)]
1131pub struct HandoffData {
1132 pub(super) name: Cow<'static, str>,
1134 pub(super) handoff: Box<dyn HandoffMeta>,
1136 pub(super) preds: SmallVec<[SubgraphId; 1]>,
1138 pub(super) succs: SmallVec<[SubgraphId; 1]>,
1140
1141 pub(super) pred_handoffs: Vec<HandoffId>,
1147 pub(super) succ_handoffs: Vec<HandoffId>,
1153}
1154
1155impl std::fmt::Debug for HandoffData {
1156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1157 f.debug_struct("HandoffData")
1158 .field("preds", &self.preds)
1159 .field("succs", &self.succs)
1160 .finish_non_exhaustive()
1161 }
1162}
1163
1164impl HandoffData {
1165 pub fn new(
1167 name: Cow<'static, str>,
1168 handoff: impl 'static + HandoffMeta,
1169 hoff_id: HandoffId,
1170 ) -> Self {
1171 let (preds, succs) = Default::default();
1172 Self {
1173 name,
1174 handoff: Box::new(handoff),
1175 preds,
1176 succs,
1177 pred_handoffs: vec![hoff_id],
1178 succ_handoffs: vec![hoff_id],
1179 }
1180 }
1181}
1182
1183pub(super) struct SubgraphData<'a> {
1188 pub(super) name: Cow<'static, str>,
1190 pub(super) stratum: usize,
1194 subgraph: Box<dyn 'a + Subgraph>,
1196
1197 preds: Vec<HandoffId>,
1198 succs: Vec<HandoffId>,
1199
1200 is_scheduled: Cell<bool>,
1205
1206 last_tick_run_in: Option<TickInstant>,
1208 last_loop_nonce: (usize, Option<usize>),
1211
1212 is_lazy: bool,
1214
1215 loop_id: Option<LoopId>,
1217 loop_depth: usize,
1219}
1220
1221impl<'a> SubgraphData<'a> {
1222 #[expect(clippy::too_many_arguments, reason = "internal use")]
1223 pub(crate) fn new(
1224 name: Cow<'static, str>,
1225 stratum: usize,
1226 subgraph: impl 'a + Subgraph,
1227 preds: Vec<HandoffId>,
1228 succs: Vec<HandoffId>,
1229 is_scheduled: bool,
1230 is_lazy: bool,
1231 loop_id: Option<LoopId>,
1232 loop_depth: usize,
1233 ) -> Self {
1234 Self {
1235 name,
1236 stratum,
1237 subgraph: Box::new(subgraph),
1238 preds,
1239 succs,
1240 is_scheduled: Cell::new(is_scheduled),
1241 last_tick_run_in: None,
1242 last_loop_nonce: (0, None),
1243 is_lazy,
1244 loop_id,
1245 loop_depth,
1246 }
1247 }
1248}
1249
1250pub(crate) struct LoopData {
1251 iter_count: Option<usize>,
1253 allow_another_iteration: bool,
1255}
1256
1257#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1259pub enum StateLifespan {
1260 Subgraph(SubgraphId),
1262 Loop(LoopId),
1264 Tick,
1266 Static,
1268}