lattices_macro/
lib.rs

1//! Macros for the `lattices` crate.
2//!
3//! See [`[derive(Lattice)`](Lattice).
4#![warn(missing_docs)]
5
6use proc_macro2::{Span, TokenStream};
7use quote::{format_ident, quote};
8use syn::punctuated::Punctuated;
9use syn::visit_mut::VisitMut;
10use syn::{
11    Field, FieldsNamed, FieldsUnnamed, Generics, Ident, Index, ItemStruct, Member, Token,
12    WhereClause, WherePredicate, parse_macro_input,
13};
14
15/// Tokens to reference the `lattices` crate.
16fn root() -> TokenStream {
17    use std::env::{VarError, var as env_var};
18
19    use proc_macro_crate::FoundCrate;
20
21    if let Ok(FoundCrate::Itself) = proc_macro_crate::crate_name("lattices_macro") {
22        return quote! { lattices };
23    }
24
25    let lattices_crate = proc_macro_crate::crate_name("lattices")
26        .expect("`lattices` should be present in `Cargo.toml`");
27    match lattices_crate {
28        FoundCrate::Itself => {
29            if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
30                && Ok("lattices") == env_var("CARGO_CRATE_NAME").as_deref()
31            {
32                // In the crate itself, including unit tests.
33                quote! { crate }
34            } else {
35                // In an integration test, example, bench, etc.
36                quote! { ::lattices }
37            }
38        }
39        FoundCrate::Name(name) => {
40            let ident = Ident::new(&name, Span::call_site());
41            quote! { ::#ident }
42        }
43    }
44}
45
46/// Renames the generics and returns the updated `WherePredicate`s.
47fn rename_generics(
48    item_struct: &mut ItemStruct,
49    rename: impl FnMut(&Ident) -> Ident,
50) -> Vec<WherePredicate> {
51    struct RenameGenerics<F> {
52        rename: F,
53        names: Vec<Ident>,
54        pub triggered: bool,
55    }
56    impl<F> VisitMut for RenameGenerics<F>
57    where
58        F: FnMut(&Ident) -> Ident,
59    {
60        fn visit_ident_mut(&mut self, i: &mut Ident) {
61            if self.names.contains(i) {
62                *i = (self.rename)(i);
63                self.triggered = true;
64            }
65        }
66    }
67
68    let names = item_struct
69        .generics
70        .type_params()
71        .map(|type_param| type_param.ident.clone())
72        .collect();
73    let mut visit = RenameGenerics {
74        rename,
75        names,
76        triggered: false,
77    };
78
79    let mut out = Vec::new();
80    if let Some(where_clause) = &mut item_struct.generics.where_clause {
81        for where_predicate in where_clause.predicates.iter_mut() {
82            visit.visit_where_predicate_mut(where_predicate);
83            if std::mem::take(&mut visit.triggered) {
84                out.push(where_predicate.clone());
85            }
86        }
87    }
88    for type_param in item_struct.generics.type_params_mut() {
89        visit.visit_type_param_mut(type_param);
90    }
91    for field in item_struct.fields.iter_mut() {
92        visit.visit_type_mut(&mut field.ty);
93    }
94    out
95}
96
97/// Ensures that `punctuated` has trailing punctuation (or is empty).
98fn ensure_trailing<T, P>(punctuated: &mut Punctuated<T, P>)
99where
100    P: Default,
101{
102    if !punctuated.empty_or_trailing() {
103        punctuated.push_punct(Default::default());
104    }
105}
106
107#[doc = include_str!("../README.md")]
108#[proc_macro_derive(Lattice)]
109pub fn derive_lattice_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
110    derive_lattice(&process_item_struct(parse_macro_input!(item))).into()
111}
112/// Derives lattice `Merge`.
113///
114/// See [`#[derive(Lattice)]`](`Lattice`) for more info.
115#[proc_macro_derive(Merge)]
116pub fn derive_merge_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
117    derive_merge(&process_item_struct(parse_macro_input!(item))).into()
118}
119/// Derives [`PartialEq`], [`PartialOrd`], and `LatticeOrd` together.
120///
121/// See [`#[derive(Lattice)]`](`Lattice`) for more info.
122#[proc_macro_derive(LatticeOrd)]
123pub fn derive_lattice_ord_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
124    derive_lattice_ord(&process_item_struct(parse_macro_input!(item))).into()
125}
126/// Derives lattice `IsBot`.
127///
128/// See [`#[derive(Lattice)]`](`Lattice`) for more info.
129#[proc_macro_derive(IsBot)]
130pub fn derive_is_bot_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
131    derive_is_bot(&process_item_struct(parse_macro_input!(item))).into()
132}
133/// Derives lattice `IsTop`.
134///
135/// See [`#[derive(Lattice)]`](`Lattice`) for more info.
136#[proc_macro_derive(IsTop)]
137pub fn derive_is_top_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
138    derive_is_top(&process_item_struct(parse_macro_input!(item))).into()
139}
140/// Derives `LatticeFrom`.
141///
142/// See [`#[derive(Lattice)]`](`Lattice`) for more info.
143#[proc_macro_derive(LatticeFrom)]
144pub fn derive_lattice_from_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
145    derive_lattice_from(&process_item_struct(parse_macro_input!(item))).into()
146}
147
148/// [`process_item_struct`] return value helper struct.
149struct ProcessItemStruct {
150    root: TokenStream,
151    item_struct: ItemStruct,
152    item_struct_renamed: ItemStruct,
153    self_where_predicates: Punctuated<WherePredicate, Token![,]>,
154    both_where_predicates: Punctuated<WherePredicate, Token![,]>,
155    field_names: Vec<Member>,
156    combined_generics: Generics,
157}
158/// Helper for common pre-processing code shared between macros.
159fn process_item_struct(item_struct: ItemStruct) -> ProcessItemStruct {
160    let mut item_struct_renamed = item_struct.clone();
161    let extra_where_predicates = rename_generics(&mut item_struct_renamed, |ident| {
162        format_ident!("__{}Other", ident)
163    });
164
165    // Basic `where` predicates, no extras.
166    let mut self_where_predicates = item_struct
167        .generics
168        .where_clause
169        .clone()
170        .map(|WhereClause { predicates, .. }| predicates)
171        .unwrap_or_default();
172    ensure_trailing(&mut self_where_predicates);
173    // Basic `where` predicates for combined original and renamed parameters.
174    let mut both_where_predicates = self_where_predicates.clone();
175    both_where_predicates.extend(extra_where_predicates);
176    ensure_trailing(&mut both_where_predicates);
177
178    // Fields.
179    let field_names = match &item_struct.fields {
180        syn::Fields::Named(FieldsNamed { named, .. }) => named
181            .iter()
182            .map(|Field { ident, .. }| Member::Named(ident.clone().unwrap()))
183            .collect::<Vec<_>>(),
184        syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => (0..(unnamed.len() as u32))
185            .map(|index| {
186                Member::Unnamed(Index {
187                    index,
188                    span: Span::call_site(),
189                })
190            })
191            .collect(),
192        syn::Fields::Unit => Vec::new(),
193    };
194
195    // Extend the original generics.
196    let mut combined_generics = item_struct.generics.clone();
197    combined_generics
198        .params
199        .extend(item_struct_renamed.generics.params.clone());
200
201    ProcessItemStruct {
202        root: root(),
203        item_struct,
204        item_struct_renamed,
205        self_where_predicates,
206        both_where_predicates,
207        field_names,
208        combined_generics,
209    }
210}
211
212/// See [`derive_lattice_macro`].
213fn derive_lattice(process_item_struct: &ProcessItemStruct) -> TokenStream {
214    let mut out = TokenStream::new();
215    out.extend(derive_merge(process_item_struct));
216    out.extend(derive_lattice_ord(process_item_struct));
217    out.extend(derive_is_bot(process_item_struct));
218    out.extend(derive_is_top(process_item_struct));
219    out.extend(derive_lattice_from(process_item_struct));
220    out
221}
222
223/// See [`derive_merge_macro`].
224fn derive_merge(
225    ProcessItemStruct {
226        root,
227        item_struct,
228        item_struct_renamed,
229        self_where_predicates: _,
230        both_where_predicates,
231        field_names,
232        combined_generics,
233    }: &ProcessItemStruct,
234) -> TokenStream {
235    let merge_where_predicates = item_struct
236        .fields
237        .iter()
238        .zip(item_struct_renamed.fields.iter())
239        .map(|(field_self, field_othr)| {
240            let ty_self = &field_self.ty;
241            let ty_othr = &field_othr.ty;
242            quote! {
243                #ty_self: #root::Merge<#ty_othr>
244            }
245        });
246
247    let ident = &item_struct.ident;
248    let (_, ty_generics_self, _) = item_struct.generics.split_for_impl();
249    let (_, ty_generics_othr, _) = item_struct_renamed.generics.split_for_impl();
250    let (impl_generics_both, _, _) = combined_generics.split_for_impl();
251    quote! {
252        impl #impl_generics_both #root::Merge<#ident #ty_generics_othr> for #ident #ty_generics_self
253        where
254            #both_where_predicates
255            #( #merge_where_predicates ),*
256        {
257            fn merge(&mut self, other: #ident #ty_generics_othr) -> bool {
258                let mut changed = false;
259                #(
260                    changed |= #root::Merge::merge(&mut self.#field_names, other.#field_names);
261                )*
262                changed
263            }
264        }
265    }
266}
267
268/// See [`derive_lattice_ord_macro`].
269fn derive_lattice_ord(
270    ProcessItemStruct {
271        root,
272        item_struct,
273        item_struct_renamed,
274        self_where_predicates: _,
275        both_where_predicates,
276        field_names,
277        combined_generics,
278    }: &ProcessItemStruct,
279) -> TokenStream {
280    // PartialEq.
281    let pareq_where_predicates = item_struct
282        .fields
283        .iter()
284        .zip(item_struct_renamed.fields.iter())
285        .map(|(field_self, field_othr)| {
286            let ty_self = &field_self.ty;
287            let ty_othr = &field_othr.ty;
288            quote! {
289                #ty_self: ::core::cmp::PartialEq<#ty_othr>
290            }
291        });
292    // PartialOrd and LatticeOrd.
293    let compare_where_predicates = item_struct
294        .fields
295        .iter()
296        .zip(item_struct_renamed.fields.iter())
297        .map(|(field_self, field_othr)| {
298            let ty_self = &field_self.ty;
299            let ty_othr = &field_othr.ty;
300            quote! {
301                #ty_self: ::core::cmp::PartialOrd<#ty_othr>
302            }
303        })
304        .collect::<Vec<_>>();
305
306    let ident = &item_struct.ident;
307    let (_, ty_generics_self, _) = item_struct.generics.split_for_impl();
308    let (_, ty_generics_othr, _) = item_struct_renamed.generics.split_for_impl();
309    let (impl_generics_both, _, _) = combined_generics.split_for_impl();
310    quote! {
311        impl #impl_generics_both ::core::cmp::PartialEq<#ident #ty_generics_othr> for #ident #ty_generics_self
312        where
313            #both_where_predicates
314            #( #pareq_where_predicates ),*
315        {
316            fn eq(&self, other: &#ident #ty_generics_othr) -> bool {
317                #(
318                    if !::core::cmp::PartialEq::eq(&self.#field_names, &other.#field_names) {
319                        return false;
320                    }
321                )*
322                true
323            }
324        }
325
326        impl #impl_generics_both ::core::cmp::PartialOrd<#ident #ty_generics_othr> for #ident #ty_generics_self
327        where
328            #both_where_predicates
329            #( #compare_where_predicates ),*
330        {
331            fn partial_cmp(&self, other: &#ident #ty_generics_othr) -> ::core::option::Option<::core::cmp::Ordering> {
332                let mut self_any_greater = false;
333                let mut othr_any_greater = false;
334                #(
335                    // `?` short-circuits `None` (uncomparable).
336                    match ::core::cmp::PartialOrd::partial_cmp(&self.#field_names, &other.#field_names)? {
337                        ::core::cmp::Ordering::Less => {
338                            othr_any_greater = true;
339                        }
340                        ::core::cmp::Ordering::Greater => {
341                            self_any_greater = true;
342                        }
343                        ::core::cmp::Ordering::Equal => {}
344                    }
345                    if self_any_greater && othr_any_greater {
346                        return ::core::option::Option::None;
347                    }
348                )*
349                ::core::option::Option::Some(
350                    match (self_any_greater, othr_any_greater) {
351                        (false, false) => ::core::cmp::Ordering::Equal,
352                        (false, true) => ::core::cmp::Ordering::Less,
353                        (true, false) => ::core::cmp::Ordering::Greater,
354                        (true, true) => ::core::unreachable!(),
355                    }
356                )
357            }
358        }
359        impl #impl_generics_both #root::LatticeOrd<#ident #ty_generics_othr> for #ident #ty_generics_self
360        where
361            #both_where_predicates
362            #( #compare_where_predicates ),*
363        {}
364    }
365}
366
367/// See [`derive_is_bot_macro`].
368fn derive_is_bot(
369    ProcessItemStruct {
370        root,
371        item_struct,
372        item_struct_renamed: _,
373        self_where_predicates,
374        both_where_predicates: _,
375        field_names,
376        combined_generics: _,
377    }: &ProcessItemStruct,
378) -> TokenStream {
379    let isbot_where_predicates = item_struct.fields.iter().map(|Field { ty, .. }| {
380        quote! {
381            #ty: #root::IsBot
382        }
383    });
384
385    let ident = &item_struct.ident;
386    let (impl_generics_self, ty_generics_self, _) = item_struct.generics.split_for_impl();
387    quote! {
388        impl #impl_generics_self #root::IsBot for #ident #ty_generics_self
389        where
390            #self_where_predicates
391            #( #isbot_where_predicates ),*
392        {
393            fn is_bot(&self) -> bool {
394                #(
395                    if !#root::IsBot::is_bot(&self.#field_names) {
396                        return false;
397                    }
398                )*
399                true
400            }
401        }
402    }
403}
404
405/// See [`derive_is_top_macro`].
406fn derive_is_top(
407    ProcessItemStruct {
408        root,
409        item_struct,
410        item_struct_renamed: _,
411        self_where_predicates,
412        both_where_predicates: _,
413        field_names,
414        combined_generics: _,
415    }: &ProcessItemStruct,
416) -> TokenStream {
417    let istop_where_predicates = item_struct.fields.iter().map(|Field { ty, .. }| {
418        quote! {
419            #ty: #root::IsTop
420        }
421    });
422
423    let ident = &item_struct.ident;
424    let (impl_generics_self, ty_generics_self, _) = item_struct.generics.split_for_impl();
425    quote! {
426        impl #impl_generics_self #root::IsTop for #ident #ty_generics_self
427        where
428            #self_where_predicates
429            #( #istop_where_predicates ),*
430        {
431            fn is_top(&self) -> bool {
432                #(
433                    if !#root::IsTop::is_top(&self.#field_names) {
434                        return false;
435                    }
436                )*
437                true
438            }
439        }
440    }
441}
442
443/// See [`derive_lattice_from_macro`].
444fn derive_lattice_from(
445    ProcessItemStruct {
446        root,
447        item_struct,
448        item_struct_renamed,
449        self_where_predicates: _,
450        both_where_predicates,
451        field_names,
452        combined_generics,
453    }: &ProcessItemStruct,
454) -> TokenStream {
455    let latticefrom_where_predicates = item_struct
456        .fields
457        .iter()
458        .zip(item_struct_renamed.fields.iter())
459        .map(|(field_self, field_othr)| {
460            let ty_self = &field_self.ty;
461            let ty_othr = &field_othr.ty;
462            quote! {
463                #ty_self: #root::LatticeFrom<#ty_othr>
464            }
465        });
466
467    let ident = &item_struct.ident;
468    let (_, ty_generics_self, _) = item_struct.generics.split_for_impl();
469    let (_, ty_generics_othr, _) = item_struct_renamed.generics.split_for_impl();
470    let (impl_generics_both, _, _) = combined_generics.split_for_impl();
471    quote! {
472        impl #impl_generics_both #root::LatticeFrom<#ident #ty_generics_othr> for #ident #ty_generics_self
473        where
474            #both_where_predicates
475            #( #latticefrom_where_predicates ),*
476        {
477            fn lattice_from(other: #ident #ty_generics_othr) -> Self {
478                Self {
479                    #(
480                        #field_names: #root::LatticeFrom::lattice_from(other.#field_names),
481                    )*
482                }
483            }
484        }
485    }
486}
487
488/// Also see `lattices/tests/macro.rs`
489#[cfg(test)]
490mod test {
491    use syn::parse_quote;
492
493    use super::*;
494
495    /// Snapshots the macro output without actually testing if it compiles.
496    /// See `lattices/tests/macro.rs` for compiling tests.
497    macro_rules! assert_derive_snapshots {
498        ( $( $t:tt )* ) => {
499            {
500                let item = parse_quote! {
501                    $( $t )*
502                };
503                let process_item_struct = process_item_struct(item);
504                let derive_lattice = derive_lattice(&process_item_struct);
505                insta::assert_snapshot!(prettyplease::unparse(&parse_quote! { #derive_lattice }));
506            }
507        };
508    }
509
510    #[test]
511    fn derive_example() {
512        assert_derive_snapshots! {
513            struct MyLattice<KeySet, Epoch> {
514                keys: SetUnion<KeySet>,
515                epoch: Max<Epoch>,
516            }
517        };
518    }
519
520    #[test]
521    fn derive_pair() {
522        assert_derive_snapshots! {
523            pub struct Pair<LatA, LatB> {
524                pub a: LatA,
525                pub b: LatB,
526            }
527        };
528    }
529
530    #[test]
531    fn derive_similar_fields() {
532        // Will create duplicate where clauses, but that is OK.
533        assert_derive_snapshots! {
534            pub struct SimilarFields {
535                a: Max<usize>,
536                b: Max<usize>,
537                c: Max<usize>,
538            }
539        };
540    }
541}