1use dfir_lang::diagnostic::Level;
2use dfir_lang::graph::{
3 BuildDfirCodeOutput, FlatGraphBuilder, FlatGraphBuilderOutput, build_dfir_code, partition_graph,
4};
5use dfir_lang::parse::DfirCode;
6use proc_macro2::{Ident, Literal, Span};
7use quote::{format_ident, quote, quote_spanned};
8use syn::spanned::Spanned;
9use syn::{
10 Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
11 parse_quote,
12};
13
14#[proc_macro]
21pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
22 dfir_syntax_internal(input, Some(Level::Help))
23}
24
25#[proc_macro]
29pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
30 dfir_syntax_internal(input, None)
31}
32
33fn root() -> proc_macro2::TokenStream {
34 use std::env::{VarError, var as env_var};
35
36 let root_crate_name = format!(
37 "{}_rs",
38 env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
39 );
40 let root_crate_ident = root_crate_name.replace('-', "_");
41 let root_crate = proc_macro_crate::crate_name(&root_crate_name)
42 .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
43 match root_crate {
44 proc_macro_crate::FoundCrate::Itself => {
45 if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
46 && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
47 && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
48 {
49 quote! { crate }
51 } else {
52 let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
54 quote! { ::#ident }
55 }
56 }
57 proc_macro_crate::FoundCrate::Name(name) => {
58 let ident = Ident::new(&name, Span::call_site());
59 quote! { ::#ident }
60 }
61 }
62}
63
64fn dfir_syntax_internal(
65 input: proc_macro::TokenStream,
66 retain_diagnostic_level: Option<Level>,
67) -> proc_macro::TokenStream {
68 let input = parse_macro_input!(input as DfirCode);
69 let root = root();
70
71 let (code, mut diagnostics) = match build_dfir_code(input, &root) {
72 Ok(BuildDfirCodeOutput {
73 partitioned_graph: _,
74 code,
75 diagnostics,
76 }) => (code, diagnostics),
77 Err(diagnostics) => (
78 quote! {
79 {
80 #root::scheduled::context::Dfir::new(
81 #root::scheduled::context::NullTickClosure,
82 <#root::scheduled::context::Context as ::std::default::Default>::default(),
83 None,
84 None,
85 )
86 }
87 },
88 diagnostics,
89 ),
90 };
91
92 let diagnostic_tokens = retain_diagnostic_level.and_then(|level| {
93 diagnostics.retain_level(level);
94 diagnostics.try_emit_all().err()
95 });
96
97 quote! {
98 {
99 #diagnostic_tokens
100 #code
101 }
102 }
103 .into()
104}
105
106#[proc_macro]
110pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
111 let input = parse_macro_input!(input as DfirCode);
112
113 let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
114 let err_diagnostics = 'err: {
115 let (mut flat_graph, mut diagnostics) = match flat_graph_builder.build() {
116 Ok(FlatGraphBuilderOutput {
117 flat_graph,
118 uses: _,
119 diagnostics,
120 }) => (flat_graph, diagnostics),
121 Err(diagnostics) => {
122 break 'err diagnostics;
123 }
124 };
125
126 if let Err(diagnostic) = flat_graph.merge_modules() {
127 diagnostics.push(diagnostic);
128 break 'err diagnostics;
129 }
130
131 let flat_mermaid = flat_graph.mermaid_string_flat();
132
133 let part_mermaid = partition_graph(flat_graph)
134 .map(|part_graph| part_graph.to_mermaid(&Default::default()))
135 .unwrap_or_else(|err| format!("failed to partition: {err}"));
136
137 let lit0 = Literal::string(&flat_mermaid);
138 let lit1 = Literal::string(&part_mermaid);
139
140 return quote! {
141 {
142 println!("{}\n\n{}\n", #lit0, #lit1);
143 }
144 }
145 .into();
146 };
147
148 err_diagnostics
149 .try_emit_all()
150 .err()
151 .unwrap_or_default()
152 .into()
153}
154
155fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
156 use quote::ToTokens;
157
158 let root = root();
159
160 let mut input: syn::ItemFn = match syn::parse(item) {
161 Ok(it) => it,
162 Err(e) => return e.into_compile_error().into(),
163 };
164
165 let statements = input.block.stmts;
166
167 input.block.stmts = parse_quote!(
168 #root::tokio::task::LocalSet::new().run_until(async {
169 #( #statements )*
170 }).await
171 );
172
173 input.attrs.push(attribute);
174
175 input.into_token_stream().into()
176}
177
178#[proc_macro]
180pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
181 item
183}
184
185#[proc_macro]
187pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
188 item
190}
191
192#[proc_macro_attribute]
193pub fn dfir_test(
194 args: proc_macro::TokenStream,
195 item: proc_macro::TokenStream,
196) -> proc_macro::TokenStream {
197 let root = root();
198 let args_2: proc_macro2::TokenStream = args.into();
199
200 wrap_localset(
201 item,
202 parse_quote!(
203 #[#root::tokio::test(flavor = "current_thread", #args_2)]
204 ),
205 )
206}
207
208#[proc_macro_attribute]
209pub fn dfir_main(
210 _: proc_macro::TokenStream,
211 item: proc_macro::TokenStream,
212) -> proc_macro::TokenStream {
213 let root = root();
214
215 wrap_localset(
216 item,
217 parse_quote!(
218 #[#root::tokio::main(flavor = "current_thread")]
219 ),
220 )
221}
222
223#[proc_macro_derive(DemuxEnum)]
224pub fn derive_demux_enum(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
225 let root = root();
226
227 let ItemEnum {
228 ident: item_ident,
229 generics,
230 variants,
231 ..
232 } = parse_macro_input!(item as ItemEnum);
233
234 let mut variants = variants.into_iter().collect::<Vec<_>>();
236 variants.sort_by(|a, b| a.ident.cmp(&b.ident));
237
238 let variant_output_types = variants
240 .iter()
241 .map(|variant| match &variant.fields {
242 Fields::Named(fields) => {
243 let field_types = fields.named.iter().map(|field| &field.ty);
244 quote! {
245 ( #( #field_types, )* )
246 }
247 }
248 Fields::Unnamed(fields) => {
249 let field_types = fields.unnamed.iter().map(|field| &field.ty);
250 quote! {
251 ( #( #field_types, )* )
252 }
253 }
254 Fields::Unit => quote!(()),
255 })
256 .collect::<Vec<_>>();
257
258 let variant_generics_sink = variants
259 .iter()
260 .map(|variant| format_ident!("__Sink{}", variant.ident))
261 .collect::<Vec<_>>();
262 let variant_generics_pinned_sink = variant_generics_sink.iter().map(|ident| {
263 quote_spanned! {ident.span()=>
264 ::std::pin::Pin::<&mut #ident>
265 }
266 });
267 let variant_generics_pinned_sink_all = quote! {
268 ( #( #variant_generics_pinned_sink, )* )
269 };
270 let variant_localvars_sink = variants
271 .iter()
272 .map(|variant| {
273 format_ident!(
274 "__sink_{}",
275 variant.ident.to_string().to_lowercase(),
276 span = variant.ident.span()
277 )
278 })
279 .collect::<Vec<_>>();
280
281 let mut full_generics_sink = generics.clone();
282 full_generics_sink.params.extend(
283 variant_generics_sink
284 .iter()
285 .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
286 );
287 full_generics_sink.make_where_clause().predicates.extend(
288 variant_generics_sink
289 .iter()
290 .zip(variant_output_types.iter())
291 .map::<WherePredicate, _>(|(sink_generic, output_type)| {
292 parse_quote! {
293 #sink_generic: #root::futures::sink::Sink<#output_type, Error = #root::Never>
295 }
296 }),
297 );
298
299 let variant_pats_sink_start_send = variants.iter().zip(variant_localvars_sink.iter()).map(
300 |(variant, sinkvar)| {
301 let Variant { ident, fields, .. } = variant;
302 let (fields_pat, push_item) = field_pattern_item(fields);
303 quote! {
304 Self::#ident #fields_pat => ::std::pin::Pin::as_mut(#sinkvar).start_send(#push_item)
305 }
306 },
307 );
308
309 let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
310 let (impl_generics_sink, _ty_generics_sink, where_clause_sink) =
311 full_generics_sink.split_for_impl();
312
313 let variant_generics_push = variants
314 .iter()
315 .map(|variant| format_ident!("__Push{}", variant.ident))
316 .collect::<Vec<_>>();
317 let variant_generics_pinned_push = variant_generics_push.iter().map(|ident| {
318 quote_spanned! {ident.span()=>
319 ::std::pin::Pin::<&mut #ident>
320 }
321 });
322 let variant_generics_pinned_push_all = quote! {
323 ( #( #variant_generics_pinned_push, )* )
324 };
325 let variant_localvars_push = variants
326 .iter()
327 .map(|variant| {
328 format_ident!(
329 "__push_{}",
330 variant.ident.to_string().to_lowercase(),
331 span = variant.ident.span()
332 )
333 })
334 .collect::<Vec<_>>();
335
336 let mut full_generics_push = generics.clone();
337 full_generics_push.params.extend(
338 variant_generics_push
339 .iter()
340 .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
341 );
342 full_generics_push.make_where_clause().predicates.extend(
344 variant_generics_push
345 .iter()
346 .zip(variant_output_types.iter())
347 .map::<WherePredicate, _>(|(push_generic, output_type)| {
348 parse_quote! {
349 #push_generic: #root::dfir_pipes::push::Push<#output_type, ()>
350 }
351 }),
352 );
353
354 let ctx_type = variant_generics_push
360 .iter()
361 .zip(variant_output_types.iter())
362 .rev()
363 .map(|(push_generic, output_type)| {
364 quote_spanned! {push_generic.span()=>
365 <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::Ctx<'__ctx>
366 }
367 })
368 .reduce(|rest, next| {
369 quote_spanned! {next.span()=>
370 <#next as #root::dfir_pipes::Context<'__ctx>>::Merged<#rest>
371 }
372 })
373 .unwrap_or_else(|| quote!(()));
374
375 let can_pend = variant_generics_push
376 .iter()
377 .zip(variant_output_types.iter())
378 .rev()
379 .map(|(push_generic, output_type)| {
380 quote_spanned! {push_generic.span()=>
381 <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::CanPend
382 }
383 })
384 .reduce(|rest, next| {
385 quote_spanned! {next.span()=>
386 <#next as #root::dfir_pipes::Toggle>::Or<#rest>
387 }
388 })
389 .unwrap_or_else(|| quote!(#root::dfir_pipes::No));
390
391 let push_poll_unwrap_context = |method_name: Ident| {
394 variant_localvars_push.split_last().map(|(lastvar, headvar)| {
395 quote! {
398 #(
399 let #headvar = {
400 let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_self(__ctx);
401 #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#headvar), __ctx)
402 };
403 let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_other(__ctx);
404 )*
405 let #lastvar = #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#lastvar), __ctx);
406 #(
408 if #variant_localvars_push.is_pending() {
409 return #root::dfir_pipes::push::PushStep::pending();
410 }
411 )*
412 }
413 })
414 };
415 let push_poll_ready_body = (push_poll_unwrap_context)(format_ident!("poll_ready"));
416 let push_poll_finalize_body = (push_poll_unwrap_context)(format_ident!("poll_finalize"));
417
418 let variant_pats_push_send =
419 variants
420 .iter()
421 .zip(variant_localvars_push.iter())
422 .map(|(variant, pushvar)| {
423 let Variant { ident, fields, .. } = variant;
424 let (fields_pat, push_item) = field_pattern_item(fields);
425 quote! {
426 Self::#ident #fields_pat => { #root::dfir_pipes::push::Push::start_send(#pushvar.as_mut(), #push_item, __meta); }
427 }
428 });
429
430 let (impl_generics_push, _ty_generics_push, where_clause_push) =
431 full_generics_push.split_for_impl();
432
433 let single_impl = (1 == variants.len()).then(|| {
434 let Variant { ident, fields, .. } = variants.first().unwrap();
435 let (fields_pat, push_item) = field_pattern_item(fields);
436 let out_type = variant_output_types.first().unwrap();
437 quote! {
438 impl #impl_generics_item #root::util::demux_enum::SingleVariant
439 for #item_ident #ty_generics #where_clause_item
440 {
441 type Output = #out_type;
442 fn single_variant(self) -> Self::Output {
443 match self {
444 Self::#ident #fields_pat => #push_item,
445 }
446 }
447 }
448 }
449 });
450
451 quote! {
452 impl #impl_generics_sink #root::util::demux_enum::DemuxEnumSink<#variant_generics_pinned_sink_all>
453 for #item_ident #ty_generics #where_clause_sink
454 {
455 type Error = #root::Never;
456
457 fn poll_ready(
458 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
459 __cx: &mut ::std::task::Context<'_>,
460 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
461 #(
463 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_ready(__cx)?;
464 )*
465 #(
466 ::std::task::ready!(#variant_localvars_sink);
467 )*
468 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
469 }
470
471 fn start_send(
472 self,
473 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
474 ) -> ::std::result::Result<(), Self::Error> {
475 match self {
476 #( #variant_pats_sink_start_send, )*
477 }
478 }
479
480 fn poll_flush(
481 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
482 __cx: &mut ::std::task::Context<'_>,
483 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
484 #(
486 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_flush(__cx)?;
487 )*
488 #(
489 ::std::task::ready!(#variant_localvars_sink);
490 )*
491 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
492 }
493
494 fn poll_close(
495 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
496 __cx: &mut ::std::task::Context<'_>,
497 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
498 #(
500 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_close(__cx)?;
501 )*
502 #(
503 ::std::task::ready!(#variant_localvars_sink);
504 )*
505 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
506 }
507 }
508
509 impl #impl_generics_push #root::util::demux_enum::DemuxEnumPush<#variant_generics_pinned_push_all, ()>
510 for #item_ident #ty_generics #where_clause_push
511 {
512 type Ctx<'__ctx> = #ctx_type;
513 type CanPend = #can_pend;
514
515 fn poll_ready(
516 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
517 __ctx: &mut Self::Ctx<'_>,
518 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
519 #push_poll_ready_body
520 #root::dfir_pipes::push::PushStep::Done
521 }
522
523 fn start_send(
524 self,
525 __meta: (),
526 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
527 ) {
528 match self {
529 #( #variant_pats_push_send, )*
530 }
531 }
532
533 fn poll_finalize(
534 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
535 __ctx: &mut Self::Ctx<'_>,
536 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
537 #push_poll_finalize_body
538 #root::dfir_pipes::push::PushStep::Done
539 }
540
541 fn size_hint(
542 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
543 __size_hint: (usize, ::std::option::Option<usize>),
544 ) {
545 #(
546 #root::dfir_pipes::push::Push::size_hint(
547 ::std::pin::Pin::as_mut(#variant_localvars_push),
548 __size_hint,
549 );
550 )*
551 }
552 }
553
554 impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
555 for #item_ident #ty_generics #where_clause_item {}
556
557 #single_impl
558 }
559 .into()
560}
561
562fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
564 let idents = fields
565 .iter()
566 .enumerate()
567 .map(|(i, field)| {
568 field
569 .ident
570 .clone()
571 .unwrap_or_else(|| format_ident!("_{}", i))
572 })
573 .collect::<Vec<_>>();
574 let (fields_pat, push_item) = match fields {
575 Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
576 Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
577 Fields::Unit => (quote!(), quote!(())),
578 };
579 (fields_pat, push_item)
580}