dfir_lang/graph/ops/persist.rs
1use quote::quote_spanned;
2
3use super::{
4 OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
5 Persistence, RANGE_0, RANGE_1, WriteContextArgs,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// Stores each item as it passes through, and replays all item every tick.
10///
11/// ```dfir
12/// // Normally `source_iter(...)` only emits once, but `persist::<'static>()` will replay the `"hello"`
13/// // on every tick.
14/// source_iter(["hello"])
15/// -> persist::<'static>()
16/// -> assert_eq(["hello"]);
17/// ```
18///
19/// `persist()` can be used to introduce statefulness into stateless pipelines. In the example below, the
20/// join only stores data for single tick. The `persist::<'static>()` operator introduces statefulness
21/// across ticks. This can be useful for optimization transformations within the dfir
22/// compiler. Equivalently, we could specify that the join has `static` persistence (`my_join = join::<'static>()`).
23/// ```rustbook
24/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<(&str, &str)>();
25/// let mut flow = dfir_rs::dfir_syntax! {
26/// source_iter([("hello", "world")]) -> persist::<'static>() -> [0]my_join;
27/// source_stream(input_recv) -> persist::<'static>() -> [1]my_join;
28/// my_join = join::<'tick>() -> for_each(|(k, (v1, v2))| println!("({}, ({}, {}))", k, v1, v2));
29/// };
30/// input_send.send(("hello", "oakland")).unwrap();
31/// flow.run_tick();
32/// input_send.send(("hello", "san francisco")).unwrap();
33/// flow.run_tick();
34/// // (hello, (world, oakland))
35/// // (hello, (world, oakland))
36/// // (hello, (world, san francisco))
37/// ```
38pub const PERSIST: OperatorConstraints = OperatorConstraints {
39 name: "persist",
40 categories: &[OperatorCategory::Persistence],
41 hard_range_inn: RANGE_1,
42 soft_range_inn: RANGE_1,
43 hard_range_out: RANGE_1,
44 soft_range_out: RANGE_1,
45 num_args: 0,
46 persistence_args: RANGE_1,
47 type_args: RANGE_0,
48 is_external_input: false,
49 has_singleton_output: true,
50 flo_type: None,
51 ports_inn: None,
52 ports_out: None,
53 input_delaytype_fn: |_| None,
54 write_fn: |wc @ &WriteContextArgs {
55 root,
56 context,
57 df_ident,
58 op_span,
59 ident,
60 is_pull,
61 inputs,
62 outputs,
63 singleton_output_ident,
64 op_name,
65 work_fn,
66 op_inst:
67 OperatorInstance {
68 generics:
69 OpInstGenerics {
70 persistence_args, ..
71 },
72 ..
73 },
74 ..
75 },
76 diagnostics| {
77 if [Persistence::Static] != persistence_args[..] {
78 diagnostics.push(Diagnostic::spanned(
79 op_span,
80 Level::Error,
81 format!("{} only supports `'static`.", op_name),
82 ));
83 }
84
85 let persistdata_ident = singleton_output_ident;
86 let vec_ident = wc.make_ident("persistvec");
87 let write_prologue = quote_spanned! {op_span=>
88 let #persistdata_ident = #df_ident.add_state(::std::cell::RefCell::new(
89 ::std::vec::Vec::new(),
90 ));
91 };
92
93 let write_iterator = if is_pull {
94 let input = &inputs[0];
95 quote_spanned! {op_span=>
96 let mut #vec_ident = unsafe {
97 // SAFETY: handle from `#df_ident.add_state(..)`.
98 #context.state_ref_unchecked(#persistdata_ident)
99 }.borrow_mut();
100
101 let #ident = {
102 if #context.is_first_run_this_tick() {
103 #work_fn(|| #vec_ident.extend(#input));
104 #vec_ident.iter().cloned()
105 } else {
106 let len = #vec_ident.len();
107 #work_fn(|| #vec_ident.extend(#input));
108 #vec_ident[len..].iter().cloned()
109 }
110 };
111 }
112 } else {
113 let output = &outputs[0];
114 quote_spanned! {op_span=>
115 let mut #vec_ident = unsafe {
116 // SAFETY: handle from `#df_ident.add_state(..)`.
117 #context.state_ref_unchecked(#persistdata_ident)
118 }.borrow_mut();
119
120 let #ident = {
121 fn constrain_types<'ctx, Push, Item>(vec: &'ctx mut Vec<Item>, mut output: Push, is_new_tick: bool) -> impl 'ctx + #root::pusherator::Pusherator<Item = Item>
122 where
123 Push: 'ctx + #root::pusherator::Pusherator<Item = Item>,
124 Item: ::std::clone::Clone,
125 {
126 if is_new_tick {
127 #work_fn(|| vec.iter().cloned().for_each(|item| {
128 #root::pusherator::Pusherator::give(&mut output, item);
129 }));
130 }
131 #root::pusherator::map::Map::new(|item| {
132 vec.push(item);
133 vec.last().unwrap().clone()
134 }, output)
135 }
136 constrain_types(&mut *#vec_ident, #output, #context.is_first_run_this_tick())
137 };
138 }
139 };
140
141 let write_iterator_after = quote_spanned! {op_span=>
142 #context.schedule_subgraph(#context.current_subgraph(), false);
143 };
144
145 Ok(OperatorWriteOutput {
146 write_prologue,
147 write_iterator,
148 write_iterator_after,
149 })
150 },
151};