dfir_lang/graph/ops/
persist_mut.rs

1use quote::quote_spanned;
2
3use super::{
4    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5    OperatorWriteOutput, Persistence, RANGE_0, RANGE_1, WriteContextArgs,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// `persist_mut()` is similar to `persist()` except that it also enables deletions.
10/// `persist_mut()` expects an input of type [`Persistence<T>`](https://docs.rs/dfir_rs/latest/dfir_rs/util/enum.Persistence.html),
11/// and it is this enumeration that enables the user to communicate deletion.
12/// Deletions/persists happen in the order they are received in the stream.
13/// For example, `[Persist(1), Delete(1), Persist(1)]` will result in a a single `1` value being stored.
14///
15/// ```dfir
16/// use dfir_rs::util::Persistence;
17///
18/// source_iter([
19///         Persistence::Persist(1),
20///         Persistence::Persist(2),
21///         Persistence::Delete(1),
22///     ])
23///     -> persist_mut::<'mutable>()
24///     -> assert_eq([2]);
25/// ```
26pub const PERSIST_MUT: OperatorConstraints = OperatorConstraints {
27    name: "persist_mut",
28    categories: &[OperatorCategory::Persistence],
29    hard_range_inn: RANGE_1,
30    soft_range_inn: RANGE_1,
31    hard_range_out: RANGE_1,
32    soft_range_out: RANGE_1,
33    num_args: 0,
34    persistence_args: RANGE_1,
35    type_args: RANGE_0,
36    is_external_input: false,
37    // If this is set to true, the state will need to be cleared using `#context.set_state_lifespan_hook`
38    // to prevent reading uncleared data if this subgraph doesn't run.
39    // https://github.com/hydro-project/hydro/issues/1298
40    // If `'tick` lifetimes are added.
41    has_singleton_output: false,
42    flo_type: None,
43    ports_inn: None,
44    ports_out: None,
45    input_delaytype_fn: |_| Some(DelayType::Stratum),
46    write_fn: |wc @ &WriteContextArgs {
47                   root,
48                   context,
49                   df_ident,
50                   op_span,
51                   ident,
52                   inputs,
53                   is_pull,
54                   op_name,
55                   op_inst:
56                       OperatorInstance {
57                           generics:
58                               OpInstGenerics {
59                                   persistence_args, ..
60                               },
61                           ..
62                       },
63                   ..
64               },
65               diagnostics| {
66        assert!(is_pull);
67
68        if [Persistence::Mutable] != persistence_args[..] {
69            diagnostics.push(Diagnostic::spanned(
70                op_span,
71                Level::Error,
72                format!(
73                    "{} only supports `'{}`.",
74                    op_name,
75                    Persistence::Mutable.to_str_lowercase()
76                ),
77            ));
78        }
79
80        let persistdata_ident = wc.make_ident("persistdata");
81        let vec_ident = wc.make_ident("persistvec");
82        let write_prologue = quote_spanned! {op_span=>
83            let #persistdata_ident = #df_ident.add_state(::std::cell::RefCell::new(
84                #root::util::sparse_vec::SparseVec::default(),
85            ));
86        };
87
88        let write_iterator = {
89            let input = &inputs[0];
90            quote_spanned! {op_span=>
91                let mut #vec_ident = unsafe {
92                    // SAFETY: handle from `#df_ident.add_state(..)`.
93                    #context.state_ref_unchecked(#persistdata_ident)
94                }.borrow_mut();
95
96                let #ident = {
97                    #[inline(always)]
98                    fn check_iter<T: ::std::hash::Hash + ::std::cmp::Eq>(iter: impl Iterator<Item = #root::util::Persistence::<T>>) -> impl Iterator<Item = #root::util::Persistence::<T>> {
99                        iter
100                    }
101
102                    if context.is_first_run_this_tick() {
103                        for item in check_iter(#input) {
104                            match item {
105                                #root::util::Persistence::Persist(v) => #vec_ident.push(v),
106                                #root::util::Persistence::Delete(v) => #vec_ident.delete(&v),
107                            }
108                        }
109
110                        Some(#vec_ident.iter().cloned()).into_iter().flatten()
111                    } else {
112                        None.into_iter().flatten()
113                    }
114                };
115            }
116        };
117
118        let write_iterator_after = quote_spanned! {op_span=>
119            #context.schedule_subgraph(#context.current_subgraph(), false);
120        };
121
122        Ok(OperatorWriteOutput {
123            write_prologue,
124            write_iterator,
125            write_iterator_after,
126            ..Default::default()
127        })
128    },
129};