1#![warn(missing_docs)]
5
6use proc_macro2::{Span, TokenStream};
7use quote::{format_ident, quote};
8use syn::punctuated::Punctuated;
9use syn::visit_mut::VisitMut;
10use syn::{
11 Field, FieldsNamed, FieldsUnnamed, Generics, Ident, Index, ItemStruct, Member, Token,
12 WhereClause, WherePredicate, parse_macro_input,
13};
14
15fn root() -> TokenStream {
17 use std::env::{VarError, var as env_var};
18
19 use proc_macro_crate::FoundCrate;
20
21 if let Ok(FoundCrate::Itself) = proc_macro_crate::crate_name("lattices_macro") {
22 return quote! { lattices };
23 }
24
25 let lattices_crate_name = env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap();
26 let lattices_crate_ident = lattices_crate_name.replace('-', "_");
27 let lattices_crate = proc_macro_crate::crate_name(lattices_crate_name)
28 .unwrap_or_else(|_| panic!("`{lattices_crate_name}` should be present in `Cargo.toml`"));
29 match lattices_crate {
30 FoundCrate::Itself => {
31 if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
32 && Ok(&*lattices_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
33 {
34 quote! { crate }
36 } else {
37 let ident = Ident::new(&lattices_crate_ident, Span::call_site());
39 quote! { ::#ident }
40 }
41 }
42 FoundCrate::Name(name) => {
43 let ident = Ident::new(&name, Span::call_site());
44 quote! { ::#ident }
45 }
46 }
47}
48
49fn rename_generics(
51 item_struct: &mut ItemStruct,
52 rename: impl FnMut(&Ident) -> Ident,
53) -> Vec<WherePredicate> {
54 struct RenameGenerics<F> {
55 rename: F,
56 names: Vec<Ident>,
57 pub triggered: bool,
58 }
59 impl<F> VisitMut for RenameGenerics<F>
60 where
61 F: FnMut(&Ident) -> Ident,
62 {
63 fn visit_ident_mut(&mut self, i: &mut Ident) {
64 if self.names.contains(i) {
65 *i = (self.rename)(i);
66 self.triggered = true;
67 }
68 }
69 }
70
71 let names = item_struct
72 .generics
73 .type_params()
74 .map(|type_param| type_param.ident.clone())
75 .collect();
76 let mut visit = RenameGenerics {
77 rename,
78 names,
79 triggered: false,
80 };
81
82 let mut out = Vec::new();
83 if let Some(where_clause) = &mut item_struct.generics.where_clause {
84 for where_predicate in where_clause.predicates.iter_mut() {
85 visit.visit_where_predicate_mut(where_predicate);
86 if std::mem::take(&mut visit.triggered) {
87 out.push(where_predicate.clone());
88 }
89 }
90 }
91 for type_param in item_struct.generics.type_params_mut() {
92 visit.visit_type_param_mut(type_param);
93 }
94 for field in item_struct.fields.iter_mut() {
95 visit.visit_type_mut(&mut field.ty);
96 }
97 out
98}
99
100fn ensure_trailing<T, P>(punctuated: &mut Punctuated<T, P>)
102where
103 P: Default,
104{
105 if !punctuated.empty_or_trailing() {
106 punctuated.push_punct(Default::default());
107 }
108}
109
110#[doc = include_str!("../README.md")]
111#[proc_macro_derive(Lattice)]
112pub fn derive_lattice_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
113 derive_lattice(&process_item_struct(parse_macro_input!(item))).into()
114}
115#[proc_macro_derive(Merge)]
119pub fn derive_merge_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
120 derive_merge(&process_item_struct(parse_macro_input!(item))).into()
121}
122#[proc_macro_derive(LatticeOrd)]
126pub fn derive_lattice_ord_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
127 derive_lattice_ord(&process_item_struct(parse_macro_input!(item))).into()
128}
129#[proc_macro_derive(IsBot)]
133pub fn derive_is_bot_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
134 derive_is_bot(&process_item_struct(parse_macro_input!(item))).into()
135}
136#[proc_macro_derive(IsTop)]
140pub fn derive_is_top_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
141 derive_is_top(&process_item_struct(parse_macro_input!(item))).into()
142}
143#[proc_macro_derive(LatticeFrom)]
147pub fn derive_lattice_from_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
148 derive_lattice_from(&process_item_struct(parse_macro_input!(item))).into()
149}
150
151struct ProcessItemStruct {
153 root: TokenStream,
154 item_struct: ItemStruct,
155 item_struct_renamed: ItemStruct,
156 self_where_predicates: Punctuated<WherePredicate, Token![,]>,
157 both_where_predicates: Punctuated<WherePredicate, Token![,]>,
158 field_names: Vec<Member>,
159 combined_generics: Generics,
160}
161fn process_item_struct(item_struct: ItemStruct) -> ProcessItemStruct {
163 let mut item_struct_renamed = item_struct.clone();
164 let extra_where_predicates = rename_generics(&mut item_struct_renamed, |ident| {
165 format_ident!("__{}Other", ident)
166 });
167
168 let mut self_where_predicates = item_struct
170 .generics
171 .where_clause
172 .clone()
173 .map(|WhereClause { predicates, .. }| predicates)
174 .unwrap_or_default();
175 ensure_trailing(&mut self_where_predicates);
176 let mut both_where_predicates = self_where_predicates.clone();
178 both_where_predicates.extend(extra_where_predicates);
179 ensure_trailing(&mut both_where_predicates);
180
181 let field_names = match &item_struct.fields {
183 syn::Fields::Named(FieldsNamed { named, .. }) => named
184 .iter()
185 .map(|Field { ident, .. }| Member::Named(ident.clone().unwrap()))
186 .collect::<Vec<_>>(),
187 syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => (0..(unnamed.len() as u32))
188 .map(|index| {
189 Member::Unnamed(Index {
190 index,
191 span: Span::call_site(),
192 })
193 })
194 .collect(),
195 syn::Fields::Unit => Vec::new(),
196 };
197
198 let mut combined_generics = item_struct.generics.clone();
200 combined_generics
201 .params
202 .extend(item_struct_renamed.generics.params.clone());
203
204 ProcessItemStruct {
205 root: root(),
206 item_struct,
207 item_struct_renamed,
208 self_where_predicates,
209 both_where_predicates,
210 field_names,
211 combined_generics,
212 }
213}
214
215fn derive_lattice(process_item_struct: &ProcessItemStruct) -> TokenStream {
217 let mut out = TokenStream::new();
218 out.extend(derive_merge(process_item_struct));
219 out.extend(derive_lattice_ord(process_item_struct));
220 out.extend(derive_is_bot(process_item_struct));
221 out.extend(derive_is_top(process_item_struct));
222 out.extend(derive_lattice_from(process_item_struct));
223 out
224}
225
226fn derive_merge(
228 ProcessItemStruct {
229 root,
230 item_struct,
231 item_struct_renamed,
232 self_where_predicates: _,
233 both_where_predicates,
234 field_names,
235 combined_generics,
236 }: &ProcessItemStruct,
237) -> TokenStream {
238 let merge_where_predicates = item_struct
239 .fields
240 .iter()
241 .zip(item_struct_renamed.fields.iter())
242 .map(|(field_self, field_othr)| {
243 let ty_self = &field_self.ty;
244 let ty_othr = &field_othr.ty;
245 quote! {
246 #ty_self: #root::Merge<#ty_othr>
247 }
248 });
249
250 let ident = &item_struct.ident;
251 let (_, ty_generics_self, _) = item_struct.generics.split_for_impl();
252 let (_, ty_generics_othr, _) = item_struct_renamed.generics.split_for_impl();
253 let (impl_generics_both, _, _) = combined_generics.split_for_impl();
254 quote! {
255 impl #impl_generics_both #root::Merge<#ident #ty_generics_othr> for #ident #ty_generics_self
256 where
257 #both_where_predicates
258 #( #merge_where_predicates ),*
259 {
260 fn merge(&mut self, other: #ident #ty_generics_othr) -> bool {
261 let mut changed = false;
262 #(
263 changed |= #root::Merge::merge(&mut self.#field_names, other.#field_names);
264 )*
265 changed
266 }
267 }
268 }
269}
270
271fn derive_lattice_ord(
273 ProcessItemStruct {
274 root,
275 item_struct,
276 item_struct_renamed,
277 self_where_predicates: _,
278 both_where_predicates,
279 field_names,
280 combined_generics,
281 }: &ProcessItemStruct,
282) -> TokenStream {
283 let pareq_where_predicates = item_struct
285 .fields
286 .iter()
287 .zip(item_struct_renamed.fields.iter())
288 .map(|(field_self, field_othr)| {
289 let ty_self = &field_self.ty;
290 let ty_othr = &field_othr.ty;
291 quote! {
292 #ty_self: ::core::cmp::PartialEq<#ty_othr>
293 }
294 });
295 let compare_where_predicates = item_struct
297 .fields
298 .iter()
299 .zip(item_struct_renamed.fields.iter())
300 .map(|(field_self, field_othr)| {
301 let ty_self = &field_self.ty;
302 let ty_othr = &field_othr.ty;
303 quote! {
304 #ty_self: ::core::cmp::PartialOrd<#ty_othr>
305 }
306 })
307 .collect::<Vec<_>>();
308
309 let ident = &item_struct.ident;
310 let (_, ty_generics_self, _) = item_struct.generics.split_for_impl();
311 let (_, ty_generics_othr, _) = item_struct_renamed.generics.split_for_impl();
312 let (impl_generics_both, _, _) = combined_generics.split_for_impl();
313 quote! {
314 impl #impl_generics_both ::core::cmp::PartialEq<#ident #ty_generics_othr> for #ident #ty_generics_self
315 where
316 #both_where_predicates
317 #( #pareq_where_predicates ),*
318 {
319 fn eq(&self, other: &#ident #ty_generics_othr) -> bool {
320 #(
321 if !::core::cmp::PartialEq::eq(&self.#field_names, &other.#field_names) {
322 return false;
323 }
324 )*
325 true
326 }
327 }
328
329 impl #impl_generics_both ::core::cmp::PartialOrd<#ident #ty_generics_othr> for #ident #ty_generics_self
330 where
331 #both_where_predicates
332 #( #compare_where_predicates ),*
333 {
334 fn partial_cmp(&self, other: &#ident #ty_generics_othr) -> ::core::option::Option<::core::cmp::Ordering> {
335 let mut self_any_greater = false;
336 let mut othr_any_greater = false;
337 #(
338 match ::core::cmp::PartialOrd::partial_cmp(&self.#field_names, &other.#field_names)? {
340 ::core::cmp::Ordering::Less => {
341 othr_any_greater = true;
342 }
343 ::core::cmp::Ordering::Greater => {
344 self_any_greater = true;
345 }
346 ::core::cmp::Ordering::Equal => {}
347 }
348 if self_any_greater && othr_any_greater {
349 return ::core::option::Option::None;
350 }
351 )*
352 ::core::option::Option::Some(
353 match (self_any_greater, othr_any_greater) {
354 (false, false) => ::core::cmp::Ordering::Equal,
355 (false, true) => ::core::cmp::Ordering::Less,
356 (true, false) => ::core::cmp::Ordering::Greater,
357 (true, true) => ::core::unreachable!(),
358 }
359 )
360 }
361 }
362 impl #impl_generics_both #root::LatticeOrd<#ident #ty_generics_othr> for #ident #ty_generics_self
363 where
364 #both_where_predicates
365 #( #compare_where_predicates ),*
366 {}
367 }
368}
369
370fn derive_is_bot(
372 ProcessItemStruct {
373 root,
374 item_struct,
375 item_struct_renamed: _,
376 self_where_predicates,
377 both_where_predicates: _,
378 field_names,
379 combined_generics: _,
380 }: &ProcessItemStruct,
381) -> TokenStream {
382 let isbot_where_predicates = item_struct.fields.iter().map(|Field { ty, .. }| {
383 quote! {
384 #ty: #root::IsBot
385 }
386 });
387
388 let ident = &item_struct.ident;
389 let (impl_generics_self, ty_generics_self, _) = item_struct.generics.split_for_impl();
390 quote! {
391 impl #impl_generics_self #root::IsBot for #ident #ty_generics_self
392 where
393 #self_where_predicates
394 #( #isbot_where_predicates ),*
395 {
396 fn is_bot(&self) -> bool {
397 #(
398 if !#root::IsBot::is_bot(&self.#field_names) {
399 return false;
400 }
401 )*
402 true
403 }
404 }
405 }
406}
407
408fn derive_is_top(
410 ProcessItemStruct {
411 root,
412 item_struct,
413 item_struct_renamed: _,
414 self_where_predicates,
415 both_where_predicates: _,
416 field_names,
417 combined_generics: _,
418 }: &ProcessItemStruct,
419) -> TokenStream {
420 let istop_where_predicates = item_struct.fields.iter().map(|Field { ty, .. }| {
421 quote! {
422 #ty: #root::IsTop
423 }
424 });
425
426 let ident = &item_struct.ident;
427 let (impl_generics_self, ty_generics_self, _) = item_struct.generics.split_for_impl();
428 quote! {
429 impl #impl_generics_self #root::IsTop for #ident #ty_generics_self
430 where
431 #self_where_predicates
432 #( #istop_where_predicates ),*
433 {
434 fn is_top(&self) -> bool {
435 #(
436 if !#root::IsTop::is_top(&self.#field_names) {
437 return false;
438 }
439 )*
440 true
441 }
442 }
443 }
444}
445
446fn derive_lattice_from(
448 ProcessItemStruct {
449 root,
450 item_struct,
451 item_struct_renamed,
452 self_where_predicates: _,
453 both_where_predicates,
454 field_names,
455 combined_generics,
456 }: &ProcessItemStruct,
457) -> TokenStream {
458 let latticefrom_where_predicates = item_struct
459 .fields
460 .iter()
461 .zip(item_struct_renamed.fields.iter())
462 .map(|(field_self, field_othr)| {
463 let ty_self = &field_self.ty;
464 let ty_othr = &field_othr.ty;
465 quote! {
466 #ty_self: #root::LatticeFrom<#ty_othr>
467 }
468 });
469
470 let ident = &item_struct.ident;
471 let (_, ty_generics_self, _) = item_struct.generics.split_for_impl();
472 let (_, ty_generics_othr, _) = item_struct_renamed.generics.split_for_impl();
473 let (impl_generics_both, _, _) = combined_generics.split_for_impl();
474 quote! {
475 impl #impl_generics_both #root::LatticeFrom<#ident #ty_generics_othr> for #ident #ty_generics_self
476 where
477 #both_where_predicates
478 #( #latticefrom_where_predicates ),*
479 {
480 fn lattice_from(other: #ident #ty_generics_othr) -> Self {
481 Self {
482 #(
483 #field_names: #root::LatticeFrom::lattice_from(other.#field_names),
484 )*
485 }
486 }
487 }
488 }
489}
490
491#[cfg(test)]
493mod test {
494 use syn::parse_quote;
495
496 use super::*;
497
498 macro_rules! assert_derive_snapshots {
501 ( $( $t:tt )* ) => {
502 {
503 let item = parse_quote! {
504 $( $t )*
505 };
506 let process_item_struct = process_item_struct(item);
507 let derive_lattice = derive_lattice(&process_item_struct);
508 hydro_build_utils::assert_snapshot!(prettyplease::unparse(&parse_quote! { #derive_lattice }));
509 }
510 };
511 }
512
513 #[test]
514 fn derive_example() {
515 assert_derive_snapshots! {
516 struct MyLattice<KeySet, Epoch> {
517 keys: SetUnion<KeySet>,
518 epoch: Max<Epoch>,
519 }
520 };
521 }
522
523 #[test]
524 fn derive_pair() {
525 assert_derive_snapshots! {
526 pub struct Pair<LatA, LatB> {
527 pub a: LatA,
528 pub b: LatB,
529 }
530 };
531 }
532
533 #[test]
534 fn derive_similar_fields() {
535 assert_derive_snapshots! {
537 pub struct SimilarFields {
538 a: Max<usize>,
539 b: Max<usize>,
540 c: Max<usize>,
541 }
542 };
543 }
544}