dfir_rs/scheduled/
context.rs

1//! Module for the user-facing [`Context`] object.
2//!
3//! Provides APIs for state and scheduling.
4
5use std::any::Any;
6use std::cell::Cell;
7use std::collections::VecDeque;
8use std::future::Future;
9use std::marker::PhantomData;
10use std::ops::DerefMut;
11use std::pin::Pin;
12
13use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
14use tokio::task::JoinHandle;
15use web_time::SystemTime;
16
17use super::state::StateHandle;
18use super::{LoopTag, StateId, SubgraphId};
19use crate::scheduled::ticks::TickInstant;
20use crate::util::priority_stack::PriorityStack;
21use crate::util::slot_vec::SlotVec;
22
23/// The main state and scheduler of the runtime instance. Provided as the `context` API to each
24/// subgraph/operator as it is run.
25///
26/// Each instance stores eactly one Context inline. Before the `Context` is provided to
27/// a running operator, the `subgraph_id` field must be updated.
28pub struct Context {
29    /// User-facing State API.
30    states: Vec<StateData>,
31
32    /// Priority stack for handling strata within loops. Prioritized by loop depth.
33    pub(super) stratum_stack: PriorityStack<usize>,
34
35    /// Stack of loop nonces. Used to identify when a new loop iteration begins.
36    pub(super) loop_nonce_stack: Vec<usize>,
37
38    /// TODO(mingwei):
39    /// used for loop iteration scheduling.
40    pub(super) schedule_deferred: Vec<SubgraphId>,
41
42    /// TODO(mingwei): separate scheduler into its own struct/trait?
43    /// Index is stratum, value is FIFO queue for that stratum.
44    pub(super) stratum_queues: Vec<VecDeque<SubgraphId>>,
45
46    /// Receive events, if second arg indicates if it is an external "important" event (true).
47    pub(super) event_queue_recv: UnboundedReceiver<(SubgraphId, bool)>,
48    /// If external events or data can justify starting the next tick.
49    pub(super) can_start_tick: bool,
50    /// If the events have been received for this tick.
51    pub(super) events_received_tick: bool,
52
53    // TODO(mingwei): as long as this is here, it's impossible to know when all work is done.
54    // Second field (bool) is for if the event is an external "important" event (true).
55    pub(super) event_queue_send: UnboundedSender<(SubgraphId, bool)>,
56
57    /// If the current subgraph wants to reschedule the current loop block (in the current tick).
58    pub(super) reschedule_loop_block: Cell<bool>,
59    pub(super) allow_another_iteration: Cell<bool>,
60
61    pub(super) current_tick: TickInstant,
62    pub(super) current_stratum: usize,
63
64    pub(super) current_tick_start: SystemTime,
65    pub(super) is_first_run_this_tick: bool,
66    pub(super) loop_iter_count: usize,
67
68    // Depth of loop (zero for top-level).
69    pub(super) loop_depth: SlotVec<LoopTag, usize>,
70
71    pub(super) loop_nonce: usize,
72
73    /// The SubgraphId of the currently running operator. When this context is
74    /// not being forwarded to a running operator, this field is meaningless.
75    pub(super) subgraph_id: SubgraphId,
76
77    tasks_to_spawn: Vec<Pin<Box<dyn Future<Output = ()> + 'static>>>,
78    /// Join handles for spawned tasks.
79    task_join_handles: Vec<JoinHandle<()>>,
80}
81/// Public APIs.
82impl Context {
83    /// Gets the current tick (local time) count.
84    pub fn current_tick(&self) -> TickInstant {
85        self.current_tick
86    }
87
88    /// Gets the timestamp of the beginning of the current tick.
89    pub fn current_tick_start(&self) -> SystemTime {
90        self.current_tick_start
91    }
92
93    /// Gets whether this is the first time this subgraph is being scheduled for this tick
94    pub fn is_first_run_this_tick(&self) -> bool {
95        self.is_first_run_this_tick
96    }
97
98    /// Gets the current loop iteration count.
99    pub fn loop_iter_count(&self) -> usize {
100        self.loop_iter_count
101    }
102
103    /// Gets the current stratum nubmer.
104    pub fn current_stratum(&self) -> usize {
105        self.current_stratum
106    }
107
108    /// Gets the ID of the current subgraph.
109    pub fn current_subgraph(&self) -> SubgraphId {
110        self.subgraph_id
111    }
112
113    /// Schedules a subgraph for the next tick.
114    ///
115    /// If `is_external` is `true`, the scheduling will trigger the next tick to begin. If it is
116    /// `false` then scheduling will be lazy and the next tick will not begin unless there is other
117    /// reason to.
118    pub fn schedule_subgraph(&self, sg_id: SubgraphId, is_external: bool) {
119        self.event_queue_send.send((sg_id, is_external)).unwrap()
120    }
121
122    /// Schedules the current loop block to be run again (_in this tick_).
123    pub fn reschedule_loop_block(&self) {
124        self.reschedule_loop_block.set(true);
125    }
126
127    /// Allow another iteration of the loop, if more data comes.
128    pub fn allow_another_iteration(&self) {
129        self.allow_another_iteration.set(true);
130    }
131
132    /// Returns a `Waker` for interacting with async Rust.
133    /// Waker events are considered to be extenral.
134    pub fn waker(&self) -> std::task::Waker {
135        use std::sync::Arc;
136
137        use futures::task::ArcWake;
138
139        struct ContextWaker {
140            subgraph_id: SubgraphId,
141            event_queue_send: UnboundedSender<(SubgraphId, bool)>,
142        }
143        impl ArcWake for ContextWaker {
144            fn wake_by_ref(arc_self: &Arc<Self>) {
145                let _recv_closed_error =
146                    arc_self.event_queue_send.send((arc_self.subgraph_id, true));
147            }
148        }
149
150        let context_waker = ContextWaker {
151            subgraph_id: self.subgraph_id,
152            event_queue_send: self.event_queue_send.clone(),
153        };
154        futures::task::waker(Arc::new(context_waker))
155    }
156
157    /// Returns a shared reference to the state.
158    ///
159    /// # Safety
160    /// `StateHandle<T>` must be from _this_ instance, created via [`Self::add_state`].
161    pub unsafe fn state_ref_unchecked<T>(&self, handle: StateHandle<T>) -> &'_ T
162    where
163        T: Any,
164    {
165        let state = self
166            .states
167            .get(handle.state_id.0)
168            .expect("Failed to find state with given handle.")
169            .state
170            .as_ref();
171
172        debug_assert!(state.is::<T>());
173
174        unsafe {
175            // SAFETY: `handle` is from this instance.
176            // TODO(shadaj): replace with `downcast_ref_unchecked` when it's stabilized
177            &*(state as *const dyn Any as *const T)
178        }
179    }
180
181    /// Returns a shared reference to the state.
182    pub fn state_ref<T>(&self, handle: StateHandle<T>) -> &'_ T
183    where
184        T: Any,
185    {
186        self.states
187            .get(handle.state_id.0)
188            .expect("Failed to find state with given handle.")
189            .state
190            .downcast_ref()
191            .expect("StateHandle wrong type T for casting.")
192    }
193
194    /// Returns an exclusive reference to the state.
195    pub fn state_mut<T>(&mut self, handle: StateHandle<T>) -> &'_ mut T
196    where
197        T: Any,
198    {
199        self.states
200            .get_mut(handle.state_id.0)
201            .expect("Failed to find state with given handle.")
202            .state
203            .downcast_mut()
204            .expect("StateHandle wrong type T for casting.")
205    }
206
207    /// Adds state to the context and returns the handle.
208    pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
209    where
210        T: Any,
211    {
212        let state_id = StateId(self.states.len());
213
214        let state_data = StateData {
215            state: Box::new(state),
216            tick_reset: None,
217        };
218        self.states.push(state_data);
219
220        StateHandle {
221            state_id,
222            _phantom: PhantomData,
223        }
224    }
225
226    /// Sets a hook to modify the state at the end of each tick, using the supplied closure.
227    pub fn set_state_tick_hook<T>(
228        &mut self,
229        handle: StateHandle<T>,
230        mut tick_hook_fn: impl 'static + FnMut(&mut T),
231    ) where
232        T: Any,
233    {
234        self.states
235            .get_mut(handle.state_id.0)
236            .expect("Failed to find state with given handle.")
237            .tick_reset = Some(Box::new(move |state| {
238            (tick_hook_fn)(state.downcast_mut::<T>().unwrap());
239        }));
240    }
241
242    /// Removes state from the context returns it as an owned heap value.
243    pub fn remove_state<T>(&mut self, handle: StateHandle<T>) -> Box<T>
244    where
245        T: Any,
246    {
247        self.states
248            .remove(handle.state_id.0)
249            .state
250            .downcast()
251            .expect("StateHandle wrong type T for casting.")
252    }
253
254    /// Prepares an async task to be launched by [`Self::spawn_tasks`].
255    pub fn request_task<Fut>(&mut self, future: Fut)
256    where
257        Fut: Future<Output = ()> + 'static,
258    {
259        self.tasks_to_spawn.push(Box::pin(future));
260    }
261
262    /// Launches all tasks requested with [`Self::request_task`] on the internal Tokio executor.
263    pub fn spawn_tasks(&mut self) {
264        for task in self.tasks_to_spawn.drain(..) {
265            self.task_join_handles.push(tokio::task::spawn_local(task));
266        }
267    }
268
269    /// Aborts all tasks spawned with [`Self::spawn_tasks`].
270    pub fn abort_tasks(&mut self) {
271        for task in self.task_join_handles.drain(..) {
272            task.abort();
273        }
274    }
275
276    /// Waits for all tasks spawned with [`Self::spawn_tasks`] to complete.
277    ///
278    /// Will probably just hang.
279    pub async fn join_tasks(&mut self) {
280        futures::future::join_all(self.task_join_handles.drain(..)).await;
281    }
282}
283
284impl Default for Context {
285    fn default() -> Self {
286        let stratum_queues = vec![Default::default()]; // Always initialize stratum #0.
287        let (event_queue_send, event_queue_recv) = mpsc::unbounded_channel();
288        let (stratum_stack, loop_depth) = Default::default();
289        Self {
290            states: Vec::new(),
291
292            stratum_stack,
293
294            loop_nonce_stack: Vec::new(),
295
296            schedule_deferred: Vec::new(),
297
298            stratum_queues,
299            event_queue_recv,
300            can_start_tick: false,
301            events_received_tick: false,
302
303            event_queue_send,
304            reschedule_loop_block: Cell::new(false),
305            allow_another_iteration: Cell::new(false),
306
307            current_stratum: 0,
308            current_tick: TickInstant::default(),
309
310            current_tick_start: SystemTime::now(),
311            is_first_run_this_tick: false,
312            loop_iter_count: 0,
313
314            loop_depth,
315            loop_nonce: 0,
316
317            // Will be re-set before use.
318            subgraph_id: SubgraphId::from_raw(0),
319
320            tasks_to_spawn: Vec::new(),
321            task_join_handles: Vec::new(),
322        }
323    }
324}
325/// Internal APIs.
326impl Context {
327    /// Makes sure stratum STRATUM is initialized.
328    pub(super) fn init_stratum(&mut self, stratum: usize) {
329        if self.stratum_queues.len() <= stratum {
330            self.stratum_queues
331                .resize_with(stratum + 1, Default::default);
332        }
333    }
334
335    /// Call this at the end of a tick,
336    pub(super) fn reset_state_at_end_of_tick(&mut self) {
337        for StateData { state, tick_reset } in self.states.iter_mut() {
338            if let Some(tick_reset) = tick_reset {
339                (tick_reset)(Box::deref_mut(state));
340            }
341        }
342    }
343}
344
345/// Internal struct containing a pointer to instance-owned state.
346struct StateData {
347    state: Box<dyn Any>,
348    tick_reset: Option<TickResetFn>,
349}
350type TickResetFn = Box<dyn FnMut(&mut dyn Any)>;