lattices_macro/
lib.rs

1//! Macros for the `lattices` crate.
2//!
3//! See [`[derive(Lattice)`](Lattice).
4#![warn(missing_docs)]
5
6use proc_macro2::{Span, TokenStream};
7use quote::{format_ident, quote};
8use syn::punctuated::Punctuated;
9use syn::visit_mut::VisitMut;
10use syn::{
11    Field, FieldsNamed, FieldsUnnamed, Generics, Ident, Index, ItemStruct, Member, Token,
12    WhereClause, WherePredicate, parse_macro_input,
13};
14
15/// Tokens to reference the `lattices` crate.
16fn root() -> TokenStream {
17    use std::env::{VarError, var as env_var};
18
19    use proc_macro_crate::FoundCrate;
20
21    if let Ok(FoundCrate::Itself) = proc_macro_crate::crate_name("lattices_macro") {
22        return quote! { lattices };
23    }
24
25    let lattices_crate_name = env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap();
26    let lattices_crate_ident = lattices_crate_name.replace('-', "_");
27    let lattices_crate = proc_macro_crate::crate_name(lattices_crate_name)
28        .unwrap_or_else(|_| panic!("`{lattices_crate_name}` should be present in `Cargo.toml`"));
29    match lattices_crate {
30        FoundCrate::Itself => {
31            if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
32                && Ok(&*lattices_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
33            {
34                // In the crate itself, including unit tests.
35                quote! { crate }
36            } else {
37                // In an integration test, example, bench, etc.
38                let ident = Ident::new(&lattices_crate_ident, Span::call_site());
39                quote! { ::#ident }
40            }
41        }
42        FoundCrate::Name(name) => {
43            let ident = Ident::new(&name, Span::call_site());
44            quote! { ::#ident }
45        }
46    }
47}
48
49/// Renames the generics and returns the updated `WherePredicate`s.
50fn rename_generics(
51    item_struct: &mut ItemStruct,
52    rename: impl FnMut(&Ident) -> Ident,
53) -> Vec<WherePredicate> {
54    struct RenameGenerics<F> {
55        rename: F,
56        names: Vec<Ident>,
57        pub triggered: bool,
58    }
59    impl<F> VisitMut for RenameGenerics<F>
60    where
61        F: FnMut(&Ident) -> Ident,
62    {
63        fn visit_ident_mut(&mut self, i: &mut Ident) {
64            if self.names.contains(i) {
65                *i = (self.rename)(i);
66                self.triggered = true;
67            }
68        }
69    }
70
71    let names = item_struct
72        .generics
73        .type_params()
74        .map(|type_param| type_param.ident.clone())
75        .collect();
76    let mut visit = RenameGenerics {
77        rename,
78        names,
79        triggered: false,
80    };
81
82    let mut out = Vec::new();
83    if let Some(where_clause) = &mut item_struct.generics.where_clause {
84        for where_predicate in where_clause.predicates.iter_mut() {
85            visit.visit_where_predicate_mut(where_predicate);
86            if std::mem::take(&mut visit.triggered) {
87                out.push(where_predicate.clone());
88            }
89        }
90    }
91    for type_param in item_struct.generics.type_params_mut() {
92        visit.visit_type_param_mut(type_param);
93    }
94    for field in item_struct.fields.iter_mut() {
95        visit.visit_type_mut(&mut field.ty);
96    }
97    out
98}
99
100/// Ensures that `punctuated` has trailing punctuation (or is empty).
101fn ensure_trailing<T, P>(punctuated: &mut Punctuated<T, P>)
102where
103    P: Default,
104{
105    if !punctuated.empty_or_trailing() {
106        punctuated.push_punct(Default::default());
107    }
108}
109
110#[doc = include_str!("../README.md")]
111#[proc_macro_derive(Lattice)]
112pub fn derive_lattice_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
113    derive_lattice(&process_item_struct(parse_macro_input!(item))).into()
114}
115/// Derives lattice `Merge`.
116///
117/// See [`#[derive(Lattice)]`](`Lattice`) for more info.
118#[proc_macro_derive(Merge)]
119pub fn derive_merge_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
120    derive_merge(&process_item_struct(parse_macro_input!(item))).into()
121}
122/// Derives [`PartialEq`], [`PartialOrd`], and `LatticeOrd` together.
123///
124/// See [`#[derive(Lattice)]`](`Lattice`) for more info.
125#[proc_macro_derive(LatticeOrd)]
126pub fn derive_lattice_ord_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
127    derive_lattice_ord(&process_item_struct(parse_macro_input!(item))).into()
128}
129/// Derives lattice `IsBot`.
130///
131/// See [`#[derive(Lattice)]`](`Lattice`) for more info.
132#[proc_macro_derive(IsBot)]
133pub fn derive_is_bot_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
134    derive_is_bot(&process_item_struct(parse_macro_input!(item))).into()
135}
136/// Derives lattice `IsTop`.
137///
138/// See [`#[derive(Lattice)]`](`Lattice`) for more info.
139#[proc_macro_derive(IsTop)]
140pub fn derive_is_top_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
141    derive_is_top(&process_item_struct(parse_macro_input!(item))).into()
142}
143/// Derives `LatticeFrom`.
144///
145/// See [`#[derive(Lattice)]`](`Lattice`) for more info.
146#[proc_macro_derive(LatticeFrom)]
147pub fn derive_lattice_from_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
148    derive_lattice_from(&process_item_struct(parse_macro_input!(item))).into()
149}
150
151/// [`process_item_struct`] return value helper struct.
152struct ProcessItemStruct {
153    root: TokenStream,
154    item_struct: ItemStruct,
155    item_struct_renamed: ItemStruct,
156    self_where_predicates: Punctuated<WherePredicate, Token![,]>,
157    both_where_predicates: Punctuated<WherePredicate, Token![,]>,
158    field_names: Vec<Member>,
159    combined_generics: Generics,
160}
161/// Helper for common pre-processing code shared between macros.
162fn process_item_struct(item_struct: ItemStruct) -> ProcessItemStruct {
163    let mut item_struct_renamed = item_struct.clone();
164    let extra_where_predicates = rename_generics(&mut item_struct_renamed, |ident| {
165        format_ident!("__{}Other", ident)
166    });
167
168    // Basic `where` predicates, no extras.
169    let mut self_where_predicates = item_struct
170        .generics
171        .where_clause
172        .clone()
173        .map(|WhereClause { predicates, .. }| predicates)
174        .unwrap_or_default();
175    ensure_trailing(&mut self_where_predicates);
176    // Basic `where` predicates for combined original and renamed parameters.
177    let mut both_where_predicates = self_where_predicates.clone();
178    both_where_predicates.extend(extra_where_predicates);
179    ensure_trailing(&mut both_where_predicates);
180
181    // Fields.
182    let field_names = match &item_struct.fields {
183        syn::Fields::Named(FieldsNamed { named, .. }) => named
184            .iter()
185            .map(|Field { ident, .. }| Member::Named(ident.clone().unwrap()))
186            .collect::<Vec<_>>(),
187        syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => (0..(unnamed.len() as u32))
188            .map(|index| {
189                Member::Unnamed(Index {
190                    index,
191                    span: Span::call_site(),
192                })
193            })
194            .collect(),
195        syn::Fields::Unit => Vec::new(),
196    };
197
198    // Extend the original generics.
199    let mut combined_generics = item_struct.generics.clone();
200    combined_generics
201        .params
202        .extend(item_struct_renamed.generics.params.clone());
203
204    ProcessItemStruct {
205        root: root(),
206        item_struct,
207        item_struct_renamed,
208        self_where_predicates,
209        both_where_predicates,
210        field_names,
211        combined_generics,
212    }
213}
214
215/// See [`derive_lattice_macro`].
216fn derive_lattice(process_item_struct: &ProcessItemStruct) -> TokenStream {
217    let mut out = TokenStream::new();
218    out.extend(derive_merge(process_item_struct));
219    out.extend(derive_lattice_ord(process_item_struct));
220    out.extend(derive_is_bot(process_item_struct));
221    out.extend(derive_is_top(process_item_struct));
222    out.extend(derive_lattice_from(process_item_struct));
223    out
224}
225
226/// See [`derive_merge_macro`].
227fn derive_merge(
228    ProcessItemStruct {
229        root,
230        item_struct,
231        item_struct_renamed,
232        self_where_predicates: _,
233        both_where_predicates,
234        field_names,
235        combined_generics,
236    }: &ProcessItemStruct,
237) -> TokenStream {
238    let merge_where_predicates = item_struct
239        .fields
240        .iter()
241        .zip(item_struct_renamed.fields.iter())
242        .map(|(field_self, field_othr)| {
243            let ty_self = &field_self.ty;
244            let ty_othr = &field_othr.ty;
245            quote! {
246                #ty_self: #root::Merge<#ty_othr>
247            }
248        });
249
250    let ident = &item_struct.ident;
251    let (_, ty_generics_self, _) = item_struct.generics.split_for_impl();
252    let (_, ty_generics_othr, _) = item_struct_renamed.generics.split_for_impl();
253    let (impl_generics_both, _, _) = combined_generics.split_for_impl();
254    quote! {
255        impl #impl_generics_both #root::Merge<#ident #ty_generics_othr> for #ident #ty_generics_self
256        where
257            #both_where_predicates
258            #( #merge_where_predicates ),*
259        {
260            fn merge(&mut self, other: #ident #ty_generics_othr) -> bool {
261                let mut changed = false;
262                #(
263                    changed |= #root::Merge::merge(&mut self.#field_names, other.#field_names);
264                )*
265                changed
266            }
267        }
268    }
269}
270
271/// See [`derive_lattice_ord_macro`].
272fn derive_lattice_ord(
273    ProcessItemStruct {
274        root,
275        item_struct,
276        item_struct_renamed,
277        self_where_predicates: _,
278        both_where_predicates,
279        field_names,
280        combined_generics,
281    }: &ProcessItemStruct,
282) -> TokenStream {
283    // PartialEq.
284    let pareq_where_predicates = item_struct
285        .fields
286        .iter()
287        .zip(item_struct_renamed.fields.iter())
288        .map(|(field_self, field_othr)| {
289            let ty_self = &field_self.ty;
290            let ty_othr = &field_othr.ty;
291            quote! {
292                #ty_self: ::core::cmp::PartialEq<#ty_othr>
293            }
294        });
295    // PartialOrd and LatticeOrd.
296    let compare_where_predicates = item_struct
297        .fields
298        .iter()
299        .zip(item_struct_renamed.fields.iter())
300        .map(|(field_self, field_othr)| {
301            let ty_self = &field_self.ty;
302            let ty_othr = &field_othr.ty;
303            quote! {
304                #ty_self: ::core::cmp::PartialOrd<#ty_othr>
305            }
306        })
307        .collect::<Vec<_>>();
308
309    let ident = &item_struct.ident;
310    let (_, ty_generics_self, _) = item_struct.generics.split_for_impl();
311    let (_, ty_generics_othr, _) = item_struct_renamed.generics.split_for_impl();
312    let (impl_generics_both, _, _) = combined_generics.split_for_impl();
313    quote! {
314        impl #impl_generics_both ::core::cmp::PartialEq<#ident #ty_generics_othr> for #ident #ty_generics_self
315        where
316            #both_where_predicates
317            #( #pareq_where_predicates ),*
318        {
319            fn eq(&self, other: &#ident #ty_generics_othr) -> bool {
320                #(
321                    if !::core::cmp::PartialEq::eq(&self.#field_names, &other.#field_names) {
322                        return false;
323                    }
324                )*
325                true
326            }
327        }
328
329        impl #impl_generics_both ::core::cmp::PartialOrd<#ident #ty_generics_othr> for #ident #ty_generics_self
330        where
331            #both_where_predicates
332            #( #compare_where_predicates ),*
333        {
334            fn partial_cmp(&self, other: &#ident #ty_generics_othr) -> ::core::option::Option<::core::cmp::Ordering> {
335                let mut self_any_greater = false;
336                let mut othr_any_greater = false;
337                #(
338                    // `?` short-circuits `None` (uncomparable).
339                    match ::core::cmp::PartialOrd::partial_cmp(&self.#field_names, &other.#field_names)? {
340                        ::core::cmp::Ordering::Less => {
341                            othr_any_greater = true;
342                        }
343                        ::core::cmp::Ordering::Greater => {
344                            self_any_greater = true;
345                        }
346                        ::core::cmp::Ordering::Equal => {}
347                    }
348                    if self_any_greater && othr_any_greater {
349                        return ::core::option::Option::None;
350                    }
351                )*
352                ::core::option::Option::Some(
353                    match (self_any_greater, othr_any_greater) {
354                        (false, false) => ::core::cmp::Ordering::Equal,
355                        (false, true) => ::core::cmp::Ordering::Less,
356                        (true, false) => ::core::cmp::Ordering::Greater,
357                        (true, true) => ::core::unreachable!(),
358                    }
359                )
360            }
361        }
362        impl #impl_generics_both #root::LatticeOrd<#ident #ty_generics_othr> for #ident #ty_generics_self
363        where
364            #both_where_predicates
365            #( #compare_where_predicates ),*
366        {}
367    }
368}
369
370/// See [`derive_is_bot_macro`].
371fn derive_is_bot(
372    ProcessItemStruct {
373        root,
374        item_struct,
375        item_struct_renamed: _,
376        self_where_predicates,
377        both_where_predicates: _,
378        field_names,
379        combined_generics: _,
380    }: &ProcessItemStruct,
381) -> TokenStream {
382    let isbot_where_predicates = item_struct.fields.iter().map(|Field { ty, .. }| {
383        quote! {
384            #ty: #root::IsBot
385        }
386    });
387
388    let ident = &item_struct.ident;
389    let (impl_generics_self, ty_generics_self, _) = item_struct.generics.split_for_impl();
390    quote! {
391        impl #impl_generics_self #root::IsBot for #ident #ty_generics_self
392        where
393            #self_where_predicates
394            #( #isbot_where_predicates ),*
395        {
396            fn is_bot(&self) -> bool {
397                #(
398                    if !#root::IsBot::is_bot(&self.#field_names) {
399                        return false;
400                    }
401                )*
402                true
403            }
404        }
405    }
406}
407
408/// See [`derive_is_top_macro`].
409fn derive_is_top(
410    ProcessItemStruct {
411        root,
412        item_struct,
413        item_struct_renamed: _,
414        self_where_predicates,
415        both_where_predicates: _,
416        field_names,
417        combined_generics: _,
418    }: &ProcessItemStruct,
419) -> TokenStream {
420    let istop_where_predicates = item_struct.fields.iter().map(|Field { ty, .. }| {
421        quote! {
422            #ty: #root::IsTop
423        }
424    });
425
426    let ident = &item_struct.ident;
427    let (impl_generics_self, ty_generics_self, _) = item_struct.generics.split_for_impl();
428    quote! {
429        impl #impl_generics_self #root::IsTop for #ident #ty_generics_self
430        where
431            #self_where_predicates
432            #( #istop_where_predicates ),*
433        {
434            fn is_top(&self) -> bool {
435                #(
436                    if !#root::IsTop::is_top(&self.#field_names) {
437                        return false;
438                    }
439                )*
440                true
441            }
442        }
443    }
444}
445
446/// See [`derive_lattice_from_macro`].
447fn derive_lattice_from(
448    ProcessItemStruct {
449        root,
450        item_struct,
451        item_struct_renamed,
452        self_where_predicates: _,
453        both_where_predicates,
454        field_names,
455        combined_generics,
456    }: &ProcessItemStruct,
457) -> TokenStream {
458    let latticefrom_where_predicates = item_struct
459        .fields
460        .iter()
461        .zip(item_struct_renamed.fields.iter())
462        .map(|(field_self, field_othr)| {
463            let ty_self = &field_self.ty;
464            let ty_othr = &field_othr.ty;
465            quote! {
466                #ty_self: #root::LatticeFrom<#ty_othr>
467            }
468        });
469
470    let ident = &item_struct.ident;
471    let (_, ty_generics_self, _) = item_struct.generics.split_for_impl();
472    let (_, ty_generics_othr, _) = item_struct_renamed.generics.split_for_impl();
473    let (impl_generics_both, _, _) = combined_generics.split_for_impl();
474    quote! {
475        impl #impl_generics_both #root::LatticeFrom<#ident #ty_generics_othr> for #ident #ty_generics_self
476        where
477            #both_where_predicates
478            #( #latticefrom_where_predicates ),*
479        {
480            fn lattice_from(other: #ident #ty_generics_othr) -> Self {
481                Self {
482                    #(
483                        #field_names: #root::LatticeFrom::lattice_from(other.#field_names),
484                    )*
485                }
486            }
487        }
488    }
489}
490
491/// Also see `lattices/tests/macro.rs`
492#[cfg(test)]
493mod test {
494    use syn::parse_quote;
495
496    use super::*;
497
498    /// Snapshots the macro output without actually testing if it compiles.
499    /// See `lattices/tests/macro.rs` for compiling tests.
500    macro_rules! assert_derive_snapshots {
501        ( $( $t:tt )* ) => {
502            {
503                let item = parse_quote! {
504                    $( $t )*
505                };
506                let process_item_struct = process_item_struct(item);
507                let derive_lattice = derive_lattice(&process_item_struct);
508                hydro_build_utils::assert_snapshot!(prettyplease::unparse(&parse_quote! { #derive_lattice }));
509            }
510        };
511    }
512
513    #[test]
514    fn derive_example() {
515        assert_derive_snapshots! {
516            struct MyLattice<KeySet, Epoch> {
517                keys: SetUnion<KeySet>,
518                epoch: Max<Epoch>,
519            }
520        };
521    }
522
523    #[test]
524    fn derive_pair() {
525        assert_derive_snapshots! {
526            pub struct Pair<LatA, LatB> {
527                pub a: LatA,
528                pub b: LatB,
529            }
530        };
531    }
532
533    #[test]
534    fn derive_similar_fields() {
535        // Will create duplicate where clauses, but that is OK.
536        assert_derive_snapshots! {
537            pub struct SimilarFields {
538                a: Max<usize>,
539                b: Max<usize>,
540                c: Max<usize>,
541            }
542        };
543    }
544}