1#![cfg_attr(
2 nightly,
3 feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::{Diagnostic, Level};
7use dfir_lang::graph::{FlatGraphBuilder, build_hfcode, partition_graph};
8use dfir_lang::parse::DfirCode;
9use proc_macro2::{Ident, Literal, Span};
10use quote::{format_ident, quote};
11use syn::{
12 Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
13 parse_quote,
14};
15
16#[proc_macro]
23pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24 dfir_syntax_internal(input, Some(Level::Help))
25}
26
27#[proc_macro]
31pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
32 dfir_syntax_internal(input, None)
33}
34
35fn root() -> proc_macro2::TokenStream {
36 use std::env::{VarError, var as env_var};
37
38 let root_crate =
39 proc_macro_crate::crate_name("dfir_rs").expect("dfir_rs should be present in `Cargo.toml`");
40 match root_crate {
41 proc_macro_crate::FoundCrate::Itself => {
42 if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
43 && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
44 && Ok("dfir_rs") == env_var("CARGO_CRATE_NAME").as_deref()
45 {
46 quote! { crate }
48 } else {
49 quote! { ::dfir_rs }
51 }
52 }
53 proc_macro_crate::FoundCrate::Name(name) => {
54 let ident: Ident = Ident::new(&name, Span::call_site());
55 quote! { ::#ident }
56 }
57 }
58}
59
60fn dfir_syntax_internal(
61 input: proc_macro::TokenStream,
62 min_diagnostic_level: Option<Level>,
63) -> proc_macro::TokenStream {
64 let input = parse_macro_input!(input as DfirCode);
65 let root = root();
66 let (graph_code_opt, diagnostics) = build_hfcode(input, &root);
67 let tokens = graph_code_opt
68 .map(|(_graph, code)| code)
69 .unwrap_or_else(|| quote! { #root::scheduled::graph::Dfir::new() });
70
71 let diagnostics = diagnostics
72 .iter()
73 .filter(|diag: &&Diagnostic| Some(diag.level) <= min_diagnostic_level);
74
75 let diagnostic_tokens = Diagnostic::try_emit_all(diagnostics)
76 .err()
77 .unwrap_or_default();
78 quote! {
79 {
80 #diagnostic_tokens
81 #tokens
82 }
83 }
84 .into()
85}
86
87#[proc_macro]
91pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
92 let input = parse_macro_input!(input as DfirCode);
93
94 let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
95 let (mut flat_graph, _uses, mut diagnostics) = flat_graph_builder.build();
96 if !diagnostics.iter().any(Diagnostic::is_error) {
97 if let Err(diagnostic) = flat_graph.merge_modules() {
98 diagnostics.push(diagnostic);
99 } else {
100 let flat_mermaid = flat_graph.mermaid_string_flat();
101
102 let part_graph = partition_graph(flat_graph).unwrap();
103 let part_mermaid = part_graph.to_mermaid(&Default::default());
104
105 let lit0 = Literal::string(&flat_mermaid);
106 let lit1 = Literal::string(&part_mermaid);
107
108 return quote! {
109 {
110 println!("{}\n\n{}\n", #lit0, #lit1);
111 }
112 }
113 .into();
114 }
115 }
116
117 Diagnostic::try_emit_all(diagnostics.iter())
118 .err()
119 .unwrap_or_default()
120 .into()
121}
122
123fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
124 use quote::ToTokens;
125
126 let root = root();
127
128 let mut input: syn::ItemFn = match syn::parse(item) {
129 Ok(it) => it,
130 Err(e) => return e.into_compile_error().into(),
131 };
132
133 let statements = input.block.stmts;
134
135 input.block.stmts = parse_quote!(
136 #root::tokio::task::LocalSet::new().run_until(async {
137 #( #statements )*
138 }).await
139 );
140
141 input.attrs.push(attribute);
142
143 input.into_token_stream().into()
144}
145
146#[proc_macro]
148pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
149 item
151}
152
153#[proc_macro]
155pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
156 item
158}
159
160#[proc_macro_attribute]
161pub fn dfir_test(
162 args: proc_macro::TokenStream,
163 item: proc_macro::TokenStream,
164) -> proc_macro::TokenStream {
165 let root = root();
166 let args_2: proc_macro2::TokenStream = args.into();
167
168 wrap_localset(
169 item,
170 parse_quote!(
171 #[#root::tokio::test(flavor = "current_thread", #args_2)]
172 ),
173 )
174}
175
176#[proc_macro_attribute]
177pub fn dfir_main(
178 _: proc_macro::TokenStream,
179 item: proc_macro::TokenStream,
180) -> proc_macro::TokenStream {
181 let root = root();
182
183 wrap_localset(
184 item,
185 parse_quote!(
186 #[#root::tokio::main(flavor = "current_thread")]
187 ),
188 )
189}
190
191#[proc_macro_derive(DemuxEnum)]
192pub fn derive_answer_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
193 let root = root();
194
195 let ItemEnum {
196 ident: item_ident,
197 generics,
198 variants,
199 ..
200 } = parse_macro_input!(item as ItemEnum);
201
202 let mut variants = variants.into_iter().collect::<Vec<_>>();
204 variants.sort_by(|a, b| a.ident.cmp(&b.ident));
205
206 let variant_pusherator_generics = variants
207 .iter()
208 .map(|variant| format_ident!("__Pusherator{}", variant.ident))
209 .collect::<Vec<_>>();
210 let variant_pusherator_localvars = variants
211 .iter()
212 .map(|variant| {
213 format_ident!(
214 "__pusherator_{}",
215 variant.ident.to_string().to_lowercase(),
216 span = variant.ident.span()
217 )
218 })
219 .collect::<Vec<_>>();
220 let variant_output_types = variants
221 .iter()
222 .map(|variant| match &variant.fields {
223 Fields::Named(fields) => {
224 let field_types = fields.named.iter().map(|field| &field.ty);
225 quote! {
226 ( #( #field_types, )* )
227 }
228 }
229 Fields::Unnamed(fields) => {
230 let field_types = fields.unnamed.iter().map(|field| &field.ty);
231 quote! {
232 ( #( #field_types, )* )
233 }
234 }
235 Fields::Unit => quote!(()),
236 })
237 .collect::<Vec<_>>();
238
239 let mut full_generics = generics.clone();
240 full_generics.params.extend(
241 variant_pusherator_generics
242 .iter()
243 .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
244 );
245 full_generics.make_where_clause().predicates.extend(
246 variant_pusherator_generics
247 .iter()
248 .zip(variant_output_types.iter())
249 .map::<WherePredicate, _>(|(pusherator_generic, output_type)| {
250 parse_quote! {
251 #pusherator_generic: #root::pusherator::Pusherator<Item = #output_type>
252 }
253 }),
254 );
255
256 let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
257 let (impl_generics, _ty_generics, where_clause) = full_generics.split_for_impl();
258
259 let variant_pats = variants
260 .iter()
261 .zip(variant_pusherator_localvars.iter())
262 .map(|(variant, pushvar)| {
263 let Variant { ident, fields, .. } = variant;
264 let (fields_pat, push_item) = field_pattern_item(fields);
265 quote! {
266 Self::#ident #fields_pat => #pushvar.give(#push_item)
267 }
268 });
269
270 let single_impl = (1 == variants.len()).then(|| {
271 let Variant { ident, fields, .. } = variants.first().unwrap();
272 let (fields_pat, push_item) = field_pattern_item(fields);
273 let out_type = variant_output_types.first().unwrap();
274 quote! {
275 impl #impl_generics_item #root::util::demux_enum::SingleVariant
276 for #item_ident #ty_generics #where_clause_item
277 {
278 type Output = #out_type;
279 fn single_variant(self) -> Self::Output {
280 match self {
281 Self::#ident #fields_pat => #push_item,
282 }
283 }
284 }
285 }
286 });
287
288 quote! {
289 impl #impl_generics #root::util::demux_enum::DemuxEnum<( #( #variant_pusherator_generics, )* )>
290 for #item_ident #ty_generics #where_clause
291 {
292 fn demux_enum(
293 self,
294 ( #( #variant_pusherator_localvars, )* ):
295 &mut ( #( #variant_pusherator_generics, )* )
296 ) {
297 match self {
298 #( #variant_pats, )*
299 }
300 }
301 }
302
303 impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
304 for #item_ident #ty_generics #where_clause_item {}
305
306 #single_impl
307 }
308 .into()
309}
310
311fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
313 let idents = fields
314 .iter()
315 .enumerate()
316 .map(|(i, field)| {
317 field
318 .ident
319 .clone()
320 .unwrap_or_else(|| format_ident!("_{}", i))
321 })
322 .collect::<Vec<_>>();
323 let (fields_pat, push_item) = match fields {
324 Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
325 Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
326 Fields::Unit => (quote!(), quote!(())),
327 };
328 (fields_pat, push_item)
329}