Skip to main content

dfir_macro/
lib.rs

1#![cfg_attr(
2    nightly,
3    feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::Level;
7use dfir_lang::graph::{
8    BuildDfirCodeOutput, FlatGraphBuilder, FlatGraphBuilderOutput, build_dfir_code, partition_graph,
9};
10use dfir_lang::parse::DfirCode;
11use proc_macro2::{Ident, Literal, Span};
12use quote::{format_ident, quote, quote_spanned};
13use syn::spanned::Spanned;
14use syn::{
15    Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
16    parse_quote,
17};
18
19/// Create a runnable graph instance using DFIR's custom syntax.
20///
21/// For example usage, take a look at the [`surface_*` tests in the `tests` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/tests)
22/// or the [`examples` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/examples)
23/// in the [Hydro repo](https://github.com/hydro-project/hydro).
24// TODO(mingwei): rustdoc examples inline.
25#[proc_macro]
26pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
27    dfir_syntax_internal(input, Some(Level::Help))
28}
29
30/// [`dfir_syntax!`] but will not emit any diagnostics (errors, warnings, etc.).
31///
32/// Used for testing, users will want to use [`dfir_syntax!`] instead.
33#[proc_macro]
34pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
35    dfir_syntax_internal(input, None)
36}
37
38fn root() -> proc_macro2::TokenStream {
39    use std::env::{VarError, var as env_var};
40
41    let root_crate_name = format!(
42        "{}_rs",
43        env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
44    );
45    let root_crate_ident = root_crate_name.replace('-', "_");
46    let root_crate = proc_macro_crate::crate_name(&root_crate_name)
47        .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
48    match root_crate {
49        proc_macro_crate::FoundCrate::Itself => {
50            if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
51                && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
52                && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
53            {
54                // In the crate itself, including unit tests.
55                quote! { crate }
56            } else {
57                // In an integration test, example, bench, etc.
58                let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
59                quote! { ::#ident }
60            }
61        }
62        proc_macro_crate::FoundCrate::Name(name) => {
63            let ident = Ident::new(&name, Span::call_site());
64            quote! { ::#ident }
65        }
66    }
67}
68
69fn dfir_syntax_internal(
70    input: proc_macro::TokenStream,
71    retain_diagnostic_level: Option<Level>,
72) -> proc_macro::TokenStream {
73    let input = parse_macro_input!(input as DfirCode);
74    let root = root();
75
76    let (code, mut diagnostics) = match build_dfir_code(input, &root) {
77        Ok(BuildDfirCodeOutput {
78            partitioned_graph: _,
79            code,
80            diagnostics,
81        }) => (code, diagnostics),
82        Err(diagnostics) => (
83            quote! {
84                {
85                    #root::scheduled::context::Dfir::new(
86                        #root::scheduled::context::NullTickClosure,
87                        <#root::scheduled::context::Context as ::std::default::Default>::default(),
88                        None,
89                        None,
90                    )
91                }
92            },
93            diagnostics,
94        ),
95    };
96
97    let diagnostic_tokens = retain_diagnostic_level.and_then(|level| {
98        diagnostics.retain_level(level);
99        diagnostics.try_emit_all().err()
100    });
101
102    quote! {
103        {
104            #diagnostic_tokens
105            #code
106        }
107    }
108    .into()
109}
110
111/// Parse DFIR syntax without emitting code.
112///
113/// Used for testing, users will want to use [`dfir_syntax!`] instead.
114#[proc_macro]
115pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
116    let input = parse_macro_input!(input as DfirCode);
117
118    let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
119    let err_diagnostics = 'err: {
120        let (mut flat_graph, mut diagnostics) = match flat_graph_builder.build() {
121            Ok(FlatGraphBuilderOutput {
122                flat_graph,
123                uses: _,
124                diagnostics,
125            }) => (flat_graph, diagnostics),
126            Err(diagnostics) => {
127                break 'err diagnostics;
128            }
129        };
130
131        if let Err(diagnostic) = flat_graph.merge_modules() {
132            diagnostics.push(diagnostic);
133            break 'err diagnostics;
134        }
135
136        let flat_mermaid = flat_graph.mermaid_string_flat();
137
138        let part_mermaid = partition_graph(flat_graph)
139            .map(|part_graph| part_graph.to_mermaid(&Default::default()))
140            .unwrap_or_else(|err| format!("failed to partition: {err}"));
141
142        let lit0 = Literal::string(&flat_mermaid);
143        let lit1 = Literal::string(&part_mermaid);
144
145        return quote! {
146            {
147                println!("{}\n\n{}\n", #lit0, #lit1);
148            }
149        }
150        .into();
151    };
152
153    err_diagnostics
154        .try_emit_all()
155        .err()
156        .unwrap_or_default()
157        .into()
158}
159
160fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
161    use quote::ToTokens;
162
163    let root = root();
164
165    let mut input: syn::ItemFn = match syn::parse(item) {
166        Ok(it) => it,
167        Err(e) => return e.into_compile_error().into(),
168    };
169
170    let statements = input.block.stmts;
171
172    input.block.stmts = parse_quote!(
173        #root::tokio::task::LocalSet::new().run_until(async {
174            #( #statements )*
175        }).await
176    );
177
178    input.attrs.push(attribute);
179
180    input.into_token_stream().into()
181}
182
183/// Checks that the given closure is a morphism. For now does nothing.
184#[proc_macro]
185pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
186    // TODO(mingwei): some sort of code analysis?
187    item
188}
189
190/// Checks that the given closure is a monotonic function. For now does nothing.
191#[proc_macro]
192pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
193    // TODO(mingwei): some sort of code analysis?
194    item
195}
196
197#[proc_macro_attribute]
198pub fn dfir_test(
199    args: proc_macro::TokenStream,
200    item: proc_macro::TokenStream,
201) -> proc_macro::TokenStream {
202    let root = root();
203    let args_2: proc_macro2::TokenStream = args.into();
204
205    wrap_localset(
206        item,
207        parse_quote!(
208            #[#root::tokio::test(flavor = "current_thread", #args_2)]
209        ),
210    )
211}
212
213#[proc_macro_attribute]
214pub fn dfir_main(
215    _: proc_macro::TokenStream,
216    item: proc_macro::TokenStream,
217) -> proc_macro::TokenStream {
218    let root = root();
219
220    wrap_localset(
221        item,
222        parse_quote!(
223            #[#root::tokio::main(flavor = "current_thread")]
224        ),
225    )
226}
227
228#[proc_macro_derive(DemuxEnum)]
229pub fn derive_demux_enum(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
230    let root = root();
231
232    let ItemEnum {
233        ident: item_ident,
234        generics,
235        variants,
236        ..
237    } = parse_macro_input!(item as ItemEnum);
238
239    // Sort variants alphabetically.
240    let mut variants = variants.into_iter().collect::<Vec<_>>();
241    variants.sort_by(|a, b| a.ident.cmp(&b.ident));
242
243    // Return type for each variant.
244    let variant_output_types = variants
245        .iter()
246        .map(|variant| match &variant.fields {
247            Fields::Named(fields) => {
248                let field_types = fields.named.iter().map(|field| &field.ty);
249                quote! {
250                    ( #( #field_types, )* )
251                }
252            }
253            Fields::Unnamed(fields) => {
254                let field_types = fields.unnamed.iter().map(|field| &field.ty);
255                quote! {
256                    ( #( #field_types, )* )
257                }
258            }
259            Fields::Unit => quote!(()),
260        })
261        .collect::<Vec<_>>();
262
263    let variant_generics_sink = variants
264        .iter()
265        .map(|variant| format_ident!("__Sink{}", variant.ident))
266        .collect::<Vec<_>>();
267    let variant_generics_pinned_sink = variant_generics_sink.iter().map(|ident| {
268        quote_spanned! {ident.span()=>
269            ::std::pin::Pin::<&mut #ident>
270        }
271    });
272    let variant_generics_pinned_sink_all = quote! {
273        ( #( #variant_generics_pinned_sink, )* )
274    };
275    let variant_localvars_sink = variants
276        .iter()
277        .map(|variant| {
278            format_ident!(
279                "__sink_{}",
280                variant.ident.to_string().to_lowercase(),
281                span = variant.ident.span()
282            )
283        })
284        .collect::<Vec<_>>();
285
286    let mut full_generics_sink = generics.clone();
287    full_generics_sink.params.extend(
288        variant_generics_sink
289            .iter()
290            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
291    );
292    full_generics_sink.make_where_clause().predicates.extend(
293        variant_generics_sink
294            .iter()
295            .zip(variant_output_types.iter())
296            .map::<WherePredicate, _>(|(sink_generic, output_type)| {
297                parse_quote! {
298                    // TODO(mingwei): generic error types?
299                    #sink_generic: #root::futures::sink::Sink<#output_type, Error = #root::Never>
300                }
301            }),
302    );
303
304    let variant_pats_sink_start_send = variants.iter().zip(variant_localvars_sink.iter()).map(
305        |(variant, sinkvar)| {
306            let Variant { ident, fields, .. } = variant;
307            let (fields_pat, push_item) = field_pattern_item(fields);
308            quote! {
309                Self::#ident #fields_pat => ::std::pin::Pin::as_mut(#sinkvar).start_send(#push_item)
310            }
311        },
312    );
313
314    let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
315    let (impl_generics_sink, _ty_generics_sink, where_clause_sink) =
316        full_generics_sink.split_for_impl();
317
318    let variant_generics_push = variants
319        .iter()
320        .map(|variant| format_ident!("__Push{}", variant.ident))
321        .collect::<Vec<_>>();
322    let variant_generics_pinned_push = variant_generics_push.iter().map(|ident| {
323        quote_spanned! {ident.span()=>
324            ::std::pin::Pin::<&mut #ident>
325        }
326    });
327    let variant_generics_pinned_push_all = quote! {
328        ( #( #variant_generics_pinned_push, )* )
329    };
330    let variant_localvars_push = variants
331        .iter()
332        .map(|variant| {
333            format_ident!(
334                "__push_{}",
335                variant.ident.to_string().to_lowercase(),
336                span = variant.ident.span()
337            )
338        })
339        .collect::<Vec<_>>();
340
341    let mut full_generics_push = generics.clone();
342    full_generics_push.params.extend(
343        variant_generics_push
344            .iter()
345            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
346    );
347    // Each push just needs Push<Item = VariantOutput, Meta = ()>.
348    full_generics_push.make_where_clause().predicates.extend(
349        variant_generics_push
350            .iter()
351            .zip(variant_output_types.iter())
352            .map::<WherePredicate, _>(|(push_generic, output_type)| {
353                parse_quote! {
354                    #push_generic: #root::dfir_pipes::push::Push<#output_type, ()>
355                }
356            }),
357    );
358
359    // Build the recursive Merged Ctx type:
360    // For 0 pushes: `()
361    // For 1 push: `Push0::Ctx<'__ctx>`
362    // For 2 pushes: `<Push0::Ctx<'__ctx> as Context<'__ctx>>::Merged<Push1::Ctx<'__ctx>>`
363    // For 3 pushes: `<Push0::Ctx<'__ctx> as Context<'__ctx>>::Merged<<Push1::Ctx<'__ctx> as Context<'__ctx>>::Merged<Push2::Ctx<'__ctx>>>`
364    let ctx_type = variant_generics_push
365        .iter()
366        .zip(variant_output_types.iter())
367        .rev()
368        .map(|(push_generic, output_type)| {
369            quote_spanned! {push_generic.span()=>
370                <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::Ctx<'__ctx>
371            }
372        })
373        .reduce(|rest, next| {
374            quote_spanned! {next.span()=>
375                <#next as #root::dfir_pipes::Context<'__ctx>>::Merged<#rest>
376            }
377        })
378        .unwrap_or_else(|| quote!(()));
379
380    let can_pend = variant_generics_push
381        .iter()
382        .zip(variant_output_types.iter())
383        .rev()
384        .map(|(push_generic, output_type)| {
385            quote_spanned! {push_generic.span()=>
386                <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::CanPend
387            }
388        })
389        .reduce(|rest, next| {
390            quote_spanned! {next.span()=>
391                <#next as #root::dfir_pipes::Toggle>::Or<#rest>
392            }
393        })
394        .unwrap_or_else(|| quote!(#root::dfir_pipes::No));
395
396    // Generate `Ctx`: `unmerge_self` for each push, `unmerge_other` to get remaining `__ctx`.
397    // For the last push, just pass `__ctx` directly (no unmerge needed).
398    let push_poll_unwrap_context = |method_name: Ident| {
399        variant_localvars_push.split_last().map(|(lastvar, headvar)| {
400            // `#( ... )*` zips all iterators to shortest; `headvar` (all-but-last) is shortest, so
401            // `variant_generics_push` and `variant_output_types` are naturally truncated to match.
402            quote! {
403                #(
404                    let #headvar = {
405                        let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_self(__ctx);
406                        #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#headvar), __ctx)
407                    };
408                    let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_other(__ctx);
409                )*
410                let #lastvar = #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#lastvar), __ctx);
411                // If any are pending, return pending.
412                #(
413                    if #variant_localvars_push.is_pending() {
414                        return #root::dfir_pipes::push::PushStep::pending();
415                    }
416                )*
417            }
418        })
419    };
420    let push_poll_ready_body = (push_poll_unwrap_context)(format_ident!("poll_ready"));
421    let push_poll_flush_body = (push_poll_unwrap_context)(format_ident!("poll_flush"));
422
423    let variant_pats_push_send =
424        variants
425            .iter()
426            .zip(variant_localvars_push.iter())
427            .map(|(variant, pushvar)| {
428                let Variant { ident, fields, .. } = variant;
429                let (fields_pat, push_item) = field_pattern_item(fields);
430                quote! {
431                    Self::#ident #fields_pat => { #root::dfir_pipes::push::Push::start_send(#pushvar.as_mut(), #push_item, __meta); }
432                }
433            });
434
435    let (impl_generics_push, _ty_generics_push, where_clause_push) =
436        full_generics_push.split_for_impl();
437
438    let single_impl = (1 == variants.len()).then(|| {
439        let Variant { ident, fields, .. } = variants.first().unwrap();
440        let (fields_pat, push_item) = field_pattern_item(fields);
441        let out_type = variant_output_types.first().unwrap();
442        quote! {
443            impl #impl_generics_item #root::util::demux_enum::SingleVariant
444                for #item_ident #ty_generics #where_clause_item
445            {
446                type Output = #out_type;
447                fn single_variant(self) -> Self::Output {
448                    match self {
449                        Self::#ident #fields_pat => #push_item,
450                    }
451                }
452            }
453        }
454    });
455
456    quote! {
457        impl #impl_generics_sink #root::util::demux_enum::DemuxEnumSink<#variant_generics_pinned_sink_all>
458            for #item_ident #ty_generics #where_clause_sink
459        {
460            type Error = #root::Never;
461
462            fn poll_ready(
463                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
464                __cx: &mut ::std::task::Context<'_>,
465            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
466                // Ready all sinks simultaneously.
467                #(
468                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_ready(__cx)?;
469                )*
470                #(
471                    ::std::task::ready!(#variant_localvars_sink);
472                )*
473                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
474            }
475
476            fn start_send(
477                self,
478                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
479            ) -> ::std::result::Result<(), Self::Error> {
480                match self {
481                    #( #variant_pats_sink_start_send, )*
482                }
483            }
484
485            fn poll_flush(
486                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
487                __cx: &mut ::std::task::Context<'_>,
488            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
489                // Flush all sinks simultaneously.
490                #(
491                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_flush(__cx)?;
492                )*
493                #(
494                    ::std::task::ready!(#variant_localvars_sink);
495                )*
496                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
497            }
498
499            fn poll_close(
500                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
501                __cx: &mut ::std::task::Context<'_>,
502            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
503                // Close all sinks simultaneously.
504                #(
505                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_close(__cx)?;
506                )*
507                #(
508                    ::std::task::ready!(#variant_localvars_sink);
509                )*
510                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
511            }
512        }
513
514        impl #impl_generics_push #root::util::demux_enum::DemuxEnumPush<#variant_generics_pinned_push_all, ()>
515            for #item_ident #ty_generics #where_clause_push
516        {
517            type Ctx<'__ctx> = #ctx_type;
518            type CanPend = #can_pend;
519
520            fn poll_ready(
521                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
522                __ctx: &mut Self::Ctx<'_>,
523            ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
524                #push_poll_ready_body
525                #root::dfir_pipes::push::PushStep::Done
526            }
527
528            fn start_send(
529                self,
530                __meta: (),
531                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
532            ) {
533                match self {
534                    #( #variant_pats_push_send, )*
535                }
536            }
537
538            fn poll_flush(
539                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
540                __ctx: &mut Self::Ctx<'_>,
541            ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
542                #push_poll_flush_body
543                #root::dfir_pipes::push::PushStep::Done
544            }
545
546            fn size_hint(
547                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
548                __size_hint: (usize, ::std::option::Option<usize>),
549            ) {
550                #(
551                    #root::dfir_pipes::push::Push::size_hint(
552                        ::std::pin::Pin::as_mut(#variant_localvars_push),
553                        __size_hint,
554                    );
555                )*
556            }
557        }
558
559        impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
560            for #item_ident #ty_generics #where_clause_item {}
561
562        #single_impl
563    }
564    .into()
565}
566
567/// (fields pattern, push item expr)
568fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
569    let idents = fields
570        .iter()
571        .enumerate()
572        .map(|(i, field)| {
573            field
574                .ident
575                .clone()
576                .unwrap_or_else(|| format_ident!("_{}", i))
577        })
578        .collect::<Vec<_>>();
579    let (fields_pat, push_item) = match fields {
580        Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
581        Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
582        Fields::Unit => (quote!(), quote!(())),
583    };
584    (fields_pat, push_item)
585}