1#![cfg_attr(
2 nightly,
3 feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::{Diagnostic, Level};
7use dfir_lang::graph::{FlatGraphBuilder, build_hfcode, partition_graph};
8use dfir_lang::parse::DfirCode;
9use proc_macro2::{Ident, Literal, Span};
10use quote::{format_ident, quote};
11use syn::{
12 Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
13 parse_quote,
14};
15
16#[proc_macro]
23pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24 dfir_syntax_internal(input, Some(Level::Help))
25}
26
27#[proc_macro]
31pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
32 dfir_syntax_internal(input, None)
33}
34
35fn root() -> proc_macro2::TokenStream {
36 use std::env::{VarError, var as env_var};
37
38 let root_crate_name = format!(
39 "{}_rs",
40 env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
41 );
42 let root_crate_ident = root_crate_name.replace('-', "_");
43 let root_crate = proc_macro_crate::crate_name(&root_crate_name)
44 .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
45 match root_crate {
46 proc_macro_crate::FoundCrate::Itself => {
47 if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
48 && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
49 && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
50 {
51 quote! { crate }
53 } else {
54 let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
56 quote! { ::#ident }
57 }
58 }
59 proc_macro_crate::FoundCrate::Name(name) => {
60 let ident = Ident::new(&name, Span::call_site());
61 quote! { ::#ident }
62 }
63 }
64}
65
66fn dfir_syntax_internal(
67 input: proc_macro::TokenStream,
68 min_diagnostic_level: Option<Level>,
69) -> proc_macro::TokenStream {
70 let input = parse_macro_input!(input as DfirCode);
71 let root = root();
72 let (graph_code_opt, diagnostics) = build_hfcode(input, &root);
73 let tokens = graph_code_opt
74 .map(|(_graph, code)| code)
75 .unwrap_or_else(|| quote! { #root::scheduled::graph::Dfir::new() });
76
77 let diagnostics = diagnostics
78 .iter()
79 .filter(|diag: &&Diagnostic| Some(diag.level) <= min_diagnostic_level);
80
81 let diagnostic_tokens = Diagnostic::try_emit_all(diagnostics)
82 .err()
83 .unwrap_or_default();
84 quote! {
85 {
86 #diagnostic_tokens
87 #tokens
88 }
89 }
90 .into()
91}
92
93#[proc_macro]
97pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
98 let input = parse_macro_input!(input as DfirCode);
99
100 let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
101 let (mut flat_graph, _uses, mut diagnostics) = flat_graph_builder.build();
102 if !diagnostics.iter().any(Diagnostic::is_error) {
103 if let Err(diagnostic) = flat_graph.merge_modules() {
104 diagnostics.push(diagnostic);
105 } else {
106 let flat_mermaid = flat_graph.mermaid_string_flat();
107
108 let part_graph = partition_graph(flat_graph).unwrap();
109 let part_mermaid = part_graph.to_mermaid(&Default::default());
110
111 let lit0 = Literal::string(&flat_mermaid);
112 let lit1 = Literal::string(&part_mermaid);
113
114 return quote! {
115 {
116 println!("{}\n\n{}\n", #lit0, #lit1);
117 }
118 }
119 .into();
120 }
121 }
122
123 Diagnostic::try_emit_all(diagnostics.iter())
124 .err()
125 .unwrap_or_default()
126 .into()
127}
128
129fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
130 use quote::ToTokens;
131
132 let root = root();
133
134 let mut input: syn::ItemFn = match syn::parse(item) {
135 Ok(it) => it,
136 Err(e) => return e.into_compile_error().into(),
137 };
138
139 let statements = input.block.stmts;
140
141 input.block.stmts = parse_quote!(
142 #root::tokio::task::LocalSet::new().run_until(async {
143 #( #statements )*
144 }).await
145 );
146
147 input.attrs.push(attribute);
148
149 input.into_token_stream().into()
150}
151
152#[proc_macro]
154pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
155 item
157}
158
159#[proc_macro]
161pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
162 item
164}
165
166#[proc_macro_attribute]
167pub fn dfir_test(
168 args: proc_macro::TokenStream,
169 item: proc_macro::TokenStream,
170) -> proc_macro::TokenStream {
171 let root = root();
172 let args_2: proc_macro2::TokenStream = args.into();
173
174 wrap_localset(
175 item,
176 parse_quote!(
177 #[#root::tokio::test(flavor = "current_thread", #args_2)]
178 ),
179 )
180}
181
182#[proc_macro_attribute]
183pub fn dfir_main(
184 _: proc_macro::TokenStream,
185 item: proc_macro::TokenStream,
186) -> proc_macro::TokenStream {
187 let root = root();
188
189 wrap_localset(
190 item,
191 parse_quote!(
192 #[#root::tokio::main(flavor = "current_thread")]
193 ),
194 )
195}
196
197#[proc_macro_derive(DemuxEnum)]
198pub fn derive_answer_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
199 let root = root();
200
201 let ItemEnum {
202 ident: item_ident,
203 generics,
204 variants,
205 ..
206 } = parse_macro_input!(item as ItemEnum);
207
208 let mut variants = variants.into_iter().collect::<Vec<_>>();
210 variants.sort_by(|a, b| a.ident.cmp(&b.ident));
211
212 let variant_pusherator_generics = variants
213 .iter()
214 .map(|variant| format_ident!("__Pusherator{}", variant.ident))
215 .collect::<Vec<_>>();
216 let variant_pusherator_localvars = variants
217 .iter()
218 .map(|variant| {
219 format_ident!(
220 "__pusherator_{}",
221 variant.ident.to_string().to_lowercase(),
222 span = variant.ident.span()
223 )
224 })
225 .collect::<Vec<_>>();
226 let variant_output_types = variants
227 .iter()
228 .map(|variant| match &variant.fields {
229 Fields::Named(fields) => {
230 let field_types = fields.named.iter().map(|field| &field.ty);
231 quote! {
232 ( #( #field_types, )* )
233 }
234 }
235 Fields::Unnamed(fields) => {
236 let field_types = fields.unnamed.iter().map(|field| &field.ty);
237 quote! {
238 ( #( #field_types, )* )
239 }
240 }
241 Fields::Unit => quote!(()),
242 })
243 .collect::<Vec<_>>();
244
245 let mut full_generics = generics.clone();
246 full_generics.params.extend(
247 variant_pusherator_generics
248 .iter()
249 .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
250 );
251 full_generics.make_where_clause().predicates.extend(
252 variant_pusherator_generics
253 .iter()
254 .zip(variant_output_types.iter())
255 .map::<WherePredicate, _>(|(pusherator_generic, output_type)| {
256 parse_quote! {
257 #pusherator_generic: #root::pusherator::Pusherator<Item = #output_type>
258 }
259 }),
260 );
261
262 let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
263 let (impl_generics, _ty_generics, where_clause) = full_generics.split_for_impl();
264
265 let variant_pats = variants
266 .iter()
267 .zip(variant_pusherator_localvars.iter())
268 .map(|(variant, pushvar)| {
269 let Variant { ident, fields, .. } = variant;
270 let (fields_pat, push_item) = field_pattern_item(fields);
271 quote! {
272 Self::#ident #fields_pat => #pushvar.give(#push_item)
273 }
274 });
275
276 let single_impl = (1 == variants.len()).then(|| {
277 let Variant { ident, fields, .. } = variants.first().unwrap();
278 let (fields_pat, push_item) = field_pattern_item(fields);
279 let out_type = variant_output_types.first().unwrap();
280 quote! {
281 impl #impl_generics_item #root::util::demux_enum::SingleVariant
282 for #item_ident #ty_generics #where_clause_item
283 {
284 type Output = #out_type;
285 fn single_variant(self) -> Self::Output {
286 match self {
287 Self::#ident #fields_pat => #push_item,
288 }
289 }
290 }
291 }
292 });
293
294 quote! {
295 impl #impl_generics #root::util::demux_enum::DemuxEnum<( #( #variant_pusherator_generics, )* )>
296 for #item_ident #ty_generics #where_clause
297 {
298 fn demux_enum(
299 self,
300 ( #( #variant_pusherator_localvars, )* ):
301 &mut ( #( #variant_pusherator_generics, )* )
302 ) {
303 match self {
304 #( #variant_pats, )*
305 }
306 }
307 }
308
309 impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
310 for #item_ident #ty_generics #where_clause_item {}
311
312 #single_impl
313 }
314 .into()
315}
316
317fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
319 let idents = fields
320 .iter()
321 .enumerate()
322 .map(|(i, field)| {
323 field
324 .ident
325 .clone()
326 .unwrap_or_else(|| format_ident!("_{}", i))
327 })
328 .collect::<Vec<_>>();
329 let (fields_pat, push_item) = match fields {
330 Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
331 Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
332 Fields::Unit => (quote!(), quote!(())),
333 };
334 (fields_pat, push_item)
335}