1#![cfg_attr(
2 nightly,
3 feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::Level;
7use dfir_lang::graph::{
8 BuildDfirCodeOutput, FlatGraphBuilder, FlatGraphBuilderOutput, build_dfir_code, partition_graph,
9};
10use dfir_lang::parse::DfirCode;
11use proc_macro2::{Ident, Literal, Span};
12use quote::{format_ident, quote, quote_spanned};
13use syn::spanned::Spanned;
14use syn::{
15 Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
16 parse_quote,
17};
18
19#[proc_macro]
26pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
27 dfir_syntax_internal(input, Some(Level::Help))
28}
29
30#[proc_macro]
34pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
35 dfir_syntax_internal(input, None)
36}
37
38fn root() -> proc_macro2::TokenStream {
39 use std::env::{VarError, var as env_var};
40
41 let root_crate_name = format!(
42 "{}_rs",
43 env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
44 );
45 let root_crate_ident = root_crate_name.replace('-', "_");
46 let root_crate = proc_macro_crate::crate_name(&root_crate_name)
47 .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
48 match root_crate {
49 proc_macro_crate::FoundCrate::Itself => {
50 if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
51 && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
52 && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
53 {
54 quote! { crate }
56 } else {
57 let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
59 quote! { ::#ident }
60 }
61 }
62 proc_macro_crate::FoundCrate::Name(name) => {
63 let ident = Ident::new(&name, Span::call_site());
64 quote! { ::#ident }
65 }
66 }
67}
68
69fn dfir_syntax_internal(
70 input: proc_macro::TokenStream,
71 retain_diagnostic_level: Option<Level>,
72) -> proc_macro::TokenStream {
73 let input = parse_macro_input!(input as DfirCode);
74 let root = root();
75
76 let (code, mut diagnostics) = match build_dfir_code(input, &root) {
77 Ok(BuildDfirCodeOutput {
78 partitioned_graph: _,
79 code,
80 diagnostics,
81 }) => (code, diagnostics),
82 Err(diagnostics) => (
83 quote! {
84 {
85 #root::scheduled::context::Dfir::new(
86 #root::scheduled::context::NullTickClosure,
87 <#root::scheduled::context::Context as ::std::default::Default>::default(),
88 None,
89 None,
90 )
91 }
92 },
93 diagnostics,
94 ),
95 };
96
97 let diagnostic_tokens = retain_diagnostic_level.and_then(|level| {
98 diagnostics.retain_level(level);
99 diagnostics.try_emit_all().err()
100 });
101
102 quote! {
103 {
104 #diagnostic_tokens
105 #code
106 }
107 }
108 .into()
109}
110
111#[proc_macro]
115pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
116 let input = parse_macro_input!(input as DfirCode);
117
118 let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
119 let err_diagnostics = 'err: {
120 let (mut flat_graph, mut diagnostics) = match flat_graph_builder.build() {
121 Ok(FlatGraphBuilderOutput {
122 flat_graph,
123 uses: _,
124 diagnostics,
125 }) => (flat_graph, diagnostics),
126 Err(diagnostics) => {
127 break 'err diagnostics;
128 }
129 };
130
131 if let Err(diagnostic) = flat_graph.merge_modules() {
132 diagnostics.push(diagnostic);
133 break 'err diagnostics;
134 }
135
136 let flat_mermaid = flat_graph.mermaid_string_flat();
137
138 let part_mermaid = partition_graph(flat_graph)
139 .map(|part_graph| part_graph.to_mermaid(&Default::default()))
140 .unwrap_or_else(|err| format!("failed to partition: {err}"));
141
142 let lit0 = Literal::string(&flat_mermaid);
143 let lit1 = Literal::string(&part_mermaid);
144
145 return quote! {
146 {
147 println!("{}\n\n{}\n", #lit0, #lit1);
148 }
149 }
150 .into();
151 };
152
153 err_diagnostics
154 .try_emit_all()
155 .err()
156 .unwrap_or_default()
157 .into()
158}
159
160fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
161 use quote::ToTokens;
162
163 let root = root();
164
165 let mut input: syn::ItemFn = match syn::parse(item) {
166 Ok(it) => it,
167 Err(e) => return e.into_compile_error().into(),
168 };
169
170 let statements = input.block.stmts;
171
172 input.block.stmts = parse_quote!(
173 #root::tokio::task::LocalSet::new().run_until(async {
174 #( #statements )*
175 }).await
176 );
177
178 input.attrs.push(attribute);
179
180 input.into_token_stream().into()
181}
182
183#[proc_macro]
185pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
186 item
188}
189
190#[proc_macro]
192pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
193 item
195}
196
197#[proc_macro_attribute]
198pub fn dfir_test(
199 args: proc_macro::TokenStream,
200 item: proc_macro::TokenStream,
201) -> proc_macro::TokenStream {
202 let root = root();
203 let args_2: proc_macro2::TokenStream = args.into();
204
205 wrap_localset(
206 item,
207 parse_quote!(
208 #[#root::tokio::test(flavor = "current_thread", #args_2)]
209 ),
210 )
211}
212
213#[proc_macro_attribute]
214pub fn dfir_main(
215 _: proc_macro::TokenStream,
216 item: proc_macro::TokenStream,
217) -> proc_macro::TokenStream {
218 let root = root();
219
220 wrap_localset(
221 item,
222 parse_quote!(
223 #[#root::tokio::main(flavor = "current_thread")]
224 ),
225 )
226}
227
228#[proc_macro_derive(DemuxEnum)]
229pub fn derive_demux_enum(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
230 let root = root();
231
232 let ItemEnum {
233 ident: item_ident,
234 generics,
235 variants,
236 ..
237 } = parse_macro_input!(item as ItemEnum);
238
239 let mut variants = variants.into_iter().collect::<Vec<_>>();
241 variants.sort_by(|a, b| a.ident.cmp(&b.ident));
242
243 let variant_output_types = variants
245 .iter()
246 .map(|variant| match &variant.fields {
247 Fields::Named(fields) => {
248 let field_types = fields.named.iter().map(|field| &field.ty);
249 quote! {
250 ( #( #field_types, )* )
251 }
252 }
253 Fields::Unnamed(fields) => {
254 let field_types = fields.unnamed.iter().map(|field| &field.ty);
255 quote! {
256 ( #( #field_types, )* )
257 }
258 }
259 Fields::Unit => quote!(()),
260 })
261 .collect::<Vec<_>>();
262
263 let variant_generics_sink = variants
264 .iter()
265 .map(|variant| format_ident!("__Sink{}", variant.ident))
266 .collect::<Vec<_>>();
267 let variant_generics_pinned_sink = variant_generics_sink.iter().map(|ident| {
268 quote_spanned! {ident.span()=>
269 ::std::pin::Pin::<&mut #ident>
270 }
271 });
272 let variant_generics_pinned_sink_all = quote! {
273 ( #( #variant_generics_pinned_sink, )* )
274 };
275 let variant_localvars_sink = variants
276 .iter()
277 .map(|variant| {
278 format_ident!(
279 "__sink_{}",
280 variant.ident.to_string().to_lowercase(),
281 span = variant.ident.span()
282 )
283 })
284 .collect::<Vec<_>>();
285
286 let mut full_generics_sink = generics.clone();
287 full_generics_sink.params.extend(
288 variant_generics_sink
289 .iter()
290 .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
291 );
292 full_generics_sink.make_where_clause().predicates.extend(
293 variant_generics_sink
294 .iter()
295 .zip(variant_output_types.iter())
296 .map::<WherePredicate, _>(|(sink_generic, output_type)| {
297 parse_quote! {
298 #sink_generic: #root::futures::sink::Sink<#output_type, Error = #root::Never>
300 }
301 }),
302 );
303
304 let variant_pats_sink_start_send = variants.iter().zip(variant_localvars_sink.iter()).map(
305 |(variant, sinkvar)| {
306 let Variant { ident, fields, .. } = variant;
307 let (fields_pat, push_item) = field_pattern_item(fields);
308 quote! {
309 Self::#ident #fields_pat => ::std::pin::Pin::as_mut(#sinkvar).start_send(#push_item)
310 }
311 },
312 );
313
314 let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
315 let (impl_generics_sink, _ty_generics_sink, where_clause_sink) =
316 full_generics_sink.split_for_impl();
317
318 let variant_generics_push = variants
319 .iter()
320 .map(|variant| format_ident!("__Push{}", variant.ident))
321 .collect::<Vec<_>>();
322 let variant_generics_pinned_push = variant_generics_push.iter().map(|ident| {
323 quote_spanned! {ident.span()=>
324 ::std::pin::Pin::<&mut #ident>
325 }
326 });
327 let variant_generics_pinned_push_all = quote! {
328 ( #( #variant_generics_pinned_push, )* )
329 };
330 let variant_localvars_push = variants
331 .iter()
332 .map(|variant| {
333 format_ident!(
334 "__push_{}",
335 variant.ident.to_string().to_lowercase(),
336 span = variant.ident.span()
337 )
338 })
339 .collect::<Vec<_>>();
340
341 let mut full_generics_push = generics.clone();
342 full_generics_push.params.extend(
343 variant_generics_push
344 .iter()
345 .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
346 );
347 full_generics_push.make_where_clause().predicates.extend(
349 variant_generics_push
350 .iter()
351 .zip(variant_output_types.iter())
352 .map::<WherePredicate, _>(|(push_generic, output_type)| {
353 parse_quote! {
354 #push_generic: #root::dfir_pipes::push::Push<#output_type, ()>
355 }
356 }),
357 );
358
359 let ctx_type = variant_generics_push
365 .iter()
366 .zip(variant_output_types.iter())
367 .rev()
368 .map(|(push_generic, output_type)| {
369 quote_spanned! {push_generic.span()=>
370 <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::Ctx<'__ctx>
371 }
372 })
373 .reduce(|rest, next| {
374 quote_spanned! {next.span()=>
375 <#next as #root::dfir_pipes::Context<'__ctx>>::Merged<#rest>
376 }
377 })
378 .unwrap_or_else(|| quote!(()));
379
380 let can_pend = variant_generics_push
381 .iter()
382 .zip(variant_output_types.iter())
383 .rev()
384 .map(|(push_generic, output_type)| {
385 quote_spanned! {push_generic.span()=>
386 <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::CanPend
387 }
388 })
389 .reduce(|rest, next| {
390 quote_spanned! {next.span()=>
391 <#next as #root::dfir_pipes::Toggle>::Or<#rest>
392 }
393 })
394 .unwrap_or_else(|| quote!(#root::dfir_pipes::No));
395
396 let push_poll_unwrap_context = |method_name: Ident| {
399 variant_localvars_push.split_last().map(|(lastvar, headvar)| {
400 quote! {
403 #(
404 let #headvar = {
405 let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_self(__ctx);
406 #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#headvar), __ctx)
407 };
408 let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_other(__ctx);
409 )*
410 let #lastvar = #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#lastvar), __ctx);
411 #(
413 if #variant_localvars_push.is_pending() {
414 return #root::dfir_pipes::push::PushStep::pending();
415 }
416 )*
417 }
418 })
419 };
420 let push_poll_ready_body = (push_poll_unwrap_context)(format_ident!("poll_ready"));
421 let push_poll_flush_body = (push_poll_unwrap_context)(format_ident!("poll_flush"));
422
423 let variant_pats_push_send =
424 variants
425 .iter()
426 .zip(variant_localvars_push.iter())
427 .map(|(variant, pushvar)| {
428 let Variant { ident, fields, .. } = variant;
429 let (fields_pat, push_item) = field_pattern_item(fields);
430 quote! {
431 Self::#ident #fields_pat => { #root::dfir_pipes::push::Push::start_send(#pushvar.as_mut(), #push_item, __meta); }
432 }
433 });
434
435 let (impl_generics_push, _ty_generics_push, where_clause_push) =
436 full_generics_push.split_for_impl();
437
438 let single_impl = (1 == variants.len()).then(|| {
439 let Variant { ident, fields, .. } = variants.first().unwrap();
440 let (fields_pat, push_item) = field_pattern_item(fields);
441 let out_type = variant_output_types.first().unwrap();
442 quote! {
443 impl #impl_generics_item #root::util::demux_enum::SingleVariant
444 for #item_ident #ty_generics #where_clause_item
445 {
446 type Output = #out_type;
447 fn single_variant(self) -> Self::Output {
448 match self {
449 Self::#ident #fields_pat => #push_item,
450 }
451 }
452 }
453 }
454 });
455
456 quote! {
457 impl #impl_generics_sink #root::util::demux_enum::DemuxEnumSink<#variant_generics_pinned_sink_all>
458 for #item_ident #ty_generics #where_clause_sink
459 {
460 type Error = #root::Never;
461
462 fn poll_ready(
463 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
464 __cx: &mut ::std::task::Context<'_>,
465 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
466 #(
468 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_ready(__cx)?;
469 )*
470 #(
471 ::std::task::ready!(#variant_localvars_sink);
472 )*
473 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
474 }
475
476 fn start_send(
477 self,
478 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
479 ) -> ::std::result::Result<(), Self::Error> {
480 match self {
481 #( #variant_pats_sink_start_send, )*
482 }
483 }
484
485 fn poll_flush(
486 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
487 __cx: &mut ::std::task::Context<'_>,
488 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
489 #(
491 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_flush(__cx)?;
492 )*
493 #(
494 ::std::task::ready!(#variant_localvars_sink);
495 )*
496 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
497 }
498
499 fn poll_close(
500 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
501 __cx: &mut ::std::task::Context<'_>,
502 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
503 #(
505 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_close(__cx)?;
506 )*
507 #(
508 ::std::task::ready!(#variant_localvars_sink);
509 )*
510 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
511 }
512 }
513
514 impl #impl_generics_push #root::util::demux_enum::DemuxEnumPush<#variant_generics_pinned_push_all, ()>
515 for #item_ident #ty_generics #where_clause_push
516 {
517 type Ctx<'__ctx> = #ctx_type;
518 type CanPend = #can_pend;
519
520 fn poll_ready(
521 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
522 __ctx: &mut Self::Ctx<'_>,
523 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
524 #push_poll_ready_body
525 #root::dfir_pipes::push::PushStep::Done
526 }
527
528 fn start_send(
529 self,
530 __meta: (),
531 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
532 ) {
533 match self {
534 #( #variant_pats_push_send, )*
535 }
536 }
537
538 fn poll_flush(
539 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
540 __ctx: &mut Self::Ctx<'_>,
541 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
542 #push_poll_flush_body
543 #root::dfir_pipes::push::PushStep::Done
544 }
545
546 fn size_hint(
547 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
548 __size_hint: (usize, ::std::option::Option<usize>),
549 ) {
550 #(
551 #root::dfir_pipes::push::Push::size_hint(
552 ::std::pin::Pin::as_mut(#variant_localvars_push),
553 __size_hint,
554 );
555 )*
556 }
557 }
558
559 impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
560 for #item_ident #ty_generics #where_clause_item {}
561
562 #single_impl
563 }
564 .into()
565}
566
567fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
569 let idents = fields
570 .iter()
571 .enumerate()
572 .map(|(i, field)| {
573 field
574 .ident
575 .clone()
576 .unwrap_or_else(|| format_ident!("_{}", i))
577 })
578 .collect::<Vec<_>>();
579 let (fields_pat, push_item) = match fields {
580 Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
581 Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
582 Fields::Unit => (quote!(), quote!(())),
583 };
584 (fields_pat, push_item)
585}