dfir_lang/graph/ops/
demux_enum.rs
1use proc_macro2::Ident;
2use quote::{quote, quote_spanned, ToTokens};
3use syn::spanned::Spanned;
4use syn::{PathArguments, PathSegment, Token, Type, TypePath};
5
6use super::{
7 OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
8 OperatorWriteOutput, PortIndexValue, PortListSpec, WriteContextArgs, RANGE_0, RANGE_1,
9};
10use crate::diagnostic::{Diagnostic, Level};
11use crate::graph::change_spans;
12
13pub const DEMUX_ENUM: OperatorConstraints = OperatorConstraints {
44 name: "demux_enum",
45 categories: &[OperatorCategory::MultiOut],
46 hard_range_inn: RANGE_1,
47 soft_range_inn: RANGE_1,
48 hard_range_out: &(..),
49 soft_range_out: &(..),
50 num_args: 0,
51 persistence_args: RANGE_0,
52 type_args: RANGE_1,
53 is_external_input: false,
54 has_singleton_output: false,
55 flo_type: None,
56 ports_inn: None,
57 ports_out: Some(|| PortListSpec::Variadic),
58 input_delaytype_fn: |_| None,
59 write_fn: |&WriteContextArgs {
60 root,
61 op_span,
62 ident,
63 inputs,
64 outputs,
65 is_pull,
66 op_name,
67 op_inst:
68 OperatorInstance {
69 output_ports,
70 generics: OpInstGenerics { type_args, .. },
71 ..
72 },
73 ..
74 },
75 diagnostics| {
76 let enum_type = &type_args[0];
77
78 let port_idents: Vec<_> = output_ports
80 .iter()
81 .filter_map(|output_port| {
82 let PortIndexValue::Path(port_expr) = output_port else {
83 diagnostics.push(Diagnostic::spanned(
84 output_port.span(),
85 Level::Error,
86 format!(
87 "Output port from `{}(..)` must be specified and must be a valid identifier.",
88 op_name,
89 ),
90 ));
91 return None;
92 };
93 let port_ident = syn::parse2::<Ident>(quote! { #port_expr })
94 .map_err(|err| diagnostics.push(err.into()))
95 .ok()?;
96
97 Some(port_ident)
98 })
99 .collect();
100
101 let enum_type_turbofish = ensure_turbofish(enum_type);
106 let port_variant_check_match_arms = port_idents
107 .iter()
108 .map(|port_ident| {
109 let enum_type_turbofish =
110 change_spans(enum_type_turbofish.to_token_stream(), port_ident.span());
111 quote_spanned! {port_ident.span()=>
112 #enum_type_turbofish::#port_ident { .. } => ()
113 }
114 })
115 .collect::<Vec<_>>();
116 let root_span = change_spans(root.clone(), enum_type.span());
117 let write_prologue = quote_spanned! {enum_type.span()=>
118 let _ = |__val: #enum_type| {
119 fn check_impl_demux_enum<T: ?Sized + #root_span::util::demux_enum::DemuxEnumBase>(_: &T) {}
120 check_impl_demux_enum(&__val);
121 match __val {
122 #(
123 #port_variant_check_match_arms,
124 )*
125 };
126 };
127 };
128
129 let write_iterator = if 1 == outputs.len() {
130 let map_fn = quote_spanned! {enum_type.span()=>
132 <#enum_type as #root::util::demux_enum::SingleVariant>::single_variant
133 };
134 if is_pull {
135 let input = &inputs[0];
136 quote_spanned! {op_span=>
137 let #ident = #input.map(#map_fn);
138 }
139 } else {
140 let output = &outputs[0];
141 quote_spanned! {op_span=>
142 let #ident = #root::pusherator::map::Map::new(#map_fn, #output);
143 }
144 }
145 } else {
146 assert!(!is_pull);
147
148 let mut sort_permute: Vec<_> = (0..port_idents.len()).collect();
149 sort_permute.sort_by_key(|&i| &port_idents[i]);
150
151 let sorted_outputs = sort_permute.iter().map(|&i| &outputs[i]);
152
153 quote_spanned! {op_span=>
154 let #ident = {
155 let mut __outputs = ( #( #sorted_outputs, )* );
156 #root::pusherator::for_each::ForEach::new(move |__item: #enum_type| {
157 #root::util::demux_enum::DemuxEnum::demux_enum(
158 __item,
159 &mut __outputs,
160 );
161 })
162 };
163 }
164 };
165
166 Ok(OperatorWriteOutput {
167 write_prologue,
168 write_iterator,
169 ..Default::default()
170 })
171 },
172};
173
174fn ensure_turbofish(ty: &Type) -> Type {
177 let mut ty = ty.clone();
178 if let Type::Path(TypePath { qself: _, path }) = &mut ty {
180 if let Some(PathSegment {
182 ident: _,
183 arguments: PathArguments::AngleBracketed(angle_bracketed),
184 }) = path.segments.last_mut()
185 {
186 angle_bracketed.colon2_token = Some(<Token![::]>::default());
188 }
189 };
190 ty
191}