Skip to main content

lattices/
vec_union.rs

1use alloc::vec::Vec;
2use core::cmp::Ordering::{self, *};
3
4use cc_traits::Iter;
5
6use crate::{DeepReveal, IsBot, IsTop, LatticeFrom, LatticeOrd, Merge};
7
8/// Vec-union compound lattice.
9///
10/// Contains any number of `Lat` sub-lattices. Sub-lattices are indexed starting at zero, merging
11/// combines corresponding sub-lattices and keeps any excess.
12///
13/// Similar to [`MapUnion<<usize, Lat>>`](super::map_union::MapUnion) but requires the key indices
14/// start with `0`, `1`, `2`, etc: i.e. integers starting at zero with no gaps.
15#[derive(Clone, Debug, Eq)]
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17pub struct VecUnion<Lat> {
18    vec: Vec<Lat>,
19}
20
21impl<Lat> VecUnion<Lat> {
22    /// Create a new `VecUnion` from a `Vec` of `Lat` instances.
23    pub fn new(vec: Vec<Lat>) -> Self {
24        Self { vec }
25    }
26
27    /// Create a new `VecUnion` from an `Into<Vec<Lat>>`.
28    pub fn new_from(vec: impl Into<Vec<Lat>>) -> Self {
29        Self::new(vec.into())
30    }
31
32    /// Reveal the inner value as a shared reference.
33    pub fn as_reveal_ref(&self) -> &Vec<Lat> {
34        &self.vec
35    }
36
37    /// Reveal the inner value as an exclusive reference.
38    pub fn as_reveal_mut(&mut self) -> &mut Vec<Lat> {
39        &mut self.vec
40    }
41
42    /// Gets the inner by value, consuming self.
43    pub fn into_reveal(self) -> Vec<Lat> {
44        self.vec
45    }
46}
47
48impl<Lat> DeepReveal for VecUnion<Lat>
49where
50    Lat: DeepReveal,
51{
52    type Revealed = Vec<Lat::Revealed>;
53
54    fn deep_reveal(self) -> Self::Revealed {
55        self.vec.into_iter().map(DeepReveal::deep_reveal).collect()
56    }
57}
58
59impl<Lat> Default for VecUnion<Lat> {
60    fn default() -> Self {
61        Self {
62            vec: Default::default(),
63        }
64    }
65}
66
67impl<LatSelf, LatOther> Merge<VecUnion<LatOther>> for VecUnion<LatSelf>
68where
69    LatSelf: Merge<LatOther> + LatticeFrom<LatOther>,
70{
71    fn merge(&mut self, mut other: VecUnion<LatOther>) -> bool {
72        let mut changed = false;
73        // Extend `self` if `other` is longer.
74        if self.vec.len() < other.vec.len() {
75            self.vec
76                .extend(other.vec.drain(self.vec.len()..).map(LatSelf::lattice_from));
77            changed = true;
78        }
79        // Merge intersecting indices.
80        for (self_val, other_val) in self.vec.iter_mut().zip(other.vec) {
81            changed |= self_val.merge(other_val);
82        }
83        changed
84    }
85}
86
87impl<LatSelf, LatOther> LatticeFrom<VecUnion<LatOther>> for VecUnion<LatSelf>
88where
89    LatSelf: LatticeFrom<LatOther>,
90{
91    fn lattice_from(other: VecUnion<LatOther>) -> Self {
92        Self::new(other.vec.into_iter().map(LatSelf::lattice_from).collect())
93    }
94}
95
96impl<LatSelf, LatOther> PartialEq<VecUnion<LatOther>> for VecUnion<LatSelf>
97where
98    LatSelf: PartialEq<LatOther>,
99{
100    fn eq(&self, other: &VecUnion<LatOther>) -> bool {
101        if self.vec.len() != other.vec.len() {
102            false
103        } else {
104            self.vec
105                .iter()
106                .zip(other.vec.iter())
107                .all(|(val_self, val_other)| val_self == val_other)
108        }
109    }
110}
111
112impl<LatSelf, LatOther> PartialOrd<VecUnion<LatOther>> for VecUnion<LatSelf>
113where
114    LatSelf: PartialOrd<LatOther>,
115{
116    fn partial_cmp(&self, other: &VecUnion<LatOther>) -> Option<Ordering> {
117        let (self_len, other_len) = (self.vec.len(), other.vec.len());
118        let mut self_any_greater = other_len < self_len;
119        let mut other_any_greater = self_len < other_len;
120        for (self_val, other_val) in self.vec.iter().zip(other.vec.iter()) {
121            match self_val.partial_cmp(other_val) {
122                None => {
123                    return None;
124                }
125                Some(Less) => {
126                    other_any_greater = true;
127                }
128                Some(Greater) => {
129                    self_any_greater = true;
130                }
131                Some(Equal) => {}
132            }
133            if self_any_greater && other_any_greater {
134                return None;
135            }
136        }
137        match (self_any_greater, other_any_greater) {
138            (true, false) => Some(Greater),
139            (false, true) => Some(Less),
140            (false, false) => Some(Equal),
141            // We check this one after each loop iteration above.
142            (true, true) => unreachable!(),
143        }
144    }
145}
146impl<LatSelf, LatOther> LatticeOrd<VecUnion<LatOther>> for VecUnion<LatSelf> where
147    Self: PartialOrd<VecUnion<LatOther>>
148{
149}
150
151impl<Lat> IsBot for VecUnion<Lat> {
152    fn is_bot(&self) -> bool {
153        self.vec.is_empty()
154    }
155}
156
157impl<Lat> IsTop for VecUnion<Lat> {
158    fn is_top(&self) -> bool {
159        false
160    }
161}
162
163#[cfg(test)]
164mod test {
165    use std::collections::HashSet;
166    use std::vec;
167    use std::vec::Vec;
168
169    use super::*;
170    use crate::Max;
171    use crate::set_union::SetUnionHashSet;
172    use crate::test::{cartesian_power, check_all};
173
174    #[test]
175    fn basic() {
176        let mut my_vec_a = VecUnion::<Max<usize>>::default();
177        let my_vec_b = VecUnion::new(vec![Max::new(9), Max::new(4), Max::new(5)]);
178        let my_vec_c = VecUnion::new(vec![Max::new(2), Max::new(5)]);
179
180        assert!(my_vec_a.merge(my_vec_b.clone()));
181        assert!(!my_vec_a.merge(my_vec_b));
182        assert!(my_vec_a.merge(my_vec_c.clone()));
183        assert!(!my_vec_a.merge(my_vec_c));
184    }
185
186    #[test]
187    fn consistency() {
188        let mut test_vec = vec![VecUnion::new(vec![] as Vec<SetUnionHashSet<_>>)];
189
190        let vals = [vec![], vec![0], vec![1], vec![0, 1]]
191            .map(HashSet::from_iter)
192            .map(SetUnionHashSet::new);
193
194        test_vec.extend(
195            cartesian_power::<_, 1>(&vals)
196                .map(|row| VecUnion::new(row.into_iter().cloned().collect())),
197        );
198        test_vec.extend(
199            cartesian_power::<_, 2>(&vals)
200                .map(|row| VecUnion::new(row.into_iter().cloned().collect())),
201        );
202
203        check_all(&test_vec);
204    }
205}