1use std::any::Any;
4use std::borrow::Cow;
5use std::cell::Cell;
6use std::cmp::Ordering;
7use std::future::Future;
8use std::marker::PhantomData;
9
10#[cfg(feature = "meta")]
11use dfir_lang::diagnostic::{Diagnostic, SerdeSpan};
12#[cfg(feature = "meta")]
13use dfir_lang::graph::DfirGraph;
14use ref_cast::RefCast;
15use smallvec::SmallVec;
16use web_time::SystemTime;
17
18use super::context::Context;
19use super::handoff::handoff_list::PortList;
20use super::handoff::{Handoff, HandoffMeta, TeeingHandoff};
21use super::port::{RECV, RecvCtx, RecvPort, SEND, SendCtx, SendPort};
22use super::reactor::Reactor;
23use super::state::StateHandle;
24use super::subgraph::Subgraph;
25use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
26use crate::Never;
27use crate::scheduled::ticks::{TickDuration, TickInstant};
28use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
29
30#[derive(Default)]
32pub struct Dfir<'a> {
33 pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
34
35 pub(super) loop_data: SecondarySlotVec<LoopTag, LoopData>,
36
37 pub(super) context: Context,
38
39 handoffs: SlotVec<HandoffTag, HandoffData>,
40
41 #[cfg(feature = "meta")]
42 meta_graph: Option<DfirGraph>,
44
45 #[cfg(feature = "meta")]
46 diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
48}
49
50impl Dfir<'_> {
52 pub fn teeing_handoff_tee<T>(
54 &mut self,
55 tee_parent_port: &RecvPort<TeeingHandoff<T>>,
56 ) -> RecvPort<TeeingHandoff<T>>
57 where
58 T: Clone,
59 {
60 let tee_root = self.handoffs[tee_parent_port.handoff_id].pred_handoffs[0];
62
63 let tee_root_data = &mut self.handoffs[tee_root];
65 let tee_root_data_name = tee_root_data.name.clone();
66
67 let teeing_handoff =
69 <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*tee_root_data.handoff).unwrap();
70 let new_handoff = teeing_handoff.tee();
71
72 let new_hoff_id = self.handoffs.insert_with_key(|new_hoff_id| {
74 let new_name = Cow::Owned(format!("{} tee {:?}", tee_root_data_name, new_hoff_id));
75 let mut new_handoff_data = HandoffData::new(new_name, new_handoff, new_hoff_id);
76 new_handoff_data.pred_handoffs = vec![tee_root];
78 new_handoff_data
79 });
80
81 let tee_root_data = &mut self.handoffs[tee_root];
83 tee_root_data.succ_handoffs.push(new_hoff_id);
84
85 assert!(
88 tee_root_data.preds.len() <= 1,
89 "Tee send side should only have one sender (or none set yet)."
90 );
91 if let Some(&pred_sg_id) = tee_root_data.preds.first() {
92 self.subgraphs[pred_sg_id].succs.push(new_hoff_id);
93 }
94
95 let output_port = RecvPort {
96 handoff_id: new_hoff_id,
97 _marker: PhantomData,
98 };
99 output_port
100 }
101
102 pub fn teeing_handoff_drop<T>(&mut self, tee_port: RecvPort<TeeingHandoff<T>>)
107 where
108 T: Clone,
109 {
110 let data = &self.handoffs[tee_port.handoff_id];
111 let teeing_handoff = <dyn Any>::downcast_ref::<TeeingHandoff<T>>(&*data.handoff).unwrap();
112 teeing_handoff.drop();
113
114 let tee_root = data.pred_handoffs[0];
115 let tee_root_data = &mut self.handoffs[tee_root];
116 tee_root_data
118 .succ_handoffs
119 .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
120 assert!(
122 tee_root_data.preds.len() <= 1,
123 "Tee send side should only have one sender (or none set yet)."
124 );
125 if let Some(&pred_sg_id) = tee_root_data.preds.first() {
126 self.subgraphs[pred_sg_id]
127 .succs
128 .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
129 }
130 }
131}
132
133impl<'a> Dfir<'a> {
134 pub fn new() -> Self {
136 Default::default()
137 }
138
139 #[doc(hidden)]
141 pub fn __assign_meta_graph(&mut self, _meta_graph_json: &str) {
142 #[cfg(feature = "meta")]
143 {
144 let mut meta_graph: DfirGraph =
145 serde_json::from_str(_meta_graph_json).expect("Failed to deserialize graph.");
146
147 let mut op_inst_diagnostics = Vec::new();
148 meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
149 assert!(
150 op_inst_diagnostics.is_empty(),
151 "Expected no diagnostics, got: {:#?}",
152 op_inst_diagnostics
153 );
154
155 assert!(self.meta_graph.replace(meta_graph).is_none());
156 }
157 }
158 #[doc(hidden)]
160 pub fn __assign_diagnostics(&mut self, _diagnostics_json: &'static str) {
161 #[cfg(feature = "meta")]
162 {
163 let diagnostics: Vec<Diagnostic<SerdeSpan>> = serde_json::from_str(_diagnostics_json)
164 .expect("Failed to deserialize diagnostics.");
165
166 assert!(self.diagnostics.replace(diagnostics).is_none());
167 }
168 }
169
170 #[cfg(feature = "meta")]
174 #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
175 pub fn meta_graph(&self) -> Option<&DfirGraph> {
176 self.meta_graph.as_ref()
177 }
178
179 #[cfg(feature = "meta")]
184 #[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
185 pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
186 self.diagnostics.as_deref()
187 }
188
189 pub fn reactor(&self) -> Reactor {
192 Reactor::new(self.context.event_queue_send.clone())
193 }
194
195 pub fn current_tick(&self) -> TickInstant {
197 self.context.current_tick
198 }
199
200 pub fn current_stratum(&self) -> usize {
202 self.context.current_stratum
203 }
204
205 #[tracing::instrument(level = "trace", skip(self), ret)]
208 pub fn run_tick(&mut self) -> bool {
209 let mut work_done = false;
210 while self.next_stratum(true) {
212 work_done = true;
213 self.run_stratum();
215 }
216 work_done
217 }
218
219 #[tracing::instrument(level = "trace", skip(self), ret)]
224 pub fn run_available(&mut self) -> bool {
225 let mut work_done = false;
226 while self.next_stratum(false) {
228 work_done = true;
229 self.run_stratum();
231 }
232 work_done
233 }
234
235 #[tracing::instrument(level = "trace", skip(self), ret)]
241 pub async fn run_available_async(&mut self) -> bool {
242 let mut work_done = false;
243 while self.next_stratum(false) {
245 work_done = true;
246 self.run_stratum();
248
249 tokio::task::yield_now().await;
252 }
253 work_done
254 }
255
256 #[tracing::instrument(level = "trace", skip(self), fields(tick = u64::from(self.context.current_tick), stratum = self.context.current_stratum), ret)]
259 pub fn run_stratum(&mut self) -> bool {
260 self.context.spawn_tasks();
263
264 let mut work_done = false;
265
266 'pop: while let Some(sg_id) =
267 self.context.stratum_queues[self.context.current_stratum].pop_front()
268 {
269 {
270 let sg_data = &mut self.subgraphs[sg_id];
271 assert!(sg_data.is_scheduled.take());
273
274 let _enter = tracing::info_span!(
275 "run-subgraph",
276 sg_id = sg_id.to_string(),
277 sg_name = &*sg_data.name,
278 sg_depth = sg_data.loop_depth,
279 sg_loop_nonce = sg_data.last_loop_nonce.0,
280 sg_iter_count = sg_data.last_loop_nonce.1,
281 )
282 .entered();
283
284 match sg_data.loop_depth.cmp(&self.context.loop_nonce_stack.len()) {
285 Ordering::Greater => {
286 self.context.loop_nonce += 1;
288 self.context.loop_nonce_stack.push(self.context.loop_nonce);
289 tracing::trace!(loop_nonce = self.context.loop_nonce, "Entered loop.");
290 }
291 Ordering::Less => {
292 self.context.loop_nonce_stack.pop();
294 tracing::trace!("Exited loop.");
295 }
296 Ordering::Equal => {}
297 }
298
299 self.context.subgraph_id = sg_id;
300 self.context.is_first_run_this_tick = sg_data
301 .last_tick_run_in
302 .is_none_or(|last_tick| last_tick < self.context.current_tick);
303
304 if let Some(loop_id) = sg_data.loop_id {
305 let curr_loop_nonce = self.context.loop_nonce_stack.last().copied();
313
314 let LoopData {
315 iter_count: loop_iter_count,
316 allow_another_iteration,
317 } = &mut self.loop_data[loop_id];
318
319 let (prev_loop_nonce, prev_iter_count) = sg_data.last_loop_nonce;
320
321 let (curr_iter_count, new_loop_execution) =
326 if curr_loop_nonce.is_none_or(|nonce| nonce == prev_loop_nonce) {
327 if *loop_iter_count == prev_iter_count {
330 if !std::mem::take(allow_another_iteration) {
332 tracing::debug!(
333 "Loop will not continue to next iteration, skipping."
334 );
335 continue 'pop;
336 }
337 loop_iter_count.map_or((0, true), |n| (n + 1, false))
339 } else {
340 debug_assert!(
342 prev_iter_count < *loop_iter_count,
343 "Expect loop iteration count to be increasing."
344 );
345 (loop_iter_count.unwrap(), false)
346 }
347 } else {
348 (0, false)
350 };
351
352 if new_loop_execution {
353 self.context.run_state_hooks_loop(loop_id);
355 }
356 tracing::debug!("Loop iteration count {}", curr_iter_count);
357
358 *loop_iter_count = Some(curr_iter_count);
359 self.context.loop_iter_count = curr_iter_count;
360 sg_data.last_loop_nonce =
361 (curr_loop_nonce.unwrap_or_default(), Some(curr_iter_count));
362 }
363
364 self.context.run_state_hooks_subgraph(sg_id);
366
367 tracing::info!("Running subgraph.");
368 sg_data.subgraph.run(&mut self.context, &mut self.handoffs);
369
370 sg_data.last_tick_run_in = Some(self.context.current_tick);
371 }
372
373 let sg_data = &self.subgraphs[sg_id];
374 for &handoff_id in sg_data.succs.iter() {
375 let handoff = &self.handoffs[handoff_id];
376 if !handoff.handoff.is_bottom() {
377 for &succ_id in handoff.succs.iter() {
378 let succ_sg_data = &self.subgraphs[succ_id];
379 if succ_sg_data.stratum < self.context.current_stratum && !sg_data.is_lazy {
381 self.context.can_start_tick = true;
382 }
383 if !succ_sg_data.is_scheduled.replace(true) {
385 self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
386 }
387 if 0 < succ_sg_data.loop_depth {
389 self.context
391 .stratum_stack
392 .push(succ_sg_data.loop_depth, succ_sg_data.stratum);
393 }
394 }
395 }
396 }
397
398 let reschedule = self.context.reschedule_loop_block.take();
399 let allow_another = self.context.allow_another_iteration.take();
400
401 if reschedule {
402 self.context.schedule_deferred.push(sg_id);
404 self.context
405 .stratum_stack
406 .push(sg_data.loop_depth, sg_data.stratum);
407 }
408 if (reschedule || allow_another)
409 && let Some(loop_id) = sg_data.loop_id
410 {
411 self.loop_data
412 .get_mut(loop_id)
413 .unwrap()
414 .allow_another_iteration = true;
415 }
416
417 work_done = true;
418 }
419 work_done
420 }
421
422 #[tracing::instrument(level = "trace", skip(self), ret)]
434 pub fn next_stratum(&mut self, current_tick_only: bool) -> bool {
435 tracing::trace!(
436 events_received_tick = self.context.events_received_tick,
437 can_start_tick = self.context.can_start_tick,
438 "Starting `next_stratum` call.",
439 );
440
441 let mut end_stratum = self.context.current_stratum;
443 let mut new_tick_started = false;
444
445 if 0 == self.context.current_stratum {
446 new_tick_started = true;
447
448 tracing::trace!("Starting tick, setting `can_start_tick = false`.");
450 self.context.can_start_tick = false;
451 self.context.current_tick_start = SystemTime::now();
452
453 if !self.context.events_received_tick {
455 self.try_recv_events();
457 }
458 }
459
460 loop {
461 tracing::trace!(
462 tick = u64::from(self.context.current_tick),
463 stratum = self.context.current_stratum,
464 "Looking for work on stratum."
465 );
466 if !self.context.stratum_queues[self.context.current_stratum].is_empty() {
468 tracing::trace!(
469 tick = u64::from(self.context.current_tick),
470 stratum = self.context.current_stratum,
471 "Work found on stratum."
472 );
473 return true;
474 }
475
476 if let Some(next_stratum) = self.context.stratum_stack.pop() {
477 self.context.current_stratum = next_stratum;
478
479 {
481 for sg_id in self.context.schedule_deferred.drain(..) {
482 let sg_data = &self.subgraphs[sg_id];
483 tracing::info!(
484 tick = u64::from(self.context.current_tick),
485 stratum = self.context.current_stratum,
486 sg_id = sg_id.to_string(),
487 sg_name = &*sg_data.name,
488 is_scheduled = sg_data.is_scheduled.get(),
489 "Rescheduling deferred subgraph."
490 );
491 if !sg_data.is_scheduled.replace(true) {
492 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
493 }
494 }
495 }
496 } else {
497 self.context.current_stratum += 1;
499
500 if self.context.current_stratum >= self.context.stratum_queues.len() {
501 new_tick_started = true;
502
503 tracing::trace!(
504 can_start_tick = self.context.can_start_tick,
505 "End of tick {}, starting tick {}.",
506 self.context.current_tick,
507 self.context.current_tick + TickDuration::SINGLE_TICK,
508 );
509 self.context.run_state_hooks_tick();
510
511 self.context.current_stratum = 0;
512 self.context.current_tick += TickDuration::SINGLE_TICK;
513 self.context.events_received_tick = false;
514
515 if current_tick_only {
516 tracing::trace!(
517 "`current_tick_only` is `true`, returning `false` before receiving events."
518 );
519 return false;
520 } else {
521 self.try_recv_events();
522 if std::mem::replace(&mut self.context.can_start_tick, false) {
523 tracing::trace!(
524 tick = u64::from(self.context.current_tick),
525 "`can_start_tick` is `true`, continuing."
526 );
527 end_stratum = 0;
529 continue;
530 } else {
531 tracing::trace!(
532 "`can_start_tick` is `false`, re-setting `events_received_tick = false`, returning `false`."
533 );
534 self.context.events_received_tick = false;
535 return false;
536 }
537 }
538 }
539 }
540
541 if new_tick_started && end_stratum == self.context.current_stratum {
543 tracing::trace!(
544 "Made full loop around stratum, re-setting `current_stratum = 0`, returning `false`."
545 );
546 self.context.events_received_tick = false;
551 self.context.current_stratum = 0;
552 return false;
553 }
554 }
555 }
556
557 #[tracing::instrument(level = "trace", skip(self), ret)]
561 pub fn run(&mut self) -> Option<Never> {
562 loop {
563 self.run_tick();
564 }
565 }
566
567 #[tracing::instrument(level = "trace", skip(self), ret)]
571 pub async fn run_async(&mut self) -> Option<Never> {
572 loop {
573 self.run_available_async().await;
575 self.recv_events_async().await;
577 }
578 }
579
580 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
584 pub fn try_recv_events(&mut self) -> usize {
585 let mut enqueued_count = 0;
586 while let Ok((sg_id, is_external)) = self.context.event_queue_recv.try_recv() {
587 let sg_data = &self.subgraphs[sg_id];
588 tracing::trace!(
589 sg_id = sg_id.to_string(),
590 is_external = is_external,
591 sg_stratum = sg_data.stratum,
592 "Event received."
593 );
594 if !sg_data.is_scheduled.replace(true) {
595 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
596 enqueued_count += 1;
597 }
598 if is_external {
599 if !self.context.events_received_tick
602 || sg_data.stratum < self.context.current_stratum
603 {
604 tracing::trace!(
605 current_stratum = self.context.current_stratum,
606 sg_stratum = sg_data.stratum,
607 "External event, setting `can_start_tick = true`."
608 );
609 self.context.can_start_tick = true;
610 }
611 }
612 }
613 self.context.events_received_tick = true;
614
615 enqueued_count
616 }
617
618 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
621 pub fn recv_events(&mut self) -> Option<usize> {
622 let mut count = 0;
623 loop {
624 let (sg_id, is_external) = self.context.event_queue_recv.blocking_recv()?;
625 let sg_data = &self.subgraphs[sg_id];
626 tracing::trace!(
627 sg_id = sg_id.to_string(),
628 is_external = is_external,
629 sg_stratum = sg_data.stratum,
630 "Event received."
631 );
632 if !sg_data.is_scheduled.replace(true) {
633 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
634 count += 1;
635 }
636 if is_external {
637 if !self.context.events_received_tick
640 || sg_data.stratum < self.context.current_stratum
641 {
642 tracing::trace!(
643 current_stratum = self.context.current_stratum,
644 sg_stratum = sg_data.stratum,
645 "External event, setting `can_start_tick = true`."
646 );
647 self.context.can_start_tick = true;
648 }
649 break;
650 }
651 }
652 self.context.events_received_tick = true;
653
654 let extra_count = self.try_recv_events();
656 Some(count + extra_count)
657 }
658
659 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
665 pub async fn recv_events_async(&mut self) -> Option<usize> {
666 let mut count = 0;
667 loop {
668 tracing::trace!("Awaiting events (`event_queue_recv`).");
669 let (sg_id, is_external) = self.context.event_queue_recv.recv().await?;
670 let sg_data = &self.subgraphs[sg_id];
671 tracing::trace!(
672 sg_id = sg_id.to_string(),
673 is_external = is_external,
674 sg_stratum = sg_data.stratum,
675 "Event received."
676 );
677 if !sg_data.is_scheduled.replace(true) {
678 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
679 count += 1;
680 }
681 if is_external {
682 if !self.context.events_received_tick
685 || sg_data.stratum < self.context.current_stratum
686 {
687 tracing::trace!(
688 current_stratum = self.context.current_stratum,
689 sg_stratum = sg_data.stratum,
690 "External event, setting `can_start_tick = true`."
691 );
692 self.context.can_start_tick = true;
693 }
694 break;
695 }
696 }
697 self.context.events_received_tick = true;
698
699 let extra_count = self.try_recv_events();
701 Some(count + extra_count)
702 }
703
704 pub fn schedule_subgraph(&mut self, sg_id: SubgraphId) -> bool {
706 let sg_data = &self.subgraphs[sg_id];
707 let already_scheduled = sg_data.is_scheduled.replace(true);
708 if !already_scheduled {
709 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
710 true
711 } else {
712 false
713 }
714 }
715
716 pub fn add_subgraph<Name, R, W, F>(
718 &mut self,
719 name: Name,
720 recv_ports: R,
721 send_ports: W,
722 subgraph: F,
723 ) -> SubgraphId
724 where
725 Name: Into<Cow<'static, str>>,
726 R: 'static + PortList<RECV>,
727 W: 'static + PortList<SEND>,
728 F: 'static + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
729 {
730 self.add_subgraph_stratified(name, 0, recv_ports, send_ports, false, subgraph)
731 }
732
733 pub fn add_subgraph_stratified<Name, R, W, F>(
737 &mut self,
738 name: Name,
739 stratum: usize,
740 recv_ports: R,
741 send_ports: W,
742 laziness: bool,
743 subgraph: F,
744 ) -> SubgraphId
745 where
746 Name: Into<Cow<'static, str>>,
747 R: 'static + PortList<RECV>,
748 W: 'static + PortList<SEND>,
749 F: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
750 {
751 self.add_subgraph_full(
752 name, stratum, recv_ports, send_ports, laziness, None, subgraph,
753 )
754 }
755
756 #[expect(clippy::too_many_arguments, reason = "Mainly for internal use.")]
758 pub fn add_subgraph_full<Name, R, W, F>(
759 &mut self,
760 name: Name,
761 stratum: usize,
762 recv_ports: R,
763 send_ports: W,
764 laziness: bool,
765 loop_id: Option<LoopId>,
766 mut subgraph: F,
767 ) -> SubgraphId
768 where
769 Name: Into<Cow<'static, str>>,
770 R: 'static + PortList<RECV>,
771 W: 'static + PortList<SEND>,
772 F: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
773 {
774 recv_ports.assert_is_from(&self.handoffs);
776 send_ports.assert_is_from(&self.handoffs);
777
778 let loop_depth = loop_id
779 .and_then(|loop_id| self.context.loop_depth.get(loop_id))
780 .copied()
781 .unwrap_or(0);
782
783 let sg_id = self.subgraphs.insert_with_key(|sg_id| {
784 let (mut subgraph_preds, mut subgraph_succs) = Default::default();
785 recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
786 send_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_succs, sg_id, false);
787
788 let subgraph =
789 move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
790 let (recv, send) = unsafe {
791 (
795 recv_ports.make_ctx(&*handoffs),
796 send_ports.make_ctx(&*handoffs),
797 )
798 };
799 (subgraph)(context, recv, send);
800 };
801 SubgraphData::new(
802 name.into(),
803 stratum,
804 subgraph,
805 subgraph_preds,
806 subgraph_succs,
807 true,
808 laziness,
809 loop_id,
810 loop_depth,
811 )
812 });
813 self.context.init_stratum(stratum);
814 self.context.stratum_queues[stratum].push_back(sg_id);
815
816 sg_id
817 }
818
819 pub fn add_subgraph_n_m<Name, R, W, F>(
821 &mut self,
822 name: Name,
823 recv_ports: Vec<RecvPort<R>>,
824 send_ports: Vec<SendPort<W>>,
825 subgraph: F,
826 ) -> SubgraphId
827 where
828 Name: Into<Cow<'static, str>>,
829 R: 'static + Handoff,
830 W: 'static + Handoff,
831 F: 'static
832 + for<'ctx> FnMut(&'ctx mut Context, &'ctx [&'ctx RecvCtx<R>], &'ctx [&'ctx SendCtx<W>]),
833 {
834 self.add_subgraph_stratified_n_m(name, 0, recv_ports, send_ports, subgraph)
835 }
836
837 pub fn add_subgraph_stratified_n_m<Name, R, W, F>(
839 &mut self,
840 name: Name,
841 stratum: usize,
842 recv_ports: Vec<RecvPort<R>>,
843 send_ports: Vec<SendPort<W>>,
844 mut subgraph: F,
845 ) -> SubgraphId
846 where
847 Name: Into<Cow<'static, str>>,
848 R: 'static + Handoff,
849 W: 'static + Handoff,
850 F: 'static
851 + for<'ctx> FnMut(&'ctx mut Context, &'ctx [&'ctx RecvCtx<R>], &'ctx [&'ctx SendCtx<W>]),
852 {
853 let sg_id = self.subgraphs.insert_with_key(|sg_id| {
854 let subgraph_preds = recv_ports.iter().map(|port| port.handoff_id).collect();
855 let subgraph_succs = send_ports.iter().map(|port| port.handoff_id).collect();
856
857 for recv_port in recv_ports.iter() {
858 self.handoffs[recv_port.handoff_id].succs.push(sg_id);
859 }
860 for send_port in send_ports.iter() {
861 self.handoffs[send_port.handoff_id].preds.push(sg_id);
862 }
863
864 let subgraph =
865 move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
866 let recvs: Vec<&RecvCtx<R>> = recv_ports
867 .iter()
868 .map(|hid| hid.handoff_id)
869 .map(|hid| handoffs.get(hid).unwrap())
870 .map(|h_data| {
871 <dyn Any>::downcast_ref(&*h_data.handoff)
872 .expect("Attempted to cast handoff to wrong type.")
873 })
874 .map(RefCast::ref_cast)
875 .collect();
876
877 let sends: Vec<&SendCtx<W>> = send_ports
878 .iter()
879 .map(|hid| hid.handoff_id)
880 .map(|hid| handoffs.get(hid).unwrap())
881 .map(|h_data| {
882 <dyn Any>::downcast_ref(&*h_data.handoff)
883 .expect("Attempted to cast handoff to wrong type.")
884 })
885 .map(RefCast::ref_cast)
886 .collect();
887
888 (subgraph)(context, &recvs, &sends)
889 };
890 SubgraphData::new(
891 name.into(),
892 stratum,
893 subgraph,
894 subgraph_preds,
895 subgraph_succs,
896 true,
897 false,
898 None,
899 0,
900 )
901 });
902
903 self.context.init_stratum(stratum);
904 self.context.stratum_queues[stratum].push_back(sg_id);
905
906 sg_id
907 }
908
909 pub fn make_edge<Name, H>(&mut self, name: Name) -> (SendPort<H>, RecvPort<H>)
911 where
912 Name: Into<Cow<'static, str>>,
913 H: 'static + Handoff,
914 {
915 let handoff = H::default();
917 let handoff_id = self
918 .handoffs
919 .insert_with_key(|hoff_id| HandoffData::new(name.into(), handoff, hoff_id));
920
921 let input_port = SendPort {
923 handoff_id,
924 _marker: PhantomData,
925 };
926 let output_port = RecvPort {
927 handoff_id,
928 _marker: PhantomData,
929 };
930 (input_port, output_port)
931 }
932
933 pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
938 where
939 T: Any,
940 {
941 self.context.add_state(state)
942 }
943
944 pub fn set_state_lifespan_hook<T>(
948 &mut self,
949 handle: StateHandle<T>,
950 lifespan: StateLifespan,
951 hook_fn: impl 'static + FnMut(&mut T),
952 ) where
953 T: Any,
954 {
955 self.context
956 .set_state_lifespan_hook(handle, lifespan, hook_fn)
957 }
958
959 pub fn context_mut(&mut self, sg_id: SubgraphId) -> &mut Context {
961 self.context.subgraph_id = sg_id;
962 &mut self.context
963 }
964
965 pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
970 let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
971 let loop_id = self.context.loop_depth.insert(depth);
972 self.loop_data.insert(
973 loop_id,
974 LoopData {
975 iter_count: None,
976 allow_another_iteration: true,
977 },
978 );
979 loop_id
980 }
981}
982
983impl Dfir<'_> {
984 pub fn request_task<Fut>(&mut self, future: Fut)
986 where
987 Fut: Future<Output = ()> + 'static,
988 {
989 self.context.request_task(future);
990 }
991
992 pub fn abort_tasks(&mut self) {
994 self.context.abort_tasks()
995 }
996
997 pub fn join_tasks(&mut self) -> impl use<'_> + Future {
999 self.context.join_tasks()
1000 }
1001}
1002
1003impl Drop for Dfir<'_> {
1004 fn drop(&mut self) {
1005 self.abort_tasks();
1006 }
1007}
1008
1009#[doc(hidden)]
1015pub struct HandoffData {
1016 pub(super) name: Cow<'static, str>,
1018 pub(super) handoff: Box<dyn HandoffMeta>,
1020 pub(super) preds: SmallVec<[SubgraphId; 1]>,
1022 pub(super) succs: SmallVec<[SubgraphId; 1]>,
1024
1025 pub(super) pred_handoffs: Vec<HandoffId>,
1031 pub(super) succ_handoffs: Vec<HandoffId>,
1037}
1038impl std::fmt::Debug for HandoffData {
1039 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1040 f.debug_struct("HandoffData")
1041 .field("preds", &self.preds)
1042 .field("succs", &self.succs)
1043 .finish_non_exhaustive()
1044 }
1045}
1046impl HandoffData {
1047 pub fn new(
1049 name: Cow<'static, str>,
1050 handoff: impl 'static + HandoffMeta,
1051 hoff_id: HandoffId,
1052 ) -> Self {
1053 let (preds, succs) = Default::default();
1054 Self {
1055 name,
1056 handoff: Box::new(handoff),
1057 preds,
1058 succs,
1059 pred_handoffs: vec![hoff_id],
1060 succ_handoffs: vec![hoff_id],
1061 }
1062 }
1063}
1064
1065pub(super) struct SubgraphData<'a> {
1070 pub(super) name: Cow<'static, str>,
1072 pub(super) stratum: usize,
1076 subgraph: Box<dyn Subgraph + 'a>,
1078
1079 #[expect(dead_code, reason = "may be useful in the future")]
1080 preds: Vec<HandoffId>,
1081 succs: Vec<HandoffId>,
1082
1083 is_scheduled: Cell<bool>,
1088
1089 last_tick_run_in: Option<TickInstant>,
1091 last_loop_nonce: (usize, Option<usize>),
1094
1095 is_lazy: bool,
1097
1098 loop_id: Option<LoopId>,
1100 loop_depth: usize,
1102}
1103impl<'a> SubgraphData<'a> {
1104 #[expect(clippy::too_many_arguments, reason = "internal use")]
1105 pub(crate) fn new(
1106 name: Cow<'static, str>,
1107 stratum: usize,
1108 subgraph: impl Subgraph + 'a,
1109 preds: Vec<HandoffId>,
1110 succs: Vec<HandoffId>,
1111 is_scheduled: bool,
1112 is_lazy: bool,
1113 loop_id: Option<LoopId>,
1114 loop_depth: usize,
1115 ) -> Self {
1116 Self {
1117 name,
1118 stratum,
1119 subgraph: Box::new(subgraph),
1120 preds,
1121 succs,
1122 is_scheduled: Cell::new(is_scheduled),
1123 last_tick_run_in: None,
1124 last_loop_nonce: (0, None),
1125 is_lazy,
1126 loop_id,
1127 loop_depth,
1128 }
1129 }
1130}
1131
1132pub(crate) struct LoopData {
1133 iter_count: Option<usize>,
1135 allow_another_iteration: bool,
1137}
1138
1139#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1141pub enum StateLifespan {
1142 Subgraph(SubgraphId),
1144 Loop(LoopId),
1146 Tick,
1148 Static,
1150}