dfir_lang/graph/ops/
persist_mut_keyed.rs
1use quote::quote_spanned;
2
3use super::{
4 DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5 OperatorWriteOutput, Persistence, RANGE_0, RANGE_1, WriteContextArgs,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9pub const PERSIST_MUT_KEYED: OperatorConstraints = OperatorConstraints {
27 name: "persist_mut_keyed",
28 categories: &[OperatorCategory::Persistence],
29 hard_range_inn: RANGE_1,
30 soft_range_inn: RANGE_1,
31 hard_range_out: RANGE_1,
32 soft_range_out: RANGE_1,
33 num_args: 0,
34 persistence_args: RANGE_1,
35 type_args: RANGE_0,
36 is_external_input: false,
37 has_singleton_output: false,
42 flo_type: None,
43 ports_inn: None,
44 ports_out: None,
45 input_delaytype_fn: |_| Some(DelayType::Stratum),
46 write_fn: |wc @ &WriteContextArgs {
47 root,
48 context,
49 df_ident,
50 op_span,
51 ident,
52 inputs,
53 is_pull,
54 op_name,
55 op_inst:
56 OperatorInstance {
57 generics:
58 OpInstGenerics {
59 persistence_args, ..
60 },
61 ..
62 },
63 ..
64 },
65 diagnostics| {
66 assert!(is_pull);
67
68 if [Persistence::Mutable] != persistence_args[..] {
69 diagnostics.push(Diagnostic::spanned(
70 op_span,
71 Level::Error,
72 format!(
73 "{} only supports `'{}`.",
74 op_name,
75 Persistence::Mutable.to_str_lowercase()
76 ),
77 ));
78 }
79
80 let persistdata_ident = wc.make_ident("persistdata");
81 let vec_ident = wc.make_ident("persistvec");
82 let write_prologue = quote_spanned! {op_span=>
83 let #persistdata_ident = #df_ident.add_state(::std::cell::RefCell::new(
84 #root::rustc_hash::FxHashMap::<_, #root::util::sparse_vec::SparseVec<_>>::default()
85 ));
86 };
87
88 let write_iterator = {
89 let input = &inputs[0];
90 quote_spanned! {op_span=>
91 let mut #vec_ident = unsafe {
92 #context.state_ref_unchecked(#persistdata_ident)
94 }.borrow_mut();
95
96 let #ident = {
97 #[inline(always)]
98 fn check_iter<K, V>(iter: impl Iterator<Item = #root::util::PersistenceKeyed::<K, V>>) -> impl Iterator<Item = #root::util::PersistenceKeyed::<K, V>> {
99 iter
100 }
101
102 if context.is_first_run_this_tick() {
103 for item in check_iter(#input) {
104 match item {
105 #root::util::PersistenceKeyed::Persist(k, v) => {
106 #vec_ident.entry(k).or_default().push(v);
107 },
108 #root::util::PersistenceKeyed::Delete(k) => {
109 #vec_ident.remove(&k);
110 }
111 }
112 }
113
114 #[allow(clippy::clone_on_copy)]
115 Some(#vec_ident.iter()
116 .flat_map(|(k, v)| v.iter().map(move |v| (k.clone(), v.clone()))))
117 .into_iter()
118 .flatten()
119 } else {
120 None.into_iter().flatten()
121 }
122 };
123 }
124 };
125
126 let write_iterator_after = quote_spanned! {op_span=>
127 #context.schedule_subgraph(#context.current_subgraph(), false);
128 };
129
130 Ok(OperatorWriteOutput {
131 write_prologue,
132 write_iterator,
133 write_iterator_after,
134 ..Default::default()
135 })
136 },
137};