1#![cfg_attr(
2 nightly,
3 feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::Level;
7use dfir_lang::graph::{
8 BuildDfirCodeOutput, FlatGraphBuilder, FlatGraphBuilderOutput, build_dfir_code,
9 build_dfir_code_inline, partition_graph,
10};
11use dfir_lang::parse::DfirCode;
12use proc_macro2::{Ident, Literal, Span};
13use quote::{format_ident, quote, quote_spanned};
14use syn::spanned::Spanned;
15use syn::{
16 Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
17 parse_quote,
18};
19
20#[proc_macro]
27pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
28 dfir_syntax_internal(input, Some(Level::Help))
29}
30
31#[proc_macro]
35pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
36 dfir_syntax_internal(input, None)
37}
38
39fn root() -> proc_macro2::TokenStream {
40 use std::env::{VarError, var as env_var};
41
42 let root_crate_name = format!(
43 "{}_rs",
44 env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
45 );
46 let root_crate_ident = root_crate_name.replace('-', "_");
47 let root_crate = proc_macro_crate::crate_name(&root_crate_name)
48 .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
49 match root_crate {
50 proc_macro_crate::FoundCrate::Itself => {
51 if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
52 && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
53 && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
54 {
55 quote! { crate }
57 } else {
58 let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
60 quote! { ::#ident }
61 }
62 }
63 proc_macro_crate::FoundCrate::Name(name) => {
64 let ident = Ident::new(&name, Span::call_site());
65 quote! { ::#ident }
66 }
67 }
68}
69
70fn dfir_syntax_internal(
71 input: proc_macro::TokenStream,
72 retain_diagnostic_level: Option<Level>,
73) -> proc_macro::TokenStream {
74 let input = parse_macro_input!(input as DfirCode);
75 let root = root();
76
77 let (code, mut diagnostics) = match build_dfir_code(input, &root) {
78 Ok(BuildDfirCodeOutput {
79 partitioned_graph: _,
80 code,
81 diagnostics,
82 }) => (code, diagnostics),
83 Err(diagnostics) => (quote! { #root::scheduled::graph::Dfir::new() }, diagnostics),
84 };
85
86 let diagnostic_tokens = retain_diagnostic_level.and_then(|level| {
87 diagnostics.retain_level(level);
88 diagnostics.try_emit_all().err()
89 });
90
91 quote! {
92 {
93 #diagnostic_tokens
94 #code
95 }
96 }
97 .into()
98}
99
100#[proc_macro]
103pub fn dfir_syntax_inline(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
104 dfir_syntax_inline_internal(input, Some(Level::Help))
105}
106
107#[proc_macro]
109pub fn dfir_syntax_inline_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
110 dfir_syntax_inline_internal(input, None)
111}
112
113fn dfir_syntax_inline_internal(
114 input: proc_macro::TokenStream,
115 retain_diagnostic_level: Option<Level>,
116) -> proc_macro::TokenStream {
117 let input = parse_macro_input!(input as DfirCode);
118 let root = root();
119
120 let (code, mut diagnostics) = match build_dfir_code_inline(input, &root) {
121 Ok(BuildDfirCodeOutput {
122 partitioned_graph: _,
123 code,
124 diagnostics,
125 }) => (code, diagnostics),
126 Err(diagnostics) => (quote! { async move || {} }, diagnostics),
127 };
128
129 let diagnostic_tokens = retain_diagnostic_level.and_then(|level| {
130 diagnostics.retain_level(level);
131 diagnostics.try_emit_all().err()
132 });
133
134 let out = quote! {
135 {
136 #diagnostic_tokens
137 #code
138 }
139 };
140 out.into()
141}
142
143#[proc_macro]
147pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
148 let input = parse_macro_input!(input as DfirCode);
149
150 let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
151 let err_diagnostics = 'err: {
152 let (mut flat_graph, mut diagnostics) = match flat_graph_builder.build() {
153 Ok(FlatGraphBuilderOutput {
154 flat_graph,
155 uses: _,
156 diagnostics,
157 }) => (flat_graph, diagnostics),
158 Err(diagnostics) => {
159 break 'err diagnostics;
160 }
161 };
162
163 if let Err(diagnostic) = flat_graph.merge_modules() {
164 diagnostics.push(diagnostic);
165 break 'err diagnostics;
166 }
167
168 let flat_mermaid = flat_graph.mermaid_string_flat();
169
170 let part_graph = partition_graph(flat_graph).unwrap();
171 let part_mermaid = part_graph.to_mermaid(&Default::default());
172
173 let lit0 = Literal::string(&flat_mermaid);
174 let lit1 = Literal::string(&part_mermaid);
175
176 return quote! {
177 {
178 println!("{}\n\n{}\n", #lit0, #lit1);
179 }
180 }
181 .into();
182 };
183
184 err_diagnostics
185 .try_emit_all()
186 .err()
187 .unwrap_or_default()
188 .into()
189}
190
191fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
192 use quote::ToTokens;
193
194 let root = root();
195
196 let mut input: syn::ItemFn = match syn::parse(item) {
197 Ok(it) => it,
198 Err(e) => return e.into_compile_error().into(),
199 };
200
201 let statements = input.block.stmts;
202
203 input.block.stmts = parse_quote!(
204 #root::tokio::task::LocalSet::new().run_until(async {
205 #( #statements )*
206 }).await
207 );
208
209 input.attrs.push(attribute);
210
211 input.into_token_stream().into()
212}
213
214#[proc_macro]
216pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
217 item
219}
220
221#[proc_macro]
223pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
224 item
226}
227
228#[proc_macro_attribute]
229pub fn dfir_test(
230 args: proc_macro::TokenStream,
231 item: proc_macro::TokenStream,
232) -> proc_macro::TokenStream {
233 let root = root();
234 let args_2: proc_macro2::TokenStream = args.into();
235
236 wrap_localset(
237 item,
238 parse_quote!(
239 #[#root::tokio::test(flavor = "current_thread", #args_2)]
240 ),
241 )
242}
243
244#[proc_macro_attribute]
245pub fn dfir_main(
246 _: proc_macro::TokenStream,
247 item: proc_macro::TokenStream,
248) -> proc_macro::TokenStream {
249 let root = root();
250
251 wrap_localset(
252 item,
253 parse_quote!(
254 #[#root::tokio::main(flavor = "current_thread")]
255 ),
256 )
257}
258
259#[proc_macro_derive(DemuxEnum)]
260pub fn derive_demux_enum(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
261 let root = root();
262
263 let ItemEnum {
264 ident: item_ident,
265 generics,
266 variants,
267 ..
268 } = parse_macro_input!(item as ItemEnum);
269
270 let mut variants = variants.into_iter().collect::<Vec<_>>();
272 variants.sort_by(|a, b| a.ident.cmp(&b.ident));
273
274 let variant_output_types = variants
276 .iter()
277 .map(|variant| match &variant.fields {
278 Fields::Named(fields) => {
279 let field_types = fields.named.iter().map(|field| &field.ty);
280 quote! {
281 ( #( #field_types, )* )
282 }
283 }
284 Fields::Unnamed(fields) => {
285 let field_types = fields.unnamed.iter().map(|field| &field.ty);
286 quote! {
287 ( #( #field_types, )* )
288 }
289 }
290 Fields::Unit => quote!(()),
291 })
292 .collect::<Vec<_>>();
293
294 let variant_generics_sink = variants
295 .iter()
296 .map(|variant| format_ident!("__Sink{}", variant.ident))
297 .collect::<Vec<_>>();
298 let variant_generics_pinned_sink = variant_generics_sink.iter().map(|ident| {
299 quote_spanned! {ident.span()=>
300 ::std::pin::Pin::<&mut #ident>
301 }
302 });
303 let variant_generics_pinned_sink_all = quote! {
304 ( #( #variant_generics_pinned_sink, )* )
305 };
306 let variant_localvars_sink = variants
307 .iter()
308 .map(|variant| {
309 format_ident!(
310 "__sink_{}",
311 variant.ident.to_string().to_lowercase(),
312 span = variant.ident.span()
313 )
314 })
315 .collect::<Vec<_>>();
316
317 let mut full_generics_sink = generics.clone();
318 full_generics_sink.params.extend(
319 variant_generics_sink
320 .iter()
321 .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
322 );
323 full_generics_sink.make_where_clause().predicates.extend(
324 variant_generics_sink
325 .iter()
326 .zip(variant_output_types.iter())
327 .map::<WherePredicate, _>(|(sink_generic, output_type)| {
328 parse_quote! {
329 #sink_generic: #root::futures::sink::Sink<#output_type, Error = #root::Never>
331 }
332 }),
333 );
334
335 let variant_pats_sink_start_send = variants.iter().zip(variant_localvars_sink.iter()).map(
336 |(variant, sinkvar)| {
337 let Variant { ident, fields, .. } = variant;
338 let (fields_pat, push_item) = field_pattern_item(fields);
339 quote! {
340 Self::#ident #fields_pat => ::std::pin::Pin::as_mut(#sinkvar).start_send(#push_item)
341 }
342 },
343 );
344
345 let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
346 let (impl_generics_sink, _ty_generics_sink, where_clause_sink) =
347 full_generics_sink.split_for_impl();
348
349 let variant_generics_push = variants
350 .iter()
351 .map(|variant| format_ident!("__Push{}", variant.ident))
352 .collect::<Vec<_>>();
353 let variant_generics_pinned_push = variant_generics_push.iter().map(|ident| {
354 quote_spanned! {ident.span()=>
355 ::std::pin::Pin::<&mut #ident>
356 }
357 });
358 let variant_generics_pinned_push_all = quote! {
359 ( #( #variant_generics_pinned_push, )* )
360 };
361 let variant_localvars_push = variants
362 .iter()
363 .map(|variant| {
364 format_ident!(
365 "__push_{}",
366 variant.ident.to_string().to_lowercase(),
367 span = variant.ident.span()
368 )
369 })
370 .collect::<Vec<_>>();
371
372 let mut full_generics_push = generics.clone();
373 full_generics_push.params.extend(
374 variant_generics_push
375 .iter()
376 .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
377 );
378 full_generics_push.make_where_clause().predicates.extend(
380 variant_generics_push
381 .iter()
382 .zip(variant_output_types.iter())
383 .map::<WherePredicate, _>(|(push_generic, output_type)| {
384 parse_quote! {
385 #push_generic: #root::dfir_pipes::push::Push<#output_type, ()>
386 }
387 }),
388 );
389
390 let ctx_type = variant_generics_push
396 .iter()
397 .zip(variant_output_types.iter())
398 .rev()
399 .map(|(push_generic, output_type)| {
400 quote_spanned! {push_generic.span()=>
401 <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::Ctx<'__ctx>
402 }
403 })
404 .reduce(|rest, next| {
405 quote_spanned! {next.span()=>
406 <#next as #root::dfir_pipes::Context<'__ctx>>::Merged<#rest>
407 }
408 })
409 .unwrap_or_else(|| quote!(()));
410
411 let can_pend = variant_generics_push
412 .iter()
413 .zip(variant_output_types.iter())
414 .rev()
415 .map(|(push_generic, output_type)| {
416 quote_spanned! {push_generic.span()=>
417 <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::CanPend
418 }
419 })
420 .reduce(|rest, next| {
421 quote_spanned! {next.span()=>
422 <#next as #root::dfir_pipes::Toggle>::Or<#rest>
423 }
424 })
425 .unwrap_or_else(|| quote!(#root::dfir_pipes::No));
426
427 let push_poll_unwrap_context = |method_name: Ident| {
430 variant_localvars_push.split_last().map(|(lastvar, headvar)| {
431 quote! {
434 #(
435 let #headvar = {
436 let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_self(__ctx);
437 #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#headvar), __ctx)
438 };
439 let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_other(__ctx);
440 )*
441 let #lastvar = #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#lastvar), __ctx);
442 #(
444 if #variant_localvars_push.is_pending() {
445 return #root::dfir_pipes::push::PushStep::pending();
446 }
447 )*
448 }
449 })
450 };
451 let push_poll_ready_body = (push_poll_unwrap_context)(format_ident!("poll_ready"));
452 let push_poll_flush_body = (push_poll_unwrap_context)(format_ident!("poll_flush"));
453
454 let variant_pats_push_send =
455 variants
456 .iter()
457 .zip(variant_localvars_push.iter())
458 .map(|(variant, pushvar)| {
459 let Variant { ident, fields, .. } = variant;
460 let (fields_pat, push_item) = field_pattern_item(fields);
461 quote! {
462 Self::#ident #fields_pat => { #root::dfir_pipes::push::Push::start_send(#pushvar.as_mut(), #push_item, __meta); }
463 }
464 });
465
466 let (impl_generics_push, _ty_generics_push, where_clause_push) =
467 full_generics_push.split_for_impl();
468
469 let single_impl = (1 == variants.len()).then(|| {
470 let Variant { ident, fields, .. } = variants.first().unwrap();
471 let (fields_pat, push_item) = field_pattern_item(fields);
472 let out_type = variant_output_types.first().unwrap();
473 quote! {
474 impl #impl_generics_item #root::util::demux_enum::SingleVariant
475 for #item_ident #ty_generics #where_clause_item
476 {
477 type Output = #out_type;
478 fn single_variant(self) -> Self::Output {
479 match self {
480 Self::#ident #fields_pat => #push_item,
481 }
482 }
483 }
484 }
485 });
486
487 quote! {
488 impl #impl_generics_sink #root::util::demux_enum::DemuxEnumSink<#variant_generics_pinned_sink_all>
489 for #item_ident #ty_generics #where_clause_sink
490 {
491 type Error = #root::Never;
492
493 fn poll_ready(
494 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
495 __cx: &mut ::std::task::Context<'_>,
496 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
497 #(
499 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_ready(__cx)?;
500 )*
501 #(
502 ::std::task::ready!(#variant_localvars_sink);
503 )*
504 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
505 }
506
507 fn start_send(
508 self,
509 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
510 ) -> ::std::result::Result<(), Self::Error> {
511 match self {
512 #( #variant_pats_sink_start_send, )*
513 }
514 }
515
516 fn poll_flush(
517 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
518 __cx: &mut ::std::task::Context<'_>,
519 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
520 #(
522 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_flush(__cx)?;
523 )*
524 #(
525 ::std::task::ready!(#variant_localvars_sink);
526 )*
527 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
528 }
529
530 fn poll_close(
531 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
532 __cx: &mut ::std::task::Context<'_>,
533 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
534 #(
536 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_close(__cx)?;
537 )*
538 #(
539 ::std::task::ready!(#variant_localvars_sink);
540 )*
541 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
542 }
543 }
544
545 impl #impl_generics_push #root::util::demux_enum::DemuxEnumPush<#variant_generics_pinned_push_all, ()>
546 for #item_ident #ty_generics #where_clause_push
547 {
548 type Ctx<'__ctx> = #ctx_type;
549 type CanPend = #can_pend;
550
551 fn poll_ready(
552 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
553 __ctx: &mut Self::Ctx<'_>,
554 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
555 #push_poll_ready_body
556 #root::dfir_pipes::push::PushStep::Done
557 }
558
559 fn start_send(
560 self,
561 __meta: (),
562 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
563 ) {
564 match self {
565 #( #variant_pats_push_send, )*
566 }
567 }
568
569 fn poll_flush(
570 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
571 __ctx: &mut Self::Ctx<'_>,
572 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
573 #push_poll_flush_body
574 #root::dfir_pipes::push::PushStep::Done
575 }
576
577 fn size_hint(
578 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
579 __size_hint: (usize, ::std::option::Option<usize>),
580 ) {
581 #(
582 #root::dfir_pipes::push::Push::size_hint(
583 ::std::pin::Pin::as_mut(#variant_localvars_push),
584 __size_hint,
585 );
586 )*
587 }
588 }
589
590 impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
591 for #item_ident #ty_generics #where_clause_item {}
592
593 #single_impl
594 }
595 .into()
596}
597
598fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
600 let idents = fields
601 .iter()
602 .enumerate()
603 .map(|(i, field)| {
604 field
605 .ident
606 .clone()
607 .unwrap_or_else(|| format_ident!("_{}", i))
608 })
609 .collect::<Vec<_>>();
610 let (fields_pat, push_item) = match fields {
611 Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
612 Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
613 Fields::Unit => (quote!(), quote!(())),
614 };
615 (fields_pat, push_item)
616}