1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
use quote::quote_spanned;
use super::{
OpInstGenerics, OperatorCategory,
OperatorConstraints, OperatorInstance, OperatorWriteOutput, Persistence, WriteContextArgs,
RANGE_0, RANGE_1,
};
use crate::diagnostic::{Diagnostic, Level};
/// Stores each item as it passes through, and replays all item every tick.
///
/// ```dfir
/// // Normally `source_iter(...)` only emits once, but `persist::<'static>()` will replay the `"hello"`
/// // on every tick.
/// source_iter(["hello"])
/// -> persist::<'static>()
/// -> assert_eq(["hello"]);
/// ```
///
/// `persist()` can be used to introduce statefulness into stateless pipelines. In the example below, the
/// join only stores data for single tick. The `persist::<'static>()` operator introduces statefulness
/// across ticks. This can be useful for optimization transformations within the dfir
/// compiler. Equivalently, we could specify that the join has `static` persistence (`my_join = join::<'static>()`).
/// ```rustbook
/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<(&str, &str)>();
/// let mut flow = dfir_rs::dfir_syntax! {
/// source_iter([("hello", "world")]) -> persist::<'static>() -> [0]my_join;
/// source_stream(input_recv) -> persist::<'static>() -> [1]my_join;
/// my_join = join::<'tick>() -> for_each(|(k, (v1, v2))| println!("({}, ({}, {}))", k, v1, v2));
/// };
/// input_send.send(("hello", "oakland")).unwrap();
/// flow.run_tick();
/// input_send.send(("hello", "san francisco")).unwrap();
/// flow.run_tick();
/// // (hello, (world, oakland))
/// // (hello, (world, oakland))
/// // (hello, (world, san francisco))
/// ```
pub const PERSIST: OperatorConstraints = OperatorConstraints {
name: "persist",
categories: &[OperatorCategory::Persistence],
hard_range_inn: RANGE_1,
soft_range_inn: RANGE_1,
hard_range_out: RANGE_1,
soft_range_out: RANGE_1,
num_args: 0,
persistence_args: RANGE_1,
type_args: RANGE_0,
is_external_input: false,
has_singleton_output: true,
flo_type: None,
ports_inn: None,
ports_out: None,
input_delaytype_fn: |_| None,
write_fn: |wc @ &WriteContextArgs {
root,
context,
hydroflow,
op_span,
ident,
is_pull,
inputs,
outputs,
singleton_output_ident,
op_name,
op_inst:
OperatorInstance {
generics:
OpInstGenerics {
persistence_args, ..
},
..
},
..
},
diagnostics| {
if [Persistence::Static] != persistence_args[..] {
diagnostics.push(Diagnostic::spanned(
op_span,
Level::Error,
format!("{} only supports `'static`.", op_name),
));
}
let persistdata_ident = singleton_output_ident;
let vec_ident = wc.make_ident("persistvec");
let write_prologue = quote_spanned! {op_span=>
let #persistdata_ident = #hydroflow.add_state(::std::cell::RefCell::new(
::std::vec::Vec::new(),
));
};
let write_iterator = if is_pull {
let input = &inputs[0];
quote_spanned! {op_span=>
let mut #vec_ident = #context.state_ref(#persistdata_ident).borrow_mut();
let #ident = {
if #context.is_first_run_this_tick() {
#vec_ident.extend(#input);
#vec_ident.iter().cloned()
} else {
let len = #vec_ident.len();
#vec_ident.extend(#input);
#vec_ident[len..].iter().cloned()
}
};
}
} else {
let output = &outputs[0];
quote_spanned! {op_span=>
let mut #vec_ident = #context.state_ref(#persistdata_ident).borrow_mut();
let #ident = {
fn constrain_types<'ctx, Push, Item>(vec: &'ctx mut Vec<Item>, mut output: Push, is_new_tick: bool) -> impl 'ctx + #root::pusherator::Pusherator<Item = Item>
where
Push: 'ctx + #root::pusherator::Pusherator<Item = Item>,
Item: ::std::clone::Clone,
{
if is_new_tick {
vec.iter().cloned().for_each(|item| {
#root::pusherator::Pusherator::give(&mut output, item);
});
}
#root::pusherator::map::Map::new(|item| {
vec.push(item);
vec.last().unwrap().clone()
}, output)
}
constrain_types(&mut *#vec_ident, #output, #context.is_first_run_this_tick())
};
}
};
let write_iterator_after = quote_spanned! {op_span=>
#context.schedule_subgraph(#context.current_subgraph(), false);
};
Ok(OperatorWriteOutput {
write_prologue,
write_iterator,
write_iterator_after,
})
},
};