dfir_lang/
process_singletons.rs

1//! Utility methods for processing singleton references: `#my_var`.
2
3use itertools::Itertools;
4use proc_macro2::{Group, Ident, TokenStream, TokenTree};
5use quote::quote_spanned;
6use syn::punctuated::Punctuated;
7use syn::{Expr, Token};
8
9use crate::parse::parse_terminated;
10
11/// Finds all the singleton references `#my_var` and appends them to `found_idents`. Returns the
12/// `TokenStream` but with the hashes removed from the varnames.
13///
14/// The returned tokens are used for "preflight" parsing, to check that the rest of the syntax is
15/// OK. However the returned tokens are not used in the codegen as we need to use [`postprocess_singletons`]
16/// later to substitute-in the context referencing code for each singleton
17pub fn preprocess_singletons(tokens: TokenStream, found_idents: &mut Vec<Ident>) -> TokenStream {
18    process_singletons(tokens, &mut |singleton_ident| {
19        found_idents.push(singleton_ident.clone());
20        TokenTree::Ident(singleton_ident)
21    })
22}
23
24/// Replaces singleton references `#my_var` with the code needed to actually get the value inside.
25///
26/// * `tokens` - The tokens to update singleton references within.
27/// * `resolved_idents` - The context `StateHandle` varnames that correspond 1:1 and in the same
28///   order as the singleton references within `tokens` (found in-order via [`preprocess_singletons`]).
29///
30/// Generates borrowing code ([`std::cell::RefCell::borrow_mut`]). Use
31/// [`postprocess_singletons_handles`] for just the `StateHandle`s.
32pub fn postprocess_singletons(
33    tokens: TokenStream,
34    resolved_idents: impl IntoIterator<Item = Ident>,
35    context: &Ident,
36) -> Punctuated<Expr, Token![,]> {
37    let mut resolved_idents_iter = resolved_idents.into_iter();
38    let processed = process_singletons(tokens, &mut |singleton_ident| {
39        let span = singleton_ident.span();
40        let context = Ident::new(&context.to_string(), span.resolved_at(context.span()));
41        let mut resolved_ident = resolved_idents_iter.next().unwrap();
42        resolved_ident.set_span(span);
43        let mut group = Group::new(
44            proc_macro2::Delimiter::Parenthesis,
45            quote_spanned! {span=>
46                *(unsafe {
47                    // SAFETY: `handle` is from this instance.
48                    #context.state_ref_unchecked(#resolved_ident)
49                }.borrow_mut())
50            },
51        );
52        group.set_span(singleton_ident.span());
53        TokenTree::Group(group)
54    });
55    parse_terminated(processed).unwrap()
56}
57
58/// Same as [`postprocess_singletons`] but generates just the `StateHandle` ident rather than full
59/// `RefCell` borrowing code.
60pub fn postprocess_singletons_handles(
61    tokens: TokenStream,
62    resolved_idents: impl IntoIterator<Item = Ident>,
63) -> Punctuated<Expr, Token![,]> {
64    let mut resolved_idents_iter = resolved_idents.into_iter();
65    let processed = process_singletons(tokens, &mut |singleton_ident| {
66        let mut resolved_ident = resolved_idents_iter.next().unwrap();
67        resolved_ident.set_span(singleton_ident.span().resolved_at(resolved_ident.span()));
68        TokenTree::Ident(resolved_ident)
69    });
70    parse_terminated(processed).unwrap()
71}
72
73/// Traverse the token stream, applying the `map_singleton_fn` whenever a singleton is found,
74/// returning the transformed token stream.
75fn process_singletons(
76    tokens: TokenStream,
77    map_singleton_fn: &mut impl FnMut(Ident) -> TokenTree,
78) -> TokenStream {
79    tokens
80        .into_iter()
81        .peekable()
82        .batching(|iter| {
83            let out = match iter.next()? {
84                TokenTree::Group(group) => {
85                    let mut new_group = Group::new(
86                        group.delimiter(),
87                        process_singletons(group.stream(), map_singleton_fn),
88                    );
89                    new_group.set_span(group.span());
90                    TokenTree::Group(new_group)
91                }
92                TokenTree::Ident(ident) => TokenTree::Ident(ident),
93                TokenTree::Punct(punct) => {
94                    if '#' == punct.as_char() && matches!(iter.peek(), Some(TokenTree::Ident(_))) {
95                        // Found a singleton.
96                        let Some(TokenTree::Ident(mut singleton_ident)) = iter.next() else {
97                            unreachable!()
98                        };
99                        {
100                            // Include the `#` in the span.
101                            let span = singleton_ident
102                                .span()
103                                .join(punct.span())
104                                .unwrap_or(singleton_ident.span());
105                            singleton_ident.set_span(span.resolved_at(singleton_ident.span()));
106                        }
107                        (map_singleton_fn)(singleton_ident)
108                    } else {
109                        TokenTree::Punct(punct)
110                    }
111                }
112                TokenTree::Literal(lit) => TokenTree::Literal(lit),
113            };
114            Some(out)
115        })
116        .collect()
117}