dfir_lang/graph/ops/
persist_mut_keyed.rs

1use quote::quote_spanned;
2
3use super::{
4    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5    OperatorWriteOutput, Persistence, RANGE_0, RANGE_1, WriteContextArgs,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// `persist_mut_keyed()` is similar to `persist_mut()` except that it also enables key-based deletions
10/// `persist_mut()` expects an input of type [`PersistenceKeyed<T>`](https://docs.rs/dfir_rs/latest/dfir_rs/util/enum.PersistenceKeyed.html),
11/// and it is this enumeration that enables the user to communicate deletion.
12/// Deletions/persists happen in the order they are received in the stream.
13/// For example, `[Persist(1), Delete(1), Persist(1)]` will result in a a single `1` value being stored.
14///
15/// ```dfir
16/// use dfir_rs::util::PersistenceKeyed;
17///
18/// source_iter([
19///         PersistenceKeyed::Persist(0, 1),
20///         PersistenceKeyed::Persist(1, 1),
21///         PersistenceKeyed::Delete(1),
22///     ])
23///     -> persist_mut_keyed::<'mutable>()
24///     -> assert_eq([(0, 1)]);
25/// ```
26pub const PERSIST_MUT_KEYED: OperatorConstraints = OperatorConstraints {
27    name: "persist_mut_keyed",
28    categories: &[OperatorCategory::Persistence],
29    hard_range_inn: RANGE_1,
30    soft_range_inn: RANGE_1,
31    hard_range_out: RANGE_1,
32    soft_range_out: RANGE_1,
33    num_args: 0,
34    persistence_args: RANGE_1,
35    type_args: RANGE_0,
36    is_external_input: false,
37    // If this is set to true, the state will need to be cleared using `#context.set_state_lifespan_hook`
38    // to prevent reading uncleared data if this subgraph doesn't run.
39    // https://github.com/hydro-project/hydro/issues/1298
40    // If `'tick` lifetimes are added.
41    has_singleton_output: false,
42    flo_type: None,
43    ports_inn: None,
44    ports_out: None,
45    input_delaytype_fn: |_| Some(DelayType::Stratum),
46    write_fn: |wc @ &WriteContextArgs {
47                   root,
48                   context,
49                   df_ident,
50                   op_span,
51                   ident,
52                   inputs,
53                   is_pull,
54                   op_name,
55                   op_inst:
56                       OperatorInstance {
57                           generics:
58                               OpInstGenerics {
59                                   persistence_args, ..
60                               },
61                           ..
62                       },
63                   ..
64               },
65               diagnostics| {
66        assert!(is_pull);
67
68        if [Persistence::Mutable] != persistence_args[..] {
69            diagnostics.push(Diagnostic::spanned(
70                op_span,
71                Level::Error,
72                format!(
73                    "{} only supports `'{}`.",
74                    op_name,
75                    Persistence::Mutable.to_str_lowercase()
76                ),
77            ));
78        }
79
80        let persistdata_ident = wc.make_ident("persistdata");
81        let vec_ident = wc.make_ident("persistvec");
82        let write_prologue = quote_spanned! {op_span=>
83            let #persistdata_ident = #df_ident.add_state(::std::cell::RefCell::new(
84                #root::rustc_hash::FxHashMap::<_, #root::util::sparse_vec::SparseVec<_>>::default()
85            ));
86        };
87
88        let write_iterator = {
89            let input = &inputs[0];
90            quote_spanned! {op_span=>
91                let mut #vec_ident = unsafe {
92                    // SAFETY: handle from `#df_ident.add_state(..)`.
93                    #context.state_ref_unchecked(#persistdata_ident)
94                }.borrow_mut();
95
96                let #ident = {
97                    #[inline(always)]
98                    fn check_iter<K, V>(iter: impl Iterator<Item = #root::util::PersistenceKeyed::<K, V>>) -> impl Iterator<Item = #root::util::PersistenceKeyed::<K, V>> {
99                        iter
100                    }
101
102                    if context.is_first_run_this_tick() {
103                        for item in check_iter(#input) {
104                            match item {
105                                #root::util::PersistenceKeyed::Persist(k, v) => {
106                                    #vec_ident.entry(k).or_default().push(v);
107                                },
108                                #root::util::PersistenceKeyed::Delete(k) => {
109                                    #vec_ident.remove(&k);
110                                }
111                            }
112                        }
113
114                        #[allow(clippy::clone_on_copy)]
115                        Some(#vec_ident.iter()
116                            .flat_map(|(k, v)| v.iter().map(move |v| (k.clone(), v.clone()))))
117                        .into_iter()
118                        .flatten()
119                    } else {
120                        None.into_iter().flatten()
121                    }
122                };
123            }
124        };
125
126        let write_iterator_after = quote_spanned! {op_span=>
127            #context.schedule_subgraph(#context.current_subgraph(), false);
128        };
129
130        Ok(OperatorWriteOutput {
131            write_prologue,
132            write_iterator,
133            write_iterator_after,
134            ..Default::default()
135        })
136    },
137};