dfir_rs/scheduled/
state.rs1use std::any::{Any, TypeId};
4use std::marker::PhantomData;
5
6use super::StateId;
7
8#[must_use]
14#[derive(Debug, PartialEq, Eq, Hash)]
15pub struct StateHandle<T> {
16 pub(crate) state_id: StateId,
17 pub(crate) _phantom: PhantomData<*mut T>,
18}
19impl<T> Copy for StateHandle<T> {}
20impl<T> Clone for StateHandle<T> {
21 fn clone(&self) -> Self {
22 *self
23 }
24}
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
36pub struct StateHandleErased {
37 state_id: StateId,
38 type_id: TypeId,
39}
40
41impl<T> TryFrom<StateHandleErased> for StateHandle<T>
43where
44 T: Any,
45{
46 type Error = StateHandleErased;
47
48 fn try_from(value: StateHandleErased) -> Result<Self, Self::Error> {
49 if TypeId::of::<T>() == value.type_id {
50 Ok(Self {
51 state_id: value.state_id,
52 _phantom: PhantomData,
53 })
54 } else {
55 Err(value)
56 }
57 }
58}
59impl<T> From<StateHandle<T>> for StateHandleErased
61where
62 T: Any,
63{
64 fn from(value: StateHandle<T>) -> Self {
65 Self {
66 state_id: value.state_id,
67 type_id: TypeId::of::<T>(),
68 }
69 }
70}
71
72#[cfg(test)]
73mod test {
74 use super::*;
75
76 #[test]
77 fn test_erasure() {
78 let handle = StateHandle::<String> {
79 state_id: StateId::from_raw(0),
80 _phantom: PhantomData,
81 };
82 let handle_erased = StateHandleErased::from(handle);
83 let handle_good = StateHandle::<String>::try_from(handle_erased);
84 let handle_bad = StateHandle::<&'static str>::try_from(handle_erased);
85
86 assert_eq!(Ok(handle), handle_good);
87 assert_eq!(Err(handle_erased), handle_bad);
88 }
89}