Skip to main content

dfir_macro/
lib.rs

1use dfir_lang::diagnostic::Level;
2use dfir_lang::graph::{
3    BuildDfirCodeOutput, FlatGraphBuilder, FlatGraphBuilderOutput, build_dfir_code, partition_graph,
4};
5use dfir_lang::parse::DfirCode;
6use proc_macro2::{Ident, Literal, Span};
7use quote::{format_ident, quote, quote_spanned};
8use syn::spanned::Spanned;
9use syn::{
10    Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
11    parse_quote,
12};
13
14/// Create a runnable graph instance using DFIR's custom syntax.
15///
16/// For example usage, take a look at the [`surface_*` tests in the `tests` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/tests)
17/// or the [`examples` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/examples)
18/// in the [Hydro repo](https://github.com/hydro-project/hydro).
19// TODO(mingwei): rustdoc examples inline.
20#[proc_macro]
21pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
22    dfir_syntax_internal(input, Some(Level::Help))
23}
24
25/// [`dfir_syntax!`] but will not emit any diagnostics (errors, warnings, etc.).
26///
27/// Used for testing, users will want to use [`dfir_syntax!`] instead.
28#[proc_macro]
29pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
30    dfir_syntax_internal(input, None)
31}
32
33fn root() -> proc_macro2::TokenStream {
34    use std::env::{VarError, var as env_var};
35
36    let root_crate_name = format!(
37        "{}_rs",
38        env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
39    );
40    let root_crate_ident = root_crate_name.replace('-', "_");
41    let root_crate = proc_macro_crate::crate_name(&root_crate_name)
42        .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
43    match root_crate {
44        proc_macro_crate::FoundCrate::Itself => {
45            if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
46                && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
47                && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
48            {
49                // In the crate itself, including unit tests.
50                quote! { crate }
51            } else {
52                // In an integration test, example, bench, etc.
53                let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
54                quote! { ::#ident }
55            }
56        }
57        proc_macro_crate::FoundCrate::Name(name) => {
58            let ident = Ident::new(&name, Span::call_site());
59            quote! { ::#ident }
60        }
61    }
62}
63
64fn dfir_syntax_internal(
65    input: proc_macro::TokenStream,
66    retain_diagnostic_level: Option<Level>,
67) -> proc_macro::TokenStream {
68    let input = parse_macro_input!(input as DfirCode);
69    let root = root();
70
71    let (code, mut diagnostics) = match build_dfir_code(input, &root) {
72        Ok(BuildDfirCodeOutput {
73            partitioned_graph: _,
74            code,
75            diagnostics,
76        }) => (code, diagnostics),
77        Err(diagnostics) => (
78            quote! {
79                {
80                    #root::scheduled::context::Dfir::new(
81                        #root::scheduled::context::NullTickClosure,
82                        <#root::scheduled::context::Context as ::std::default::Default>::default(),
83                        None,
84                        None,
85                    )
86                }
87            },
88            diagnostics,
89        ),
90    };
91
92    let diagnostic_tokens = retain_diagnostic_level.and_then(|level| {
93        diagnostics.retain_level(level);
94        diagnostics.try_emit_all().err()
95    });
96
97    quote! {
98        {
99            #diagnostic_tokens
100            #code
101        }
102    }
103    .into()
104}
105
106/// Parse DFIR syntax without emitting code.
107///
108/// Used for testing, users will want to use [`dfir_syntax!`] instead.
109#[proc_macro]
110pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
111    let input = parse_macro_input!(input as DfirCode);
112
113    let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
114    let err_diagnostics = 'err: {
115        let (mut flat_graph, mut diagnostics) = match flat_graph_builder.build() {
116            Ok(FlatGraphBuilderOutput {
117                flat_graph,
118                uses: _,
119                diagnostics,
120            }) => (flat_graph, diagnostics),
121            Err(diagnostics) => {
122                break 'err diagnostics;
123            }
124        };
125
126        if let Err(diagnostic) = flat_graph.merge_modules() {
127            diagnostics.push(diagnostic);
128            break 'err diagnostics;
129        }
130
131        let flat_mermaid = flat_graph.mermaid_string_flat();
132
133        let part_mermaid = partition_graph(flat_graph)
134            .map(|part_graph| part_graph.to_mermaid(&Default::default()))
135            .unwrap_or_else(|err| format!("failed to partition: {err}"));
136
137        let lit0 = Literal::string(&flat_mermaid);
138        let lit1 = Literal::string(&part_mermaid);
139
140        return quote! {
141            {
142                println!("{}\n\n{}\n", #lit0, #lit1);
143            }
144        }
145        .into();
146    };
147
148    err_diagnostics
149        .try_emit_all()
150        .err()
151        .unwrap_or_default()
152        .into()
153}
154
155fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
156    use quote::ToTokens;
157
158    let root = root();
159
160    let mut input: syn::ItemFn = match syn::parse(item) {
161        Ok(it) => it,
162        Err(e) => return e.into_compile_error().into(),
163    };
164
165    let statements = input.block.stmts;
166
167    input.block.stmts = parse_quote!(
168        #root::tokio::task::LocalSet::new().run_until(async {
169            #( #statements )*
170        }).await
171    );
172
173    input.attrs.push(attribute);
174
175    input.into_token_stream().into()
176}
177
178/// Checks that the given closure is a morphism. For now does nothing.
179#[proc_macro]
180pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
181    // TODO(mingwei): some sort of code analysis?
182    item
183}
184
185/// Checks that the given closure is a monotonic function. For now does nothing.
186#[proc_macro]
187pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
188    // TODO(mingwei): some sort of code analysis?
189    item
190}
191
192#[proc_macro_attribute]
193pub fn dfir_test(
194    args: proc_macro::TokenStream,
195    item: proc_macro::TokenStream,
196) -> proc_macro::TokenStream {
197    let root = root();
198    let args_2: proc_macro2::TokenStream = args.into();
199
200    wrap_localset(
201        item,
202        parse_quote!(
203            #[#root::tokio::test(flavor = "current_thread", #args_2)]
204        ),
205    )
206}
207
208#[proc_macro_attribute]
209pub fn dfir_main(
210    _: proc_macro::TokenStream,
211    item: proc_macro::TokenStream,
212) -> proc_macro::TokenStream {
213    let root = root();
214
215    wrap_localset(
216        item,
217        parse_quote!(
218            #[#root::tokio::main(flavor = "current_thread")]
219        ),
220    )
221}
222
223#[proc_macro_derive(DemuxEnum)]
224pub fn derive_demux_enum(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
225    let root = root();
226
227    let ItemEnum {
228        ident: item_ident,
229        generics,
230        variants,
231        ..
232    } = parse_macro_input!(item as ItemEnum);
233
234    // Sort variants alphabetically.
235    let mut variants = variants.into_iter().collect::<Vec<_>>();
236    variants.sort_by(|a, b| a.ident.cmp(&b.ident));
237
238    // Return type for each variant.
239    let variant_output_types = variants
240        .iter()
241        .map(|variant| match &variant.fields {
242            Fields::Named(fields) => {
243                let field_types = fields.named.iter().map(|field| &field.ty);
244                quote! {
245                    ( #( #field_types, )* )
246                }
247            }
248            Fields::Unnamed(fields) => {
249                let field_types = fields.unnamed.iter().map(|field| &field.ty);
250                quote! {
251                    ( #( #field_types, )* )
252                }
253            }
254            Fields::Unit => quote!(()),
255        })
256        .collect::<Vec<_>>();
257
258    let variant_generics_sink = variants
259        .iter()
260        .map(|variant| format_ident!("__Sink{}", variant.ident))
261        .collect::<Vec<_>>();
262    let variant_generics_pinned_sink = variant_generics_sink.iter().map(|ident| {
263        quote_spanned! {ident.span()=>
264            ::std::pin::Pin::<&mut #ident>
265        }
266    });
267    let variant_generics_pinned_sink_all = quote! {
268        ( #( #variant_generics_pinned_sink, )* )
269    };
270    let variant_localvars_sink = variants
271        .iter()
272        .map(|variant| {
273            format_ident!(
274                "__sink_{}",
275                variant.ident.to_string().to_lowercase(),
276                span = variant.ident.span()
277            )
278        })
279        .collect::<Vec<_>>();
280
281    let mut full_generics_sink = generics.clone();
282    full_generics_sink.params.extend(
283        variant_generics_sink
284            .iter()
285            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
286    );
287    full_generics_sink.make_where_clause().predicates.extend(
288        variant_generics_sink
289            .iter()
290            .zip(variant_output_types.iter())
291            .map::<WherePredicate, _>(|(sink_generic, output_type)| {
292                parse_quote! {
293                    // TODO(mingwei): generic error types?
294                    #sink_generic: #root::futures::sink::Sink<#output_type, Error = #root::Never>
295                }
296            }),
297    );
298
299    let variant_pats_sink_start_send = variants.iter().zip(variant_localvars_sink.iter()).map(
300        |(variant, sinkvar)| {
301            let Variant { ident, fields, .. } = variant;
302            let (fields_pat, push_item) = field_pattern_item(fields);
303            quote! {
304                Self::#ident #fields_pat => ::std::pin::Pin::as_mut(#sinkvar).start_send(#push_item)
305            }
306        },
307    );
308
309    let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
310    let (impl_generics_sink, _ty_generics_sink, where_clause_sink) =
311        full_generics_sink.split_for_impl();
312
313    let variant_generics_push = variants
314        .iter()
315        .map(|variant| format_ident!("__Push{}", variant.ident))
316        .collect::<Vec<_>>();
317    let variant_generics_pinned_push = variant_generics_push.iter().map(|ident| {
318        quote_spanned! {ident.span()=>
319            ::std::pin::Pin::<&mut #ident>
320        }
321    });
322    let variant_generics_pinned_push_all = quote! {
323        ( #( #variant_generics_pinned_push, )* )
324    };
325    let variant_localvars_push = variants
326        .iter()
327        .map(|variant| {
328            format_ident!(
329                "__push_{}",
330                variant.ident.to_string().to_lowercase(),
331                span = variant.ident.span()
332            )
333        })
334        .collect::<Vec<_>>();
335
336    let mut full_generics_push = generics.clone();
337    full_generics_push.params.extend(
338        variant_generics_push
339            .iter()
340            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
341    );
342    // Each push just needs Push<Item = VariantOutput, Meta = ()>.
343    full_generics_push.make_where_clause().predicates.extend(
344        variant_generics_push
345            .iter()
346            .zip(variant_output_types.iter())
347            .map::<WherePredicate, _>(|(push_generic, output_type)| {
348                parse_quote! {
349                    #push_generic: #root::dfir_pipes::push::Push<#output_type, ()>
350                }
351            }),
352    );
353
354    // Build the recursive Merged Ctx type:
355    // For 0 pushes: `()
356    // For 1 push: `Push0::Ctx<'__ctx>`
357    // For 2 pushes: `<Push0::Ctx<'__ctx> as Context<'__ctx>>::Merged<Push1::Ctx<'__ctx>>`
358    // For 3 pushes: `<Push0::Ctx<'__ctx> as Context<'__ctx>>::Merged<<Push1::Ctx<'__ctx> as Context<'__ctx>>::Merged<Push2::Ctx<'__ctx>>>`
359    let ctx_type = variant_generics_push
360        .iter()
361        .zip(variant_output_types.iter())
362        .rev()
363        .map(|(push_generic, output_type)| {
364            quote_spanned! {push_generic.span()=>
365                <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::Ctx<'__ctx>
366            }
367        })
368        .reduce(|rest, next| {
369            quote_spanned! {next.span()=>
370                <#next as #root::dfir_pipes::Context<'__ctx>>::Merged<#rest>
371            }
372        })
373        .unwrap_or_else(|| quote!(()));
374
375    let can_pend = variant_generics_push
376        .iter()
377        .zip(variant_output_types.iter())
378        .rev()
379        .map(|(push_generic, output_type)| {
380            quote_spanned! {push_generic.span()=>
381                <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::CanPend
382            }
383        })
384        .reduce(|rest, next| {
385            quote_spanned! {next.span()=>
386                <#next as #root::dfir_pipes::Toggle>::Or<#rest>
387            }
388        })
389        .unwrap_or_else(|| quote!(#root::dfir_pipes::No));
390
391    // Generate `Ctx`: `unmerge_self` for each push, `unmerge_other` to get remaining `__ctx`.
392    // For the last push, just pass `__ctx` directly (no unmerge needed).
393    let push_poll_unwrap_context = |method_name: Ident| {
394        variant_localvars_push.split_last().map(|(lastvar, headvar)| {
395            // `#( ... )*` zips all iterators to shortest; `headvar` (all-but-last) is shortest, so
396            // `variant_generics_push` and `variant_output_types` are naturally truncated to match.
397            quote! {
398                #(
399                    let #headvar = {
400                        let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_self(__ctx);
401                        #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#headvar), __ctx)
402                    };
403                    let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_other(__ctx);
404                )*
405                let #lastvar = #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#lastvar), __ctx);
406                // If any are pending, return pending.
407                #(
408                    if #variant_localvars_push.is_pending() {
409                        return #root::dfir_pipes::push::PushStep::pending();
410                    }
411                )*
412            }
413        })
414    };
415    let push_poll_ready_body = (push_poll_unwrap_context)(format_ident!("poll_ready"));
416    let push_poll_finalize_body = (push_poll_unwrap_context)(format_ident!("poll_finalize"));
417
418    let variant_pats_push_send =
419        variants
420            .iter()
421            .zip(variant_localvars_push.iter())
422            .map(|(variant, pushvar)| {
423                let Variant { ident, fields, .. } = variant;
424                let (fields_pat, push_item) = field_pattern_item(fields);
425                quote! {
426                    Self::#ident #fields_pat => { #root::dfir_pipes::push::Push::start_send(#pushvar.as_mut(), #push_item, __meta); }
427                }
428            });
429
430    let (impl_generics_push, _ty_generics_push, where_clause_push) =
431        full_generics_push.split_for_impl();
432
433    let single_impl = (1 == variants.len()).then(|| {
434        let Variant { ident, fields, .. } = variants.first().unwrap();
435        let (fields_pat, push_item) = field_pattern_item(fields);
436        let out_type = variant_output_types.first().unwrap();
437        quote! {
438            impl #impl_generics_item #root::util::demux_enum::SingleVariant
439                for #item_ident #ty_generics #where_clause_item
440            {
441                type Output = #out_type;
442                fn single_variant(self) -> Self::Output {
443                    match self {
444                        Self::#ident #fields_pat => #push_item,
445                    }
446                }
447            }
448        }
449    });
450
451    quote! {
452        impl #impl_generics_sink #root::util::demux_enum::DemuxEnumSink<#variant_generics_pinned_sink_all>
453            for #item_ident #ty_generics #where_clause_sink
454        {
455            type Error = #root::Never;
456
457            fn poll_ready(
458                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
459                __cx: &mut ::std::task::Context<'_>,
460            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
461                // Ready all sinks simultaneously.
462                #(
463                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_ready(__cx)?;
464                )*
465                #(
466                    ::std::task::ready!(#variant_localvars_sink);
467                )*
468                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
469            }
470
471            fn start_send(
472                self,
473                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
474            ) -> ::std::result::Result<(), Self::Error> {
475                match self {
476                    #( #variant_pats_sink_start_send, )*
477                }
478            }
479
480            fn poll_flush(
481                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
482                __cx: &mut ::std::task::Context<'_>,
483            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
484                // Flush all sinks simultaneously.
485                #(
486                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_flush(__cx)?;
487                )*
488                #(
489                    ::std::task::ready!(#variant_localvars_sink);
490                )*
491                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
492            }
493
494            fn poll_close(
495                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
496                __cx: &mut ::std::task::Context<'_>,
497            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
498                // Close all sinks simultaneously.
499                #(
500                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_close(__cx)?;
501                )*
502                #(
503                    ::std::task::ready!(#variant_localvars_sink);
504                )*
505                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
506            }
507        }
508
509        impl #impl_generics_push #root::util::demux_enum::DemuxEnumPush<#variant_generics_pinned_push_all, ()>
510            for #item_ident #ty_generics #where_clause_push
511        {
512            type Ctx<'__ctx> = #ctx_type;
513            type CanPend = #can_pend;
514
515            fn poll_ready(
516                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
517                __ctx: &mut Self::Ctx<'_>,
518            ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
519                #push_poll_ready_body
520                #root::dfir_pipes::push::PushStep::Done
521            }
522
523            fn start_send(
524                self,
525                __meta: (),
526                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
527            ) {
528                match self {
529                    #( #variant_pats_push_send, )*
530                }
531            }
532
533            fn poll_finalize(
534                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
535                __ctx: &mut Self::Ctx<'_>,
536            ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
537                #push_poll_finalize_body
538                #root::dfir_pipes::push::PushStep::Done
539            }
540
541            fn size_hint(
542                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
543                __size_hint: (usize, ::std::option::Option<usize>),
544            ) {
545                #(
546                    #root::dfir_pipes::push::Push::size_hint(
547                        ::std::pin::Pin::as_mut(#variant_localvars_push),
548                        __size_hint,
549                    );
550                )*
551            }
552        }
553
554        impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
555            for #item_ident #ty_generics #where_clause_item {}
556
557        #single_impl
558    }
559    .into()
560}
561
562/// (fields pattern, push item expr)
563fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
564    let idents = fields
565        .iter()
566        .enumerate()
567        .map(|(i, field)| {
568            field
569                .ident
570                .clone()
571                .unwrap_or_else(|| format_ident!("_{}", i))
572        })
573        .collect::<Vec<_>>();
574    let (fields_pat, push_item) = match fields {
575        Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
576        Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
577        Fields::Unit => (quote!(), quote!(())),
578    };
579    (fields_pat, push_item)
580}