dfir_rs/scheduled/handoff/
handoff_list.rs

1//! Module for variadic handoff port lists, [`PortList`].
2
3use std::any::Any;
4
5use ref_cast::RefCast;
6use sealed::sealed;
7use variadics::{Variadic, variadic_trait};
8
9use super::Handoff;
10use crate::scheduled::graph::HandoffData;
11use crate::scheduled::port::{Polarity, Port, PortCtx};
12use crate::scheduled::{HandoffId, HandoffTag, SubgraphId};
13use crate::util::slot_vec::SlotVec;
14
15/// Sealed trait for variadic lists of ports.
16///
17/// See the [`variadics`] crate for the strategy we use to implement variadics in Rust.
18#[sealed]
19pub trait PortList<S>: Variadic
20where
21    S: Polarity,
22{
23    /// Iteratively/recursively set the graph metadata for each port in this list.
24    ///
25    /// Specifically sets:
26    /// - `HandoffData::preds` and `HandoffData::succs` in the `handoffs` slice for the
27    ///   handoffs in this [`PortList`] (using `pred` and/or `succ`).
28    /// - `out_handoff_ids` will be extended with all the handoff IDs in this [`PortList`].
29    ///
30    /// `handoffs_are_preds`:
31    /// - `true`: Handoffs are predecessors (inputs) to subgraph `sg_id`.
32    /// - `false`: Handoffs are successors (outputs) from subgraph `sg_id`.
33    fn set_graph_meta(
34        &self,
35        handoffs: &mut SlotVec<HandoffTag, HandoffData>,
36        out_handoff_ids: &mut Vec<HandoffId>,
37        sg_id: SubgraphId,
38        handoffs_are_preds: bool,
39    );
40
41    /// The [`Variadic`] return type of [`Self::make_ctx`].
42    type Ctx<'a>: Variadic;
43    /// Iteratively/recursively construct a `Ctx` variadic list.
44    ///
45    /// (Note that unlike [`Self::set_graph_meta`], this does not mess with pred/succ handoffs for
46    /// teeing).
47    ///
48    /// # Safety
49    /// The handoffs in this port list (`self`) must come from the `handoffs` [`SlotVec`].
50    /// This ensure the types will match.
51    ///
52    /// Use [`Self::assert_is_from`] to check this.
53    unsafe fn make_ctx<'a>(&self, handoffs: &'a SlotVec<HandoffTag, HandoffData>) -> Self::Ctx<'a>;
54
55    /// Asserts that `self` is a valid port list from `handoffs`. Panics if not.
56    fn assert_is_from(&self, handoffs: &SlotVec<HandoffTag, HandoffData>);
57}
58#[sealed]
59impl<S, Rest, H> PortList<S> for (Port<S, H>, Rest)
60where
61    S: Polarity,
62    H: Handoff,
63    Rest: PortList<S>,
64{
65    fn set_graph_meta(
66        &self,
67        handoffs: &mut SlotVec<HandoffTag, HandoffData>,
68        out_handoff_ids: &mut Vec<HandoffId>,
69        sg_id: SubgraphId,
70        handoffs_are_preds: bool,
71    ) {
72        let (this, rest) = self;
73        let this_handoff = &mut handoffs[this.handoff_id];
74
75        // Set subgraph's info (`out_handoff_ids`) about neighbor handoffs.
76        // Use the "representative" handoff (pred or succ) for teeing handoffs, for the subgraph metadata.
77        // For regular Vec handoffs, `pred_handoffs` and `succ_handoffs` will just be the handoff itself.
78        out_handoff_ids.extend(if handoffs_are_preds {
79            this_handoff.pred_handoffs.iter().copied()
80        } else {
81            this_handoff.succ_handoffs.iter().copied()
82        });
83
84        // Set handoff's info (`preds`/`succs`) about neighbor subgraph (`sg_id`).
85        if handoffs_are_preds {
86            for succ_hoff in this_handoff.succ_handoffs.clone() {
87                handoffs[succ_hoff].succs.push(sg_id);
88            }
89        } else {
90            for pred_hoff in this_handoff.pred_handoffs.clone() {
91                handoffs[pred_hoff].preds.push(sg_id);
92            }
93        }
94        rest.set_graph_meta(handoffs, out_handoff_ids, sg_id, handoffs_are_preds);
95    }
96
97    type Ctx<'a> = (&'a PortCtx<S, H>, Rest::Ctx<'a>);
98    unsafe fn make_ctx<'a>(&self, handoffs: &'a SlotVec<HandoffTag, HandoffData>) -> Self::Ctx<'a> {
99        let (this, rest) = self;
100        let hoff_any: &dyn Any = &*handoffs.get(this.handoff_id).unwrap().handoff;
101        debug_assert!(hoff_any.is::<H>());
102
103        let handoff = unsafe {
104            // SAFETY: Caller must ensure `self` is from `handoffs`.
105            // TODO(shadaj): replace with `downcast_ref_unchecked` when it's stabilized
106            &*(hoff_any as *const dyn Any as *const H)
107        };
108
109        let ctx = RefCast::ref_cast(handoff);
110        let ctx_rest = unsafe {
111            // SAFETY: Same invariants hold, as we recurse through the list.
112            rest.make_ctx(handoffs)
113        };
114        (ctx, ctx_rest)
115    }
116
117    fn assert_is_from(&self, handoffs: &SlotVec<HandoffTag, HandoffData>) {
118        let (this, rest) = self;
119        let Some(hoff_data) = handoffs.get(this.handoff_id) else {
120            panic!("Handoff ID {} not found in `handoffs`.", this.handoff_id);
121        };
122        let hoff_any: &dyn Any = &*hoff_data.handoff;
123        assert!(
124            hoff_any.is::<H>(),
125            "Handoff ID {} is not of type {} in `handoffs`.",
126            this.handoff_id,
127            std::any::type_name::<H>(),
128        );
129        rest.assert_is_from(handoffs);
130    }
131}
132#[sealed]
133impl<S> PortList<S> for ()
134where
135    S: Polarity,
136{
137    fn set_graph_meta(
138        &self,
139        _handoffs: &mut SlotVec<HandoffTag, HandoffData>,
140        _out_handoff_ids: &mut Vec<HandoffId>,
141        _sg_id: SubgraphId,
142        _handoffs_are_preds: bool,
143    ) {
144    }
145
146    type Ctx<'a> = ();
147    unsafe fn make_ctx<'a>(
148        &self,
149        _handoffs: &'a SlotVec<HandoffTag, HandoffData>,
150    ) -> Self::Ctx<'a> {
151    }
152
153    fn assert_is_from(&self, _handoffs: &SlotVec<HandoffTag, HandoffData>) {}
154}
155
156/// Trait for splitting a list of ports into two.
157#[sealed]
158pub trait PortListSplit<S, A>: PortList<S>
159where
160    S: Polarity,
161    A: PortList<S>,
162{
163    /// The suffix, second half of the split.
164    type Suffix: PortList<S>;
165
166    /// Split the port list, returning the prefix and [`Self::Suffix`] as the two halves.
167    fn split_ctx(ctx: Self::Ctx<'_>) -> (A::Ctx<'_>, <Self::Suffix as PortList<S>>::Ctx<'_>);
168}
169#[sealed]
170impl<S, H, T, U> PortListSplit<S, (Port<S, H>, U)> for (Port<S, H>, T)
171where
172    S: Polarity,
173    H: Handoff,
174    T: PortListSplit<S, U>,
175    U: PortList<S>,
176{
177    type Suffix = T::Suffix;
178
179    fn split_ctx(
180        ctx: Self::Ctx<'_>,
181    ) -> (
182        <(Port<S, H>, U) as PortList<S>>::Ctx<'_>,
183        <Self::Suffix as PortList<S>>::Ctx<'_>,
184    ) {
185        let (x, t) = ctx;
186        let (u, v) = T::split_ctx(t);
187        ((x, u), v)
188    }
189}
190#[sealed]
191impl<S, T> PortListSplit<S, ()> for T
192where
193    S: Polarity,
194    T: PortList<S>,
195{
196    type Suffix = T;
197
198    fn split_ctx(ctx: Self::Ctx<'_>) -> ((), T::Ctx<'_>) {
199        ((), ctx)
200    }
201}
202
203variadic_trait! {
204    /// A variadic list of Handoff types, represented using a lisp-style tuple structure.
205    ///
206    /// This trait is sealed and not meant to be implemented or used directly. Instead tuple lists (which already implement this trait) should be used, for example:
207    /// ```ignore
208    /// type MyHandoffList = (VecHandoff<usize>, (VecHandoff<String>, (TeeingHandoff<u32>, ())));
209    /// ```
210    /// The [`var_expr!`](crate::var) macro simplifies usage of this kind:
211    /// ```ignore
212    /// type MyHandoffList = var_expr!(VecHandoff<usize>, VecHandoff<String>, TeeingHandoff<u32>);
213    /// ```
214    #[sealed]
215    pub variadic<T> HandoffList where T: 'static + Handoff {}
216}