dfir_lang/graph/ops/
unique.rs

1use quote::quote_spanned;
2
3use super::{
4    OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
5    Persistence, RANGE_0, RANGE_1, WriteContextArgs,
6};
7use crate::diagnostic::{Diagnostic, Level};
8
9/// Takes one stream as input and filters out any duplicate occurrences. The output
10/// contains all unique values from the input.
11///
12/// ```dfir
13/// source_iter(vec![1, 1, 2, 3, 2, 1, 3])
14///     -> unique()
15///     -> assert_eq([1, 2, 3]);
16/// ```
17///
18/// `unique` can also be provided with one generic lifetime persistence argument, either
19/// `'tick` or `'static`, to specify how data persists. The default is `'tick`.
20/// With `'tick`, uniqueness is only considered within the current tick, so across multiple ticks
21/// duplicate values may be emitted.
22/// With `'static`, values will be remembered across ticks and no duplicates will ever be emitted.
23///
24/// ```rustbook
25/// let (input_send, input_recv) = dfir_rs::util::unbounded_channel::<usize>();
26/// let mut flow = dfir_rs::dfir_syntax! {
27///     source_stream(input_recv)
28///         -> unique::<'tick>()
29///         -> for_each(|n| println!("{}", n));
30/// };
31///
32/// input_send.send(3).unwrap();
33/// input_send.send(3).unwrap();
34/// input_send.send(4).unwrap();
35/// input_send.send(3).unwrap();
36/// flow.run_available();
37/// // 3, 4
38///
39/// input_send.send(3).unwrap();
40/// input_send.send(5).unwrap();
41/// flow.run_available();
42/// // 3, 5
43/// // Note: 3 is emitted again.
44/// ```
45pub const UNIQUE: OperatorConstraints = OperatorConstraints {
46    name: "unique",
47    categories: &[OperatorCategory::Persistence],
48    hard_range_inn: RANGE_1,
49    soft_range_inn: RANGE_1,
50    hard_range_out: RANGE_1,
51    soft_range_out: RANGE_1,
52    num_args: 0,
53    persistence_args: &(0..=1),
54    type_args: RANGE_0,
55    is_external_input: false,
56    has_singleton_output: false,
57    flo_type: None,
58    ports_inn: None,
59    ports_out: None,
60    input_delaytype_fn: |_| None,
61    write_fn: |wc @ &WriteContextArgs {
62                   root,
63                   op_span,
64                   context,
65                   df_ident,
66                   loop_id,
67                   ident,
68                   inputs,
69                   outputs,
70                   is_pull,
71                   op_inst:
72                       OperatorInstance {
73                           generics:
74                               OpInstGenerics {
75                                   persistence_args, ..
76                               },
77                           ..
78                       },
79                   ..
80               },
81               diagnostics| {
82        let persistence = persistence_args.first().copied().unwrap_or_else(|| {
83            if loop_id.is_some() {
84                Persistence::None
85            } else {
86                Persistence::Tick
87            }
88        });
89
90        let input = &inputs[0];
91        let output = &outputs[0];
92
93        let uniquedata_ident = wc.make_ident("uniquedata");
94
95        let (write_prologue, get_set) = match persistence {
96            Persistence::None => (
97                Default::default(),
98                quote_spanned! {op_span=>
99                    let mut set = #root::rustc_hash::FxHashSet::default();
100                },
101            ),
102            Persistence::Tick => {
103                let write_prologue = quote_spanned! {op_span=>
104                    let #uniquedata_ident = #df_ident.add_state(::std::cell::RefCell::new(
105                        #root::util::monotonic_map::MonotonicMap::<_, #root::rustc_hash::FxHashSet<_>>::default(),
106                    ));
107                };
108                let get_set = quote_spanned! {op_span=>
109                    let mut borrow = unsafe {
110                        // SAFETY: handle from `#df_ident.add_state(..)`.
111                        #context.state_ref_unchecked(#uniquedata_ident)
112                    }.borrow_mut();
113                    let set = borrow.get_mut_clear((#context.current_tick(), #context.current_stratum()));
114                };
115                (write_prologue, get_set)
116            }
117            Persistence::Static => {
118                let write_prologue = quote_spanned! {op_span=>
119                    let #uniquedata_ident = #df_ident.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashSet::default()));
120                };
121                let get_set = quote_spanned! {op_span=>
122                    let mut set = unsafe {
123                        // SAFETY: handle from `#df_ident.add_state(..)`.
124                        #context.state_ref_unchecked(#uniquedata_ident)
125                    }.borrow_mut();
126                };
127                (write_prologue, get_set)
128            }
129            Persistence::Mutable => {
130                diagnostics.push(Diagnostic::spanned(
131                    op_span,
132                    Level::Error,
133                    "An implementation of 'mutable does not exist",
134                ));
135                return Err(());
136            }
137        };
138
139        let filter_fn = quote_spanned! {op_span=>
140            |item| {
141                #get_set
142                if !set.contains(item) {
143                    set.insert(::std::clone::Clone::clone(item));
144                    true
145                } else {
146                    false
147                }
148            }
149        };
150        let write_iterator = if is_pull {
151            quote_spanned! {op_span=>
152                let #ident = #input.filter(#filter_fn);
153            }
154        } else {
155            quote_spanned! {op_span=>
156                let #ident = #root::pusherator::filter::Filter::new(#filter_fn, #output);
157            }
158        };
159
160        Ok(OperatorWriteOutput {
161            write_prologue,
162            write_iterator,
163            ..Default::default()
164        })
165    },
166};