1#![warn(missing_docs)]
5
6use proc_macro2::{Span, TokenStream};
7use quote::{format_ident, quote};
8use syn::punctuated::Punctuated;
9use syn::visit_mut::VisitMut;
10use syn::{
11 Field, FieldsNamed, FieldsUnnamed, Generics, Ident, Index, ItemStruct, Member, Token,
12 WhereClause, WherePredicate, parse_macro_input,
13};
14
15fn root() -> TokenStream {
17 use std::env::{VarError, var as env_var};
18
19 use proc_macro_crate::FoundCrate;
20
21 if let Ok(FoundCrate::Itself) = proc_macro_crate::crate_name("lattices_macro") {
22 return quote! { lattices };
23 }
24
25 let lattices_crate = proc_macro_crate::crate_name("lattices")
26 .expect("`lattices` should be present in `Cargo.toml`");
27 match lattices_crate {
28 FoundCrate::Itself => {
29 if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
30 && Ok("lattices") == env_var("CARGO_CRATE_NAME").as_deref()
31 {
32 quote! { crate }
34 } else {
35 quote! { ::lattices }
37 }
38 }
39 FoundCrate::Name(name) => {
40 let ident = Ident::new(&name, Span::call_site());
41 quote! { ::#ident }
42 }
43 }
44}
45
46fn rename_generics(
48 item_struct: &mut ItemStruct,
49 rename: impl FnMut(&Ident) -> Ident,
50) -> Vec<WherePredicate> {
51 struct RenameGenerics<F> {
52 rename: F,
53 names: Vec<Ident>,
54 pub triggered: bool,
55 }
56 impl<F> VisitMut for RenameGenerics<F>
57 where
58 F: FnMut(&Ident) -> Ident,
59 {
60 fn visit_ident_mut(&mut self, i: &mut Ident) {
61 if self.names.contains(i) {
62 *i = (self.rename)(i);
63 self.triggered = true;
64 }
65 }
66 }
67
68 let names = item_struct
69 .generics
70 .type_params()
71 .map(|type_param| type_param.ident.clone())
72 .collect();
73 let mut visit = RenameGenerics {
74 rename,
75 names,
76 triggered: false,
77 };
78
79 let mut out = Vec::new();
80 if let Some(where_clause) = &mut item_struct.generics.where_clause {
81 for where_predicate in where_clause.predicates.iter_mut() {
82 visit.visit_where_predicate_mut(where_predicate);
83 if std::mem::take(&mut visit.triggered) {
84 out.push(where_predicate.clone());
85 }
86 }
87 }
88 for type_param in item_struct.generics.type_params_mut() {
89 visit.visit_type_param_mut(type_param);
90 }
91 for field in item_struct.fields.iter_mut() {
92 visit.visit_type_mut(&mut field.ty);
93 }
94 out
95}
96
97fn ensure_trailing<T, P>(punctuated: &mut Punctuated<T, P>)
99where
100 P: Default,
101{
102 if !punctuated.empty_or_trailing() {
103 punctuated.push_punct(Default::default());
104 }
105}
106
107#[doc = include_str!("../README.md")]
108#[proc_macro_derive(Lattice)]
109pub fn derive_lattice_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
110 derive_lattice(&process_item_struct(parse_macro_input!(item))).into()
111}
112#[proc_macro_derive(Merge)]
116pub fn derive_merge_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
117 derive_merge(&process_item_struct(parse_macro_input!(item))).into()
118}
119#[proc_macro_derive(LatticeOrd)]
123pub fn derive_lattice_ord_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
124 derive_lattice_ord(&process_item_struct(parse_macro_input!(item))).into()
125}
126#[proc_macro_derive(IsBot)]
130pub fn derive_is_bot_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
131 derive_is_bot(&process_item_struct(parse_macro_input!(item))).into()
132}
133#[proc_macro_derive(IsTop)]
137pub fn derive_is_top_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
138 derive_is_top(&process_item_struct(parse_macro_input!(item))).into()
139}
140#[proc_macro_derive(LatticeFrom)]
144pub fn derive_lattice_from_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
145 derive_lattice_from(&process_item_struct(parse_macro_input!(item))).into()
146}
147
148struct ProcessItemStruct {
150 root: TokenStream,
151 item_struct: ItemStruct,
152 item_struct_renamed: ItemStruct,
153 self_where_predicates: Punctuated<WherePredicate, Token![,]>,
154 both_where_predicates: Punctuated<WherePredicate, Token![,]>,
155 field_names: Vec<Member>,
156 combined_generics: Generics,
157}
158fn process_item_struct(item_struct: ItemStruct) -> ProcessItemStruct {
160 let mut item_struct_renamed = item_struct.clone();
161 let extra_where_predicates = rename_generics(&mut item_struct_renamed, |ident| {
162 format_ident!("__{}Other", ident)
163 });
164
165 let mut self_where_predicates = item_struct
167 .generics
168 .where_clause
169 .clone()
170 .map(|WhereClause { predicates, .. }| predicates)
171 .unwrap_or_default();
172 ensure_trailing(&mut self_where_predicates);
173 let mut both_where_predicates = self_where_predicates.clone();
175 both_where_predicates.extend(extra_where_predicates);
176 ensure_trailing(&mut both_where_predicates);
177
178 let field_names = match &item_struct.fields {
180 syn::Fields::Named(FieldsNamed { named, .. }) => named
181 .iter()
182 .map(|Field { ident, .. }| Member::Named(ident.clone().unwrap()))
183 .collect::<Vec<_>>(),
184 syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => (0..(unnamed.len() as u32))
185 .map(|index| {
186 Member::Unnamed(Index {
187 index,
188 span: Span::call_site(),
189 })
190 })
191 .collect(),
192 syn::Fields::Unit => Vec::new(),
193 };
194
195 let mut combined_generics = item_struct.generics.clone();
197 combined_generics
198 .params
199 .extend(item_struct_renamed.generics.params.clone());
200
201 ProcessItemStruct {
202 root: root(),
203 item_struct,
204 item_struct_renamed,
205 self_where_predicates,
206 both_where_predicates,
207 field_names,
208 combined_generics,
209 }
210}
211
212fn derive_lattice(process_item_struct: &ProcessItemStruct) -> TokenStream {
214 let mut out = TokenStream::new();
215 out.extend(derive_merge(process_item_struct));
216 out.extend(derive_lattice_ord(process_item_struct));
217 out.extend(derive_is_bot(process_item_struct));
218 out.extend(derive_is_top(process_item_struct));
219 out.extend(derive_lattice_from(process_item_struct));
220 out
221}
222
223fn derive_merge(
225 ProcessItemStruct {
226 root,
227 item_struct,
228 item_struct_renamed,
229 self_where_predicates: _,
230 both_where_predicates,
231 field_names,
232 combined_generics,
233 }: &ProcessItemStruct,
234) -> TokenStream {
235 let merge_where_predicates = item_struct
236 .fields
237 .iter()
238 .zip(item_struct_renamed.fields.iter())
239 .map(|(field_self, field_othr)| {
240 let ty_self = &field_self.ty;
241 let ty_othr = &field_othr.ty;
242 quote! {
243 #ty_self: #root::Merge<#ty_othr>
244 }
245 });
246
247 let ident = &item_struct.ident;
248 let (_, ty_generics_self, _) = item_struct.generics.split_for_impl();
249 let (_, ty_generics_othr, _) = item_struct_renamed.generics.split_for_impl();
250 let (impl_generics_both, _, _) = combined_generics.split_for_impl();
251 quote! {
252 impl #impl_generics_both #root::Merge<#ident #ty_generics_othr> for #ident #ty_generics_self
253 where
254 #both_where_predicates
255 #( #merge_where_predicates ),*
256 {
257 fn merge(&mut self, other: #ident #ty_generics_othr) -> bool {
258 let mut changed = false;
259 #(
260 changed |= #root::Merge::merge(&mut self.#field_names, other.#field_names);
261 )*
262 changed
263 }
264 }
265 }
266}
267
268fn derive_lattice_ord(
270 ProcessItemStruct {
271 root,
272 item_struct,
273 item_struct_renamed,
274 self_where_predicates: _,
275 both_where_predicates,
276 field_names,
277 combined_generics,
278 }: &ProcessItemStruct,
279) -> TokenStream {
280 let pareq_where_predicates = item_struct
282 .fields
283 .iter()
284 .zip(item_struct_renamed.fields.iter())
285 .map(|(field_self, field_othr)| {
286 let ty_self = &field_self.ty;
287 let ty_othr = &field_othr.ty;
288 quote! {
289 #ty_self: ::core::cmp::PartialEq<#ty_othr>
290 }
291 });
292 let compare_where_predicates = item_struct
294 .fields
295 .iter()
296 .zip(item_struct_renamed.fields.iter())
297 .map(|(field_self, field_othr)| {
298 let ty_self = &field_self.ty;
299 let ty_othr = &field_othr.ty;
300 quote! {
301 #ty_self: ::core::cmp::PartialOrd<#ty_othr>
302 }
303 })
304 .collect::<Vec<_>>();
305
306 let ident = &item_struct.ident;
307 let (_, ty_generics_self, _) = item_struct.generics.split_for_impl();
308 let (_, ty_generics_othr, _) = item_struct_renamed.generics.split_for_impl();
309 let (impl_generics_both, _, _) = combined_generics.split_for_impl();
310 quote! {
311 impl #impl_generics_both ::core::cmp::PartialEq<#ident #ty_generics_othr> for #ident #ty_generics_self
312 where
313 #both_where_predicates
314 #( #pareq_where_predicates ),*
315 {
316 fn eq(&self, other: &#ident #ty_generics_othr) -> bool {
317 #(
318 if !::core::cmp::PartialEq::eq(&self.#field_names, &other.#field_names) {
319 return false;
320 }
321 )*
322 true
323 }
324 }
325
326 impl #impl_generics_both ::core::cmp::PartialOrd<#ident #ty_generics_othr> for #ident #ty_generics_self
327 where
328 #both_where_predicates
329 #( #compare_where_predicates ),*
330 {
331 fn partial_cmp(&self, other: &#ident #ty_generics_othr) -> ::core::option::Option<::core::cmp::Ordering> {
332 let mut self_any_greater = false;
333 let mut othr_any_greater = false;
334 #(
335 match ::core::cmp::PartialOrd::partial_cmp(&self.#field_names, &other.#field_names)? {
337 ::core::cmp::Ordering::Less => {
338 othr_any_greater = true;
339 }
340 ::core::cmp::Ordering::Greater => {
341 self_any_greater = true;
342 }
343 ::core::cmp::Ordering::Equal => {}
344 }
345 if self_any_greater && othr_any_greater {
346 return ::core::option::Option::None;
347 }
348 )*
349 ::core::option::Option::Some(
350 match (self_any_greater, othr_any_greater) {
351 (false, false) => ::core::cmp::Ordering::Equal,
352 (false, true) => ::core::cmp::Ordering::Less,
353 (true, false) => ::core::cmp::Ordering::Greater,
354 (true, true) => ::core::unreachable!(),
355 }
356 )
357 }
358 }
359 impl #impl_generics_both #root::LatticeOrd<#ident #ty_generics_othr> for #ident #ty_generics_self
360 where
361 #both_where_predicates
362 #( #compare_where_predicates ),*
363 {}
364 }
365}
366
367fn derive_is_bot(
369 ProcessItemStruct {
370 root,
371 item_struct,
372 item_struct_renamed: _,
373 self_where_predicates,
374 both_where_predicates: _,
375 field_names,
376 combined_generics: _,
377 }: &ProcessItemStruct,
378) -> TokenStream {
379 let isbot_where_predicates = item_struct.fields.iter().map(|Field { ty, .. }| {
380 quote! {
381 #ty: #root::IsBot
382 }
383 });
384
385 let ident = &item_struct.ident;
386 let (impl_generics_self, ty_generics_self, _) = item_struct.generics.split_for_impl();
387 quote! {
388 impl #impl_generics_self #root::IsBot for #ident #ty_generics_self
389 where
390 #self_where_predicates
391 #( #isbot_where_predicates ),*
392 {
393 fn is_bot(&self) -> bool {
394 #(
395 if !#root::IsBot::is_bot(&self.#field_names) {
396 return false;
397 }
398 )*
399 true
400 }
401 }
402 }
403}
404
405fn derive_is_top(
407 ProcessItemStruct {
408 root,
409 item_struct,
410 item_struct_renamed: _,
411 self_where_predicates,
412 both_where_predicates: _,
413 field_names,
414 combined_generics: _,
415 }: &ProcessItemStruct,
416) -> TokenStream {
417 let istop_where_predicates = item_struct.fields.iter().map(|Field { ty, .. }| {
418 quote! {
419 #ty: #root::IsTop
420 }
421 });
422
423 let ident = &item_struct.ident;
424 let (impl_generics_self, ty_generics_self, _) = item_struct.generics.split_for_impl();
425 quote! {
426 impl #impl_generics_self #root::IsTop for #ident #ty_generics_self
427 where
428 #self_where_predicates
429 #( #istop_where_predicates ),*
430 {
431 fn is_top(&self) -> bool {
432 #(
433 if !#root::IsTop::is_top(&self.#field_names) {
434 return false;
435 }
436 )*
437 true
438 }
439 }
440 }
441}
442
443fn derive_lattice_from(
445 ProcessItemStruct {
446 root,
447 item_struct,
448 item_struct_renamed,
449 self_where_predicates: _,
450 both_where_predicates,
451 field_names,
452 combined_generics,
453 }: &ProcessItemStruct,
454) -> TokenStream {
455 let latticefrom_where_predicates = item_struct
456 .fields
457 .iter()
458 .zip(item_struct_renamed.fields.iter())
459 .map(|(field_self, field_othr)| {
460 let ty_self = &field_self.ty;
461 let ty_othr = &field_othr.ty;
462 quote! {
463 #ty_self: #root::LatticeFrom<#ty_othr>
464 }
465 });
466
467 let ident = &item_struct.ident;
468 let (_, ty_generics_self, _) = item_struct.generics.split_for_impl();
469 let (_, ty_generics_othr, _) = item_struct_renamed.generics.split_for_impl();
470 let (impl_generics_both, _, _) = combined_generics.split_for_impl();
471 quote! {
472 impl #impl_generics_both #root::LatticeFrom<#ident #ty_generics_othr> for #ident #ty_generics_self
473 where
474 #both_where_predicates
475 #( #latticefrom_where_predicates ),*
476 {
477 fn lattice_from(other: #ident #ty_generics_othr) -> Self {
478 Self {
479 #(
480 #field_names: #root::LatticeFrom::lattice_from(other.#field_names),
481 )*
482 }
483 }
484 }
485 }
486}
487
488#[cfg(test)]
490mod test {
491 use syn::parse_quote;
492
493 use super::*;
494
495 macro_rules! assert_derive_snapshots {
498 ( $( $t:tt )* ) => {
499 {
500 let item = parse_quote! {
501 $( $t )*
502 };
503 let process_item_struct = process_item_struct(item);
504 let derive_lattice = derive_lattice(&process_item_struct);
505 insta::assert_snapshot!(prettyplease::unparse(&parse_quote! { #derive_lattice }));
506 }
507 };
508 }
509
510 #[test]
511 fn derive_example() {
512 assert_derive_snapshots! {
513 struct MyLattice<KeySet, Epoch> {
514 keys: SetUnion<KeySet>,
515 epoch: Max<Epoch>,
516 }
517 };
518 }
519
520 #[test]
521 fn derive_pair() {
522 assert_derive_snapshots! {
523 pub struct Pair<LatA, LatB> {
524 pub a: LatA,
525 pub b: LatB,
526 }
527 };
528 }
529
530 #[test]
531 fn derive_similar_fields() {
532 assert_derive_snapshots! {
534 pub struct SimilarFields {
535 a: Max<usize>,
536 b: Max<usize>,
537 c: Max<usize>,
538 }
539 };
540 }
541}