multiplatform_test/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::{quote, quote_spanned};
5
6#[derive(Clone, Copy)]
7enum Platform {
8    Default,
9    Tokio,
10    AsyncStd,
11    Dfir,
12    Wasm,
13    EnvLogging,
14    EnvTracing,
15}
16impl Platform {
17    // All platforms.
18    const ALL: [Self; 7] = [
19        Self::Default,
20        Self::Tokio,
21        Self::AsyncStd,
22        Self::Dfir,
23        Self::Wasm,
24        Self::EnvLogging,
25        Self::EnvTracing,
26    ];
27    // Default when no platforms are specified.
28    const DEFAULT: [Self; 2] = [Self::Default, Self::Wasm];
29
30    /// Name of platform ident in attribute.
31    const fn name(self) -> &'static str {
32        match self {
33            Self::Default => "test",
34            Self::Tokio => "tokio",
35            Self::AsyncStd => "async_std",
36            Self::Dfir => "dfir",
37            Self::Wasm => "wasm",
38            Self::EnvLogging => "env_logging",
39            Self::EnvTracing => "env_tracing",
40        }
41    }
42
43    /// Generate the attribute for this platform (if any).
44    fn make_attribute(self) -> proc_macro2::TokenStream {
45        // Fully specify crate names so that the consumer does not need to add another
46        // use statement. They still need to depend on the crate in their `Cargo.toml`,
47        // though.
48        // TODO(mingwei): use `proc_macro_crate::crate_name(...)` to handle renames.
49        match self {
50            Platform::Default => quote! { #[test] },
51            Platform::Tokio => quote! { #[tokio::test ] },
52            Platform::AsyncStd => quote! { #[async_std::test] },
53            Platform::Dfir => quote! { #[dfir_rs::test] },
54            Platform::Wasm => quote! { #[wasm_bindgen_test::wasm_bindgen_test] },
55            Platform::EnvLogging | Platform::EnvTracing => Default::default(),
56        }
57    }
58
59    /// Generate the initialization code statements for this platform (if any).
60    fn make_init_code(self) -> proc_macro2::TokenStream {
61        match self {
62            Platform::EnvLogging => quote! {
63                let _ = env_logger::builder().is_test(true).try_init();
64            },
65            Platform::EnvTracing => quote! {
66                let subscriber = tracing_subscriber::FmtSubscriber::builder()
67                    .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
68                    .with_test_writer()
69                    .finish();
70                let _ = tracing::subscriber::set_global_default(subscriber);
71            },
72            _ => Default::default(),
73        }
74    }
75}
76
77/// See the [crate] docs for usage information.
78#[proc_macro_attribute]
79pub fn multiplatform_test(attr: TokenStream, body: TokenStream) -> TokenStream {
80    let ts = multiplatform_test_impl(
81        proc_macro2::TokenStream::from(attr),
82        proc_macro2::TokenStream::from(body),
83    );
84    TokenStream::from(ts)
85}
86
87fn multiplatform_test_impl(
88    attr: proc_macro2::TokenStream,
89    body: proc_macro2::TokenStream,
90) -> proc_macro2::TokenStream {
91    let mut attr = attr.into_iter();
92    let mut platforms = Vec::<Platform>::new();
93
94    while let Some(token) = attr.next() {
95        let proc_macro2::TokenTree::Ident(i) = &token else {
96            return quote_spanned! {token.span()=>
97                compile_error!("malformed #[multiplatform_test] attribute; expected identifier.");
98            };
99        };
100        let name = i.to_string();
101        let Some(&platform) = Platform::ALL
102            .iter()
103            .find(|platform| name == platform.name())
104        else {
105            let msg = proc_macro2::Literal::string(&format!(
106                "unknown platform {}; expected one of [{}]",
107                name,
108                Platform::ALL.map(Platform::name).join(", "),
109            ));
110            return quote_spanned! {token.span()=> compile_error!(#msg); };
111        };
112        platforms.push(platform);
113
114        match &attr.next() {
115            Some(proc_macro2::TokenTree::Punct(op)) if op.as_char() == ',' => {}
116            Some(other) => {
117                return quote_spanned! {other.span()=>
118                    compile_error!("malformed `#[multiplatform_test]` attribute; expected `,`.");
119                };
120            }
121            None => break,
122        }
123    }
124    if platforms.is_empty() {
125        platforms.extend(Platform::DEFAULT.iter());
126    }
127
128    let mut output = proc_macro2::TokenStream::new();
129    let mut init_code = proc_macro2::TokenStream::new();
130
131    for p in platforms {
132        output.extend(p.make_attribute());
133        init_code.extend(p.make_init_code());
134    }
135
136    if init_code.is_empty() {
137        output.extend(body);
138    } else {
139        let mut body_head = body.into_iter().collect::<Vec<_>>();
140        let Some(proc_macro2::TokenTree::Group(body_code)) = body_head.pop() else {
141            panic!();
142        };
143
144        output.extend(body_head);
145        output.extend(quote! {
146            {
147                { #init_code };
148                #body_code
149            }
150        });
151    }
152    output
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn test_default_platforms() {
161        let test_fn: proc_macro2::TokenStream = quote! { fn test() { } };
162        let attrs = proc_macro2::TokenStream::new();
163        let got: proc_macro2::TokenStream = multiplatform_test_impl(attrs, test_fn);
164        let want = quote! {
165            #[test]
166            #[wasm_bindgen_test::wasm_bindgen_test]
167            fn test() { }
168        };
169
170        assert_eq!(want.to_string(), got.to_string());
171    }
172
173    #[test]
174    fn test_host_platform() {
175        let test_fn = quote! { fn test() { } };
176        let attrs = quote! { test };
177        let got = multiplatform_test_impl(attrs, test_fn);
178        let want = quote! {
179            #[test]
180            fn test() { }
181        };
182
183        assert_eq!(want.to_string(), got.to_string());
184    }
185
186    #[test]
187    fn test_wasm_platform() {
188        let test_fn = quote! { fn test() { } };
189        let attrs = quote! { wasm };
190        let got = multiplatform_test_impl(attrs, test_fn);
191        let want = quote! {
192            #[wasm_bindgen_test::wasm_bindgen_test]
193            fn test() { }
194        };
195
196        assert_eq!(want.to_string(), got.to_string());
197    }
198
199    #[test]
200    fn test_host_wasm_platform() {
201        let test_fn = quote! { fn test() { } };
202        let attrs = quote! { test, wasm };
203        let got = multiplatform_test_impl(attrs, test_fn);
204        let want = quote! {
205            #[test]
206            #[wasm_bindgen_test::wasm_bindgen_test]
207            fn test() { }
208        };
209
210        assert_eq!(want.to_string(), got.to_string());
211    }
212
213    #[test]
214    fn test_unknown_platform() {
215        let test_fn = quote! { fn test() { } };
216        let attrs = quote! { hello };
217        let got = multiplatform_test_impl(attrs, test_fn);
218        assert!(got.to_string().starts_with("compile_error !"));
219    }
220
221    #[test]
222    fn test_invalid_attr_nocomma_platform() {
223        let test_fn = quote! { fn test() { } };
224        let attrs = quote! { wasm() };
225        let got = multiplatform_test_impl(attrs, test_fn);
226        assert!(got.to_string().starts_with("compile_error !"));
227    }
228
229    #[test]
230    fn test_invalid_attr_noident_platform() {
231        let test_fn = quote! { fn test() { } };
232        let attrs = quote! { () };
233        let got = multiplatform_test_impl(attrs, test_fn);
234        assert!(got.to_string().starts_with("compile_error !"));
235    }
236}