dfir_lang/graph/ops/
persist.rs

1use quote::quote_spanned;
2
3use super::{
4    OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
5    Persistence, RANGE_0, RANGE_1, WriteContextArgs,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// Stores each item as it passes through, and replays all item every tick.
10///
11/// ```dfir
12/// // Normally `source_iter(...)` only emits once, but `persist::<'static>()` will replay the `"hello"`
13/// // on every tick.
14/// source_iter(["hello"])
15///     -> persist::<'static>()
16///     -> assert_eq(["hello"]);
17/// ```
18///
19/// `persist()` can be used to introduce statefulness into stateless pipelines. In the example below, the
20/// join only stores data for single tick. The `persist::<'static>()` operator introduces statefulness
21/// across ticks. This can be useful for optimization transformations within the dfir
22/// compiler. Equivalently, we could specify that the join has `static` persistence (`my_join = join::<'static>()`).
23/// ```rustbook
24/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<(&str, &str)>();
25/// let mut flow = dfir_rs::dfir_syntax! {
26///     source_iter([("hello", "world")]) -> persist::<'static>() -> [0]my_join;
27///     source_stream(input_recv) -> persist::<'static>() -> [1]my_join;
28///     my_join = join::<'tick>() -> for_each(|(k, (v1, v2))| println!("({}, ({}, {}))", k, v1, v2));
29/// };
30/// input_send.send(("hello", "oakland")).unwrap();
31/// flow.run_tick();
32/// input_send.send(("hello", "san francisco")).unwrap();
33/// flow.run_tick();
34/// // (hello, (world, oakland))
35/// // (hello, (world, oakland))
36/// // (hello, (world, san francisco))
37/// ```
38pub const PERSIST: OperatorConstraints = OperatorConstraints {
39    name: "persist",
40    categories: &[OperatorCategory::Persistence],
41    hard_range_inn: RANGE_1,
42    soft_range_inn: RANGE_1,
43    hard_range_out: RANGE_1,
44    soft_range_out: RANGE_1,
45    num_args: 0,
46    persistence_args: RANGE_1,
47    type_args: RANGE_0,
48    is_external_input: false,
49    has_singleton_output: true,
50    flo_type: None,
51    ports_inn: None,
52    ports_out: None,
53    input_delaytype_fn: |_| None,
54    write_fn: |wc @ &WriteContextArgs {
55                   root,
56                   context,
57                   df_ident,
58                   op_span,
59                   ident,
60                   is_pull,
61                   inputs,
62                   outputs,
63                   singleton_output_ident,
64                   op_name,
65                   work_fn,
66                   op_inst:
67                       OperatorInstance {
68                           generics:
69                               OpInstGenerics {
70                                   persistence_args, ..
71                               },
72                           ..
73                       },
74                   ..
75               },
76               diagnostics| {
77        if [Persistence::Static] != persistence_args[..] {
78            diagnostics.push(Diagnostic::spanned(
79                op_span,
80                Level::Error,
81                format!("{} only supports `'static`.", op_name),
82            ));
83        }
84
85        let persistdata_ident = singleton_output_ident;
86        let vec_ident = wc.make_ident("persistvec");
87        let write_prologue = quote_spanned! {op_span=>
88            let #persistdata_ident = #df_ident.add_state(::std::cell::RefCell::new(
89                ::std::vec::Vec::new(),
90            ));
91        };
92
93        let write_iterator = if is_pull {
94            let input = &inputs[0];
95            quote_spanned! {op_span=>
96                let mut #vec_ident = unsafe {
97                    // SAFETY: handle from `#df_ident.add_state(..)`.
98                    #context.state_ref_unchecked(#persistdata_ident)
99                }.borrow_mut();
100
101                let #ident = {
102                    if #context.is_first_run_this_tick() {
103                        #work_fn(|| #vec_ident.extend(#input));
104                        #vec_ident.iter().cloned()
105                    } else {
106                        let len = #vec_ident.len();
107                        #work_fn(|| #vec_ident.extend(#input));
108                        #vec_ident[len..].iter().cloned()
109                    }
110                };
111            }
112        } else {
113            let output = &outputs[0];
114            quote_spanned! {op_span=>
115                let mut #vec_ident = unsafe {
116                    // SAFETY: handle from `#df_ident.add_state(..)`.
117                    #context.state_ref_unchecked(#persistdata_ident)
118                }.borrow_mut();
119
120                let #ident = {
121                    fn constrain_types<'ctx, Push, Item>(vec: &'ctx mut Vec<Item>, mut output: Push, is_new_tick: bool) -> impl 'ctx + #root::pusherator::Pusherator<Item = Item>
122                    where
123                        Push: 'ctx + #root::pusherator::Pusherator<Item = Item>,
124                        Item: ::std::clone::Clone,
125                    {
126                        if is_new_tick {
127                            #work_fn(|| vec.iter().cloned().for_each(|item| {
128                                #root::pusherator::Pusherator::give(&mut output, item);
129                            }));
130                        }
131                        #root::pusherator::map::Map::new(|item| {
132                            vec.push(item);
133                            vec.last().unwrap().clone()
134                        }, output)
135                    }
136                    constrain_types(&mut *#vec_ident, #output, #context.is_first_run_this_tick())
137                };
138            }
139        };
140
141        let write_iterator_after = quote_spanned! {op_span=>
142            #context.schedule_subgraph(#context.current_subgraph(), false);
143        };
144
145        Ok(OperatorWriteOutput {
146            write_prologue,
147            write_iterator,
148            write_iterator_after,
149        })
150    },
151};