dfir_rs/scheduled/
state.rs

1//! Module for [`StateHandle`], part of the "state API".
2
3use std::any::{Any, TypeId};
4use std::marker::PhantomData;
5
6use super::StateId;
7
8/// A handle into a particular [`Dfir`](super::graph::Dfir) instance, referring to data
9/// inserted by [`add_state`](super::graph::Dfir::add_state).
10///
11/// If you need to store state handles in a data structure see [`StateHandleErased`] which hides
12/// the generic type parameter.
13#[must_use]
14#[derive(Debug, PartialEq, Eq, Hash)]
15pub struct StateHandle<T> {
16    pub(crate) state_id: StateId,
17    pub(crate) _phantom: PhantomData<*mut T>,
18}
19impl<T> Copy for StateHandle<T> {}
20impl<T> Clone for StateHandle<T> {
21    fn clone(&self) -> Self {
22        *self
23    }
24}
25
26/// A state handle with the generic type parameter erased, allowing it to be stored in omogenous
27/// data structures. The type is tracked internally as data via [`TypeId`].
28///
29/// Use [`StateHandleErased::from(state_handle)`](StateHandleErased::from) to create an instance
30/// from a typed [`StateHandle<T>`].
31///
32/// Use [`StateHandle::<T>::try_from()`](StateHandle::try_from) to convert the `StateHandleErased`
33/// back into a `StateHandle<T>` of the given type `T`. If `T` is the wrong type then the original
34/// `StateHandleErased` will be returned as the `Err`.
35#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
36pub struct StateHandleErased {
37    state_id: StateId,
38    type_id: TypeId,
39}
40
41/// See [`StateHandleErased`].
42impl<T> TryFrom<StateHandleErased> for StateHandle<T>
43where
44    T: Any,
45{
46    type Error = StateHandleErased;
47
48    fn try_from(value: StateHandleErased) -> Result<Self, Self::Error> {
49        if TypeId::of::<T>() == value.type_id {
50            Ok(Self {
51                state_id: value.state_id,
52                _phantom: PhantomData,
53            })
54        } else {
55            Err(value)
56        }
57    }
58}
59/// See [`StateHandleErased`].
60impl<T> From<StateHandle<T>> for StateHandleErased
61where
62    T: Any,
63{
64    fn from(value: StateHandle<T>) -> Self {
65        Self {
66            state_id: value.state_id,
67            type_id: TypeId::of::<T>(),
68        }
69    }
70}
71
72#[cfg(test)]
73mod test {
74    use super::*;
75
76    #[test]
77    fn test_erasure() {
78        let handle = StateHandle::<String> {
79            state_id: StateId(0),
80            _phantom: PhantomData,
81        };
82        let handle_erased = StateHandleErased::from(handle);
83        let handle_good = StateHandle::<String>::try_from(handle_erased);
84        let handle_bad = StateHandle::<&'static str>::try_from(handle_erased);
85
86        assert_eq!(Ok(handle), handle_good);
87        assert_eq!(Err(handle_erased), handle_bad);
88    }
89}