dfir_lang/graph/ops/
enumerate.rs

1use quote::quote_spanned;
2
3use super::{
4    OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
5    OperatorWriteOutput, Persistence, WriteContextArgs, RANGE_0, RANGE_1,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// > 1 input stream of type `T`, 1 output stream of type `(usize, T)`
10///
11/// For each item passed in, enumerate it with its index: `(0, x_0)`, `(1, x_1)`, etc.
12///
13/// `enumerate` can also be provided with one generic lifetime persistence argument, either
14/// `'tick` or `'static`, to specify if indexing resets. If `'tick` (the default) is specified, indexing will
15/// restart at zero at the start of each tick. Otherwise `'static` will never reset
16/// and count monotonically upwards.
17///
18/// ```dfir
19/// source_iter(vec!["hello", "world"])
20///     -> enumerate()
21///     -> assert_eq([(0, "hello"), (1, "world")]);
22/// ```
23pub const ENUMERATE: OperatorConstraints = OperatorConstraints {
24    name: "enumerate",
25    categories: &[OperatorCategory::Map],
26    hard_range_inn: RANGE_1,
27    soft_range_inn: RANGE_1,
28    hard_range_out: RANGE_1,
29    soft_range_out: RANGE_1,
30    num_args: 0,
31    persistence_args: &(0..=1),
32    type_args: RANGE_0,
33    is_external_input: false,
34    has_singleton_output: false,
35    flo_type: None,
36    ports_inn: None,
37    ports_out: None,
38    input_delaytype_fn: |_| None,
39    write_fn: |wc @ &WriteContextArgs {
40                   root,
41                   op_span,
42                   context,
43                   df_ident,
44                   ident,
45                   inputs,
46                   outputs,
47                   is_pull,
48                   op_inst:
49                       OperatorInstance {
50                           generics:
51                               OpInstGenerics {
52                                   persistence_args, ..
53                               },
54                           ..
55                       },
56                   ..
57               },
58               diagnostics| {
59        let persistence = match persistence_args[..] {
60            [] => Persistence::Tick,
61            [Persistence::Mutable] => {
62                diagnostics.push(Diagnostic::spanned(
63                    op_span,
64                    Level::Error,
65                    "An implementation of 'mutable does not exist",
66                ));
67                Persistence::Tick
68            },
69            [a] => a,
70            _ => unreachable!(),
71        };
72
73        let input = &inputs[0];
74        let output = &outputs[0];
75
76        let counter_ident = wc.make_ident("counterdata");
77
78        let mut write_prologue = quote_spanned! {op_span=>
79            let #counter_ident = #df_ident.add_state(::std::cell::RefCell::new(0..));
80        };
81        if Persistence::Tick == persistence {
82            write_prologue.extend(quote_spanned! {op_span=>
83                #df_ident.set_state_tick_hook(#counter_ident, |rcell| { rcell.replace(0..); });
84            });
85        }
86
87        let map_fn = quote_spanned! {op_span=>
88            |item| {
89                let mut counter = unsafe {
90                    // SAFETY: handle from `#df_ident.add_state(..)`.
91                    #context.state_ref_unchecked(#counter_ident)
92                }.borrow_mut();
93                (counter.next().unwrap(), item)
94            }
95        };
96        let write_iterator = if is_pull {
97            quote_spanned! {op_span=>
98                let #ident = ::std::iter::Iterator::map(#input, #map_fn);
99            }
100        } else {
101            quote_spanned! {op_span=>
102                let #ident = #root::pusherator::map::Map::new(#map_fn, #output);
103            }
104        };
105
106        Ok(OperatorWriteOutput {
107            write_prologue,
108            write_iterator,
109            ..Default::default()
110        })
111    },
112};