1use std::collections::BTreeSet;
2
3use proc_macro2::Span;
4use quote::quote_spanned;
5use syn::spanned::Spanned;
6use syn::token::Colon;
7use syn::{parse_quote_spanned, Expr, Ident, LitInt, LitStr, Pat, PatType};
8
9use super::{
10 OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput, PortIndexValue,
11 PortListSpec, WriteContextArgs, RANGE_0, RANGE_1,
12};
13use crate::diagnostic::{Diagnostic, Level};
14use crate::pretty_span::PrettySpan;
15
16pub const PARTITION: OperatorConstraints = OperatorConstraints {
57 name: "partition",
58 categories: &[OperatorCategory::MultiOut],
59 hard_range_inn: RANGE_1,
60 soft_range_inn: RANGE_1,
61 hard_range_out: &(2..),
62 soft_range_out: &(2..),
63 num_args: 1,
64 persistence_args: RANGE_0,
65 type_args: RANGE_0,
66 is_external_input: false,
67 has_singleton_output: false,
68 flo_type: None,
69 ports_inn: None,
70 ports_out: Some(|| PortListSpec::Variadic),
71 input_delaytype_fn: |_| None,
72 write_fn: |wc @ &WriteContextArgs {
73 root,
74 op_span,
75 ident,
76 outputs,
77 is_pull,
78 op_name,
79 op_inst: OperatorInstance { output_ports, .. },
80 arguments,
81 ..
82 },
83 diagnostics| {
84 assert!(!is_pull);
85
86 let mut func = arguments[0].clone();
88
89 let idx_ints = (0..output_ports.len())
90 .map(|i| LitInt::new(&format!("{}_usize", i), op_span))
91 .collect::<Vec<_>>();
92
93 let mut output_sort_permutation: Vec<_> = (0..outputs.len()).collect();
94 let (output_idents, arg2_val) = if let Some(port_idents) =
95 determine_indices_or_idents(output_ports, op_span, op_name, diagnostics)?
96 {
97 let (closure_idents, arg2_span) =
99 extract_closure_idents(&mut func, op_name).map_err(|err| diagnostics.push(err))?;
100 check_closure_ports_match(
101 &closure_idents,
102 &port_idents,
103 op_name,
104 arg2_span,
105 diagnostics,
106 )?;
107 output_sort_permutation.sort_by_key(|&i| {
108 closure_idents
109 .iter()
110 .position(|ident| ident == &port_idents[i])
111 .expect(
112 "Missing port, this should've been caught in the check above, this is a bug.",
113 )
114 });
115 let arg2_val = quote_spanned! {arg2_span.span()=> [ #( #idx_ints ),* ] };
116
117 (closure_idents, arg2_val)
118 } else {
119 let numeric_idents = (0..output_ports.len())
121 .map(|i| wc.make_ident(format!("{}_push", i)))
122 .collect();
123 let len_lit = LitInt::new(&format!("{}_usize", output_ports.len()), op_span);
124 let arg2_val = quote_spanned! {op_span=> #len_lit };
125 (numeric_idents, arg2_val)
126 };
127
128 let err_str = LitStr::new(
129 &format!(
130 "Index `{{}}` returned by `{}(..)` closure is out-of-bounds.",
131 op_name
132 ),
133 op_span,
134 );
135 let ident_item = wc.make_ident("item");
136 let ident_index = wc.make_ident("index");
137 let ident_unknown = wc.make_ident("match_unknown");
138
139 let sorted_outputs = output_sort_permutation.into_iter().map(|i| &outputs[i]);
140
141 let write_iterator = quote_spanned! {op_span=>
142 let #ident = {
143 #root::pusherator::demux::Demux::new(
144 |#ident_item, #root::var_args!( #( #output_idents ),* )| {
145 #[allow(unused_imports)]
146 use #root::pusherator::Pusherator;
147
148 let #ident_index = {
149 #[allow(clippy::redundant_closure_call)]
150 (#func)(&#ident_item, #arg2_val)
151 };
152 match #ident_index {
153 #(
154 #idx_ints => #output_idents.give(#ident_item),
155 )*
156 #ident_unknown => panic!(#err_str, #ident_unknown),
157 };
158 },
159 #root::var_expr!( #( #sorted_outputs ),* ),
160 )
161 };
162 };
163
164 Ok(OperatorWriteOutput {
165 write_iterator,
166 ..Default::default()
167 })
168 },
169};
170
171fn determine_indices_or_idents(
174 output_ports: &[PortIndexValue],
175 op_span: Span,
176 op_name: &'static str,
177 diagnostics: &mut Vec<Diagnostic>,
178) -> Result<Option<Vec<Ident>>, ()> {
179 let mut ports_numeric = BTreeSet::new();
184 let mut ports_idents = Vec::new();
185 let mut err_elided = false;
187 for output_port in output_ports {
188 match output_port {
189 PortIndexValue::Elided(port_span) => {
190 err_elided = true;
191 diagnostics.push(Diagnostic::spanned(
192 port_span.unwrap_or(op_span),
193 Level::Error,
194 format!(
195 "Output ports from `{}` cannot be blank, must be named or indexed.",
196 op_name
197 ),
198 ));
199 }
200 PortIndexValue::Int(port_idx) => {
201 ports_numeric.insert(port_idx);
202
203 if port_idx.value < 0 {
204 diagnostics.push(Diagnostic::spanned(
205 port_idx.span,
206 Level::Error,
207 format!("Output ports from `{}` must be non-nonegative indices starting from zero.", op_name),
208 ));
209 }
210 }
211 PortIndexValue::Path(port_path) => {
212 let port_ident = syn::parse2::<Ident>(quote_spanned!(op_span=> #port_path))
213 .map_err(|err| diagnostics.push(err.into()))?;
214 ports_idents.push(port_ident);
215 }
216 }
217 }
218 if err_elided {
219 return Err(());
220 }
221
222 match (!ports_numeric.is_empty(), !ports_idents.is_empty()) {
223 (false, false) => {
224 assert!(diagnostics.iter().any(Diagnostic::is_error), "Empty input ports, expected an error diagnostic but none were emitted, this is a bug.");
226 Err(())
227 }
228 (true, true) => {
229 let msg = &*format!(
231 "Output ports from `{}` must either be all integer indices or all identifiers.",
232 op_name
233 );
234 diagnostics.extend(
235 output_ports
236 .iter()
237 .map(|output_port| Diagnostic::spanned(output_port.span(), Level::Error, msg)),
238 );
239 Err(())
240 }
241 (true, false) => {
242 let max_port_idx = ports_numeric.last().unwrap().value;
243 if usize::try_from(max_port_idx).unwrap() >= ports_numeric.len() {
244 let mut expected = 0;
245 for port_numeric in ports_numeric {
246 if expected != port_numeric.value {
247 diagnostics.push(Diagnostic::spanned(
248 port_numeric.span,
249 Level::Error,
250 format!(
251 "Output port indices from `{}` must be consecutive from zero, missing {}.",
252 op_name, expected
253 ),
254 ));
255 }
256 expected = port_numeric.value + 1;
257 }
258 }
261 Ok(None)
262 }
263 (false, true) => Ok(Some(ports_idents)),
264 }
265}
266
267fn extract_closure_idents(
269 func: &mut Expr,
270 op_name: &'static str,
271) -> Result<(Vec<Ident>, Span), Diagnostic> {
272 let Expr::Closure(func) = func else {
273 return Err(Diagnostic::spanned(
274 func.span(),
275 Level::Error,
276 "Argument must be a two-argument closure expression",
277 ));
278 };
279 if 2 != func.inputs.len() {
280 return Err(Diagnostic::spanned(
281 func.inputs.span(),
282 Level::Error,
283 &*format!(
284 "Closure provided to `{}(..)` must have two arguments: \
285 the first argument is the item, and for named ports the second argument must contain a Rust 'slice pattern' to determine the port names and order. \
286 For example, the second argument could be `[foo, bar, baz]` for ports `foo`, `bar`, and `baz`.",
287 op_name
288 ),
289 ));
290 }
291
292 let mut arg2 = &mut func.inputs[1];
294 let mut already_has_type = false;
295 if let Pat::Type(pat_type) = arg2 {
296 arg2 = &mut *pat_type.pat;
297 already_has_type = true;
298 }
299
300 let arg2_span = arg2.span();
301 if let Pat::Ident(pat_ident) = arg2 {
302 arg2 = &mut *pat_ident
303 .subpat
304 .as_mut()
305 .ok_or_else(|| Diagnostic::spanned(
306 arg2_span,
307 Level::Error,
308 format!(
309 "Second argument for the `{}` closure must contain a Rust 'slice pattern' to determine the port names and order. \
310 For example: `arr @ [foo, bar, baz]` for ports `foo`, `bar`, and `baz`.",
311 op_name
312 )
313 ))?
314 .1;
315 }
316 let Pat::Slice(pat_slice) = arg2 else {
317 return Err(Diagnostic::spanned(
318 arg2_span,
319 Level::Error,
320 format!(
321 "Second argument for the `{}` closure must have a Rust 'slice pattern' to determine the port names and order. \
322 For example: `[foo, bar, baz]` for ports `foo`, `bar`, and `baz`.",
323 op_name
324 )
325 ));
326 };
327
328 let idents = pat_slice
329 .elems
330 .iter()
331 .map(|pat| {
332 let Pat::Ident(pat_ident) = pat else {
333 panic!("TODO(mingwei) expected ident pat");
334 };
335 pat_ident.ident.clone()
336 })
337 .collect();
338
339 if !already_has_type {
341 let len = LitInt::new(&pat_slice.elems.len().to_string(), arg2_span);
342 *arg2 = Pat::Type(PatType {
343 attrs: vec![],
344 pat: Box::new(arg2.clone()),
345 colon_token: Colon { spans: [arg2_span] },
346 ty: parse_quote_spanned! {arg2_span=> [usize; #len] },
347 });
348 }
349
350 Ok((idents, arg2_span))
351}
352
353fn check_closure_ports_match(
355 closure_idents: &[Ident],
356 port_idents: &[Ident],
357 op_name: &'static str,
358 arg2_span: Span,
359 diagnostics: &mut Vec<Diagnostic>,
360) -> Result<(), ()> {
361 let mut err = false;
362 for port_ident in port_idents {
363 if !closure_idents.contains(port_ident) {
364 err = true;
366 diagnostics.push(Diagnostic::spanned(
367 arg2_span,
368 Level::Error,
369 format!(
370 "Argument specifying the output ports in `{0}(..)` does not contain extra port `{1}`: ({2}) (1/2).",
371 op_name, port_ident, PrettySpan(port_ident.span()),
372 ),
373 ));
374 diagnostics.push(Diagnostic::spanned(
375 port_ident.span(),
376 Level::Error,
377 format!(
378 "Port `{1}` not found in the arguments specified in `{0}(..)`'s closure: ({2}) (2/2).",
379 op_name, port_ident, PrettySpan(arg2_span),
380 ),
381 ));
382 }
383 }
384 for closure_ident in closure_idents {
385 if !port_idents.contains(closure_ident) {
386 err = true;
388 diagnostics.push(Diagnostic::spanned(
389 closure_ident.span(),
390 Level::Error,
391 format!(
392 "`{}(..)` closure argument `{}` missing corresponding output port.",
393 op_name, closure_ident,
394 ),
395 ));
396 }
397 }
398 (!err).then_some(()).ok_or(())
399}