dfir_macro/
lib.rs

1#![cfg_attr(
2    nightly,
3    feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::{Diagnostic, Level};
7use dfir_lang::graph::{FlatGraphBuilder, build_hfcode, partition_graph};
8use dfir_lang::parse::DfirCode;
9use proc_macro2::{Ident, Literal, Span};
10use quote::{format_ident, quote};
11use syn::{
12    Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
13    parse_quote,
14};
15
16/// Create a runnable graph instance using DFIR's custom syntax.
17///
18/// For example usage, take a look at the [`surface_*` tests in the `tests` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/tests)
19/// or the [`examples` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/examples)
20/// in the [Hydro repo](https://github.com/hydro-project/hydro).
21// TODO(mingwei): rustdoc examples inline.
22#[proc_macro]
23pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24    dfir_syntax_internal(input, Some(Level::Help))
25}
26
27/// [`dfir_syntax!`] but will not emit any diagnostics (errors, warnings, etc.).
28///
29/// Used for testing, users will want to use [`dfir_syntax!`] instead.
30#[proc_macro]
31pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
32    dfir_syntax_internal(input, None)
33}
34
35fn root() -> proc_macro2::TokenStream {
36    use std::env::{VarError, var as env_var};
37
38    let root_crate =
39        proc_macro_crate::crate_name("dfir_rs").expect("dfir_rs should be present in `Cargo.toml`");
40    match root_crate {
41        proc_macro_crate::FoundCrate::Itself => {
42            if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
43                && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
44                && Ok("dfir_rs") == env_var("CARGO_CRATE_NAME").as_deref()
45            {
46                // In the crate itself, including unit tests.
47                quote! { crate }
48            } else {
49                // In an integration test, example, bench, etc.
50                quote! { ::dfir_rs }
51            }
52        }
53        proc_macro_crate::FoundCrate::Name(name) => {
54            let ident: Ident = Ident::new(&name, Span::call_site());
55            quote! { ::#ident }
56        }
57    }
58}
59
60fn dfir_syntax_internal(
61    input: proc_macro::TokenStream,
62    min_diagnostic_level: Option<Level>,
63) -> proc_macro::TokenStream {
64    let input = parse_macro_input!(input as DfirCode);
65    let root = root();
66    let (graph_code_opt, diagnostics) = build_hfcode(input, &root);
67    let tokens = graph_code_opt
68        .map(|(_graph, code)| code)
69        .unwrap_or_else(|| quote! { #root::scheduled::graph::Dfir::new() });
70
71    let diagnostics = diagnostics
72        .iter()
73        .filter(|diag: &&Diagnostic| Some(diag.level) <= min_diagnostic_level);
74
75    let diagnostic_tokens = Diagnostic::try_emit_all(diagnostics)
76        .err()
77        .unwrap_or_default();
78    quote! {
79        {
80            #diagnostic_tokens
81            #tokens
82        }
83    }
84    .into()
85}
86
87/// Parse DFIR syntax without emitting code.
88///
89/// Used for testing, users will want to use [`dfir_syntax!`] instead.
90#[proc_macro]
91pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
92    let input = parse_macro_input!(input as DfirCode);
93
94    let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
95    let (mut flat_graph, _uses, mut diagnostics) = flat_graph_builder.build();
96    if !diagnostics.iter().any(Diagnostic::is_error) {
97        if let Err(diagnostic) = flat_graph.merge_modules() {
98            diagnostics.push(diagnostic);
99        } else {
100            let flat_mermaid = flat_graph.mermaid_string_flat();
101
102            let part_graph = partition_graph(flat_graph).unwrap();
103            let part_mermaid = part_graph.to_mermaid(&Default::default());
104
105            let lit0 = Literal::string(&flat_mermaid);
106            let lit1 = Literal::string(&part_mermaid);
107
108            return quote! {
109                {
110                    println!("{}\n\n{}\n", #lit0, #lit1);
111                }
112            }
113            .into();
114        }
115    }
116
117    Diagnostic::try_emit_all(diagnostics.iter())
118        .err()
119        .unwrap_or_default()
120        .into()
121}
122
123fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
124    use quote::ToTokens;
125
126    let root = root();
127
128    let mut input: syn::ItemFn = match syn::parse(item) {
129        Ok(it) => it,
130        Err(e) => return e.into_compile_error().into(),
131    };
132
133    let statements = input.block.stmts;
134
135    input.block.stmts = parse_quote!(
136        #root::tokio::task::LocalSet::new().run_until(async {
137            #( #statements )*
138        }).await
139    );
140
141    input.attrs.push(attribute);
142
143    input.into_token_stream().into()
144}
145
146/// Checks that the given closure is a morphism. For now does nothing.
147#[proc_macro]
148pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
149    // TODO(mingwei): some sort of code analysis?
150    item
151}
152
153/// Checks that the given closure is a monotonic function. For now does nothing.
154#[proc_macro]
155pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
156    // TODO(mingwei): some sort of code analysis?
157    item
158}
159
160#[proc_macro_attribute]
161pub fn dfir_test(
162    args: proc_macro::TokenStream,
163    item: proc_macro::TokenStream,
164) -> proc_macro::TokenStream {
165    let root = root();
166    let args_2: proc_macro2::TokenStream = args.into();
167
168    wrap_localset(
169        item,
170        parse_quote!(
171            #[#root::tokio::test(flavor = "current_thread", #args_2)]
172        ),
173    )
174}
175
176#[proc_macro_attribute]
177pub fn dfir_main(
178    _: proc_macro::TokenStream,
179    item: proc_macro::TokenStream,
180) -> proc_macro::TokenStream {
181    let root = root();
182
183    wrap_localset(
184        item,
185        parse_quote!(
186            #[#root::tokio::main(flavor = "current_thread")]
187        ),
188    )
189}
190
191#[proc_macro_derive(DemuxEnum)]
192pub fn derive_answer_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
193    let root = root();
194
195    let ItemEnum {
196        ident: item_ident,
197        generics,
198        variants,
199        ..
200    } = parse_macro_input!(item as ItemEnum);
201
202    // Sort variants alphabetically.
203    let mut variants = variants.into_iter().collect::<Vec<_>>();
204    variants.sort_by(|a, b| a.ident.cmp(&b.ident));
205
206    let variant_pusherator_generics = variants
207        .iter()
208        .map(|variant| format_ident!("__Pusherator{}", variant.ident))
209        .collect::<Vec<_>>();
210    let variant_pusherator_localvars = variants
211        .iter()
212        .map(|variant| {
213            format_ident!(
214                "__pusherator_{}",
215                variant.ident.to_string().to_lowercase(),
216                span = variant.ident.span()
217            )
218        })
219        .collect::<Vec<_>>();
220    let variant_output_types = variants
221        .iter()
222        .map(|variant| match &variant.fields {
223            Fields::Named(fields) => {
224                let field_types = fields.named.iter().map(|field| &field.ty);
225                quote! {
226                    ( #( #field_types, )* )
227                }
228            }
229            Fields::Unnamed(fields) => {
230                let field_types = fields.unnamed.iter().map(|field| &field.ty);
231                quote! {
232                    ( #( #field_types, )* )
233                }
234            }
235            Fields::Unit => quote!(()),
236        })
237        .collect::<Vec<_>>();
238
239    let mut full_generics = generics.clone();
240    full_generics.params.extend(
241        variant_pusherator_generics
242            .iter()
243            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
244    );
245    full_generics.make_where_clause().predicates.extend(
246        variant_pusherator_generics
247            .iter()
248            .zip(variant_output_types.iter())
249            .map::<WherePredicate, _>(|(pusherator_generic, output_type)| {
250                parse_quote! {
251                    #pusherator_generic: #root::pusherator::Pusherator<Item = #output_type>
252                }
253            }),
254    );
255
256    let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
257    let (impl_generics, _ty_generics, where_clause) = full_generics.split_for_impl();
258
259    let variant_pats = variants
260        .iter()
261        .zip(variant_pusherator_localvars.iter())
262        .map(|(variant, pushvar)| {
263            let Variant { ident, fields, .. } = variant;
264            let (fields_pat, push_item) = field_pattern_item(fields);
265            quote! {
266                Self::#ident #fields_pat => #pushvar.give(#push_item)
267            }
268        });
269
270    let single_impl = (1 == variants.len()).then(|| {
271        let Variant { ident, fields, .. } = variants.first().unwrap();
272        let (fields_pat, push_item) = field_pattern_item(fields);
273        let out_type = variant_output_types.first().unwrap();
274        quote! {
275            impl #impl_generics_item #root::util::demux_enum::SingleVariant
276                for #item_ident #ty_generics #where_clause_item
277            {
278                type Output = #out_type;
279                fn single_variant(self) -> Self::Output {
280                    match self {
281                        Self::#ident #fields_pat => #push_item,
282                    }
283                }
284            }
285        }
286    });
287
288    quote! {
289        impl #impl_generics #root::util::demux_enum::DemuxEnum<( #( #variant_pusherator_generics, )* )>
290            for #item_ident #ty_generics #where_clause
291        {
292            fn demux_enum(
293                self,
294                ( #( #variant_pusherator_localvars, )* ):
295                    &mut ( #( #variant_pusherator_generics, )* )
296            ) {
297                match self {
298                    #( #variant_pats, )*
299                }
300            }
301        }
302
303        impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
304            for #item_ident #ty_generics #where_clause_item {}
305
306        #single_impl
307    }
308    .into()
309}
310
311/// (fields pattern, push item expr)
312fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
313    let idents = fields
314        .iter()
315        .enumerate()
316        .map(|(i, field)| {
317            field
318                .ident
319                .clone()
320                .unwrap_or_else(|| format_ident!("_{}", i))
321        })
322        .collect::<Vec<_>>();
323    let (fields_pat, push_item) = match fields {
324        Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
325        Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
326        Fields::Unit => (quote!(), quote!(())),
327    };
328    (fields_pat, push_item)
329}