dfir_lang/graph/ops/
demux.rs
1use std::collections::HashMap;
2
3use proc_macro2::{Ident, TokenTree};
4use quote::{quote_spanned, ToTokens};
5use syn::spanned::Spanned;
6use syn::{Expr, Pat};
7
8use super::{
9 OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
10 PortIndexValue, PortListSpec, WriteContextArgs, RANGE_0, RANGE_1,
11};
12use crate::diagnostic::{Diagnostic, Level};
13use crate::pretty_span::PrettySpan;
14
15pub const DEMUX: OperatorConstraints = OperatorConstraints {
45 name: "demux",
46 categories: &[OperatorCategory::MultiOut],
47 hard_range_inn: RANGE_1,
48 soft_range_inn: RANGE_1,
49 hard_range_out: &(2..),
50 soft_range_out: &(2..),
51 num_args: 1,
52 persistence_args: RANGE_0,
53 type_args: RANGE_0,
54 is_external_input: false,
55 has_singleton_output: false,
56 flo_type: None,
57 ports_inn: None,
58 ports_out: Some(|| PortListSpec::Variadic),
59 input_delaytype_fn: |_| None,
60 write_fn: |&WriteContextArgs {
61 root,
62 op_span,
63 ident,
64 outputs,
65 is_pull,
66 op_name,
67 op_inst: OperatorInstance { output_ports, .. },
68 arguments,
69 ..
70 },
71 diagnostics| {
72 assert!(!is_pull);
73 let func = &arguments[0];
74 let Expr::Closure(func) = func else {
75 diagnostics.push(Diagnostic::spanned(
76 func.span(),
77 Level::Error,
78 "Argument must be a two-argument closure expression",
79 ));
80 return Err(());
81 };
82 if 2 != func.inputs.len() {
83 diagnostics.push(Diagnostic::spanned(
84 func.inputs.span(),
85 Level::Error,
86 &*format!(
87 "Closure provided to `{}(..)` must have two arguments: \
88 the first argument is the item, and the second argument lists ports. \
89 E.g. the second argument could be `var_args!(port_a, port_b, ..)`.",
90 op_name
91 ),
92 ));
93 return Err(());
94 }
95
96 let arg2 = &func.inputs[1];
98 let closure_idents = extract_closure_idents(arg2);
99
100 let port_idents: Vec<_> = output_ports
102 .iter()
103 .filter_map(|output_port| {
104 let PortIndexValue::Path(port_expr) = output_port else {
105 diagnostics.push(Diagnostic::spanned(
106 output_port.span(),
107 Level::Error,
108 format!(
109 "Output port from `{}(..)` must be specified and must be a valid identifier.",
110 op_name,
111 ),
112 ));
113 return None;
114 };
115 let port_ident = syn::parse2::<Ident>(quote_spanned! {op_span=> #port_expr })
116 .map_err(|err| diagnostics.push(err.into()))
117 .ok()?;
118
119 if !closure_idents.contains_key(&port_ident) {
120 diagnostics.push(Diagnostic::spanned(
122 arg2.span(),
123 Level::Error,
124 format!(
125 "Argument specifying the output ports in `{0}(..)` does not contain extra port `{1}`: ({2}) (1/2).",
126 op_name, port_ident, PrettySpan(output_port.span()),
127 ),
128 ));
129 diagnostics.push(Diagnostic::spanned(
130 output_port.span(),
131 Level::Error,
132 format!(
133 "Port `{1}` not found in the arguments specified in `{0}(..)`'s closure: ({2}) (2/2).",
134 op_name, port_ident, PrettySpan(arg2.span()),
135 ),
136 ));
137 return None;
138 }
139
140 Some(port_ident)
141 })
142 .collect();
143
144 for closure_ident in closure_idents.keys() {
145 if !port_idents.contains(closure_ident) {
146 diagnostics.push(Diagnostic::spanned(
147 closure_ident.span(),
148 Level::Error,
149 format!(
150 "`{}(..)` closure argument `{}` missing corresponding output port.",
151 op_name, closure_ident,
152 ),
153 ));
154 }
155 }
156
157 if diagnostics.iter().any(Diagnostic::is_error) {
158 return Err(());
159 }
160
161 assert_eq!(outputs.len(), port_idents.len());
162 assert_eq!(outputs.len(), closure_idents.len());
163
164 let mut sort_permute: Vec<_> = (0..outputs.len()).collect();
165 sort_permute.sort_by_key(|&i| closure_idents[&port_idents[i]]);
166
167 let sorted_outputs = sort_permute.iter().map(|&i| &outputs[i]);
168
169 let write_iterator = quote_spanned! {op_span=>
170 let #ident = {
171 #[allow(unused_imports)] use #root::pusherator::Pusherator;
172 #root::pusherator::demux::Demux::new(#func, #root::var_expr!( #( #sorted_outputs ),* ))
173 };
174 };
175
176 Ok(OperatorWriteOutput {
177 write_iterator,
178 ..Default::default()
179 })
180 },
181};
182
183fn extract_closure_idents(arg2: &Pat) -> HashMap<Ident, usize> {
184 let tokens = if let Pat::Macro(pat_macro) = arg2 {
185 pat_macro.mac.tokens.clone()
186 } else {
187 arg2.to_token_stream()
188 };
189
190 let mut idents = HashMap::new();
191 let mut stack: Vec<_> = tokens.into_iter().collect();
192 stack.reverse();
193 while let Some(tt) = stack.pop() {
194 match tt {
195 TokenTree::Group(group) => {
196 let a = stack.len();
197 stack.extend(group.stream());
198 let b = stack.len();
199 stack[a..b].reverse();
200 }
201 TokenTree::Ident(ident) => {
202 idents.insert(ident, idents.len());
203 }
204 TokenTree::Punct(_) => (),
205 TokenTree::Literal(_) => (),
206 }
207 }
208 idents
209}