1#![cfg_attr(
2 nightly,
3 feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::{Diagnostic, Level};
7use dfir_lang::graph::{FlatGraphBuilder, build_hfcode, partition_graph};
8use dfir_lang::parse::DfirCode;
9use proc_macro2::{Ident, Literal, Span};
10use quote::{format_ident, quote, quote_spanned};
11use syn::{
12 Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
13 parse_quote,
14};
15
16#[proc_macro]
23pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24 dfir_syntax_internal(input, Some(Level::Help))
25}
26
27#[proc_macro]
31pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
32 dfir_syntax_internal(input, None)
33}
34
35fn root() -> proc_macro2::TokenStream {
36 use std::env::{VarError, var as env_var};
37
38 let root_crate_name = format!(
39 "{}_rs",
40 env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
41 );
42 let root_crate_ident = root_crate_name.replace('-', "_");
43 let root_crate = proc_macro_crate::crate_name(&root_crate_name)
44 .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
45 match root_crate {
46 proc_macro_crate::FoundCrate::Itself => {
47 if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
48 && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
49 && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
50 {
51 quote! { crate }
53 } else {
54 let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
56 quote! { ::#ident }
57 }
58 }
59 proc_macro_crate::FoundCrate::Name(name) => {
60 let ident = Ident::new(&name, Span::call_site());
61 quote! { ::#ident }
62 }
63 }
64}
65
66fn dfir_syntax_internal(
67 input: proc_macro::TokenStream,
68 min_diagnostic_level: Option<Level>,
69) -> proc_macro::TokenStream {
70 let input = parse_macro_input!(input as DfirCode);
71 let root = root();
72 let (graph_code_opt, diagnostics) = build_hfcode(input, &root);
73 let tokens = graph_code_opt
74 .map(|(_graph, code)| code)
75 .unwrap_or_else(|| quote! { #root::scheduled::graph::Dfir::new() });
76
77 let diagnostics = diagnostics
78 .iter()
79 .filter(|diag: &&Diagnostic| Some(diag.level) <= min_diagnostic_level);
80
81 let diagnostic_tokens = Diagnostic::try_emit_all(diagnostics)
82 .err()
83 .unwrap_or_default();
84 quote! {
85 {
86 #diagnostic_tokens
87 #tokens
88 }
89 }
90 .into()
91}
92
93#[proc_macro]
97pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
98 let input = parse_macro_input!(input as DfirCode);
99
100 let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
101 let (mut flat_graph, _uses, mut diagnostics) = flat_graph_builder.build();
102 if !diagnostics.iter().any(Diagnostic::is_error) {
103 if let Err(diagnostic) = flat_graph.merge_modules() {
104 diagnostics.push(diagnostic);
105 } else {
106 let flat_mermaid = flat_graph.mermaid_string_flat();
107
108 let part_graph = partition_graph(flat_graph).unwrap();
109 let part_mermaid = part_graph.to_mermaid(&Default::default());
110
111 let lit0 = Literal::string(&flat_mermaid);
112 let lit1 = Literal::string(&part_mermaid);
113
114 return quote! {
115 {
116 println!("{}\n\n{}\n", #lit0, #lit1);
117 }
118 }
119 .into();
120 }
121 }
122
123 Diagnostic::try_emit_all(diagnostics.iter())
124 .err()
125 .unwrap_or_default()
126 .into()
127}
128
129fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
130 use quote::ToTokens;
131
132 let root = root();
133
134 let mut input: syn::ItemFn = match syn::parse(item) {
135 Ok(it) => it,
136 Err(e) => return e.into_compile_error().into(),
137 };
138
139 let statements = input.block.stmts;
140
141 input.block.stmts = parse_quote!(
142 #root::tokio::task::LocalSet::new().run_until(async {
143 #( #statements )*
144 }).await
145 );
146
147 input.attrs.push(attribute);
148
149 input.into_token_stream().into()
150}
151
152#[proc_macro]
154pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
155 item
157}
158
159#[proc_macro]
161pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
162 item
164}
165
166#[proc_macro_attribute]
167pub fn dfir_test(
168 args: proc_macro::TokenStream,
169 item: proc_macro::TokenStream,
170) -> proc_macro::TokenStream {
171 let root = root();
172 let args_2: proc_macro2::TokenStream = args.into();
173
174 wrap_localset(
175 item,
176 parse_quote!(
177 #[#root::tokio::test(flavor = "current_thread", #args_2)]
178 ),
179 )
180}
181
182#[proc_macro_attribute]
183pub fn dfir_main(
184 _: proc_macro::TokenStream,
185 item: proc_macro::TokenStream,
186) -> proc_macro::TokenStream {
187 let root = root();
188
189 wrap_localset(
190 item,
191 parse_quote!(
192 #[#root::tokio::main(flavor = "current_thread")]
193 ),
194 )
195}
196
197#[proc_macro_derive(DemuxEnum)]
198pub fn derive_demux_enum(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
199 let root = root();
200
201 let ItemEnum {
202 ident: item_ident,
203 generics,
204 variants,
205 ..
206 } = parse_macro_input!(item as ItemEnum);
207
208 let mut variants = variants.into_iter().collect::<Vec<_>>();
210 variants.sort_by(|a, b| a.ident.cmp(&b.ident));
211
212 let variant_output_types = variants
214 .iter()
215 .map(|variant| match &variant.fields {
216 Fields::Named(fields) => {
217 let field_types = fields.named.iter().map(|field| &field.ty);
218 quote! {
219 ( #( #field_types, )* )
220 }
221 }
222 Fields::Unnamed(fields) => {
223 let field_types = fields.unnamed.iter().map(|field| &field.ty);
224 quote! {
225 ( #( #field_types, )* )
226 }
227 }
228 Fields::Unit => quote!(()),
229 })
230 .collect::<Vec<_>>();
231
232 let variant_generics_sink = variants
233 .iter()
234 .map(|variant| format_ident!("__Sink{}", variant.ident))
235 .collect::<Vec<_>>();
236 let variant_generics_pinned_sink = variant_generics_sink.iter().map(|ident| {
237 quote_spanned! {ident.span()=>
238 ::std::pin::Pin::<&mut #ident>
239 }
240 });
241 let variant_generics_pinned_sink_all = quote! {
242 ( #( #variant_generics_pinned_sink, )* )
243 };
244 let variant_localvars_sink = variants
245 .iter()
246 .map(|variant| {
247 format_ident!(
248 "__sink_{}",
249 variant.ident.to_string().to_lowercase(),
250 span = variant.ident.span()
251 )
252 })
253 .collect::<Vec<_>>();
254
255 let mut full_generics_sink = generics.clone();
256 full_generics_sink.params.extend(
257 variant_generics_sink
258 .iter()
259 .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
260 );
261 full_generics_sink.make_where_clause().predicates.extend(
262 variant_generics_sink
263 .iter()
264 .zip(variant_output_types.iter())
265 .map::<WherePredicate, _>(|(sink_generic, output_type)| {
266 parse_quote! {
267 #sink_generic: #root::futures::sink::Sink<#output_type, Error = #root::Never>
269 }
270 }),
271 );
272
273 let variant_pats_sink_start_send =
274 variants
275 .iter()
276 .zip(variant_localvars_sink.iter())
277 .map(|(variant, sinkvar)| {
278 let Variant { ident, fields, .. } = variant;
279 let (fields_pat, push_item) = field_pattern_item(fields);
280 quote! {
281 Self::#ident #fields_pat => #sinkvar.as_mut().start_send(#push_item)
282 }
283 });
284
285 let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
286 let (impl_generics_sink, _ty_generics_sink, where_clause_sink) =
287 full_generics_sink.split_for_impl();
288
289 let single_impl = (1 == variants.len()).then(|| {
290 let Variant { ident, fields, .. } = variants.first().unwrap();
291 let (fields_pat, push_item) = field_pattern_item(fields);
292 let out_type = variant_output_types.first().unwrap();
293 quote! {
294 impl #impl_generics_item #root::util::demux_enum::SingleVariant
295 for #item_ident #ty_generics #where_clause_item
296 {
297 type Output = #out_type;
298 fn single_variant(self) -> Self::Output {
299 match self {
300 Self::#ident #fields_pat => #push_item,
301 }
302 }
303 }
304 }
305 });
306
307 quote! {
308 impl #impl_generics_sink #root::util::demux_enum::DemuxEnumSink<#variant_generics_pinned_sink_all>
309 for #item_ident #ty_generics #where_clause_sink
310 {
311 type Error = #root::Never;
312
313 fn poll_ready(
314 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
315 __cx: &mut ::std::task::Context<'_>,
316 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
317 #(
319 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_ready(__cx)?;
320 )*
321 #(
322 ::std::task::ready!(#variant_localvars_sink);
323 )*
324 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
325 }
326
327 fn start_send(
328 self,
329 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
330 ) -> ::std::result::Result<(), Self::Error> {
331 match self {
332 #( #variant_pats_sink_start_send, )*
333 }
334 }
335
336 fn poll_flush(
337 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
338 __cx: &mut ::std::task::Context<'_>,
339 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
340 #(
342 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_flush(__cx)?;
343 )*
344 #(
345 ::std::task::ready!(#variant_localvars_sink);
346 )*
347 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
348 }
349
350 fn poll_close(
351 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
352 __cx: &mut ::std::task::Context<'_>,
353 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
354 #(
356 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_close(__cx)?;
357 )*
358 #(
359 ::std::task::ready!(#variant_localvars_sink);
360 )*
361 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
362 }
363 }
364
365 impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
366 for #item_ident #ty_generics #where_clause_item {}
367
368 #single_impl
369 }
370 .into()
371}
372
373fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
375 let idents = fields
376 .iter()
377 .enumerate()
378 .map(|(i, field)| {
379 field
380 .ident
381 .clone()
382 .unwrap_or_else(|| format_ident!("_{}", i))
383 })
384 .collect::<Vec<_>>();
385 let (fields_pat, push_item) = match fields {
386 Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
387 Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
388 Fields::Unit => (quote!(), quote!(())),
389 };
390 (fields_pat, push_item)
391}