multiplatform_test/
lib.rs
1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::{quote, quote_spanned};
5
6#[derive(Clone, Copy)]
7enum Platform {
8 Default,
9 Tokio,
10 AsyncStd,
11 Dfir,
12 Wasm,
13 EnvLogging,
14 EnvTracing,
15}
16impl Platform {
17 const ALL: [Self; 7] = [
19 Self::Default,
20 Self::Tokio,
21 Self::AsyncStd,
22 Self::Dfir,
23 Self::Wasm,
24 Self::EnvLogging,
25 Self::EnvTracing,
26 ];
27 const DEFAULT: [Self; 2] = [Self::Default, Self::Wasm];
29
30 const fn name(self) -> &'static str {
32 match self {
33 Self::Default => "test",
34 Self::Tokio => "tokio",
35 Self::AsyncStd => "async_std",
36 Self::Dfir => "dfir",
37 Self::Wasm => "wasm",
38 Self::EnvLogging => "env_logging",
39 Self::EnvTracing => "env_tracing",
40 }
41 }
42
43 fn make_attribute(self) -> proc_macro2::TokenStream {
45 match self {
50 Platform::Default => quote! { #[test] },
51 Platform::Tokio => quote! { #[tokio::test ] },
52 Platform::AsyncStd => quote! { #[async_std::test] },
53 Platform::Dfir => quote! { #[dfir_rs::test] },
54 Platform::Wasm => quote! { #[wasm_bindgen_test::wasm_bindgen_test] },
55 Platform::EnvLogging | Platform::EnvTracing => Default::default(),
56 }
57 }
58
59 fn make_init_code(self) -> proc_macro2::TokenStream {
61 match self {
62 Platform::EnvLogging => quote! {
63 let _ = env_logger::builder().is_test(true).try_init();
64 },
65 Platform::EnvTracing => quote! {
66 let subscriber = tracing_subscriber::FmtSubscriber::builder()
67 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
68 .with_test_writer()
69 .finish();
70 let _ = tracing::subscriber::set_global_default(subscriber);
71 },
72 _ => Default::default(),
73 }
74 }
75}
76
77#[proc_macro_attribute]
79pub fn multiplatform_test(attr: TokenStream, body: TokenStream) -> TokenStream {
80 let ts = multiplatform_test_impl(
81 proc_macro2::TokenStream::from(attr),
82 proc_macro2::TokenStream::from(body),
83 );
84 TokenStream::from(ts)
85}
86
87fn multiplatform_test_impl(
88 attr: proc_macro2::TokenStream,
89 body: proc_macro2::TokenStream,
90) -> proc_macro2::TokenStream {
91 let mut attr = attr.into_iter();
92 let mut platforms = Vec::<Platform>::new();
93
94 while let Some(token) = attr.next() {
95 let proc_macro2::TokenTree::Ident(i) = &token else {
96 return quote_spanned! {token.span()=>
97 compile_error!("malformed #[multiplatform_test] attribute; expected identifier.");
98 };
99 };
100 let name = i.to_string();
101 let Some(&platform) = Platform::ALL
102 .iter()
103 .find(|platform| name == platform.name())
104 else {
105 let msg = proc_macro2::Literal::string(&format!(
106 "unknown platform {}; expected one of [{}]",
107 name,
108 Platform::ALL.map(Platform::name).join(", "),
109 ));
110 return quote_spanned! {token.span()=> compile_error!(#msg); };
111 };
112 platforms.push(platform);
113
114 match &attr.next() {
115 Some(proc_macro2::TokenTree::Punct(op)) if op.as_char() == ',' => {}
116 Some(other) => {
117 return quote_spanned! {other.span()=>
118 compile_error!("malformed `#[multiplatform_test]` attribute; expected `,`.");
119 };
120 }
121 None => break,
122 }
123 }
124 if platforms.is_empty() {
125 platforms.extend(Platform::DEFAULT.iter());
126 }
127
128 let mut output = proc_macro2::TokenStream::new();
129 let mut init_code = proc_macro2::TokenStream::new();
130
131 for p in platforms {
132 output.extend(p.make_attribute());
133 init_code.extend(p.make_init_code());
134 }
135
136 if init_code.is_empty() {
137 output.extend(body);
138 } else {
139 let mut body_head = body.into_iter().collect::<Vec<_>>();
140 let Some(proc_macro2::TokenTree::Group(body_code)) = body_head.pop() else {
141 panic!();
142 };
143
144 output.extend(body_head);
145 output.extend(quote! {
146 {
147 { #init_code };
148 #body_code
149 }
150 });
151 }
152 output
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn test_default_platforms() {
161 let test_fn: proc_macro2::TokenStream = quote! { fn test() { } };
162 let attrs = proc_macro2::TokenStream::new();
163 let got: proc_macro2::TokenStream = multiplatform_test_impl(attrs, test_fn);
164 let want = quote! {
165 #[test]
166 #[wasm_bindgen_test::wasm_bindgen_test]
167 fn test() { }
168 };
169
170 assert_eq!(want.to_string(), got.to_string());
171 }
172
173 #[test]
174 fn test_host_platform() {
175 let test_fn = quote! { fn test() { } };
176 let attrs = quote! { test };
177 let got = multiplatform_test_impl(attrs, test_fn);
178 let want = quote! {
179 #[test]
180 fn test() { }
181 };
182
183 assert_eq!(want.to_string(), got.to_string());
184 }
185
186 #[test]
187 fn test_wasm_platform() {
188 let test_fn = quote! { fn test() { } };
189 let attrs = quote! { wasm };
190 let got = multiplatform_test_impl(attrs, test_fn);
191 let want = quote! {
192 #[wasm_bindgen_test::wasm_bindgen_test]
193 fn test() { }
194 };
195
196 assert_eq!(want.to_string(), got.to_string());
197 }
198
199 #[test]
200 fn test_host_wasm_platform() {
201 let test_fn = quote! { fn test() { } };
202 let attrs = quote! { test, wasm };
203 let got = multiplatform_test_impl(attrs, test_fn);
204 let want = quote! {
205 #[test]
206 #[wasm_bindgen_test::wasm_bindgen_test]
207 fn test() { }
208 };
209
210 assert_eq!(want.to_string(), got.to_string());
211 }
212
213 #[test]
214 fn test_unknown_platform() {
215 let test_fn = quote! { fn test() { } };
216 let attrs = quote! { hello };
217 let got = multiplatform_test_impl(attrs, test_fn);
218 assert!(got.to_string().starts_with("compile_error !"));
219 }
220
221 #[test]
222 fn test_invalid_attr_nocomma_platform() {
223 let test_fn = quote! { fn test() { } };
224 let attrs = quote! { wasm() };
225 let got = multiplatform_test_impl(attrs, test_fn);
226 assert!(got.to_string().starts_with("compile_error !"));
227 }
228
229 #[test]
230 fn test_invalid_attr_noident_platform() {
231 let test_fn = quote! { fn test() { } };
232 let attrs = quote! { () };
233 let got = multiplatform_test_impl(attrs, test_fn);
234 assert!(got.to_string().starts_with("compile_error !"));
235 }
236}