dfir_lang/graph/ops/
demux.rs

1use std::collections::HashMap;
2
3use proc_macro2::{Ident, TokenTree};
4use quote::{quote_spanned, ToTokens};
5use syn::spanned::Spanned;
6use syn::{Expr, Pat};
7
8use super::{
9    OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
10    PortIndexValue, PortListSpec, WriteContextArgs, RANGE_0, RANGE_1,
11};
12use crate::diagnostic::{Diagnostic, Level};
13use crate::pretty_span::PrettySpan;
14
15// TODO(mingwei): Preprocess rustdoc links in mdbook or in the `operator_docgen` macro.
16/// > Arguments: A Rust closure, the first argument is a received item and the
17/// > second argument is a variadic [`var_args!` tuple list](https://hydro.run/rustdoc/dfir_rs/macro.var_args.html)
18/// > where each item name is an output port.
19///
20/// Takes the input stream and allows the user to determine which items to
21/// deliver to any number of output streams.
22///
23/// > Note: Downstream operators may need explicit type annotations.
24///
25/// > Note: The [`Pusherator`](https://hydro.run/rustdoc/pusherator/trait.Pusherator.html)
26/// > trait is automatically imported to enable the [`.give(...)` method](https://hydro.run/rustdoc/pusherator/trait.Pusherator.html#tymethod.give).
27///
28/// > Note: The closure has access to the [`context` object](surface_flows.mdx#the-context-object).
29///
30/// ```dfir
31/// my_demux = source_iter(1..=100) -> demux(|v, var_args!(fzbz, fizz, buzz, rest)|
32///     match (v % 3, v % 5) {
33///         (0, 0) => fzbz.give(v),
34///         (0, _) => fizz.give(v),
35///         (_, 0) => buzz.give(v),
36///         (_, _) => rest.give(v),
37///     }
38/// );
39/// my_demux[fzbz] -> for_each(|v| println!("{}: fizzbuzz", v));
40/// my_demux[fizz] -> for_each(|v| println!("{}: fizz", v));
41/// my_demux[buzz] -> for_each(|v| println!("{}: buzz", v));
42/// my_demux[rest] -> for_each(|v| println!("{}", v));
43/// ```
44pub const DEMUX: OperatorConstraints = OperatorConstraints {
45    name: "demux",
46    categories: &[OperatorCategory::MultiOut],
47    hard_range_inn: RANGE_1,
48    soft_range_inn: RANGE_1,
49    hard_range_out: &(2..),
50    soft_range_out: &(2..),
51    num_args: 1,
52    persistence_args: RANGE_0,
53    type_args: RANGE_0,
54    is_external_input: false,
55    has_singleton_output: false,
56    flo_type: None,
57    ports_inn: None,
58    ports_out: Some(|| PortListSpec::Variadic),
59    input_delaytype_fn: |_| None,
60    write_fn: |&WriteContextArgs {
61                   root,
62                   op_span,
63                   ident,
64                   outputs,
65                   is_pull,
66                   op_name,
67                   op_inst: OperatorInstance { output_ports, .. },
68                   arguments,
69                   ..
70               },
71               diagnostics| {
72        assert!(!is_pull);
73        let func = &arguments[0];
74        let Expr::Closure(func) = func else {
75            diagnostics.push(Diagnostic::spanned(
76                func.span(),
77                Level::Error,
78                "Argument must be a two-argument closure expression",
79            ));
80            return Err(());
81        };
82        if 2 != func.inputs.len() {
83            diagnostics.push(Diagnostic::spanned(
84                func.inputs.span(),
85                Level::Error,
86                &*format!(
87                    "Closure provided to `{}(..)` must have two arguments: \
88                    the first argument is the item, and the second argument lists ports. \
89                    E.g. the second argument could be `var_args!(port_a, port_b, ..)`.",
90                    op_name
91                ),
92            ));
93            return Err(());
94        }
95
96        // Port idents specified in the closure's second argument.
97        let arg2 = &func.inputs[1];
98        let closure_idents = extract_closure_idents(arg2);
99
100        // Port idents supplied via port connections in the surface syntax.
101        let port_idents: Vec<_> = output_ports
102            .iter()
103            .filter_map(|output_port| {
104                let PortIndexValue::Path(port_expr) = output_port else {
105                    diagnostics.push(Diagnostic::spanned(
106                        output_port.span(),
107                        Level::Error,
108                        format!(
109                            "Output port from `{}(..)` must be specified and must be a valid identifier.",
110                            op_name,
111                        ),
112                    ));
113                    return None;
114                };
115                let port_ident = syn::parse2::<Ident>(quote_spanned! {op_span=> #port_expr })
116                    .map_err(|err| diagnostics.push(err.into()))
117                    .ok()?;
118
119                if !closure_idents.contains_key(&port_ident) {
120                    // TODO(mingwei): Use MultiSpan when `proc_macro2` supports it.
121                    diagnostics.push(Diagnostic::spanned(
122                        arg2.span(),
123                        Level::Error,
124                        format!(
125                            "Argument specifying the output ports in `{0}(..)` does not contain extra port `{1}`: ({2}) (1/2).",
126                            op_name, port_ident, PrettySpan(output_port.span()),
127                        ),
128                    ));
129                    diagnostics.push(Diagnostic::spanned(
130                        output_port.span(),
131                        Level::Error,
132                        format!(
133                            "Port `{1}` not found in the arguments specified in `{0}(..)`'s closure: ({2}) (2/2).",
134                            op_name, port_ident, PrettySpan(arg2.span()),
135                        ),
136                    ));
137                    return None;
138                }
139
140                Some(port_ident)
141            })
142            .collect();
143
144        for closure_ident in closure_idents.keys() {
145            if !port_idents.contains(closure_ident) {
146                diagnostics.push(Diagnostic::spanned(
147                    closure_ident.span(),
148                    Level::Error,
149                    format!(
150                        "`{}(..)` closure argument `{}` missing corresponding output port.",
151                        op_name, closure_ident,
152                    ),
153                ));
154            }
155        }
156
157        if diagnostics.iter().any(Diagnostic::is_error) {
158            return Err(());
159        }
160
161        assert_eq!(outputs.len(), port_idents.len());
162        assert_eq!(outputs.len(), closure_idents.len());
163
164        let mut sort_permute: Vec<_> = (0..outputs.len()).collect();
165        sort_permute.sort_by_key(|&i| closure_idents[&port_idents[i]]);
166
167        let sorted_outputs = sort_permute.iter().map(|&i| &outputs[i]);
168
169        let write_iterator = quote_spanned! {op_span=>
170            let #ident = {
171                #[allow(unused_imports)] use #root::pusherator::Pusherator;
172                #root::pusherator::demux::Demux::new(#func, #root::var_expr!( #( #sorted_outputs ),* ))
173            };
174        };
175
176        Ok(OperatorWriteOutput {
177            write_iterator,
178            ..Default::default()
179        })
180    },
181};
182
183fn extract_closure_idents(arg2: &Pat) -> HashMap<Ident, usize> {
184    let tokens = if let Pat::Macro(pat_macro) = arg2 {
185        pat_macro.mac.tokens.clone()
186    } else {
187        arg2.to_token_stream()
188    };
189
190    let mut idents = HashMap::new();
191    let mut stack: Vec<_> = tokens.into_iter().collect();
192    stack.reverse();
193    while let Some(tt) = stack.pop() {
194        match tt {
195            TokenTree::Group(group) => {
196                let a = stack.len();
197                stack.extend(group.stream());
198                let b = stack.len();
199                stack[a..b].reverse();
200            }
201            TokenTree::Ident(ident) => {
202                idents.insert(ident, idents.len());
203            }
204            TokenTree::Punct(_) => (),
205            TokenTree::Literal(_) => (),
206        }
207    }
208    idents
209}