dfir_macro/
lib.rs

1#![cfg_attr(
2    nightly,
3    feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::{Diagnostic, Level};
7use dfir_lang::graph::{FlatGraphBuilder, build_hfcode, partition_graph};
8use dfir_lang::parse::DfirCode;
9use proc_macro2::{Ident, Literal, Span};
10use quote::{format_ident, quote, quote_spanned};
11use syn::{
12    Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
13    parse_quote,
14};
15
16/// Create a runnable graph instance using DFIR's custom syntax.
17///
18/// For example usage, take a look at the [`surface_*` tests in the `tests` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/tests)
19/// or the [`examples` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/examples)
20/// in the [Hydro repo](https://github.com/hydro-project/hydro).
21// TODO(mingwei): rustdoc examples inline.
22#[proc_macro]
23pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24    dfir_syntax_internal(input, Some(Level::Help))
25}
26
27/// [`dfir_syntax!`] but will not emit any diagnostics (errors, warnings, etc.).
28///
29/// Used for testing, users will want to use [`dfir_syntax!`] instead.
30#[proc_macro]
31pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
32    dfir_syntax_internal(input, None)
33}
34
35fn root() -> proc_macro2::TokenStream {
36    use std::env::{VarError, var as env_var};
37
38    let root_crate_name = format!(
39        "{}_rs",
40        env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
41    );
42    let root_crate_ident = root_crate_name.replace('-', "_");
43    let root_crate = proc_macro_crate::crate_name(&root_crate_name)
44        .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
45    match root_crate {
46        proc_macro_crate::FoundCrate::Itself => {
47            if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
48                && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
49                && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
50            {
51                // In the crate itself, including unit tests.
52                quote! { crate }
53            } else {
54                // In an integration test, example, bench, etc.
55                let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
56                quote! { ::#ident }
57            }
58        }
59        proc_macro_crate::FoundCrate::Name(name) => {
60            let ident = Ident::new(&name, Span::call_site());
61            quote! { ::#ident }
62        }
63    }
64}
65
66fn dfir_syntax_internal(
67    input: proc_macro::TokenStream,
68    min_diagnostic_level: Option<Level>,
69) -> proc_macro::TokenStream {
70    let input = parse_macro_input!(input as DfirCode);
71    let root = root();
72    let (graph_code_opt, diagnostics) = build_hfcode(input, &root);
73    let tokens = graph_code_opt
74        .map(|(_graph, code)| code)
75        .unwrap_or_else(|| quote! { #root::scheduled::graph::Dfir::new() });
76
77    let diagnostics = diagnostics
78        .iter()
79        .filter(|diag: &&Diagnostic| Some(diag.level) <= min_diagnostic_level);
80
81    let diagnostic_tokens = Diagnostic::try_emit_all(diagnostics)
82        .err()
83        .unwrap_or_default();
84    quote! {
85        {
86            #diagnostic_tokens
87            #tokens
88        }
89    }
90    .into()
91}
92
93/// Parse DFIR syntax without emitting code.
94///
95/// Used for testing, users will want to use [`dfir_syntax!`] instead.
96#[proc_macro]
97pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
98    let input = parse_macro_input!(input as DfirCode);
99
100    let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
101    let (mut flat_graph, _uses, mut diagnostics) = flat_graph_builder.build();
102    if !diagnostics.iter().any(Diagnostic::is_error) {
103        if let Err(diagnostic) = flat_graph.merge_modules() {
104            diagnostics.push(diagnostic);
105        } else {
106            let flat_mermaid = flat_graph.mermaid_string_flat();
107
108            let part_graph = partition_graph(flat_graph).unwrap();
109            let part_mermaid = part_graph.to_mermaid(&Default::default());
110
111            let lit0 = Literal::string(&flat_mermaid);
112            let lit1 = Literal::string(&part_mermaid);
113
114            return quote! {
115                {
116                    println!("{}\n\n{}\n", #lit0, #lit1);
117                }
118            }
119            .into();
120        }
121    }
122
123    Diagnostic::try_emit_all(diagnostics.iter())
124        .err()
125        .unwrap_or_default()
126        .into()
127}
128
129fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
130    use quote::ToTokens;
131
132    let root = root();
133
134    let mut input: syn::ItemFn = match syn::parse(item) {
135        Ok(it) => it,
136        Err(e) => return e.into_compile_error().into(),
137    };
138
139    let statements = input.block.stmts;
140
141    input.block.stmts = parse_quote!(
142        #root::tokio::task::LocalSet::new().run_until(async {
143            #( #statements )*
144        }).await
145    );
146
147    input.attrs.push(attribute);
148
149    input.into_token_stream().into()
150}
151
152/// Checks that the given closure is a morphism. For now does nothing.
153#[proc_macro]
154pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
155    // TODO(mingwei): some sort of code analysis?
156    item
157}
158
159/// Checks that the given closure is a monotonic function. For now does nothing.
160#[proc_macro]
161pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
162    // TODO(mingwei): some sort of code analysis?
163    item
164}
165
166#[proc_macro_attribute]
167pub fn dfir_test(
168    args: proc_macro::TokenStream,
169    item: proc_macro::TokenStream,
170) -> proc_macro::TokenStream {
171    let root = root();
172    let args_2: proc_macro2::TokenStream = args.into();
173
174    wrap_localset(
175        item,
176        parse_quote!(
177            #[#root::tokio::test(flavor = "current_thread", #args_2)]
178        ),
179    )
180}
181
182#[proc_macro_attribute]
183pub fn dfir_main(
184    _: proc_macro::TokenStream,
185    item: proc_macro::TokenStream,
186) -> proc_macro::TokenStream {
187    let root = root();
188
189    wrap_localset(
190        item,
191        parse_quote!(
192            #[#root::tokio::main(flavor = "current_thread")]
193        ),
194    )
195}
196
197#[proc_macro_derive(DemuxEnum)]
198pub fn derive_demux_enum(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
199    let root = root();
200
201    let ItemEnum {
202        ident: item_ident,
203        generics,
204        variants,
205        ..
206    } = parse_macro_input!(item as ItemEnum);
207
208    // Sort variants alphabetically.
209    let mut variants = variants.into_iter().collect::<Vec<_>>();
210    variants.sort_by(|a, b| a.ident.cmp(&b.ident));
211
212    // Return type for each variant.
213    let variant_output_types = variants
214        .iter()
215        .map(|variant| match &variant.fields {
216            Fields::Named(fields) => {
217                let field_types = fields.named.iter().map(|field| &field.ty);
218                quote! {
219                    ( #( #field_types, )* )
220                }
221            }
222            Fields::Unnamed(fields) => {
223                let field_types = fields.unnamed.iter().map(|field| &field.ty);
224                quote! {
225                    ( #( #field_types, )* )
226                }
227            }
228            Fields::Unit => quote!(()),
229        })
230        .collect::<Vec<_>>();
231
232    let variant_generics_sink = variants
233        .iter()
234        .map(|variant| format_ident!("__Sink{}", variant.ident))
235        .collect::<Vec<_>>();
236    let variant_generics_pinned_sink = variant_generics_sink.iter().map(|ident| {
237        quote_spanned! {ident.span()=>
238            ::std::pin::Pin::<&mut #ident>
239        }
240    });
241    let variant_generics_pinned_sink_all = quote! {
242        ( #( #variant_generics_pinned_sink, )* )
243    };
244    let variant_localvars_sink = variants
245        .iter()
246        .map(|variant| {
247            format_ident!(
248                "__sink_{}",
249                variant.ident.to_string().to_lowercase(),
250                span = variant.ident.span()
251            )
252        })
253        .collect::<Vec<_>>();
254
255    let mut full_generics_sink = generics.clone();
256    full_generics_sink.params.extend(
257        variant_generics_sink
258            .iter()
259            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
260    );
261    full_generics_sink.make_where_clause().predicates.extend(
262        variant_generics_sink
263            .iter()
264            .zip(variant_output_types.iter())
265            .map::<WherePredicate, _>(|(sink_generic, output_type)| {
266                parse_quote! {
267                    // TODO(mingwei): generic error types?
268                    #sink_generic: #root::futures::sink::Sink<#output_type, Error = #root::Never>
269                }
270            }),
271    );
272
273    let variant_pats_sink_start_send =
274        variants
275            .iter()
276            .zip(variant_localvars_sink.iter())
277            .map(|(variant, sinkvar)| {
278                let Variant { ident, fields, .. } = variant;
279                let (fields_pat, push_item) = field_pattern_item(fields);
280                quote! {
281                    Self::#ident #fields_pat => #sinkvar.as_mut().start_send(#push_item)
282                }
283            });
284
285    let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
286    let (impl_generics_sink, _ty_generics_sink, where_clause_sink) =
287        full_generics_sink.split_for_impl();
288
289    let single_impl = (1 == variants.len()).then(|| {
290        let Variant { ident, fields, .. } = variants.first().unwrap();
291        let (fields_pat, push_item) = field_pattern_item(fields);
292        let out_type = variant_output_types.first().unwrap();
293        quote! {
294            impl #impl_generics_item #root::util::demux_enum::SingleVariant
295                for #item_ident #ty_generics #where_clause_item
296            {
297                type Output = #out_type;
298                fn single_variant(self) -> Self::Output {
299                    match self {
300                        Self::#ident #fields_pat => #push_item,
301                    }
302                }
303            }
304        }
305    });
306
307    quote! {
308        impl #impl_generics_sink #root::util::demux_enum::DemuxEnumSink<#variant_generics_pinned_sink_all>
309            for #item_ident #ty_generics #where_clause_sink
310        {
311            type Error = #root::Never;
312
313            fn poll_ready(
314                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
315                __cx: &mut ::std::task::Context<'_>,
316            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
317                // Ready all sinks simultaneously.
318                #(
319                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_ready(__cx)?;
320                )*
321                #(
322                    ::std::task::ready!(#variant_localvars_sink);
323                )*
324                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
325            }
326
327            fn start_send(
328                self,
329                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
330            ) -> ::std::result::Result<(), Self::Error> {
331                match self {
332                    #( #variant_pats_sink_start_send, )*
333                }
334            }
335
336            fn poll_flush(
337                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
338                __cx: &mut ::std::task::Context<'_>,
339            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
340                // Flush all sinks simultaneously.
341                #(
342                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_flush(__cx)?;
343                )*
344                #(
345                    ::std::task::ready!(#variant_localvars_sink);
346                )*
347                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
348            }
349
350            fn poll_close(
351                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
352                __cx: &mut ::std::task::Context<'_>,
353            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
354                // Close all sinks simultaneously.
355                #(
356                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_close(__cx)?;
357                )*
358                #(
359                    ::std::task::ready!(#variant_localvars_sink);
360                )*
361                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
362            }
363        }
364
365        impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
366            for #item_ident #ty_generics #where_clause_item {}
367
368        #single_impl
369    }
370    .into()
371}
372
373/// (fields pattern, push item expr)
374fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
375    let idents = fields
376        .iter()
377        .enumerate()
378        .map(|(i, field)| {
379            field
380                .ident
381                .clone()
382                .unwrap_or_else(|| format_ident!("_{}", i))
383        })
384        .collect::<Vec<_>>();
385    let (fields_pat, push_item) = match fields {
386        Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
387        Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
388        Fields::Unit => (quote!(), quote!(())),
389    };
390    (fields_pat, push_item)
391}