1use std::any::Any;
4use std::borrow::Cow;
5use std::cell::Cell;
6use std::cmp::Ordering;
7use std::future::Future;
8use std::marker::PhantomData;
9
10#[cfg(feature = "meta")]
11use dfir_lang::diagnostic::{Diagnostic, SerdeSpan};
12#[cfg(feature = "meta")]
13use dfir_lang::graph::DfirGraph;
14use ref_cast::RefCast;
15use smallvec::SmallVec;
16use web_time::SystemTime;
17
18use super::context::Context;
19use super::handoff::handoff_list::PortList;
20use super::handoff::{Handoff, HandoffMeta, TeeingHandoff};
21use super::port::{RECV, RecvCtx, RecvPort, SEND, SendCtx, SendPort};
22use super::reactor::Reactor;
23use super::state::StateHandle;
24use super::subgraph::Subgraph;
25use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
26use crate::Never;
27use crate::scheduled::ticks::{TickDuration, TickInstant};
28use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
29
30#[derive(Default)]
32pub struct Dfir<'a> {
33 pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
34
35 pub(super) loop_data: SecondarySlotVec<LoopTag, LoopData>,
36
37 pub(super) context: Context,
38
39 handoffs: SlotVec<HandoffTag, HandoffData>,
40
41 #[cfg(feature = "meta")]
42 meta_graph: Option<DfirGraph>,
44
45 #[cfg(feature = "meta")]
46 diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
48}
49
50impl Dfir<'_> {
52 pub fn teeing_handoff_tee<T>(
54 &mut self,
55 tee_parent_port: &RecvPort<TeeingHandoff<T>>,
56 ) -> RecvPort<TeeingHandoff<T>>
57 where
58 T: Clone,
59 {
60 let tee_root = self.handoffs[tee_parent_port.handoff_id].pred_handoffs[0];
62
63 let tee_root_data = &mut self.handoffs[tee_root];
65 let tee_root_data_name = tee_root_data.name.clone();
66
67 let teeing_handoff =
69 <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*tee_root_data.handoff).unwrap();
70 let new_handoff = teeing_handoff.tee();
71
72 let new_hoff_id = self.handoffs.insert_with_key(|new_hoff_id| {
74 let new_name = Cow::Owned(format!("{} tee {:?}", tee_root_data_name, new_hoff_id));
75 let mut new_handoff_data = HandoffData::new(new_name, new_handoff, new_hoff_id);
76 new_handoff_data.pred_handoffs = vec![tee_root];
78 new_handoff_data
79 });
80
81 let tee_root_data = &mut self.handoffs[tee_root];
83 tee_root_data.succ_handoffs.push(new_hoff_id);
84
85 assert!(
88 tee_root_data.preds.len() <= 1,
89 "Tee send side should only have one sender (or none set yet)."
90 );
91 if let Some(&pred_sg_id) = tee_root_data.preds.first() {
92 self.subgraphs[pred_sg_id].succs.push(new_hoff_id);
93 }
94
95 let output_port = RecvPort {
96 handoff_id: new_hoff_id,
97 _marker: PhantomData,
98 };
99 output_port
100 }
101
102 pub fn teeing_handoff_drop<T>(&mut self, tee_port: RecvPort<TeeingHandoff<T>>)
107 where
108 T: Clone,
109 {
110 let data = &self.handoffs[tee_port.handoff_id];
111 let teeing_handoff = <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*data.handoff).unwrap();
112 teeing_handoff.drop();
113
114 let tee_root = data.pred_handoffs[0];
115 let tee_root_data = &mut self.handoffs[tee_root];
116 tee_root_data
118 .succ_handoffs
119 .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
120 assert!(
122 tee_root_data.preds.len() <= 1,
123 "Tee send side should only have one sender (or none set yet)."
124 );
125 if let Some(&pred_sg_id) = tee_root_data.preds.first() {
126 self.subgraphs[pred_sg_id]
127 .succs
128 .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
129 }
130 }
131}
132
133impl<'a> Dfir<'a> {
134 pub fn new() -> Self {
136 Default::default()
137 }
138
139 #[doc(hidden)]
141 pub fn __assign_meta_graph(&mut self, _meta_graph_json: &str) {
142 #[cfg(feature = "meta")]
143 {
144 let mut meta_graph: DfirGraph =
145 serde_json::from_str(_meta_graph_json).expect("Failed to deserialize graph.");
146
147 let mut op_inst_diagnostics = Vec::new();
148 meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
149 assert!(
150 op_inst_diagnostics.is_empty(),
151 "Expected no diagnostics, got: {:#?}",
152 op_inst_diagnostics
153 );
154
155 assert!(self.meta_graph.replace(meta_graph).is_none());
156 }
157 }
158 #[doc(hidden)]
160 pub fn __assign_diagnostics(&mut self, _diagnostics_json: &'static str) {
161 #[cfg(feature = "meta")]
162 {
163 let diagnostics: Vec<Diagnostic<SerdeSpan>> = serde_json::from_str(_diagnostics_json)
164 .expect("Failed to deserialize diagnostics.");
165
166 assert!(self.diagnostics.replace(diagnostics).is_none());
167 }
168 }
169
170 #[cfg(feature = "meta")]
174 #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
175 pub fn meta_graph(&self) -> Option<&DfirGraph> {
176 self.meta_graph.as_ref()
177 }
178
179 #[cfg(feature = "meta")]
184 #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
185 pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
186 self.diagnostics.as_deref()
187 }
188
189 pub fn reactor(&self) -> Reactor {
192 Reactor::new(self.context.event_queue_send.clone())
193 }
194
195 pub fn current_tick(&self) -> TickInstant {
197 self.context.current_tick
198 }
199
200 pub fn current_stratum(&self) -> usize {
202 self.context.current_stratum
203 }
204
205 #[tracing::instrument(level = "trace", skip(self), ret)]
211 pub async fn run_tick(&mut self) -> bool {
212 let mut work_done = false;
213 while self.next_stratum(true) {
215 work_done = true;
216 self.run_stratum().await;
218 }
219 work_done
220 }
221
222 #[tracing::instrument(level = "trace", skip(self), ret)]
230 pub async fn run_available(&mut self) -> bool {
231 let mut work_done = false;
232 while self.next_stratum(false) {
234 work_done = true;
235 self.run_stratum().await;
237
238 tokio::task::yield_now().await;
241 }
242 work_done
243 }
244
245 #[tracing::instrument(level = "trace", skip(self), fields(tick = u64::from(self.context.current_tick), stratum = self.context.current_stratum), ret)]
249 pub async fn run_stratum(&mut self) -> bool {
250 self.context.spawn_tasks();
253
254 let mut work_done = false;
255
256 'pop: while let Some(sg_id) =
257 self.context.stratum_queues[self.context.current_stratum].pop_front()
258 {
259 {
260 let sg_data = &mut self.subgraphs[sg_id];
261 assert!(sg_data.is_scheduled.take());
263
264 let _enter = tracing::info_span!(
265 "run-subgraph",
266 sg_id = sg_id.to_string(),
267 sg_name = &*sg_data.name,
268 sg_depth = sg_data.loop_depth,
269 sg_loop_nonce = sg_data.last_loop_nonce.0,
270 sg_iter_count = sg_data.last_loop_nonce.1,
271 )
272 .entered();
273
274 match sg_data.loop_depth.cmp(&self.context.loop_nonce_stack.len()) {
275 Ordering::Greater => {
276 self.context.loop_nonce += 1;
278 self.context.loop_nonce_stack.push(self.context.loop_nonce);
279 tracing::trace!(loop_nonce = self.context.loop_nonce, "Entered loop.");
280 }
281 Ordering::Less => {
282 self.context.loop_nonce_stack.pop();
284 tracing::trace!("Exited loop.");
285 }
286 Ordering::Equal => {}
287 }
288
289 self.context.subgraph_id = sg_id;
290 self.context.is_first_run_this_tick = sg_data
291 .last_tick_run_in
292 .is_none_or(|last_tick| last_tick < self.context.current_tick);
293
294 if let Some(loop_id) = sg_data.loop_id {
295 let curr_loop_nonce = self.context.loop_nonce_stack.last().copied();
303
304 let LoopData {
305 iter_count: loop_iter_count,
306 allow_another_iteration,
307 } = &mut self.loop_data[loop_id];
308
309 let (prev_loop_nonce, prev_iter_count) = sg_data.last_loop_nonce;
310
311 let (curr_iter_count, new_loop_execution) =
316 if curr_loop_nonce.is_none_or(|nonce| nonce == prev_loop_nonce) {
317 if *loop_iter_count == prev_iter_count {
320 if !std::mem::take(allow_another_iteration) {
322 tracing::debug!(
323 "Loop will not continue to next iteration, skipping."
324 );
325 continue 'pop;
326 }
327 loop_iter_count.map_or((0, true), |n| (n + 1, false))
329 } else {
330 debug_assert!(
332 prev_iter_count < *loop_iter_count,
333 "Expect loop iteration count to be increasing."
334 );
335 (loop_iter_count.unwrap(), false)
336 }
337 } else {
338 (0, false)
340 };
341
342 if new_loop_execution {
343 self.context.run_state_hooks_loop(loop_id);
345 }
346 tracing::debug!("Loop iteration count {}", curr_iter_count);
347
348 *loop_iter_count = Some(curr_iter_count);
349 self.context.loop_iter_count = curr_iter_count;
350 sg_data.last_loop_nonce =
351 (curr_loop_nonce.unwrap_or_default(), Some(curr_iter_count));
352 }
353
354 self.context.run_state_hooks_subgraph(sg_id);
356
357 tracing::info!("Running subgraph.");
358 sg_data.last_tick_run_in = Some(self.context.current_tick);
359 Box::into_pin(sg_data.subgraph.run(&mut self.context, &mut self.handoffs)).await;
360 };
361
362 let sg_data = &self.subgraphs[sg_id];
363 for &handoff_id in sg_data.succs.iter() {
364 let handoff = &self.handoffs[handoff_id];
365 if !handoff.handoff.is_bottom() {
366 for &succ_id in handoff.succs.iter() {
367 let succ_sg_data = &self.subgraphs[succ_id];
368 if succ_sg_data.stratum < self.context.current_stratum && !sg_data.is_lazy {
370 self.context.can_start_tick = true;
371 }
372 if !succ_sg_data.is_scheduled.replace(true) {
374 self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
375 }
376 if 0 < succ_sg_data.loop_depth {
378 self.context
380 .stratum_stack
381 .push(succ_sg_data.loop_depth, succ_sg_data.stratum);
382 }
383 }
384 }
385 }
386
387 let reschedule = self.context.reschedule_loop_block.take();
388 let allow_another = self.context.allow_another_iteration.take();
389
390 if reschedule {
391 self.context.schedule_deferred.push(sg_id);
393 self.context
394 .stratum_stack
395 .push(sg_data.loop_depth, sg_data.stratum);
396 }
397 if (reschedule || allow_another)
398 && let Some(loop_id) = sg_data.loop_id
399 {
400 self.loop_data
401 .get_mut(loop_id)
402 .unwrap()
403 .allow_another_iteration = true;
404 }
405
406 work_done = true;
407 }
408 work_done
409 }
410
411 #[tracing::instrument(level = "trace", skip(self), ret)]
423 pub fn next_stratum(&mut self, current_tick_only: bool) -> bool {
424 tracing::trace!(
425 events_received_tick = self.context.events_received_tick,
426 can_start_tick = self.context.can_start_tick,
427 "Starting `next_stratum` call.",
428 );
429
430 let mut end_stratum = self.context.current_stratum;
432 let mut new_tick_started = false;
433
434 if 0 == self.context.current_stratum {
435 new_tick_started = true;
436
437 tracing::trace!("Starting tick, setting `can_start_tick = false`.");
439 self.context.can_start_tick = false;
440 self.context.current_tick_start = SystemTime::now();
441
442 if !self.context.events_received_tick {
444 self.try_recv_events();
446 }
447 }
448
449 loop {
450 tracing::trace!(
451 tick = u64::from(self.context.current_tick),
452 stratum = self.context.current_stratum,
453 "Looking for work on stratum."
454 );
455 if !self.context.stratum_queues[self.context.current_stratum].is_empty() {
457 tracing::trace!(
458 tick = u64::from(self.context.current_tick),
459 stratum = self.context.current_stratum,
460 "Work found on stratum."
461 );
462 return true;
463 }
464
465 if let Some(next_stratum) = self.context.stratum_stack.pop() {
466 self.context.current_stratum = next_stratum;
467
468 {
470 for sg_id in self.context.schedule_deferred.drain(..) {
471 let sg_data = &self.subgraphs[sg_id];
472 tracing::info!(
473 tick = u64::from(self.context.current_tick),
474 stratum = self.context.current_stratum,
475 sg_id = sg_id.to_string(),
476 sg_name = &*sg_data.name,
477 is_scheduled = sg_data.is_scheduled.get(),
478 "Rescheduling deferred subgraph."
479 );
480 if !sg_data.is_scheduled.replace(true) {
481 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
482 }
483 }
484 }
485 } else {
486 self.context.current_stratum += 1;
488
489 if self.context.current_stratum >= self.context.stratum_queues.len() {
490 new_tick_started = true;
491
492 tracing::trace!(
493 can_start_tick = self.context.can_start_tick,
494 "End of tick {}, starting tick {}.",
495 self.context.current_tick,
496 self.context.current_tick + TickDuration::SINGLE_TICK,
497 );
498 self.context.run_state_hooks_tick();
499
500 self.context.current_stratum = 0;
501 self.context.current_tick += TickDuration::SINGLE_TICK;
502 self.context.events_received_tick = false;
503
504 if current_tick_only {
505 tracing::trace!(
506 "`current_tick_only` is `true`, returning `false` before receiving events."
507 );
508 return false;
509 } else {
510 self.try_recv_events();
511 if std::mem::replace(&mut self.context.can_start_tick, false) {
512 tracing::trace!(
513 tick = u64::from(self.context.current_tick),
514 "`can_start_tick` is `true`, continuing."
515 );
516 end_stratum = 0;
518 continue;
519 } else {
520 tracing::trace!(
521 "`can_start_tick` is `false`, re-setting `events_received_tick = false`, returning `false`."
522 );
523 self.context.events_received_tick = false;
524 return false;
525 }
526 }
527 }
528 }
529
530 if new_tick_started && end_stratum == self.context.current_stratum {
532 tracing::trace!(
533 "Made full loop around stratum, re-setting `current_stratum = 0`, returning `false`."
534 );
535 self.context.events_received_tick = false;
540 self.context.current_stratum = 0;
541 return false;
542 }
543 }
544 }
545
546 #[tracing::instrument(level = "trace", skip(self), ret)]
550 pub async fn run(&mut self) -> Option<Never> {
551 loop {
552 self.run_available().await;
554 self.recv_events_async().await;
556 }
557 }
558
559 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
563 pub fn try_recv_events(&mut self) -> usize {
564 let mut enqueued_count = 0;
565 while let Ok((sg_id, is_external)) = self.context.event_queue_recv.try_recv() {
566 let sg_data = &self.subgraphs[sg_id];
567 tracing::trace!(
568 sg_id = sg_id.to_string(),
569 is_external = is_external,
570 sg_stratum = sg_data.stratum,
571 "Event received."
572 );
573 if !sg_data.is_scheduled.replace(true) {
574 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
575 enqueued_count += 1;
576 }
577 if is_external {
578 if !self.context.events_received_tick
581 || sg_data.stratum < self.context.current_stratum
582 {
583 tracing::trace!(
584 current_stratum = self.context.current_stratum,
585 sg_stratum = sg_data.stratum,
586 "External event, setting `can_start_tick = true`."
587 );
588 self.context.can_start_tick = true;
589 }
590 }
591 }
592 self.context.events_received_tick = true;
593
594 enqueued_count
595 }
596
597 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
600 pub fn recv_events(&mut self) -> Option<usize> {
601 let mut count = 0;
602 loop {
603 let (sg_id, is_external) = self.context.event_queue_recv.blocking_recv()?;
604 let sg_data = &self.subgraphs[sg_id];
605 tracing::trace!(
606 sg_id = sg_id.to_string(),
607 is_external = is_external,
608 sg_stratum = sg_data.stratum,
609 "Event received."
610 );
611 if !sg_data.is_scheduled.replace(true) {
612 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
613 count += 1;
614 }
615 if is_external {
616 if !self.context.events_received_tick
619 || sg_data.stratum < self.context.current_stratum
620 {
621 tracing::trace!(
622 current_stratum = self.context.current_stratum,
623 sg_stratum = sg_data.stratum,
624 "External event, setting `can_start_tick = true`."
625 );
626 self.context.can_start_tick = true;
627 }
628 break;
629 }
630 }
631 self.context.events_received_tick = true;
632
633 let extra_count = self.try_recv_events();
635 Some(count + extra_count)
636 }
637
638 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
644 pub async fn recv_events_async(&mut self) -> Option<usize> {
645 let mut count = 0;
646 loop {
647 tracing::trace!("Awaiting events (`event_queue_recv`).");
648 let (sg_id, is_external) = self.context.event_queue_recv.recv().await?;
649 let sg_data = &self.subgraphs[sg_id];
650 tracing::trace!(
651 sg_id = sg_id.to_string(),
652 is_external = is_external,
653 sg_stratum = sg_data.stratum,
654 "Event received."
655 );
656 if !sg_data.is_scheduled.replace(true) {
657 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
658 count += 1;
659 }
660 if is_external {
661 if !self.context.events_received_tick
664 || sg_data.stratum < self.context.current_stratum
665 {
666 tracing::trace!(
667 current_stratum = self.context.current_stratum,
668 sg_stratum = sg_data.stratum,
669 "External event, setting `can_start_tick = true`."
670 );
671 self.context.can_start_tick = true;
672 }
673 break;
674 }
675 }
676 self.context.events_received_tick = true;
677
678 let extra_count = self.try_recv_events();
680 Some(count + extra_count)
681 }
682
683 pub fn schedule_subgraph(&mut self, sg_id: SubgraphId) -> bool {
685 let sg_data = &self.subgraphs[sg_id];
686 let already_scheduled = sg_data.is_scheduled.replace(true);
687 if !already_scheduled {
688 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
689 true
690 } else {
691 false
692 }
693 }
694
695 pub fn add_subgraph<Name, R, W, Func, Fut>(
697 &mut self,
698 name: Name,
699 recv_ports: R,
700 send_ports: W,
701 subgraph: Func,
702 ) -> SubgraphId
703 where
704 Name: Into<Cow<'static, str>>,
705 R: 'static + PortList<RECV>,
706 W: 'static + PortList<SEND>,
707 Func: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>) -> Fut,
708 Fut: 'a + Future<Output = ()>,
709 {
710 self.add_subgraph_stratified(name, 0, recv_ports, send_ports, false, subgraph)
711 }
712
713 pub fn add_subgraph_stratified<Name, R, W, Func, Fut>(
717 &mut self,
718 name: Name,
719 stratum: usize,
720 recv_ports: R,
721 send_ports: W,
722 laziness: bool,
723 subgraph: Func,
724 ) -> SubgraphId
725 where
726 Name: Into<Cow<'static, str>>,
727 R: 'static + PortList<RECV>,
728 W: 'static + PortList<SEND>,
729 Func: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>) -> Fut,
730 Fut: 'a + Future<Output = ()>,
731 {
732 self.add_subgraph_full(
733 name, stratum, recv_ports, send_ports, laziness, None, subgraph,
734 )
735 }
736
737 #[expect(clippy::too_many_arguments, reason = "Mainly for internal use.")]
739 pub fn add_subgraph_full<Name, R, W, Func, Fut>(
740 &mut self,
741 name: Name,
742 stratum: usize,
743 recv_ports: R,
744 send_ports: W,
745 laziness: bool,
746 loop_id: Option<LoopId>,
747 mut subgraph: Func,
748 ) -> SubgraphId
749 where
750 Name: Into<Cow<'static, str>>,
751 R: 'static + PortList<RECV>,
752 W: 'static + PortList<SEND>,
753 Func: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>) -> Fut,
754 Fut: 'a + Future<Output = ()>,
755 {
756 recv_ports.assert_is_from(&self.handoffs);
758 send_ports.assert_is_from(&self.handoffs);
759
760 let loop_depth = loop_id
761 .and_then(|loop_id| self.context.loop_depth.get(loop_id))
762 .copied()
763 .unwrap_or(0);
764
765 let sg_id = self.subgraphs.insert_with_key(|sg_id| {
766 let (mut subgraph_preds, mut subgraph_succs) = Default::default();
767 recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
768 send_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_succs, sg_id, false);
769
770 let subgraph =
771 move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
772 let (recv, send) = unsafe {
773 (
777 recv_ports.make_ctx(&*handoffs),
778 send_ports.make_ctx(&*handoffs),
779 )
780 };
781 (subgraph)(context, recv, send)
782 };
783 SubgraphData::new(
784 name.into(),
785 stratum,
786 subgraph,
787 subgraph_preds,
788 subgraph_succs,
789 true,
790 laziness,
791 loop_id,
792 loop_depth,
793 )
794 });
795 self.context.init_stratum(stratum);
796 self.context.stratum_queues[stratum].push_back(sg_id);
797
798 sg_id
799 }
800
801 pub fn add_subgraph_n_m<Name, R, W, Func, Fut>(
803 &mut self,
804 name: Name,
805 recv_ports: Vec<RecvPort<R>>,
806 send_ports: Vec<SendPort<W>>,
807 subgraph: Func,
808 ) -> SubgraphId
809 where
810 Name: Into<Cow<'static, str>>,
811 R: 'static + Handoff,
812 W: 'static + Handoff,
813 Func: 'a
814 + for<'ctx> FnMut(
815 &'ctx mut Context,
816 &'ctx [&'ctx RecvCtx<R>],
817 &'ctx [&'ctx SendCtx<W>],
818 ) -> Fut,
819 Fut: 'a + Future<Output = ()>,
820 {
821 self.add_subgraph_stratified_n_m(name, 0, recv_ports, send_ports, subgraph)
822 }
823
824 pub fn add_subgraph_stratified_n_m<Name, R, W, Func, Fut>(
826 &mut self,
827 name: Name,
828 stratum: usize,
829 recv_ports: Vec<RecvPort<R>>,
830 send_ports: Vec<SendPort<W>>,
831 mut subgraph: Func,
832 ) -> SubgraphId
833 where
834 Name: Into<Cow<'static, str>>,
835 R: 'static + Handoff,
836 W: 'static + Handoff,
837 Func: 'a
838 + for<'ctx> FnMut(
839 &'ctx mut Context,
840 &'ctx [&'ctx RecvCtx<R>],
841 &'ctx [&'ctx SendCtx<W>],
842 ) -> Fut,
843 Fut: 'a + Future<Output = ()>,
844 {
845 let sg_id = self.subgraphs.insert_with_key(|sg_id| {
846 let subgraph_preds = recv_ports.iter().map(|port| port.handoff_id).collect();
847 let subgraph_succs = send_ports.iter().map(|port| port.handoff_id).collect();
848
849 for recv_port in recv_ports.iter() {
850 self.handoffs[recv_port.handoff_id].succs.push(sg_id);
851 }
852 for send_port in send_ports.iter() {
853 self.handoffs[send_port.handoff_id].preds.push(sg_id);
854 }
855
856 let subgraph =
857 move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
858 let recvs: Vec<&RecvCtx<R>> = recv_ports
859 .iter()
860 .map(|hid| hid.handoff_id)
861 .map(|hid| handoffs.get(hid).unwrap())
862 .map(|h_data| {
863 <dyn Any>::downcast_ref(&*h_data.handoff)
864 .expect("Attempted to cast handoff to wrong type.")
865 })
866 .map(RefCast::ref_cast)
867 .collect();
868
869 let sends: Vec<&SendCtx<W>> = send_ports
870 .iter()
871 .map(|hid| hid.handoff_id)
872 .map(|hid| handoffs.get(hid).unwrap())
873 .map(|h_data| {
874 <dyn Any>::downcast_ref(&*h_data.handoff)
875 .expect("Attempted to cast handoff to wrong type.")
876 })
877 .map(RefCast::ref_cast)
878 .collect();
879
880 (subgraph)(context, &recvs, &sends)
881 };
882 SubgraphData::new(
883 name.into(),
884 stratum,
885 subgraph,
886 subgraph_preds,
887 subgraph_succs,
888 true,
889 false,
890 None,
891 0,
892 )
893 });
894
895 self.context.init_stratum(stratum);
896 self.context.stratum_queues[stratum].push_back(sg_id);
897
898 sg_id
899 }
900
901 pub fn make_edge<Name, H>(&mut self, name: Name) -> (SendPort<H>, RecvPort<H>)
903 where
904 Name: Into<Cow<'static, str>>,
905 H: 'static + Handoff,
906 {
907 let handoff = H::default();
909 let handoff_id = self
910 .handoffs
911 .insert_with_key(|hoff_id| HandoffData::new(name.into(), handoff, hoff_id));
912
913 let input_port = SendPort {
915 handoff_id,
916 _marker: PhantomData,
917 };
918 let output_port = RecvPort {
919 handoff_id,
920 _marker: PhantomData,
921 };
922 (input_port, output_port)
923 }
924
925 pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
930 where
931 T: Any,
932 {
933 self.context.add_state(state)
934 }
935
936 pub fn set_state_lifespan_hook<T>(
940 &mut self,
941 handle: StateHandle<T>,
942 lifespan: StateLifespan,
943 hook_fn: impl 'static + FnMut(&mut T),
944 ) where
945 T: Any,
946 {
947 self.context
948 .set_state_lifespan_hook(handle, lifespan, hook_fn)
949 }
950
951 pub fn context_mut(&mut self, sg_id: SubgraphId) -> &mut Context {
953 self.context.subgraph_id = sg_id;
954 &mut self.context
955 }
956
957 pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
962 let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
963 let loop_id = self.context.loop_depth.insert(depth);
964 self.loop_data.insert(
965 loop_id,
966 LoopData {
967 iter_count: None,
968 allow_another_iteration: true,
969 },
970 );
971 loop_id
972 }
973}
974
975impl Dfir<'_> {
976 pub fn request_task<Fut>(&mut self, future: Fut)
978 where
979 Fut: Future<Output = ()> + 'static,
980 {
981 self.context.request_task(future);
982 }
983
984 pub fn abort_tasks(&mut self) {
986 self.context.abort_tasks()
987 }
988
989 pub fn join_tasks(&mut self) -> impl use<'_> + Future {
991 self.context.join_tasks()
992 }
993}
994
995macro_rules! make_sync {
996 ($synch:ident, $asynch:ident, $ret:ty) => {
997 #[doc = concat!("A synchronous wrapper around [`Self::", stringify!($asynch), "`].")]
998 pub fn $synch(&mut self) -> $ret {
1001 use std::sync::Arc;
1002
1003 #[derive(Default)]
1004 struct BoolWaker {
1005 woke: std::sync::atomic::AtomicBool,
1006 }
1007 impl BoolWaker {
1008 fn new() -> Arc<Self> {
1009 Arc::new(Self::default())
1010 }
1011
1012 fn woke(&self) -> bool {
1013 self.woke.load(std::sync::atomic::Ordering::Relaxed)
1014 }
1015 }
1016 impl futures::task::ArcWake for BoolWaker {
1017 fn wake_by_ref(arc_self: &Arc<Self>) {
1018 arc_self.woke.store(true, std::sync::atomic::Ordering::Relaxed);
1019 }
1020 }
1021
1022 let mut fut = std::pin::pin!(self.$asynch());
1023 loop {
1024 let bool_waker = BoolWaker::new();
1025 let waker = futures::task::waker(Arc::clone(&bool_waker));
1026 let mut ctx = std::task::Context::from_waker(&waker);
1027 if let std::task::Poll::Ready(out) = fut.as_mut().poll(&mut ctx) {
1028 return out;
1029 }
1030 if !bool_waker.woke() {
1032 panic!(
1033 "Future has pending work: DFIR graph has an async subgraph which failed to run synchronously."
1034 )
1035 }
1036 }
1037 }
1038 };
1039}
1040impl Dfir<'_> {
1041 make_sync!(run_available_sync, run_available, bool);
1042 make_sync!(run_tick_sync, run_tick, bool);
1043 make_sync!(run_sync, run, Option<Never>);
1044}
1045
1046impl Drop for Dfir<'_> {
1047 fn drop(&mut self) {
1048 self.abort_tasks();
1049 }
1050}
1051
1052#[doc(hidden)]
1058pub struct HandoffData {
1059 pub(super) name: Cow<'static, str>,
1061 pub(super) handoff: Box<dyn HandoffMeta>,
1063 pub(super) preds: SmallVec<[SubgraphId; 1]>,
1065 pub(super) succs: SmallVec<[SubgraphId; 1]>,
1067
1068 pub(super) pred_handoffs: Vec<HandoffId>,
1074 pub(super) succ_handoffs: Vec<HandoffId>,
1080}
1081impl std::fmt::Debug for HandoffData {
1082 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1083 f.debug_struct("HandoffData")
1084 .field("preds", &self.preds)
1085 .field("succs", &self.succs)
1086 .finish_non_exhaustive()
1087 }
1088}
1089impl HandoffData {
1090 pub fn new(
1092 name: Cow<'static, str>,
1093 handoff: impl 'static + HandoffMeta,
1094 hoff_id: HandoffId,
1095 ) -> Self {
1096 let (preds, succs) = Default::default();
1097 Self {
1098 name,
1099 handoff: Box::new(handoff),
1100 preds,
1101 succs,
1102 pred_handoffs: vec![hoff_id],
1103 succ_handoffs: vec![hoff_id],
1104 }
1105 }
1106}
1107
1108pub(super) struct SubgraphData<'a> {
1113 pub(super) name: Cow<'static, str>,
1115 pub(super) stratum: usize,
1119 subgraph: Box<dyn 'a + Subgraph<'a>>,
1121
1122 #[expect(dead_code, reason = "may be useful in the future")]
1123 preds: Vec<HandoffId>,
1124 succs: Vec<HandoffId>,
1125
1126 is_scheduled: Cell<bool>,
1131
1132 last_tick_run_in: Option<TickInstant>,
1134 last_loop_nonce: (usize, Option<usize>),
1137
1138 is_lazy: bool,
1140
1141 loop_id: Option<LoopId>,
1143 loop_depth: usize,
1145}
1146impl<'a> SubgraphData<'a> {
1147 #[expect(clippy::too_many_arguments, reason = "internal use")]
1148 pub(crate) fn new(
1149 name: Cow<'static, str>,
1150 stratum: usize,
1151 subgraph: impl 'a + Subgraph<'a>,
1152 preds: Vec<HandoffId>,
1153 succs: Vec<HandoffId>,
1154 is_scheduled: bool,
1155 is_lazy: bool,
1156 loop_id: Option<LoopId>,
1157 loop_depth: usize,
1158 ) -> Self {
1159 Self {
1160 name,
1161 stratum,
1162 subgraph: Box::new(subgraph),
1163 preds,
1164 succs,
1165 is_scheduled: Cell::new(is_scheduled),
1166 last_tick_run_in: None,
1167 last_loop_nonce: (0, None),
1168 is_lazy,
1169 loop_id,
1170 loop_depth,
1171 }
1172 }
1173}
1174
1175pub(crate) struct LoopData {
1176 iter_count: Option<usize>,
1178 allow_another_iteration: bool,
1180}
1181
1182#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1184pub enum StateLifespan {
1185 Subgraph(SubgraphId),
1187 Loop(LoopId),
1189 Tick,
1191 Static,
1193}