dfir_lang/graph/ops/
persist_mut_keyed.rs
1use quote::quote_spanned;
2
3use super::{
4 DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5 OperatorWriteOutput, Persistence, WriteContextArgs, RANGE_0, RANGE_1,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9pub const PERSIST_MUT_KEYED: OperatorConstraints = OperatorConstraints {
27 name: "persist_mut_keyed",
28 categories: &[OperatorCategory::Persistence],
29 hard_range_inn: RANGE_1,
30 soft_range_inn: RANGE_1,
31 hard_range_out: RANGE_1,
32 soft_range_out: RANGE_1,
33 num_args: 0,
34 persistence_args: RANGE_1,
35 type_args: RANGE_0,
36 is_external_input: false,
37 has_singleton_output: false,
42 flo_type: None,
43 ports_inn: None,
44 ports_out: None,
45 input_delaytype_fn: |_| Some(DelayType::Stratum),
46 write_fn: |wc @ &WriteContextArgs {
47 root,
48 context,
49 df_ident,
50 op_span,
51 ident,
52 inputs,
53 is_pull,
54 op_name,
55 op_inst:
56 OperatorInstance {
57 generics:
58 OpInstGenerics {
59 persistence_args, ..
60 },
61 ..
62 },
63 ..
64 },
65 diagnostics| {
66 assert!(is_pull);
67
68 if [Persistence::Static] != persistence_args[..] {
69 diagnostics.push(Diagnostic::spanned(
70 op_span,
71 Level::Error,
72 format!("{} only supports `'static`.", op_name),
73 ));
74 }
75
76 let persistdata_ident = wc.make_ident("persistdata");
77 let vec_ident = wc.make_ident("persistvec");
78 let write_prologue = quote_spanned! {op_span=>
79 let #persistdata_ident = #df_ident.add_state(::std::cell::RefCell::new(
80 #root::rustc_hash::FxHashMap::<_, #root::util::sparse_vec::SparseVec<_>>::default()
81 ));
82 };
83
84 let write_iterator = {
85 let input = &inputs[0];
86 quote_spanned! {op_span=>
87 let mut #vec_ident = unsafe {
88 #context.state_ref_unchecked(#persistdata_ident)
90 }.borrow_mut();
91
92 let #ident = {
93 #[inline(always)]
94 fn check_iter<K, V>(iter: impl Iterator<Item = #root::util::PersistenceKeyed::<K, V>>) -> impl Iterator<Item = #root::util::PersistenceKeyed::<K, V>> {
95 iter
96 }
97
98 if context.is_first_run_this_tick() {
99 for item in check_iter(#input) {
100 match item {
101 #root::util::PersistenceKeyed::Persist(k, v) => {
102 #vec_ident.entry(k).or_default().push(v);
103 },
104 #root::util::PersistenceKeyed::Delete(k) => {
105 #vec_ident.remove(&k);
106 }
107 }
108 }
109
110 #[allow(clippy::clone_on_copy)]
111 Some(#vec_ident.iter()
112 .flat_map(|(k, v)| v.iter().map(move |v| (k.clone(), v.clone()))))
113 .into_iter()
114 .flatten()
115 } else {
116 None.into_iter().flatten()
117 }
118 };
119 }
120 };
121
122 let write_iterator_after = quote_spanned! {op_span=>
123 #context.schedule_subgraph(#context.current_subgraph(), false);
124 };
125
126 Ok(OperatorWriteOutput {
127 write_prologue,
128 write_iterator,
129 write_iterator_after,
130 })
131 },
132};