dfir_lang/graph/ops/
persist_mut.rs

1use quote::quote_spanned;
2
3use super::{
4    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5    OperatorWriteOutput, Persistence, WriteContextArgs, RANGE_0, RANGE_1,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// `persist_mut()` is similar to `persist()` except that it also enables deletions.
10/// `persist_mut()` expects an input of type [`Persistence<T>`](https://docs.rs/dfir_rs/latest/dfir_rs/util/enum.Persistence.html),
11/// and it is this enumeration that enables the user to communicate deletion.
12/// Deletions/persists happen in the order they are received in the stream.
13/// For example, `[Persist(1), Delete(1), Persist(1)]` will result in a a single `1` value being stored.
14///
15/// ```dfir
16/// use dfir_rs::util::Persistence;
17///
18/// source_iter([
19///         Persistence::Persist(1),
20///         Persistence::Persist(2),
21///         Persistence::Delete(1),
22///     ])
23///     -> persist_mut::<'static>()
24///     -> assert_eq([2]);
25/// ```
26pub const PERSIST_MUT: OperatorConstraints = OperatorConstraints {
27    name: "persist_mut",
28    categories: &[OperatorCategory::Persistence],
29    hard_range_inn: RANGE_1,
30    soft_range_inn: RANGE_1,
31    hard_range_out: RANGE_1,
32    soft_range_out: RANGE_1,
33    num_args: 0,
34    persistence_args: RANGE_1,
35    type_args: RANGE_0,
36    is_external_input: false,
37    // If this is set to true, the state will need to be cleared using `#context.set_state_tick_hook`
38    // to prevent reading uncleared data if this subgraph doesn't run.
39    // https://github.com/hydro-project/hydro/issues/1298
40    // If `'tick` lifetimes are added.
41    has_singleton_output: false,
42    flo_type: None,
43    ports_inn: None,
44    ports_out: None,
45    input_delaytype_fn: |_| Some(DelayType::Stratum),
46    write_fn: |wc @ &WriteContextArgs {
47                   root,
48                   context,
49                   df_ident,
50                   op_span,
51                   ident,
52                   inputs,
53                   is_pull,
54                   op_name,
55                   op_inst:
56                       OperatorInstance {
57                           generics:
58                               OpInstGenerics {
59                                   persistence_args, ..
60                               },
61                           ..
62                       },
63                   ..
64               },
65               diagnostics| {
66        assert!(is_pull);
67
68        if [Persistence::Static] != persistence_args[..] {
69            diagnostics.push(Diagnostic::spanned(
70                op_span,
71                Level::Error,
72                format!("{} only supports `'static`.", op_name),
73            ));
74        }
75
76        let persistdata_ident = wc.make_ident("persistdata");
77        let vec_ident = wc.make_ident("persistvec");
78        let write_prologue = quote_spanned! {op_span=>
79            let #persistdata_ident = #df_ident.add_state(::std::cell::RefCell::new(
80                #root::util::sparse_vec::SparseVec::default(),
81            ));
82        };
83
84        let write_iterator = {
85            let input = &inputs[0];
86            quote_spanned! {op_span=>
87                let mut #vec_ident = unsafe {
88                    // SAFETY: handle from `#df_ident.add_state(..)`.
89                    #context.state_ref_unchecked(#persistdata_ident)
90                }.borrow_mut();
91
92                let #ident = {
93                    #[inline(always)]
94                    fn check_iter<T: ::std::hash::Hash + ::std::cmp::Eq>(iter: impl Iterator<Item = #root::util::Persistence::<T>>) -> impl Iterator<Item = #root::util::Persistence::<T>> {
95                        iter
96                    }
97
98                    if context.is_first_run_this_tick() {
99                        for item in check_iter(#input) {
100                            match item {
101                                #root::util::Persistence::Persist(v) => #vec_ident.push(v),
102                                #root::util::Persistence::Delete(v) => #vec_ident.delete(&v),
103                            }
104                        }
105
106                        Some(#vec_ident.iter().cloned()).into_iter().flatten()
107                    } else {
108                        None.into_iter().flatten()
109                    }
110                };
111            }
112        };
113
114        let write_iterator_after = quote_spanned! {op_span=>
115            #context.schedule_subgraph(#context.current_subgraph(), false);
116        };
117
118        Ok(OperatorWriteOutput {
119            write_prologue,
120            write_iterator,
121            write_iterator_after,
122        })
123    },
124};