dfir_lang/graph/ops/
partition.rs

1use std::collections::BTreeSet;
2
3use proc_macro2::Span;
4use quote::quote_spanned;
5use syn::spanned::Spanned;
6use syn::token::Colon;
7use syn::{parse_quote_spanned, Expr, Ident, LitInt, LitStr, Pat, PatType};
8
9use super::{
10    OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput, PortIndexValue,
11    PortListSpec, WriteContextArgs, RANGE_0, RANGE_1,
12};
13use crate::diagnostic::{Diagnostic, Level};
14use crate::pretty_span::PrettySpan;
15
16/// This operator takes the input pipeline and allows the user to determine which singular output
17/// pipeline each item should be delivered to.
18///
19/// > Arguments: A Rust closure, the first argument is a reference to the item and the second
20/// > argument corresponds to one of two modes, either named or indexed.
21///
22/// > Note: The closure has access to the [`context` object](surface_flows.mdx#the-context-object).
23///
24/// # Named mode
25/// With named ports, the closure's second argument must be a Rust 'slice pattern' of names, such as
26/// `[port_a, port_b, port_c]`, where each name is an output port. The closure should return the
27/// name of the desired output port.
28///
29/// ```dfir
30/// my_partition = source_iter(1..=100) -> partition(|val: &usize, [fzbz, fizz, buzz, rest]|
31///     match (val % 3, val % 5) {
32///         (0, 0) => fzbz,
33///         (0, _) => fizz,
34///         (_, 0) => buzz,
35///         (_, _) => rest,
36///     }
37/// );
38/// my_partition[fzbz] -> for_each(|v| println!("{}: fizzbuzz", v));
39/// my_partition[fizz] -> for_each(|v| println!("{}: fizz", v));
40/// my_partition[buzz] -> for_each(|v| println!("{}: buzz", v));
41/// my_partition[rest] -> for_each(|v| println!("{}", v));
42/// ```
43///
44/// # Indexed mode
45/// With indexed mode, the closure's second argument is a the number of output ports. This is a
46/// single usize value, useful for e.g. round robin partitioning. Each output pipeline port must be
47/// numbered with an index, starting from zero and with no gaps. The closure returns the index of
48/// the desired output port.
49///
50/// ```dfir
51/// my_partition = source_iter(1..=100) -> partition(|val, num_outputs| val % num_outputs);
52/// my_partition[0] -> for_each(|v| println!("0: {}", v));
53/// my_partition[1] -> for_each(|v| println!("1: {}", v));
54/// my_partition[2] -> for_each(|v| println!("2: {}", v));
55/// ```
56pub const PARTITION: OperatorConstraints = OperatorConstraints {
57    name: "partition",
58    categories: &[OperatorCategory::MultiOut],
59    hard_range_inn: RANGE_1,
60    soft_range_inn: RANGE_1,
61    hard_range_out: &(2..),
62    soft_range_out: &(2..),
63    num_args: 1,
64    persistence_args: RANGE_0,
65    type_args: RANGE_0,
66    is_external_input: false,
67    has_singleton_output: false,
68    flo_type: None,
69    ports_inn: None,
70    ports_out: Some(|| PortListSpec::Variadic),
71    input_delaytype_fn: |_| None,
72    write_fn: |wc @ &WriteContextArgs {
73                   root,
74                   op_span,
75                   ident,
76                   outputs,
77                   is_pull,
78                   op_name,
79                   op_inst: OperatorInstance { output_ports, .. },
80                   arguments,
81                   ..
82               },
83               diagnostics| {
84        assert!(!is_pull);
85
86        // Clone because we may modify the closure's arg2 to inject the type.
87        let mut func = arguments[0].clone();
88
89        let idx_ints = (0..output_ports.len())
90            .map(|i| LitInt::new(&format!("{}_usize", i), op_span))
91            .collect::<Vec<_>>();
92
93        let mut output_sort_permutation: Vec<_> = (0..outputs.len()).collect();
94        let (output_idents, arg2_val) = if let Some(port_idents) =
95            determine_indices_or_idents(output_ports, op_span, op_name, diagnostics)?
96        {
97            // All idents.
98            let (closure_idents, arg2_span) =
99                extract_closure_idents(&mut func, op_name).map_err(|err| diagnostics.push(err))?;
100            check_closure_ports_match(
101                &closure_idents,
102                &port_idents,
103                op_name,
104                arg2_span,
105                diagnostics,
106            )?;
107            output_sort_permutation.sort_by_key(|&i| {
108                closure_idents
109                    .iter()
110                    .position(|ident| ident == &port_idents[i])
111                    .expect(
112                        "Missing port, this should've been caught in the check above, this is a bug.",
113                    )
114            });
115            let arg2_val = quote_spanned! {arg2_span.span()=> [ #( #idx_ints ),* ] };
116
117            (closure_idents, arg2_val)
118        } else {
119            // All indices.
120            let numeric_idents = (0..output_ports.len())
121                .map(|i| wc.make_ident(format!("{}_push", i)))
122                .collect();
123            let len_lit = LitInt::new(&format!("{}_usize", output_ports.len()), op_span);
124            let arg2_val = quote_spanned! {op_span=> #len_lit };
125            (numeric_idents, arg2_val)
126        };
127
128        let err_str = LitStr::new(
129            &format!(
130                "Index `{{}}` returned by `{}(..)` closure is out-of-bounds.",
131                op_name
132            ),
133            op_span,
134        );
135        let ident_item = wc.make_ident("item");
136        let ident_index = wc.make_ident("index");
137        let ident_unknown = wc.make_ident("match_unknown");
138
139        let sorted_outputs = output_sort_permutation.into_iter().map(|i| &outputs[i]);
140
141        let write_iterator = quote_spanned! {op_span=>
142            let #ident = {
143                #root::pusherator::demux::Demux::new(
144                    |#ident_item, #root::var_args!( #( #output_idents ),* )| {
145                        #[allow(unused_imports)]
146                        use #root::pusherator::Pusherator;
147
148                        let #ident_index = {
149                            #[allow(clippy::redundant_closure_call)]
150                            (#func)(&#ident_item, #arg2_val)
151                        };
152                        match #ident_index {
153                            #(
154                                #idx_ints => #output_idents.give(#ident_item),
155                            )*
156                            #ident_unknown => panic!(#err_str, #ident_unknown),
157                        };
158                    },
159                    #root::var_expr!( #( #sorted_outputs ),* ),
160                )
161            };
162        };
163
164        Ok(OperatorWriteOutput {
165            write_iterator,
166            ..Default::default()
167        })
168    },
169};
170
171/// Returns `Ok(Some(idents))` if ports are idents, or `Ok(None)` if ports are indices.
172/// Returns `Err(())` if there are any errors (pushed to `diagnostics`).
173fn determine_indices_or_idents(
174    output_ports: &[PortIndexValue],
175    op_span: Span,
176    op_name: &'static str,
177    diagnostics: &mut Vec<Diagnostic>,
178) -> Result<Option<Vec<Ident>>, ()> {
179    // Port idents supplied via port connections in the surface syntax.
180    // Two modes, either all numeric `0, 1, 2, 3, ...` or all `Ident`s.
181    // If ports are `Idents` then the closure's 2nd argument, input array must have named
182    // values corresponding to the port idents.
183    let mut ports_numeric = BTreeSet::new();
184    let mut ports_idents = Vec::new();
185    // If any ports are elided we return `Err(())` early.
186    let mut err_elided = false;
187    for output_port in output_ports {
188        match output_port {
189            PortIndexValue::Elided(port_span) => {
190                err_elided = true;
191                diagnostics.push(Diagnostic::spanned(
192                    port_span.unwrap_or(op_span),
193                    Level::Error,
194                    format!(
195                        "Output ports from `{}` cannot be blank, must be named or indexed.",
196                        op_name
197                    ),
198                ));
199            }
200            PortIndexValue::Int(port_idx) => {
201                ports_numeric.insert(port_idx);
202
203                if port_idx.value < 0 {
204                    diagnostics.push(Diagnostic::spanned(
205                        port_idx.span,
206                        Level::Error,
207                        format!("Output ports from `{}` must be non-nonegative indices starting from zero.", op_name),
208                    ));
209                }
210            }
211            PortIndexValue::Path(port_path) => {
212                let port_ident = syn::parse2::<Ident>(quote_spanned!(op_span=> #port_path))
213                    .map_err(|err| diagnostics.push(err.into()))?;
214                ports_idents.push(port_ident);
215            }
216        }
217    }
218    if err_elided {
219        return Err(());
220    }
221
222    match (!ports_numeric.is_empty(), !ports_idents.is_empty()) {
223        (false, false) => {
224            // Had no ports or only elided ports.
225            assert!(diagnostics.iter().any(Diagnostic::is_error), "Empty input ports, expected an error diagnostic but none were emitted, this is a bug.");
226            Err(())
227        }
228        (true, true) => {
229            // Conflict.
230            let msg = &*format!(
231                "Output ports from `{}` must either be all integer indices or all identifiers.",
232                op_name
233            );
234            diagnostics.extend(
235                output_ports
236                    .iter()
237                    .map(|output_port| Diagnostic::spanned(output_port.span(), Level::Error, msg)),
238            );
239            Err(())
240        }
241        (true, false) => {
242            let max_port_idx = ports_numeric.last().unwrap().value;
243            if usize::try_from(max_port_idx).unwrap() >= ports_numeric.len() {
244                let mut expected = 0;
245                for port_numeric in ports_numeric {
246                    if expected != port_numeric.value {
247                        diagnostics.push(Diagnostic::spanned(
248                            port_numeric.span,
249                            Level::Error,
250                            format!(
251                                "Output port indices from `{}` must be consecutive from zero, missing {}.",
252                                op_name, expected
253                            ),
254                        ));
255                    }
256                    expected = port_numeric.value + 1;
257                }
258                // Can continue with code gen, port numbers will be treated as if they're
259                // consecutive from their ascending order.
260            }
261            Ok(None)
262        }
263        (false, true) => Ok(Some(ports_idents)),
264    }
265}
266
267// Returns a vec of closure idents and the arg2 span.
268fn extract_closure_idents(
269    func: &mut Expr,
270    op_name: &'static str,
271) -> Result<(Vec<Ident>, Span), Diagnostic> {
272    let Expr::Closure(func) = func else {
273        return Err(Diagnostic::spanned(
274            func.span(),
275            Level::Error,
276            "Argument must be a two-argument closure expression",
277        ));
278    };
279    if 2 != func.inputs.len() {
280        return Err(Diagnostic::spanned(
281            func.inputs.span(),
282            Level::Error,
283            &*format!(
284                "Closure provided to `{}(..)` must have two arguments: \
285                the first argument is the item, and for named ports the second argument must contain a Rust 'slice pattern' to determine the port names and order. \
286                For example, the second argument could be `[foo, bar, baz]` for ports `foo`, `bar`, and `baz`.",
287                op_name
288            ),
289        ));
290    }
291
292    // Port idents specified in the closure's second argument.
293    let mut arg2 = &mut func.inputs[1];
294    let mut already_has_type = false;
295    if let Pat::Type(pat_type) = arg2 {
296        arg2 = &mut *pat_type.pat;
297        already_has_type = true;
298    }
299
300    let arg2_span = arg2.span();
301    if let Pat::Ident(pat_ident) = arg2 {
302        arg2 = &mut *pat_ident
303            .subpat
304            .as_mut()
305            .ok_or_else(|| Diagnostic::spanned(
306                arg2_span,
307                Level::Error,
308                format!(
309                    "Second argument for the `{}` closure must contain a Rust 'slice pattern' to determine the port names and order. \
310                    For example: `arr @ [foo, bar, baz]` for ports `foo`, `bar`, and `baz`.",
311                    op_name
312                )
313            ))?
314            .1;
315    }
316    let Pat::Slice(pat_slice) = arg2 else {
317        return Err(Diagnostic::spanned(
318            arg2_span,
319            Level::Error,
320            format!(
321                "Second argument for the `{}` closure must have a Rust 'slice pattern' to determine the port names and order. \
322                For example: `[foo, bar, baz]` for ports `foo`, `bar`, and `baz`.",
323                op_name
324            )
325        ));
326    };
327
328    let idents = pat_slice
329        .elems
330        .iter()
331        .map(|pat| {
332            let Pat::Ident(pat_ident) = pat else {
333                panic!("TODO(mingwei) expected ident pat");
334            };
335            pat_ident.ident.clone()
336        })
337        .collect();
338
339    // Last step: set the type `[a, b, c]: [usize; 3]` if it is not already specified.
340    if !already_has_type {
341        let len = LitInt::new(&pat_slice.elems.len().to_string(), arg2_span);
342        *arg2 = Pat::Type(PatType {
343            attrs: vec![],
344            pat: Box::new(arg2.clone()),
345            colon_token: Colon { spans: [arg2_span] },
346            ty: parse_quote_spanned! {arg2_span=> [usize; #len] },
347        });
348    }
349
350    Ok((idents, arg2_span))
351}
352
353// Checks that the closure names and output port names match.
354fn check_closure_ports_match(
355    closure_idents: &[Ident],
356    port_idents: &[Ident],
357    op_name: &'static str,
358    arg2_span: Span,
359    diagnostics: &mut Vec<Diagnostic>,
360) -> Result<(), ()> {
361    let mut err = false;
362    for port_ident in port_idents {
363        if !closure_idents.contains(port_ident) {
364            // An output port is missing from the closure args.
365            err = true;
366            diagnostics.push(Diagnostic::spanned(
367                arg2_span,
368                Level::Error,
369                format!(
370                    "Argument specifying the output ports in `{0}(..)` does not contain extra port `{1}`: ({2}) (1/2).",
371                    op_name, port_ident, PrettySpan(port_ident.span()),
372                ),
373            ));
374            diagnostics.push(Diagnostic::spanned(
375                port_ident.span(),
376                Level::Error,
377                format!(
378                    "Port `{1}` not found in the arguments specified in `{0}(..)`'s closure: ({2}) (2/2).",
379                    op_name, port_ident, PrettySpan(arg2_span),
380                ),
381            ));
382        }
383    }
384    for closure_ident in closure_idents {
385        if !port_idents.contains(closure_ident) {
386            // A closure arg is missing from the output ports.
387            err = true;
388            diagnostics.push(Diagnostic::spanned(
389                closure_ident.span(),
390                Level::Error,
391                format!(
392                    "`{}(..)` closure argument `{}` missing corresponding output port.",
393                    op_name, closure_ident,
394                ),
395            ));
396        }
397    }
398    (!err).then_some(()).ok_or(())
399}