dfir_macro/
lib.rs

1#![cfg_attr(
2    nightly,
3    feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::{Diagnostic, Level};
7use dfir_lang::graph::{FlatGraphBuilder, build_hfcode, partition_graph};
8use dfir_lang::parse::DfirCode;
9use proc_macro2::{Ident, Literal, Span};
10use quote::{format_ident, quote};
11use syn::{
12    Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
13    parse_quote,
14};
15
16/// Create a runnable graph instance using DFIR's custom syntax.
17///
18/// For example usage, take a look at the [`surface_*` tests in the `tests` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/tests)
19/// or the [`examples` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/examples)
20/// in the [Hydro repo](https://github.com/hydro-project/hydro).
21// TODO(mingwei): rustdoc examples inline.
22#[proc_macro]
23pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24    dfir_syntax_internal(input, Some(Level::Help))
25}
26
27/// [`dfir_syntax!`] but will not emit any diagnostics (errors, warnings, etc.).
28///
29/// Used for testing, users will want to use [`dfir_syntax!`] instead.
30#[proc_macro]
31pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
32    dfir_syntax_internal(input, None)
33}
34
35fn root() -> proc_macro2::TokenStream {
36    use std::env::{VarError, var as env_var};
37
38    let root_crate_name = format!(
39        "{}_rs",
40        env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
41    );
42    let root_crate_ident = root_crate_name.replace('-', "_");
43    let root_crate = proc_macro_crate::crate_name(&root_crate_name)
44        .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
45    match root_crate {
46        proc_macro_crate::FoundCrate::Itself => {
47            if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
48                && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
49                && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
50            {
51                // In the crate itself, including unit tests.
52                quote! { crate }
53            } else {
54                // In an integration test, example, bench, etc.
55                let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
56                quote! { ::#ident }
57            }
58        }
59        proc_macro_crate::FoundCrate::Name(name) => {
60            let ident = Ident::new(&name, Span::call_site());
61            quote! { ::#ident }
62        }
63    }
64}
65
66fn dfir_syntax_internal(
67    input: proc_macro::TokenStream,
68    min_diagnostic_level: Option<Level>,
69) -> proc_macro::TokenStream {
70    let input = parse_macro_input!(input as DfirCode);
71    let root = root();
72    let (graph_code_opt, diagnostics) = build_hfcode(input, &root);
73    let tokens = graph_code_opt
74        .map(|(_graph, code)| code)
75        .unwrap_or_else(|| quote! { #root::scheduled::graph::Dfir::new() });
76
77    let diagnostics = diagnostics
78        .iter()
79        .filter(|diag: &&Diagnostic| Some(diag.level) <= min_diagnostic_level);
80
81    let diagnostic_tokens = Diagnostic::try_emit_all(diagnostics)
82        .err()
83        .unwrap_or_default();
84    quote! {
85        {
86            #diagnostic_tokens
87            #tokens
88        }
89    }
90    .into()
91}
92
93/// Parse DFIR syntax without emitting code.
94///
95/// Used for testing, users will want to use [`dfir_syntax!`] instead.
96#[proc_macro]
97pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
98    let input = parse_macro_input!(input as DfirCode);
99
100    let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
101    let (mut flat_graph, _uses, mut diagnostics) = flat_graph_builder.build();
102    if !diagnostics.iter().any(Diagnostic::is_error) {
103        if let Err(diagnostic) = flat_graph.merge_modules() {
104            diagnostics.push(diagnostic);
105        } else {
106            let flat_mermaid = flat_graph.mermaid_string_flat();
107
108            let part_graph = partition_graph(flat_graph).unwrap();
109            let part_mermaid = part_graph.to_mermaid(&Default::default());
110
111            let lit0 = Literal::string(&flat_mermaid);
112            let lit1 = Literal::string(&part_mermaid);
113
114            return quote! {
115                {
116                    println!("{}\n\n{}\n", #lit0, #lit1);
117                }
118            }
119            .into();
120        }
121    }
122
123    Diagnostic::try_emit_all(diagnostics.iter())
124        .err()
125        .unwrap_or_default()
126        .into()
127}
128
129fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
130    use quote::ToTokens;
131
132    let root = root();
133
134    let mut input: syn::ItemFn = match syn::parse(item) {
135        Ok(it) => it,
136        Err(e) => return e.into_compile_error().into(),
137    };
138
139    let statements = input.block.stmts;
140
141    input.block.stmts = parse_quote!(
142        #root::tokio::task::LocalSet::new().run_until(async {
143            #( #statements )*
144        }).await
145    );
146
147    input.attrs.push(attribute);
148
149    input.into_token_stream().into()
150}
151
152/// Checks that the given closure is a morphism. For now does nothing.
153#[proc_macro]
154pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
155    // TODO(mingwei): some sort of code analysis?
156    item
157}
158
159/// Checks that the given closure is a monotonic function. For now does nothing.
160#[proc_macro]
161pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
162    // TODO(mingwei): some sort of code analysis?
163    item
164}
165
166#[proc_macro_attribute]
167pub fn dfir_test(
168    args: proc_macro::TokenStream,
169    item: proc_macro::TokenStream,
170) -> proc_macro::TokenStream {
171    let root = root();
172    let args_2: proc_macro2::TokenStream = args.into();
173
174    wrap_localset(
175        item,
176        parse_quote!(
177            #[#root::tokio::test(flavor = "current_thread", #args_2)]
178        ),
179    )
180}
181
182#[proc_macro_attribute]
183pub fn dfir_main(
184    _: proc_macro::TokenStream,
185    item: proc_macro::TokenStream,
186) -> proc_macro::TokenStream {
187    let root = root();
188
189    wrap_localset(
190        item,
191        parse_quote!(
192            #[#root::tokio::main(flavor = "current_thread")]
193        ),
194    )
195}
196
197#[proc_macro_derive(DemuxEnum)]
198pub fn derive_answer_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
199    let root = root();
200
201    let ItemEnum {
202        ident: item_ident,
203        generics,
204        variants,
205        ..
206    } = parse_macro_input!(item as ItemEnum);
207
208    // Sort variants alphabetically.
209    let mut variants = variants.into_iter().collect::<Vec<_>>();
210    variants.sort_by(|a, b| a.ident.cmp(&b.ident));
211
212    let variant_pusherator_generics = variants
213        .iter()
214        .map(|variant| format_ident!("__Pusherator{}", variant.ident))
215        .collect::<Vec<_>>();
216    let variant_pusherator_localvars = variants
217        .iter()
218        .map(|variant| {
219            format_ident!(
220                "__pusherator_{}",
221                variant.ident.to_string().to_lowercase(),
222                span = variant.ident.span()
223            )
224        })
225        .collect::<Vec<_>>();
226    let variant_output_types = variants
227        .iter()
228        .map(|variant| match &variant.fields {
229            Fields::Named(fields) => {
230                let field_types = fields.named.iter().map(|field| &field.ty);
231                quote! {
232                    ( #( #field_types, )* )
233                }
234            }
235            Fields::Unnamed(fields) => {
236                let field_types = fields.unnamed.iter().map(|field| &field.ty);
237                quote! {
238                    ( #( #field_types, )* )
239                }
240            }
241            Fields::Unit => quote!(()),
242        })
243        .collect::<Vec<_>>();
244
245    let mut full_generics = generics.clone();
246    full_generics.params.extend(
247        variant_pusherator_generics
248            .iter()
249            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
250    );
251    full_generics.make_where_clause().predicates.extend(
252        variant_pusherator_generics
253            .iter()
254            .zip(variant_output_types.iter())
255            .map::<WherePredicate, _>(|(pusherator_generic, output_type)| {
256                parse_quote! {
257                    #pusherator_generic: #root::pusherator::Pusherator<Item = #output_type>
258                }
259            }),
260    );
261
262    let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
263    let (impl_generics, _ty_generics, where_clause) = full_generics.split_for_impl();
264
265    let variant_pats = variants
266        .iter()
267        .zip(variant_pusherator_localvars.iter())
268        .map(|(variant, pushvar)| {
269            let Variant { ident, fields, .. } = variant;
270            let (fields_pat, push_item) = field_pattern_item(fields);
271            quote! {
272                Self::#ident #fields_pat => #pushvar.give(#push_item)
273            }
274        });
275
276    let single_impl = (1 == variants.len()).then(|| {
277        let Variant { ident, fields, .. } = variants.first().unwrap();
278        let (fields_pat, push_item) = field_pattern_item(fields);
279        let out_type = variant_output_types.first().unwrap();
280        quote! {
281            impl #impl_generics_item #root::util::demux_enum::SingleVariant
282                for #item_ident #ty_generics #where_clause_item
283            {
284                type Output = #out_type;
285                fn single_variant(self) -> Self::Output {
286                    match self {
287                        Self::#ident #fields_pat => #push_item,
288                    }
289                }
290            }
291        }
292    });
293
294    quote! {
295        impl #impl_generics #root::util::demux_enum::DemuxEnum<( #( #variant_pusherator_generics, )* )>
296            for #item_ident #ty_generics #where_clause
297        {
298            fn demux_enum(
299                self,
300                ( #( #variant_pusherator_localvars, )* ):
301                    &mut ( #( #variant_pusherator_generics, )* )
302            ) {
303                match self {
304                    #( #variant_pats, )*
305                }
306            }
307        }
308
309        impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
310            for #item_ident #ty_generics #where_clause_item {}
311
312        #single_impl
313    }
314    .into()
315}
316
317/// (fields pattern, push item expr)
318fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
319    let idents = fields
320        .iter()
321        .enumerate()
322        .map(|(i, field)| {
323            field
324                .ident
325                .clone()
326                .unwrap_or_else(|| format_ident!("_{}", i))
327        })
328        .collect::<Vec<_>>();
329    let (fields_pat, push_item) = match fields {
330        Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
331        Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
332        Fields::Unit => (quote!(), quote!(())),
333    };
334    (fields_pat, push_item)
335}