dfir_lang/graph/ops/
demux_enum.rs

1use proc_macro2::Ident;
2use quote::{quote, quote_spanned, ToTokens};
3use syn::spanned::Spanned;
4use syn::{PathArguments, PathSegment, Token, Type, TypePath};
5
6use super::{
7    OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance,
8    OperatorWriteOutput, PortIndexValue, PortListSpec, WriteContextArgs, RANGE_0, RANGE_1,
9};
10use crate::diagnostic::{Diagnostic, Level};
11use crate::graph::change_spans;
12
13/// > Generic Argument: A enum type which has `#[derive(DemuxEnum)]`. Must match the items in the input stream.
14///
15/// Takes an input stream of enum instances and splits them into their variants.
16///
17/// ```rustdoc
18/// #[derive(DemuxEnum)]
19/// enum Shape {
20///     Square(f64),
21///     Rectangle(f64, f64),
22///     Circle { r: f64 },
23///     Triangle { w: f64, h: f64 }
24/// }
25///
26/// let mut df = dfir_syntax! {
27///     my_demux = source_iter([
28///         Shape::Square(9.0),
29///         Shape::Rectangle(10.0, 8.0),
30///         Shape::Circle { r: 5.0 },
31///         Shape::Triangle { w: 12.0, h: 13.0 },
32///     ]) -> demux_enum::<Shape>();
33///
34///     my_demux[Square] -> map(|s| s * s) -> out;
35///     my_demux[Circle] -> map(|(r,)| std::f64::consts::PI * r * r) -> out;
36///     my_demux[Rectangle] -> map(|(w, h)| w * h) -> out;
37///     my_demux[Circle] -> map(|(w, h)| 0.5 * w * h) -> out;
38///
39///     out = union() -> for_each(|area| println!("Area: {}", area));
40/// };
41/// df.run_available();
42/// ```
43pub const DEMUX_ENUM: OperatorConstraints = OperatorConstraints {
44    name: "demux_enum",
45    categories: &[OperatorCategory::MultiOut],
46    hard_range_inn: RANGE_1,
47    soft_range_inn: RANGE_1,
48    hard_range_out: &(..),
49    soft_range_out: &(..),
50    num_args: 0,
51    persistence_args: RANGE_0,
52    type_args: RANGE_1,
53    is_external_input: false,
54    has_singleton_output: false,
55    flo_type: None,
56    ports_inn: None,
57    ports_out: Some(|| PortListSpec::Variadic),
58    input_delaytype_fn: |_| None,
59    write_fn: |&WriteContextArgs {
60                   root,
61                   op_span,
62                   ident,
63                   inputs,
64                   outputs,
65                   is_pull,
66                   op_name,
67                   op_inst:
68                       OperatorInstance {
69                           output_ports,
70                           generics: OpInstGenerics { type_args, .. },
71                           ..
72                       },
73                   ..
74               },
75               diagnostics| {
76        let enum_type = &type_args[0];
77
78        // Port idents supplied via port connections in the surface syntax.
79        let port_idents: Vec<_> = output_ports
80                    .iter()
81                    .filter_map(|output_port| {
82                        let PortIndexValue::Path(port_expr) = output_port else {
83                            diagnostics.push(Diagnostic::spanned(
84                                output_port.span(),
85                                Level::Error,
86                                format!(
87                                    "Output port from `{}(..)` must be specified and must be a valid identifier.",
88                                    op_name,
89                                ),
90                            ));
91                            return None;
92                        };
93                        let port_ident = syn::parse2::<Ident>(quote! { #port_expr })
94                            .map_err(|err| diagnostics.push(err.into()))
95                            .ok()?;
96
97                        Some(port_ident)
98                    })
99                    .collect();
100
101        // The entire purpose of this closure and match statement is to generate readable error messages:
102        // "missing match arm: `Variant(_)` not covered."
103        // Or "no variant named `Variant` found for enum `Shape`"
104        // Note this uses the `enum_type`'s span.
105        let enum_type_turbofish = ensure_turbofish(enum_type);
106        let port_variant_check_match_arms = port_idents
107            .iter()
108            .map(|port_ident| {
109                let enum_type_turbofish =
110                    change_spans(enum_type_turbofish.to_token_stream(), port_ident.span());
111                quote_spanned! {port_ident.span()=>
112                    #enum_type_turbofish::#port_ident { .. } => ()
113                }
114            })
115            .collect::<Vec<_>>();
116        let root_span = change_spans(root.clone(), enum_type.span());
117        let write_prologue = quote_spanned! {enum_type.span()=>
118            let _ = |__val: #enum_type| {
119                fn check_impl_demux_enum<T: ?Sized + #root_span::util::demux_enum::DemuxEnumBase>(_: &T) {}
120                check_impl_demux_enum(&__val);
121                match __val {
122                    #(
123                        #port_variant_check_match_arms,
124                    )*
125                };
126            };
127        };
128
129        let write_iterator = if 1 == outputs.len() {
130            // Use `enum_type`'s span.
131            let map_fn = quote_spanned! {enum_type.span()=>
132                <#enum_type as #root::util::demux_enum::SingleVariant>::single_variant
133            };
134            if is_pull {
135                let input = &inputs[0];
136                quote_spanned! {op_span=>
137                    let #ident = #input.map(#map_fn);
138                }
139            } else {
140                let output = &outputs[0];
141                quote_spanned! {op_span=>
142                    let #ident = #root::pusherator::map::Map::new(#map_fn, #output);
143                }
144            }
145        } else {
146            assert!(!is_pull);
147
148            let mut sort_permute: Vec<_> = (0..port_idents.len()).collect();
149            sort_permute.sort_by_key(|&i| &port_idents[i]);
150
151            let sorted_outputs = sort_permute.iter().map(|&i| &outputs[i]);
152
153            quote_spanned! {op_span=>
154                let #ident = {
155                    let mut __outputs = ( #( #sorted_outputs, )* );
156                    #root::pusherator::for_each::ForEach::new(move |__item: #enum_type| {
157                        #root::util::demux_enum::DemuxEnum::demux_enum(
158                            __item,
159                            &mut __outputs,
160                        );
161                    })
162                };
163            }
164        };
165
166        Ok(OperatorWriteOutput {
167            write_prologue,
168            write_iterator,
169            ..Default::default()
170        })
171    },
172};
173
174/// Ensure enum type has double colon turbofish syntax.
175/// `my_mod::MyType<MyGeneric>` becomes `my_mod::MyType::<MyGeneric>`.
176fn ensure_turbofish(ty: &Type) -> Type {
177    let mut ty = ty.clone();
178    // If type is path.
179    if let Type::Path(TypePath { qself: _, path }) = &mut ty {
180        // If path ends in angle bracketed generics.
181        if let Some(PathSegment {
182            ident: _,
183            arguments: PathArguments::AngleBracketed(angle_bracketed),
184        }) = path.segments.last_mut()
185        {
186            // Ensure the final turbofish double-colon is set.
187            angle_bracketed.colon2_token = Some(<Token![::]>::default());
188        }
189    };
190    ty
191}