dfir_lang/graph/ops/
persist_mut_keyed.rs

1use quote::quote_spanned;
2
3use super::{
4    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5    OperatorWriteOutput, Persistence, WriteContextArgs, RANGE_0, RANGE_1,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// `persist_mut_keyed()` is similar to `persist_mut()` except that it also enables key-based deletions
10/// `persist_mut()` expects an input of type [`PersistenceKeyed<T>`](https://docs.rs/dfir_rs/latest/dfir_rs/util/enum.PersistenceKeyed.html),
11/// and it is this enumeration that enables the user to communicate deletion.
12/// Deletions/persists happen in the order they are received in the stream.
13/// For example, `[Persist(1), Delete(1), Persist(1)]` will result in a a single `1` value being stored.
14///
15/// ```dfir
16/// use dfir_rs::util::PersistenceKeyed;
17///
18/// source_iter([
19///         PersistenceKeyed::Persist(0, 1),
20///         PersistenceKeyed::Persist(1, 1),
21///         PersistenceKeyed::Delete(1),
22///     ])
23///     -> persist_mut_keyed::<'static>()
24///     -> assert_eq([(0, 1)]);
25/// ```
26pub const PERSIST_MUT_KEYED: OperatorConstraints = OperatorConstraints {
27    name: "persist_mut_keyed",
28    categories: &[OperatorCategory::Persistence],
29    hard_range_inn: RANGE_1,
30    soft_range_inn: RANGE_1,
31    hard_range_out: RANGE_1,
32    soft_range_out: RANGE_1,
33    num_args: 0,
34    persistence_args: RANGE_1,
35    type_args: RANGE_0,
36    is_external_input: false,
37    // If this is set to true, the state will need to be cleared using `#context.set_state_tick_hook`
38    // to prevent reading uncleared data if this subgraph doesn't run.
39    // https://github.com/hydro-project/hydro/issues/1298
40    // If `'tick` lifetimes are added.
41    has_singleton_output: false,
42    flo_type: None,
43    ports_inn: None,
44    ports_out: None,
45    input_delaytype_fn: |_| Some(DelayType::Stratum),
46    write_fn: |wc @ &WriteContextArgs {
47                   root,
48                   context,
49                   df_ident,
50                   op_span,
51                   ident,
52                   inputs,
53                   is_pull,
54                   op_name,
55                   op_inst:
56                       OperatorInstance {
57                           generics:
58                               OpInstGenerics {
59                                   persistence_args, ..
60                               },
61                           ..
62                       },
63                   ..
64               },
65               diagnostics| {
66        assert!(is_pull);
67
68        if [Persistence::Static] != persistence_args[..] {
69            diagnostics.push(Diagnostic::spanned(
70                op_span,
71                Level::Error,
72                format!("{} only supports `'static`.", op_name),
73            ));
74        }
75
76        let persistdata_ident = wc.make_ident("persistdata");
77        let vec_ident = wc.make_ident("persistvec");
78        let write_prologue = quote_spanned! {op_span=>
79            let #persistdata_ident = #df_ident.add_state(::std::cell::RefCell::new(
80                #root::rustc_hash::FxHashMap::<_, #root::util::sparse_vec::SparseVec<_>>::default()
81            ));
82        };
83
84        let write_iterator = {
85            let input = &inputs[0];
86            quote_spanned! {op_span=>
87                let mut #vec_ident = unsafe {
88                    // SAFETY: handle from `#df_ident.add_state(..)`.
89                    #context.state_ref_unchecked(#persistdata_ident)
90                }.borrow_mut();
91
92                let #ident = {
93                    #[inline(always)]
94                    fn check_iter<K, V>(iter: impl Iterator<Item = #root::util::PersistenceKeyed::<K, V>>) -> impl Iterator<Item = #root::util::PersistenceKeyed::<K, V>> {
95                        iter
96                    }
97
98                    if context.is_first_run_this_tick() {
99                        for item in check_iter(#input) {
100                            match item {
101                                #root::util::PersistenceKeyed::Persist(k, v) => {
102                                    #vec_ident.entry(k).or_default().push(v);
103                                },
104                                #root::util::PersistenceKeyed::Delete(k) => {
105                                    #vec_ident.remove(&k);
106                                }
107                            }
108                        }
109
110                        #[allow(clippy::clone_on_copy)]
111                        Some(#vec_ident.iter()
112                            .flat_map(|(k, v)| v.iter().map(move |v| (k.clone(), v.clone()))))
113                        .into_iter()
114                        .flatten()
115                    } else {
116                        None.into_iter().flatten()
117                    }
118                };
119            }
120        };
121
122        let write_iterator_after = quote_spanned! {op_span=>
123            #context.schedule_subgraph(#context.current_subgraph(), false);
124        };
125
126        Ok(OperatorWriteOutput {
127            write_prologue,
128            write_iterator,
129            write_iterator_after,
130        })
131    },
132};