1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
use std::collections::HashMap;

use proc_macro2::{Ident, TokenTree};
use quote::{quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{Expr, Pat};

use super::{
    OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
    PortIndexValue, PortListSpec, WriteContextArgs, RANGE_0, RANGE_1,
};
use crate::diagnostic::{Diagnostic, Level};
use crate::pretty_span::PrettySpan;

// TODO(mingwei): Preprocess rustdoc links in mdbook or in the `operator_docgen` macro.
/// > Arguments: A Rust closure, the first argument is a received item and the
/// > second argument is a variadic [`var_args!` tuple list](https://hydro.run/rustdoc/dfir_rs/macro.var_args.html)
/// > where each item name is an output port.
///
/// Takes the input stream and allows the user to determine which items to
/// deliver to any number of output streams.
///
/// > Note: Downstream operators may need explicit type annotations.
///
/// > Note: The [`Pusherator`](https://hydro.run/rustdoc/pusherator/trait.Pusherator.html)
/// > trait is automatically imported to enable the [`.give(...)` method](https://hydro.run/rustdoc/pusherator/trait.Pusherator.html#tymethod.give).
///
/// > Note: The closure has access to the [`context` object](surface_flows.mdx#the-context-object).
///
/// ```dfir
/// my_demux = source_iter(1..=100) -> demux(|v, var_args!(fzbz, fizz, buzz, rest)|
///     match (v % 3, v % 5) {
///         (0, 0) => fzbz.give(v),
///         (0, _) => fizz.give(v),
///         (_, 0) => buzz.give(v),
///         (_, _) => rest.give(v),
///     }
/// );
/// my_demux[fzbz] -> for_each(|v| println!("{}: fizzbuzz", v));
/// my_demux[fizz] -> for_each(|v| println!("{}: fizz", v));
/// my_demux[buzz] -> for_each(|v| println!("{}: buzz", v));
/// my_demux[rest] -> for_each(|v| println!("{}", v));
/// ```
pub const DEMUX: OperatorConstraints = OperatorConstraints {
    name: "demux",
    categories: &[OperatorCategory::MultiOut],
    hard_range_inn: RANGE_1,
    soft_range_inn: RANGE_1,
    hard_range_out: &(2..),
    soft_range_out: &(2..),
    num_args: 1,
    persistence_args: RANGE_0,
    type_args: RANGE_0,
    is_external_input: false,
    has_singleton_output: false,
    flo_type: None,
    ports_inn: None,
    ports_out: Some(|| PortListSpec::Variadic),
    input_delaytype_fn: |_| None,
    write_fn: |&WriteContextArgs {
                   root,
                   op_span,
                   ident,
                   outputs,
                   is_pull,
                   op_name,
                   op_inst: OperatorInstance { output_ports, .. },
                   arguments,
                   ..
               },
               diagnostics| {
        assert!(!is_pull);
        let func = &arguments[0];
        let Expr::Closure(func) = func else {
            diagnostics.push(Diagnostic::spanned(
                func.span(),
                Level::Error,
                "Argument must be a two-argument closure expression",
            ));
            return Err(());
        };
        if 2 != func.inputs.len() {
            diagnostics.push(Diagnostic::spanned(
                func.inputs.span(),
                Level::Error,
                &*format!(
                    "Closure provided to `{}(..)` must have two arguments: \
                    the first argument is the item, and the second argument lists ports. \
                    E.g. the second argument could be `var_args!(port_a, port_b, ..)`.",
                    op_name
                ),
            ));
            return Err(());
        }

        // Port idents specified in the closure's second argument.
        let arg2 = &func.inputs[1];
        let closure_idents = extract_closure_idents(arg2);

        // Port idents supplied via port connections in the surface syntax.
        let port_idents: Vec<_> = output_ports
            .iter()
            .filter_map(|output_port| {
                let PortIndexValue::Path(port_expr) = output_port else {
                    diagnostics.push(Diagnostic::spanned(
                        output_port.span(),
                        Level::Error,
                        format!(
                            "Output port from `{}(..)` must be specified and must be a valid identifier.",
                            op_name,
                        ),
                    ));
                    return None;
                };
                let port_ident = syn::parse2::<Ident>(quote_spanned! {op_span=> #port_expr })
                    .map_err(|err| diagnostics.push(err.into()))
                    .ok()?;

                if !closure_idents.contains_key(&port_ident) {
                    // TODO(mingwei): Use MultiSpan when `proc_macro2` supports it.
                    diagnostics.push(Diagnostic::spanned(
                        arg2.span(),
                        Level::Error,
                        format!(
                            "Argument specifying the output ports in `{0}(..)` does not contain extra port `{1}`: ({2}) (1/2).",
                            op_name, port_ident, PrettySpan(output_port.span()),
                        ),
                    ));
                    diagnostics.push(Diagnostic::spanned(
                        output_port.span(),
                        Level::Error,
                        format!(
                            "Port `{1}` not found in the arguments specified in `{0}(..)`'s closure: ({2}) (2/2).",
                            op_name, port_ident, PrettySpan(arg2.span()),
                        ),
                    ));
                    return None;
                }

                Some(port_ident)
            })
            .collect();

        for closure_ident in closure_idents.keys() {
            if !port_idents.contains(closure_ident) {
                diagnostics.push(Diagnostic::spanned(
                    closure_ident.span(),
                    Level::Error,
                    format!(
                        "`{}(..)` closure argument `{}` missing corresponding output port.",
                        op_name, closure_ident,
                    ),
                ));
            }
        }

        if diagnostics.iter().any(Diagnostic::is_error) {
            return Err(());
        }

        assert_eq!(outputs.len(), port_idents.len());
        assert_eq!(outputs.len(), closure_idents.len());

        let mut sort_permute: Vec<_> = (0..outputs.len()).collect();
        sort_permute.sort_by_key(|&i| closure_idents[&port_idents[i]]);

        let sorted_outputs = sort_permute.iter().map(|&i| &outputs[i]);

        let write_iterator = quote_spanned! {op_span=>
            let #ident = {
                #[allow(unused_imports)] use #root::pusherator::Pusherator;
                #root::pusherator::demux::Demux::new(#func, #root::var_expr!( #( #sorted_outputs ),* ))
            };
        };

        Ok(OperatorWriteOutput {
            write_iterator,
            ..Default::default()
        })
    },
};

fn extract_closure_idents(arg2: &Pat) -> HashMap<Ident, usize> {
    let tokens = if let Pat::Macro(pat_macro) = arg2 {
        pat_macro.mac.tokens.clone()
    } else {
        arg2.to_token_stream()
    };

    let mut idents = HashMap::new();
    let mut stack: Vec<_> = tokens.into_iter().collect();
    stack.reverse();
    while let Some(tt) = stack.pop() {
        match tt {
            TokenTree::Group(group) => {
                let a = stack.len();
                stack.extend(group.stream());
                let b = stack.len();
                stack[a..b].reverse();
            }
            TokenTree::Ident(ident) => {
                idents.insert(ident, idents.len());
            }
            TokenTree::Punct(_) => (),
            TokenTree::Literal(_) => (),
        }
    }
    idents
}