Skip to main content

dfir_macro/
lib.rs

1#![cfg_attr(
2    nightly,
3    feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::Level;
7use dfir_lang::graph::{
8    BuildDfirCodeOutput, FlatGraphBuilder, FlatGraphBuilderOutput, build_dfir_code,
9    build_dfir_code_inline, partition_graph,
10};
11use dfir_lang::parse::DfirCode;
12use proc_macro2::{Ident, Literal, Span};
13use quote::{format_ident, quote, quote_spanned};
14use syn::spanned::Spanned;
15use syn::{
16    Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
17    parse_quote,
18};
19
20/// Create a runnable graph instance using DFIR's custom syntax.
21///
22/// For example usage, take a look at the [`surface_*` tests in the `tests` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/tests)
23/// or the [`examples` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/examples)
24/// in the [Hydro repo](https://github.com/hydro-project/hydro).
25// TODO(mingwei): rustdoc examples inline.
26#[proc_macro]
27pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
28    dfir_syntax_internal(input, Some(Level::Help))
29}
30
31/// [`dfir_syntax!`] but will not emit any diagnostics (errors, warnings, etc.).
32///
33/// Used for testing, users will want to use [`dfir_syntax!`] instead.
34#[proc_macro]
35pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
36    dfir_syntax_internal(input, None)
37}
38
39fn root() -> proc_macro2::TokenStream {
40    use std::env::{VarError, var as env_var};
41
42    let root_crate_name = format!(
43        "{}_rs",
44        env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
45    );
46    let root_crate_ident = root_crate_name.replace('-', "_");
47    let root_crate = proc_macro_crate::crate_name(&root_crate_name)
48        .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
49    match root_crate {
50        proc_macro_crate::FoundCrate::Itself => {
51            if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
52                && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
53                && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
54            {
55                // In the crate itself, including unit tests.
56                quote! { crate }
57            } else {
58                // In an integration test, example, bench, etc.
59                let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
60                quote! { ::#ident }
61            }
62        }
63        proc_macro_crate::FoundCrate::Name(name) => {
64            let ident = Ident::new(&name, Span::call_site());
65            quote! { ::#ident }
66        }
67    }
68}
69
70fn dfir_syntax_internal(
71    input: proc_macro::TokenStream,
72    retain_diagnostic_level: Option<Level>,
73) -> proc_macro::TokenStream {
74    let input = parse_macro_input!(input as DfirCode);
75    let root = root();
76
77    let (code, mut diagnostics) = match build_dfir_code(input, &root) {
78        Ok(BuildDfirCodeOutput {
79            partitioned_graph: _,
80            code,
81            diagnostics,
82        }) => (code, diagnostics),
83        Err(diagnostics) => (quote! { #root::scheduled::graph::Dfir::new() }, diagnostics),
84    };
85
86    let diagnostic_tokens = retain_diagnostic_level.and_then(|level| {
87        diagnostics.retain_level(level);
88        diagnostics.try_emit_all().err()
89    });
90
91    quote! {
92        {
93            #diagnostic_tokens
94            #code
95        }
96    }
97    .into()
98}
99
100/// Create an inline dataflow graph that runs immediately without the Dfir scheduler.
101/// Experimental: uses local Vec buffers instead of handoffs, runs subgraphs in topological order.
102#[proc_macro]
103pub fn dfir_syntax_inline(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
104    dfir_syntax_inline_internal(input, Some(Level::Help))
105}
106
107/// [`dfir_syntax_inline!`] but will not emit any diagnostics.
108#[proc_macro]
109pub fn dfir_syntax_inline_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
110    dfir_syntax_inline_internal(input, None)
111}
112
113fn dfir_syntax_inline_internal(
114    input: proc_macro::TokenStream,
115    retain_diagnostic_level: Option<Level>,
116) -> proc_macro::TokenStream {
117    let input = parse_macro_input!(input as DfirCode);
118    let root = root();
119
120    let (code, mut diagnostics) = match build_dfir_code_inline(input, &root) {
121        Ok(BuildDfirCodeOutput {
122            partitioned_graph: _,
123            code,
124            diagnostics,
125        }) => (code, diagnostics),
126        Err(diagnostics) => (quote! { async move || {} }, diagnostics),
127    };
128
129    let diagnostic_tokens = retain_diagnostic_level.and_then(|level| {
130        diagnostics.retain_level(level);
131        diagnostics.try_emit_all().err()
132    });
133
134    let out = quote! {
135        {
136            #diagnostic_tokens
137            #code
138        }
139    };
140    out.into()
141}
142
143/// Parse DFIR syntax without emitting code.
144///
145/// Used for testing, users will want to use [`dfir_syntax!`] instead.
146#[proc_macro]
147pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
148    let input = parse_macro_input!(input as DfirCode);
149
150    let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
151    let err_diagnostics = 'err: {
152        let (mut flat_graph, mut diagnostics) = match flat_graph_builder.build() {
153            Ok(FlatGraphBuilderOutput {
154                flat_graph,
155                uses: _,
156                diagnostics,
157            }) => (flat_graph, diagnostics),
158            Err(diagnostics) => {
159                break 'err diagnostics;
160            }
161        };
162
163        if let Err(diagnostic) = flat_graph.merge_modules() {
164            diagnostics.push(diagnostic);
165            break 'err diagnostics;
166        }
167
168        let flat_mermaid = flat_graph.mermaid_string_flat();
169
170        let part_graph = partition_graph(flat_graph).unwrap();
171        let part_mermaid = part_graph.to_mermaid(&Default::default());
172
173        let lit0 = Literal::string(&flat_mermaid);
174        let lit1 = Literal::string(&part_mermaid);
175
176        return quote! {
177            {
178                println!("{}\n\n{}\n", #lit0, #lit1);
179            }
180        }
181        .into();
182    };
183
184    err_diagnostics
185        .try_emit_all()
186        .err()
187        .unwrap_or_default()
188        .into()
189}
190
191fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
192    use quote::ToTokens;
193
194    let root = root();
195
196    let mut input: syn::ItemFn = match syn::parse(item) {
197        Ok(it) => it,
198        Err(e) => return e.into_compile_error().into(),
199    };
200
201    let statements = input.block.stmts;
202
203    input.block.stmts = parse_quote!(
204        #root::tokio::task::LocalSet::new().run_until(async {
205            #( #statements )*
206        }).await
207    );
208
209    input.attrs.push(attribute);
210
211    input.into_token_stream().into()
212}
213
214/// Checks that the given closure is a morphism. For now does nothing.
215#[proc_macro]
216pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
217    // TODO(mingwei): some sort of code analysis?
218    item
219}
220
221/// Checks that the given closure is a monotonic function. For now does nothing.
222#[proc_macro]
223pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
224    // TODO(mingwei): some sort of code analysis?
225    item
226}
227
228#[proc_macro_attribute]
229pub fn dfir_test(
230    args: proc_macro::TokenStream,
231    item: proc_macro::TokenStream,
232) -> proc_macro::TokenStream {
233    let root = root();
234    let args_2: proc_macro2::TokenStream = args.into();
235
236    wrap_localset(
237        item,
238        parse_quote!(
239            #[#root::tokio::test(flavor = "current_thread", #args_2)]
240        ),
241    )
242}
243
244#[proc_macro_attribute]
245pub fn dfir_main(
246    _: proc_macro::TokenStream,
247    item: proc_macro::TokenStream,
248) -> proc_macro::TokenStream {
249    let root = root();
250
251    wrap_localset(
252        item,
253        parse_quote!(
254            #[#root::tokio::main(flavor = "current_thread")]
255        ),
256    )
257}
258
259#[proc_macro_derive(DemuxEnum)]
260pub fn derive_demux_enum(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
261    let root = root();
262
263    let ItemEnum {
264        ident: item_ident,
265        generics,
266        variants,
267        ..
268    } = parse_macro_input!(item as ItemEnum);
269
270    // Sort variants alphabetically.
271    let mut variants = variants.into_iter().collect::<Vec<_>>();
272    variants.sort_by(|a, b| a.ident.cmp(&b.ident));
273
274    // Return type for each variant.
275    let variant_output_types = variants
276        .iter()
277        .map(|variant| match &variant.fields {
278            Fields::Named(fields) => {
279                let field_types = fields.named.iter().map(|field| &field.ty);
280                quote! {
281                    ( #( #field_types, )* )
282                }
283            }
284            Fields::Unnamed(fields) => {
285                let field_types = fields.unnamed.iter().map(|field| &field.ty);
286                quote! {
287                    ( #( #field_types, )* )
288                }
289            }
290            Fields::Unit => quote!(()),
291        })
292        .collect::<Vec<_>>();
293
294    let variant_generics_sink = variants
295        .iter()
296        .map(|variant| format_ident!("__Sink{}", variant.ident))
297        .collect::<Vec<_>>();
298    let variant_generics_pinned_sink = variant_generics_sink.iter().map(|ident| {
299        quote_spanned! {ident.span()=>
300            ::std::pin::Pin::<&mut #ident>
301        }
302    });
303    let variant_generics_pinned_sink_all = quote! {
304        ( #( #variant_generics_pinned_sink, )* )
305    };
306    let variant_localvars_sink = variants
307        .iter()
308        .map(|variant| {
309            format_ident!(
310                "__sink_{}",
311                variant.ident.to_string().to_lowercase(),
312                span = variant.ident.span()
313            )
314        })
315        .collect::<Vec<_>>();
316
317    let mut full_generics_sink = generics.clone();
318    full_generics_sink.params.extend(
319        variant_generics_sink
320            .iter()
321            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
322    );
323    full_generics_sink.make_where_clause().predicates.extend(
324        variant_generics_sink
325            .iter()
326            .zip(variant_output_types.iter())
327            .map::<WherePredicate, _>(|(sink_generic, output_type)| {
328                parse_quote! {
329                    // TODO(mingwei): generic error types?
330                    #sink_generic: #root::futures::sink::Sink<#output_type, Error = #root::Never>
331                }
332            }),
333    );
334
335    let variant_pats_sink_start_send = variants.iter().zip(variant_localvars_sink.iter()).map(
336        |(variant, sinkvar)| {
337            let Variant { ident, fields, .. } = variant;
338            let (fields_pat, push_item) = field_pattern_item(fields);
339            quote! {
340                Self::#ident #fields_pat => ::std::pin::Pin::as_mut(#sinkvar).start_send(#push_item)
341            }
342        },
343    );
344
345    let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
346    let (impl_generics_sink, _ty_generics_sink, where_clause_sink) =
347        full_generics_sink.split_for_impl();
348
349    let variant_generics_push = variants
350        .iter()
351        .map(|variant| format_ident!("__Push{}", variant.ident))
352        .collect::<Vec<_>>();
353    let variant_generics_pinned_push = variant_generics_push.iter().map(|ident| {
354        quote_spanned! {ident.span()=>
355            ::std::pin::Pin::<&mut #ident>
356        }
357    });
358    let variant_generics_pinned_push_all = quote! {
359        ( #( #variant_generics_pinned_push, )* )
360    };
361    let variant_localvars_push = variants
362        .iter()
363        .map(|variant| {
364            format_ident!(
365                "__push_{}",
366                variant.ident.to_string().to_lowercase(),
367                span = variant.ident.span()
368            )
369        })
370        .collect::<Vec<_>>();
371
372    let mut full_generics_push = generics.clone();
373    full_generics_push.params.extend(
374        variant_generics_push
375            .iter()
376            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
377    );
378    // Each push just needs Push<Item = VariantOutput, Meta = ()>.
379    full_generics_push.make_where_clause().predicates.extend(
380        variant_generics_push
381            .iter()
382            .zip(variant_output_types.iter())
383            .map::<WherePredicate, _>(|(push_generic, output_type)| {
384                parse_quote! {
385                    #push_generic: #root::dfir_pipes::push::Push<#output_type, ()>
386                }
387            }),
388    );
389
390    // Build the recursive Merged Ctx type:
391    // For 0 pushes: `()
392    // For 1 push: `Push0::Ctx<'__ctx>`
393    // For 2 pushes: `<Push0::Ctx<'__ctx> as Context<'__ctx>>::Merged<Push1::Ctx<'__ctx>>`
394    // For 3 pushes: `<Push0::Ctx<'__ctx> as Context<'__ctx>>::Merged<<Push1::Ctx<'__ctx> as Context<'__ctx>>::Merged<Push2::Ctx<'__ctx>>>`
395    let ctx_type = variant_generics_push
396        .iter()
397        .zip(variant_output_types.iter())
398        .rev()
399        .map(|(push_generic, output_type)| {
400            quote_spanned! {push_generic.span()=>
401                <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::Ctx<'__ctx>
402            }
403        })
404        .reduce(|rest, next| {
405            quote_spanned! {next.span()=>
406                <#next as #root::dfir_pipes::Context<'__ctx>>::Merged<#rest>
407            }
408        })
409        .unwrap_or_else(|| quote!(()));
410
411    let can_pend = variant_generics_push
412        .iter()
413        .zip(variant_output_types.iter())
414        .rev()
415        .map(|(push_generic, output_type)| {
416            quote_spanned! {push_generic.span()=>
417                <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::CanPend
418            }
419        })
420        .reduce(|rest, next| {
421            quote_spanned! {next.span()=>
422                <#next as #root::dfir_pipes::Toggle>::Or<#rest>
423            }
424        })
425        .unwrap_or_else(|| quote!(#root::dfir_pipes::No));
426
427    // Generate `Ctx`: `unmerge_self` for each push, `unmerge_other` to get remaining `__ctx`.
428    // For the last push, just pass `__ctx` directly (no unmerge needed).
429    let push_poll_unwrap_context = |method_name: Ident| {
430        variant_localvars_push.split_last().map(|(lastvar, headvar)| {
431            // `#( ... )*` zips all iterators to shortest; `headvar` (all-but-last) is shortest, so
432            // `variant_generics_push` and `variant_output_types` are naturally truncated to match.
433            quote! {
434                #(
435                    let #headvar = {
436                        let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_self(__ctx);
437                        #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#headvar), __ctx)
438                    };
439                    let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_other(__ctx);
440                )*
441                let #lastvar = #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#lastvar), __ctx);
442                // If any are pending, return pending.
443                #(
444                    if #variant_localvars_push.is_pending() {
445                        return #root::dfir_pipes::push::PushStep::pending();
446                    }
447                )*
448            }
449        })
450    };
451    let push_poll_ready_body = (push_poll_unwrap_context)(format_ident!("poll_ready"));
452    let push_poll_flush_body = (push_poll_unwrap_context)(format_ident!("poll_flush"));
453
454    let variant_pats_push_send =
455        variants
456            .iter()
457            .zip(variant_localvars_push.iter())
458            .map(|(variant, pushvar)| {
459                let Variant { ident, fields, .. } = variant;
460                let (fields_pat, push_item) = field_pattern_item(fields);
461                quote! {
462                    Self::#ident #fields_pat => { #root::dfir_pipes::push::Push::start_send(#pushvar.as_mut(), #push_item, __meta); }
463                }
464            });
465
466    let (impl_generics_push, _ty_generics_push, where_clause_push) =
467        full_generics_push.split_for_impl();
468
469    let single_impl = (1 == variants.len()).then(|| {
470        let Variant { ident, fields, .. } = variants.first().unwrap();
471        let (fields_pat, push_item) = field_pattern_item(fields);
472        let out_type = variant_output_types.first().unwrap();
473        quote! {
474            impl #impl_generics_item #root::util::demux_enum::SingleVariant
475                for #item_ident #ty_generics #where_clause_item
476            {
477                type Output = #out_type;
478                fn single_variant(self) -> Self::Output {
479                    match self {
480                        Self::#ident #fields_pat => #push_item,
481                    }
482                }
483            }
484        }
485    });
486
487    quote! {
488        impl #impl_generics_sink #root::util::demux_enum::DemuxEnumSink<#variant_generics_pinned_sink_all>
489            for #item_ident #ty_generics #where_clause_sink
490        {
491            type Error = #root::Never;
492
493            fn poll_ready(
494                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
495                __cx: &mut ::std::task::Context<'_>,
496            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
497                // Ready all sinks simultaneously.
498                #(
499                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_ready(__cx)?;
500                )*
501                #(
502                    ::std::task::ready!(#variant_localvars_sink);
503                )*
504                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
505            }
506
507            fn start_send(
508                self,
509                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
510            ) -> ::std::result::Result<(), Self::Error> {
511                match self {
512                    #( #variant_pats_sink_start_send, )*
513                }
514            }
515
516            fn poll_flush(
517                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
518                __cx: &mut ::std::task::Context<'_>,
519            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
520                // Flush all sinks simultaneously.
521                #(
522                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_flush(__cx)?;
523                )*
524                #(
525                    ::std::task::ready!(#variant_localvars_sink);
526                )*
527                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
528            }
529
530            fn poll_close(
531                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
532                __cx: &mut ::std::task::Context<'_>,
533            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
534                // Close all sinks simultaneously.
535                #(
536                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_close(__cx)?;
537                )*
538                #(
539                    ::std::task::ready!(#variant_localvars_sink);
540                )*
541                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
542            }
543        }
544
545        impl #impl_generics_push #root::util::demux_enum::DemuxEnumPush<#variant_generics_pinned_push_all, ()>
546            for #item_ident #ty_generics #where_clause_push
547        {
548            type Ctx<'__ctx> = #ctx_type;
549            type CanPend = #can_pend;
550
551            fn poll_ready(
552                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
553                __ctx: &mut Self::Ctx<'_>,
554            ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
555                #push_poll_ready_body
556                #root::dfir_pipes::push::PushStep::Done
557            }
558
559            fn start_send(
560                self,
561                __meta: (),
562                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
563            ) {
564                match self {
565                    #( #variant_pats_push_send, )*
566                }
567            }
568
569            fn poll_flush(
570                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
571                __ctx: &mut Self::Ctx<'_>,
572            ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
573                #push_poll_flush_body
574                #root::dfir_pipes::push::PushStep::Done
575            }
576
577            fn size_hint(
578                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
579                __size_hint: (usize, ::std::option::Option<usize>),
580            ) {
581                #(
582                    #root::dfir_pipes::push::Push::size_hint(
583                        ::std::pin::Pin::as_mut(#variant_localvars_push),
584                        __size_hint,
585                    );
586                )*
587            }
588        }
589
590        impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
591            for #item_ident #ty_generics #where_clause_item {}
592
593        #single_impl
594    }
595    .into()
596}
597
598/// (fields pattern, push item expr)
599fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
600    let idents = fields
601        .iter()
602        .enumerate()
603        .map(|(i, field)| {
604            field
605                .ident
606                .clone()
607                .unwrap_or_else(|| format_ident!("_{}", i))
608        })
609        .collect::<Vec<_>>();
610    let (fields_pat, push_item) = match fields {
611        Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
612        Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
613        Fields::Unit => (quote!(), quote!(())),
614    };
615    (fields_pat, push_item)
616}