1use std::any::Any;
4use std::borrow::Cow;
5use std::cell::Cell;
6use std::cmp::Ordering;
7use std::future::Future;
8use std::marker::PhantomData;
9
10#[cfg(feature = "meta")]
11use dfir_lang::diagnostic::{Diagnostic, SerdeSpan};
12#[cfg(feature = "meta")]
13use dfir_lang::graph::DfirGraph;
14use ref_cast::RefCast;
15use smallvec::SmallVec;
16use web_time::SystemTime;
17
18use super::context::Context;
19use super::handoff::handoff_list::PortList;
20use super::handoff::{Handoff, HandoffMeta, TeeingHandoff};
21use super::port::{RECV, RecvCtx, RecvPort, SEND, SendCtx, SendPort};
22use super::reactor::Reactor;
23use super::state::StateHandle;
24use super::subgraph::Subgraph;
25use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
26use crate::Never;
27use crate::scheduled::ticks::{TickDuration, TickInstant};
28use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
29
30#[derive(Default)]
32pub struct Dfir<'a> {
33 pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
34
35 pub(super) loop_data: SecondarySlotVec<LoopTag, LoopData>,
36
37 pub(super) context: Context,
38
39 handoffs: SlotVec<HandoffTag, HandoffData>,
40
41 #[cfg(feature = "meta")]
42 meta_graph: Option<DfirGraph>,
44
45 #[cfg(feature = "meta")]
46 diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
48}
49
50impl Dfir<'_> {
52 pub fn teeing_handoff_tee<T>(
54 &mut self,
55 tee_parent_port: &RecvPort<TeeingHandoff<T>>,
56 ) -> RecvPort<TeeingHandoff<T>>
57 where
58 T: Clone,
59 {
60 let tee_root = self.handoffs[tee_parent_port.handoff_id].pred_handoffs[0];
62
63 let tee_root_data = &mut self.handoffs[tee_root];
65 let tee_root_data_name = tee_root_data.name.clone();
66
67 let teeing_handoff =
69 <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*tee_root_data.handoff).unwrap();
70 let new_handoff = teeing_handoff.tee();
71
72 let new_hoff_id = self.handoffs.insert_with_key(|new_hoff_id| {
74 let new_name = Cow::Owned(format!("{} tee {:?}", tee_root_data_name, new_hoff_id));
75 let mut new_handoff_data = HandoffData::new(new_name, new_handoff, new_hoff_id);
76 new_handoff_data.pred_handoffs = vec![tee_root];
78 new_handoff_data
79 });
80
81 let tee_root_data = &mut self.handoffs[tee_root];
83 tee_root_data.succ_handoffs.push(new_hoff_id);
84
85 assert!(
88 tee_root_data.preds.len() <= 1,
89 "Tee send side should only have one sender (or none set yet)."
90 );
91 if let Some(&pred_sg_id) = tee_root_data.preds.first() {
92 self.subgraphs[pred_sg_id].succs.push(new_hoff_id);
93 }
94
95 let output_port = RecvPort {
96 handoff_id: new_hoff_id,
97 _marker: PhantomData,
98 };
99 output_port
100 }
101
102 pub fn teeing_handoff_drop<T>(&mut self, tee_port: RecvPort<TeeingHandoff<T>>)
107 where
108 T: Clone,
109 {
110 let data = &self.handoffs[tee_port.handoff_id];
111 let teeing_handoff = <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*data.handoff).unwrap();
112 teeing_handoff.drop();
113
114 let tee_root = data.pred_handoffs[0];
115 let tee_root_data = &mut self.handoffs[tee_root];
116 tee_root_data
118 .succ_handoffs
119 .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
120 assert!(
122 tee_root_data.preds.len() <= 1,
123 "Tee send side should only have one sender (or none set yet)."
124 );
125 if let Some(&pred_sg_id) = tee_root_data.preds.first() {
126 self.subgraphs[pred_sg_id]
127 .succs
128 .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
129 }
130 }
131}
132
133impl<'a> Dfir<'a> {
134 pub fn new() -> Self {
136 Default::default()
137 }
138
139 #[doc(hidden)]
141 pub fn __assign_meta_graph(&mut self, _meta_graph_json: &str) {
142 #[cfg(feature = "meta")]
143 {
144 let mut meta_graph: DfirGraph =
145 serde_json::from_str(_meta_graph_json).expect("Failed to deserialize graph.");
146
147 let mut op_inst_diagnostics = Vec::new();
148 meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
149 assert!(
150 op_inst_diagnostics.is_empty(),
151 "Expected no diagnostics, got: {:#?}",
152 op_inst_diagnostics
153 );
154
155 assert!(self.meta_graph.replace(meta_graph).is_none());
156 }
157 }
158 #[doc(hidden)]
160 pub fn __assign_diagnostics(&mut self, _diagnostics_json: &'static str) {
161 #[cfg(feature = "meta")]
162 {
163 let diagnostics: Vec<Diagnostic<SerdeSpan>> = serde_json::from_str(_diagnostics_json)
164 .expect("Failed to deserialize diagnostics.");
165
166 assert!(self.diagnostics.replace(diagnostics).is_none());
167 }
168 }
169
170 #[cfg(feature = "meta")]
174 #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
175 pub fn meta_graph(&self) -> Option<&DfirGraph> {
176 self.meta_graph.as_ref()
177 }
178
179 #[cfg(feature = "meta")]
184 #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
185 pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
186 self.diagnostics.as_deref()
187 }
188
189 pub fn reactor(&self) -> Reactor {
192 Reactor::new(self.context.event_queue_send.clone())
193 }
194
195 pub fn current_tick(&self) -> TickInstant {
197 self.context.current_tick
198 }
199
200 pub fn current_stratum(&self) -> usize {
202 self.context.current_stratum
203 }
204
205 #[tracing::instrument(level = "trace", skip(self), ret)]
211 pub async fn run_tick(&mut self) -> bool {
212 let mut work_done = false;
213 while self.next_stratum(true) {
215 work_done = true;
216 self.run_stratum().await;
218 }
219 work_done
220 }
221
222 #[tracing::instrument(level = "trace", skip(self), ret)]
224 pub fn run_tick_sync(&mut self) -> bool {
225 let mut work_done = false;
226 while self.next_stratum(true) {
228 work_done = true;
229 run_sync(self.run_stratum());
231 }
232 work_done
233 }
234
235 #[tracing::instrument(level = "trace", skip(self), ret)]
243 pub async fn run_available(&mut self) -> bool {
244 let mut work_done = false;
245 while self.next_stratum(false) {
247 work_done = true;
248 self.run_stratum().await;
250
251 tokio::task::yield_now().await;
254 }
255 work_done
256 }
257
258 #[tracing::instrument(level = "trace", skip(self), ret)]
260 pub fn run_available_sync(&mut self) -> bool {
261 let mut work_done = false;
262 while self.next_stratum(false) {
264 work_done = true;
265 run_sync(self.run_stratum());
267 }
268 work_done
269 }
270
271 #[tracing::instrument(level = "trace", skip(self), fields(tick = u64::from(self.context.current_tick), stratum = self.context.current_stratum), ret)]
275 pub async fn run_stratum(&mut self) -> bool {
276 self.context.spawn_tasks();
279
280 let mut work_done = false;
281
282 'pop: while let Some(sg_id) =
283 self.context.stratum_queues[self.context.current_stratum].pop_front()
284 {
285 {
286 let sg_data = &mut self.subgraphs[sg_id];
287 assert!(sg_data.is_scheduled.take());
289
290 let _enter = tracing::info_span!(
291 "run-subgraph",
292 sg_id = sg_id.to_string(),
293 sg_name = &*sg_data.name,
294 sg_depth = sg_data.loop_depth,
295 sg_loop_nonce = sg_data.last_loop_nonce.0,
296 sg_iter_count = sg_data.last_loop_nonce.1,
297 )
298 .entered();
299
300 match sg_data.loop_depth.cmp(&self.context.loop_nonce_stack.len()) {
301 Ordering::Greater => {
302 self.context.loop_nonce += 1;
304 self.context.loop_nonce_stack.push(self.context.loop_nonce);
305 tracing::trace!(loop_nonce = self.context.loop_nonce, "Entered loop.");
306 }
307 Ordering::Less => {
308 self.context.loop_nonce_stack.pop();
310 tracing::trace!("Exited loop.");
311 }
312 Ordering::Equal => {}
313 }
314
315 self.context.subgraph_id = sg_id;
316 self.context.is_first_run_this_tick = sg_data
317 .last_tick_run_in
318 .is_none_or(|last_tick| last_tick < self.context.current_tick);
319
320 if let Some(loop_id) = sg_data.loop_id {
321 let curr_loop_nonce = self.context.loop_nonce_stack.last().copied();
329
330 let LoopData {
331 iter_count: loop_iter_count,
332 allow_another_iteration,
333 } = &mut self.loop_data[loop_id];
334
335 let (prev_loop_nonce, prev_iter_count) = sg_data.last_loop_nonce;
336
337 let (curr_iter_count, new_loop_execution) =
342 if curr_loop_nonce.is_none_or(|nonce| nonce == prev_loop_nonce) {
343 if *loop_iter_count == prev_iter_count {
346 if !std::mem::take(allow_another_iteration) {
348 tracing::debug!(
349 "Loop will not continue to next iteration, skipping."
350 );
351 continue 'pop;
352 }
353 loop_iter_count.map_or((0, true), |n| (n + 1, false))
355 } else {
356 debug_assert!(
358 prev_iter_count < *loop_iter_count,
359 "Expect loop iteration count to be increasing."
360 );
361 (loop_iter_count.unwrap(), false)
362 }
363 } else {
364 (0, false)
366 };
367
368 if new_loop_execution {
369 self.context.run_state_hooks_loop(loop_id);
371 }
372 tracing::debug!("Loop iteration count {}", curr_iter_count);
373
374 *loop_iter_count = Some(curr_iter_count);
375 self.context.loop_iter_count = curr_iter_count;
376 sg_data.last_loop_nonce =
377 (curr_loop_nonce.unwrap_or_default(), Some(curr_iter_count));
378 }
379
380 self.context.run_state_hooks_subgraph(sg_id);
382
383 tracing::info!("Running subgraph.");
384 sg_data.last_tick_run_in = Some(self.context.current_tick);
385 Box::into_pin(sg_data.subgraph.run(&mut self.context, &mut self.handoffs)).await;
386 };
387
388 let sg_data = &self.subgraphs[sg_id];
389 for &handoff_id in sg_data.succs.iter() {
390 let handoff = &self.handoffs[handoff_id];
391 if !handoff.handoff.is_bottom() {
392 for &succ_id in handoff.succs.iter() {
393 let succ_sg_data = &self.subgraphs[succ_id];
394 if succ_sg_data.stratum < self.context.current_stratum && !sg_data.is_lazy {
396 self.context.can_start_tick = true;
397 }
398 if !succ_sg_data.is_scheduled.replace(true) {
400 self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
401 }
402 if 0 < succ_sg_data.loop_depth {
404 self.context
406 .stratum_stack
407 .push(succ_sg_data.loop_depth, succ_sg_data.stratum);
408 }
409 }
410 }
411 }
412
413 let reschedule = self.context.reschedule_loop_block.take();
414 let allow_another = self.context.allow_another_iteration.take();
415
416 if reschedule {
417 self.context.schedule_deferred.push(sg_id);
419 self.context
420 .stratum_stack
421 .push(sg_data.loop_depth, sg_data.stratum);
422 }
423 if (reschedule || allow_another)
424 && let Some(loop_id) = sg_data.loop_id
425 {
426 self.loop_data
427 .get_mut(loop_id)
428 .unwrap()
429 .allow_another_iteration = true;
430 }
431
432 work_done = true;
433 }
434 work_done
435 }
436
437 #[tracing::instrument(level = "trace", skip(self), ret)]
449 pub fn next_stratum(&mut self, current_tick_only: bool) -> bool {
450 tracing::trace!(
451 events_received_tick = self.context.events_received_tick,
452 can_start_tick = self.context.can_start_tick,
453 "Starting `next_stratum` call.",
454 );
455
456 let mut end_stratum = self.context.current_stratum;
458 let mut new_tick_started = false;
459
460 if 0 == self.context.current_stratum {
461 new_tick_started = true;
462
463 tracing::trace!("Starting tick, setting `can_start_tick = false`.");
465 self.context.can_start_tick = false;
466 self.context.current_tick_start = SystemTime::now();
467
468 if !self.context.events_received_tick {
470 self.try_recv_events();
472 }
473 }
474
475 loop {
476 tracing::trace!(
477 tick = u64::from(self.context.current_tick),
478 stratum = self.context.current_stratum,
479 "Looking for work on stratum."
480 );
481 if !self.context.stratum_queues[self.context.current_stratum].is_empty() {
483 tracing::trace!(
484 tick = u64::from(self.context.current_tick),
485 stratum = self.context.current_stratum,
486 "Work found on stratum."
487 );
488 return true;
489 }
490
491 if let Some(next_stratum) = self.context.stratum_stack.pop() {
492 self.context.current_stratum = next_stratum;
493
494 {
496 for sg_id in self.context.schedule_deferred.drain(..) {
497 let sg_data = &self.subgraphs[sg_id];
498 tracing::info!(
499 tick = u64::from(self.context.current_tick),
500 stratum = self.context.current_stratum,
501 sg_id = sg_id.to_string(),
502 sg_name = &*sg_data.name,
503 is_scheduled = sg_data.is_scheduled.get(),
504 "Rescheduling deferred subgraph."
505 );
506 if !sg_data.is_scheduled.replace(true) {
507 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
508 }
509 }
510 }
511 } else {
512 self.context.current_stratum += 1;
514
515 if self.context.current_stratum >= self.context.stratum_queues.len() {
516 new_tick_started = true;
517
518 tracing::trace!(
519 can_start_tick = self.context.can_start_tick,
520 "End of tick {}, starting tick {}.",
521 self.context.current_tick,
522 self.context.current_tick + TickDuration::SINGLE_TICK,
523 );
524 self.context.run_state_hooks_tick();
525
526 self.context.current_stratum = 0;
527 self.context.current_tick += TickDuration::SINGLE_TICK;
528 self.context.events_received_tick = false;
529
530 if current_tick_only {
531 tracing::trace!(
532 "`current_tick_only` is `true`, returning `false` before receiving events."
533 );
534 return false;
535 } else {
536 self.try_recv_events();
537 if std::mem::replace(&mut self.context.can_start_tick, false) {
538 tracing::trace!(
539 tick = u64::from(self.context.current_tick),
540 "`can_start_tick` is `true`, continuing."
541 );
542 end_stratum = 0;
544 continue;
545 } else {
546 tracing::trace!(
547 "`can_start_tick` is `false`, re-setting `events_received_tick = false`, returning `false`."
548 );
549 self.context.events_received_tick = false;
550 return false;
551 }
552 }
553 }
554 }
555
556 if new_tick_started && end_stratum == self.context.current_stratum {
558 tracing::trace!(
559 "Made full loop around stratum, re-setting `current_stratum = 0`, returning `false`."
560 );
561 self.context.events_received_tick = false;
566 self.context.current_stratum = 0;
567 return false;
568 }
569 }
570 }
571
572 #[tracing::instrument(level = "trace", skip(self), ret)]
576 pub async fn run(&mut self) -> Option<Never> {
577 loop {
578 self.run_available().await;
580 self.recv_events_async().await;
582 }
583 }
584
585 #[tracing::instrument(level = "trace", skip(self), ret)]
587 pub fn run_sync(&mut self) -> Option<Never> {
588 loop {
589 self.run_available_sync();
591 self.recv_events();
593 }
594 }
595
596 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
600 pub fn try_recv_events(&mut self) -> usize {
601 let mut enqueued_count = 0;
602 while let Ok((sg_id, is_external)) = self.context.event_queue_recv.try_recv() {
603 let sg_data = &self.subgraphs[sg_id];
604 tracing::trace!(
605 sg_id = sg_id.to_string(),
606 is_external = is_external,
607 sg_stratum = sg_data.stratum,
608 "Event received."
609 );
610 if !sg_data.is_scheduled.replace(true) {
611 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
612 enqueued_count += 1;
613 }
614 if is_external {
615 if !self.context.events_received_tick
618 || sg_data.stratum < self.context.current_stratum
619 {
620 tracing::trace!(
621 current_stratum = self.context.current_stratum,
622 sg_stratum = sg_data.stratum,
623 "External event, setting `can_start_tick = true`."
624 );
625 self.context.can_start_tick = true;
626 }
627 }
628 }
629 self.context.events_received_tick = true;
630
631 enqueued_count
632 }
633
634 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
637 pub fn recv_events(&mut self) -> Option<usize> {
638 let mut count = 0;
639 loop {
640 let (sg_id, is_external) = self.context.event_queue_recv.blocking_recv()?;
641 let sg_data = &self.subgraphs[sg_id];
642 tracing::trace!(
643 sg_id = sg_id.to_string(),
644 is_external = is_external,
645 sg_stratum = sg_data.stratum,
646 "Event received."
647 );
648 if !sg_data.is_scheduled.replace(true) {
649 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
650 count += 1;
651 }
652 if is_external {
653 if !self.context.events_received_tick
656 || sg_data.stratum < self.context.current_stratum
657 {
658 tracing::trace!(
659 current_stratum = self.context.current_stratum,
660 sg_stratum = sg_data.stratum,
661 "External event, setting `can_start_tick = true`."
662 );
663 self.context.can_start_tick = true;
664 }
665 break;
666 }
667 }
668 self.context.events_received_tick = true;
669
670 let extra_count = self.try_recv_events();
672 Some(count + extra_count)
673 }
674
675 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
681 pub async fn recv_events_async(&mut self) -> Option<usize> {
682 let mut count = 0;
683 loop {
684 tracing::trace!("Awaiting events (`event_queue_recv`).");
685 let (sg_id, is_external) = self.context.event_queue_recv.recv().await?;
686 let sg_data = &self.subgraphs[sg_id];
687 tracing::trace!(
688 sg_id = sg_id.to_string(),
689 is_external = is_external,
690 sg_stratum = sg_data.stratum,
691 "Event received."
692 );
693 if !sg_data.is_scheduled.replace(true) {
694 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
695 count += 1;
696 }
697 if is_external {
698 if !self.context.events_received_tick
701 || sg_data.stratum < self.context.current_stratum
702 {
703 tracing::trace!(
704 current_stratum = self.context.current_stratum,
705 sg_stratum = sg_data.stratum,
706 "External event, setting `can_start_tick = true`."
707 );
708 self.context.can_start_tick = true;
709 }
710 break;
711 }
712 }
713 self.context.events_received_tick = true;
714
715 let extra_count = self.try_recv_events();
717 Some(count + extra_count)
718 }
719
720 pub fn schedule_subgraph(&mut self, sg_id: SubgraphId) -> bool {
722 let sg_data = &self.subgraphs[sg_id];
723 let already_scheduled = sg_data.is_scheduled.replace(true);
724 if !already_scheduled {
725 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
726 true
727 } else {
728 false
729 }
730 }
731
732 pub fn add_subgraph<Name, R, W, Func>(
734 &mut self,
735 name: Name,
736 recv_ports: R,
737 send_ports: W,
738 subgraph: Func,
739 ) -> SubgraphId
740 where
741 Name: Into<Cow<'static, str>>,
742 R: 'static + PortList<RECV>,
743 W: 'static + PortList<SEND>,
744 Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
745 {
746 self.add_subgraph_stratified(name, 0, recv_ports, send_ports, false, subgraph)
747 }
748
749 pub fn add_subgraph_stratified<Name, R, W, Func>(
753 &mut self,
754 name: Name,
755 stratum: usize,
756 recv_ports: R,
757 send_ports: W,
758 laziness: bool,
759 subgraph: Func,
760 ) -> SubgraphId
761 where
762 Name: Into<Cow<'static, str>>,
763 R: 'static + PortList<RECV>,
764 W: 'static + PortList<SEND>,
765 Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
766 {
767 self.add_subgraph_full(
768 name, stratum, recv_ports, send_ports, laziness, None, subgraph,
769 )
770 }
771
772 #[expect(clippy::too_many_arguments, reason = "Mainly for internal use.")]
774 pub fn add_subgraph_full<Name, R, W, Func>(
775 &mut self,
776 name: Name,
777 stratum: usize,
778 recv_ports: R,
779 send_ports: W,
780 laziness: bool,
781 loop_id: Option<LoopId>,
782 mut subgraph: Func,
783 ) -> SubgraphId
784 where
785 Name: Into<Cow<'static, str>>,
786 R: 'static + PortList<RECV>,
787 W: 'static + PortList<SEND>,
788 Func: 'a + for<'ctx> AsyncFnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
789 {
790 recv_ports.assert_is_from(&self.handoffs);
792 send_ports.assert_is_from(&self.handoffs);
793
794 let loop_depth = loop_id
795 .and_then(|loop_id| self.context.loop_depth.get(loop_id))
796 .copied()
797 .unwrap_or(0);
798
799 let sg_id = self.subgraphs.insert_with_key(|sg_id| {
800 let (mut subgraph_preds, mut subgraph_succs) = Default::default();
801 recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
802 send_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_succs, sg_id, false);
803
804 let subgraph =
805 async move |context: &mut Context,
806 handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
807 let (recv, send) = unsafe {
808 (
812 recv_ports.make_ctx(&*handoffs),
813 send_ports.make_ctx(&*handoffs),
814 )
815 };
816 (subgraph)(context, recv, send).await;
817 };
818 SubgraphData::new(
819 name.into(),
820 stratum,
821 subgraph,
822 subgraph_preds,
823 subgraph_succs,
824 true,
825 laziness,
826 loop_id,
827 loop_depth,
828 )
829 });
830 self.context.init_stratum(stratum);
831 self.context.stratum_queues[stratum].push_back(sg_id);
832
833 sg_id
834 }
835
836 pub fn add_subgraph_n_m<Name, R, W, Func>(
838 &mut self,
839 name: Name,
840 recv_ports: Vec<RecvPort<R>>,
841 send_ports: Vec<SendPort<W>>,
842 subgraph: Func,
843 ) -> SubgraphId
844 where
845 Name: Into<Cow<'static, str>>,
846 R: 'static + Handoff,
847 W: 'static + Handoff,
848 Func: 'a
849 + for<'ctx> AsyncFnMut(
850 &'ctx mut Context,
851 &'ctx [&'ctx RecvCtx<R>],
852 &'ctx [&'ctx SendCtx<W>],
853 ),
854 {
855 self.add_subgraph_stratified_n_m(name, 0, recv_ports, send_ports, subgraph)
856 }
857
858 pub fn add_subgraph_stratified_n_m<Name, R, W, Func>(
860 &mut self,
861 name: Name,
862 stratum: usize,
863 recv_ports: Vec<RecvPort<R>>,
864 send_ports: Vec<SendPort<W>>,
865 mut subgraph: Func,
866 ) -> SubgraphId
867 where
868 Name: Into<Cow<'static, str>>,
869 R: 'static + Handoff,
870 W: 'static + Handoff,
871 Func: 'a
872 + for<'ctx> AsyncFnMut(
873 &'ctx mut Context,
874 &'ctx [&'ctx RecvCtx<R>],
875 &'ctx [&'ctx SendCtx<W>],
876 ),
877 {
878 let sg_id = self.subgraphs.insert_with_key(|sg_id| {
879 let subgraph_preds = recv_ports.iter().map(|port| port.handoff_id).collect();
880 let subgraph_succs = send_ports.iter().map(|port| port.handoff_id).collect();
881
882 for recv_port in recv_ports.iter() {
883 self.handoffs[recv_port.handoff_id].succs.push(sg_id);
884 }
885 for send_port in send_ports.iter() {
886 self.handoffs[send_port.handoff_id].preds.push(sg_id);
887 }
888
889 let subgraph =
890 async move |context: &mut Context,
891 handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
892 let recvs: Vec<&RecvCtx<R>> = recv_ports
893 .iter()
894 .map(|hid| hid.handoff_id)
895 .map(|hid| handoffs.get(hid).unwrap())
896 .map(|h_data| {
897 <dyn Any>::downcast_ref(&*h_data.handoff)
898 .expect("Attempted to cast handoff to wrong type.")
899 })
900 .map(RefCast::ref_cast)
901 .collect();
902
903 let sends: Vec<&SendCtx<W>> = send_ports
904 .iter()
905 .map(|hid| hid.handoff_id)
906 .map(|hid| handoffs.get(hid).unwrap())
907 .map(|h_data| {
908 <dyn Any>::downcast_ref(&*h_data.handoff)
909 .expect("Attempted to cast handoff to wrong type.")
910 })
911 .map(RefCast::ref_cast)
912 .collect();
913
914 (subgraph)(context, &recvs, &sends).await;
915 };
916 SubgraphData::new(
917 name.into(),
918 stratum,
919 subgraph,
920 subgraph_preds,
921 subgraph_succs,
922 true,
923 false,
924 None,
925 0,
926 )
927 });
928
929 self.context.init_stratum(stratum);
930 self.context.stratum_queues[stratum].push_back(sg_id);
931
932 sg_id
933 }
934
935 pub fn make_edge<Name, H>(&mut self, name: Name) -> (SendPort<H>, RecvPort<H>)
937 where
938 Name: Into<Cow<'static, str>>,
939 H: 'static + Handoff,
940 {
941 let handoff = H::default();
943 let handoff_id = self
944 .handoffs
945 .insert_with_key(|hoff_id| HandoffData::new(name.into(), handoff, hoff_id));
946
947 let input_port = SendPort {
949 handoff_id,
950 _marker: PhantomData,
951 };
952 let output_port = RecvPort {
953 handoff_id,
954 _marker: PhantomData,
955 };
956 (input_port, output_port)
957 }
958
959 pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
964 where
965 T: Any,
966 {
967 self.context.add_state(state)
968 }
969
970 pub fn set_state_lifespan_hook<T>(
974 &mut self,
975 handle: StateHandle<T>,
976 lifespan: StateLifespan,
977 hook_fn: impl 'static + FnMut(&mut T),
978 ) where
979 T: Any,
980 {
981 self.context
982 .set_state_lifespan_hook(handle, lifespan, hook_fn)
983 }
984
985 pub fn context_mut(&mut self, sg_id: SubgraphId) -> &mut Context {
987 self.context.subgraph_id = sg_id;
988 &mut self.context
989 }
990
991 pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
996 let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
997 let loop_id = self.context.loop_depth.insert(depth);
998 self.loop_data.insert(
999 loop_id,
1000 LoopData {
1001 iter_count: None,
1002 allow_another_iteration: true,
1003 },
1004 );
1005 loop_id
1006 }
1007}
1008
1009impl Dfir<'_> {
1010 pub fn request_task<Fut>(&mut self, future: Fut)
1012 where
1013 Fut: Future<Output = ()> + 'static,
1014 {
1015 self.context.request_task(future);
1016 }
1017
1018 pub fn abort_tasks(&mut self) {
1020 self.context.abort_tasks()
1021 }
1022
1023 pub fn join_tasks(&mut self) -> impl use<'_> + Future {
1025 self.context.join_tasks()
1026 }
1027}
1028
1029fn run_sync<Fut>(fut: Fut) -> Fut::Output
1030where
1031 Fut: Future,
1032{
1033 let mut fut = std::pin::pin!(fut);
1034 let mut ctx = std::task::Context::from_waker(std::task::Waker::noop());
1035 match fut.as_mut().poll(&mut ctx) {
1036 std::task::Poll::Ready(out) => out,
1037 std::task::Poll::Pending => panic!("Future did not resolve immediately."),
1038 }
1039}
1040
1041impl Drop for Dfir<'_> {
1042 fn drop(&mut self) {
1043 self.abort_tasks();
1044 }
1045}
1046
1047#[doc(hidden)]
1053pub struct HandoffData {
1054 pub(super) name: Cow<'static, str>,
1056 pub(super) handoff: Box<dyn HandoffMeta>,
1058 pub(super) preds: SmallVec<[SubgraphId; 1]>,
1060 pub(super) succs: SmallVec<[SubgraphId; 1]>,
1062
1063 pub(super) pred_handoffs: Vec<HandoffId>,
1069 pub(super) succ_handoffs: Vec<HandoffId>,
1075}
1076impl std::fmt::Debug for HandoffData {
1077 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1078 f.debug_struct("HandoffData")
1079 .field("preds", &self.preds)
1080 .field("succs", &self.succs)
1081 .finish_non_exhaustive()
1082 }
1083}
1084impl HandoffData {
1085 pub fn new(
1087 name: Cow<'static, str>,
1088 handoff: impl 'static + HandoffMeta,
1089 hoff_id: HandoffId,
1090 ) -> Self {
1091 let (preds, succs) = Default::default();
1092 Self {
1093 name,
1094 handoff: Box::new(handoff),
1095 preds,
1096 succs,
1097 pred_handoffs: vec![hoff_id],
1098 succ_handoffs: vec![hoff_id],
1099 }
1100 }
1101}
1102
1103pub(super) struct SubgraphData<'a> {
1108 pub(super) name: Cow<'static, str>,
1110 pub(super) stratum: usize,
1114 subgraph: Box<dyn 'a + Subgraph>,
1116
1117 #[expect(dead_code, reason = "may be useful in the future")]
1118 preds: Vec<HandoffId>,
1119 succs: Vec<HandoffId>,
1120
1121 is_scheduled: Cell<bool>,
1126
1127 last_tick_run_in: Option<TickInstant>,
1129 last_loop_nonce: (usize, Option<usize>),
1132
1133 is_lazy: bool,
1135
1136 loop_id: Option<LoopId>,
1138 loop_depth: usize,
1140}
1141impl<'a> SubgraphData<'a> {
1142 #[expect(clippy::too_many_arguments, reason = "internal use")]
1143 pub(crate) fn new(
1144 name: Cow<'static, str>,
1145 stratum: usize,
1146 subgraph: impl 'a + Subgraph,
1147 preds: Vec<HandoffId>,
1148 succs: Vec<HandoffId>,
1149 is_scheduled: bool,
1150 is_lazy: bool,
1151 loop_id: Option<LoopId>,
1152 loop_depth: usize,
1153 ) -> Self {
1154 Self {
1155 name,
1156 stratum,
1157 subgraph: Box::new(subgraph),
1158 preds,
1159 succs,
1160 is_scheduled: Cell::new(is_scheduled),
1161 last_tick_run_in: None,
1162 last_loop_nonce: (0, None),
1163 is_lazy,
1164 loop_id,
1165 loop_depth,
1166 }
1167 }
1168}
1169
1170pub(crate) struct LoopData {
1171 iter_count: Option<usize>,
1173 allow_another_iteration: bool,
1175}
1176
1177#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1179pub enum StateLifespan {
1180 Subgraph(SubgraphId),
1182 Loop(LoopId),
1184 Tick,
1186 Static,
1188}