Skip to main content

lattices/
map_union.rs

1//! Module containing the [`MapUnion`] lattice and aliases for different datastructures.
2
3#[cfg(feature = "alloc")]
4use alloc::collections::BTreeMap;
5use core::cmp::Ordering::{self, *};
6use core::fmt::Debug;
7use core::marker::PhantomData;
8#[cfg(feature = "std")]
9use std::collections::HashMap;
10
11use cc_traits::{Collection, GetKeyValue, Iter, MapInsert, SimpleCollectionRef};
12
13#[cfg(feature = "alloc")]
14use crate::cc_traits::GetMut;
15use crate::cc_traits::{Keyed, Map, MapIter, SimpleKeyedRef};
16#[cfg(feature = "alloc")]
17use crate::collections::VecMap;
18use crate::collections::{ArrayMap, MapMapValues, OptionMap, SingletonMap};
19#[cfg(feature = "alloc")]
20use crate::{Atomize, Merge};
21use crate::{DeepReveal, IsBot, IsTop, LatticeBimorphism, LatticeFrom, LatticeOrd};
22
23/// Map-union compound lattice.
24///
25/// Each key corresponds to a lattice value instance. Merging map-union lattices is done by
26/// unioning the keys and merging the values of intersecting keys.
27#[repr(transparent)]
28#[derive(Copy, Clone, Debug, Default)]
29#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
30pub struct MapUnion<Map>(Map);
31impl<Map> MapUnion<Map> {
32    /// Create a new `MapUnion` from a `Map`.
33    pub fn new(val: Map) -> Self {
34        Self(val)
35    }
36
37    /// Create a new `MapUnion` from an `Into<Map>`.
38    pub fn new_from(val: impl Into<Map>) -> Self {
39        Self::new(val.into())
40    }
41
42    /// Reveal the inner value as a shared reference.
43    pub fn as_reveal_ref(&self) -> &Map {
44        &self.0
45    }
46
47    /// Reveal the inner value as an exclusive reference.
48    pub fn as_reveal_mut(&mut self) -> &mut Map {
49        &mut self.0
50    }
51
52    /// Gets the inner by value, consuming self.
53    pub fn into_reveal(self) -> Map {
54        self.0
55    }
56}
57
58impl<Map, Val> DeepReveal for MapUnion<Map>
59where
60    Map: Keyed<Item = Val> + MapMapValues<Val>,
61    Val: DeepReveal,
62{
63    type Revealed = Map::MapValue<Val::Revealed>;
64
65    fn deep_reveal(self) -> Self::Revealed {
66        self.0.map_values(DeepReveal::deep_reveal)
67    }
68}
69
70#[cfg(feature = "alloc")]
71impl<MapSelf, MapOther, K, ValSelf, ValOther> Merge<MapUnion<MapOther>> for MapUnion<MapSelf>
72where
73    MapSelf: Keyed<Key = K, Item = ValSelf>
74        + Extend<(K, ValSelf)>
75        + for<'a> GetMut<&'a K, Item = ValSelf>,
76    MapOther: IntoIterator<Item = (K, ValOther)>,
77    ValSelf: Merge<ValOther> + LatticeFrom<ValOther>,
78    ValOther: IsBot,
79{
80    fn merge(&mut self, other: MapUnion<MapOther>) -> bool {
81        use alloc::vec::Vec;
82
83        let mut changed = false;
84        // This vec collect is needed to prevent simultaneous mut references `self.0.extend` and
85        // `self.0.get_mut`.
86        // TODO(mingwei): This could be fixed with a different structure, maybe some sort of
87        // `Collection` entry API.
88        let iter: Vec<_> = other
89            .0
90            .into_iter()
91            .filter(|(_k_other, val_other)| !val_other.is_bot())
92            .filter_map(|(k_other, val_other)| {
93                match self.0.get_mut(&k_other) {
94                    // Key collision, merge into `self`.
95                    Some(mut val_self) => {
96                        changed |= val_self.merge(val_other);
97                        None
98                    }
99                    // New value, convert for extending.
100                    None => {
101                        changed = true;
102                        Some((k_other, ValSelf::lattice_from(val_other)))
103                    }
104                }
105            })
106            .collect();
107        self.0.extend(iter);
108        changed
109    }
110}
111
112impl<MapSelf, MapOther, K, ValSelf, ValOther> LatticeFrom<MapUnion<MapOther>> for MapUnion<MapSelf>
113where
114    MapSelf: Keyed<Key = K, Item = ValSelf> + FromIterator<(K, ValSelf)>,
115    MapOther: IntoIterator<Item = (K, ValOther)>,
116    ValSelf: LatticeFrom<ValOther>,
117{
118    fn lattice_from(other: MapUnion<MapOther>) -> Self {
119        Self(
120            other
121                .0
122                .into_iter()
123                .map(|(k_other, val_other)| (k_other, LatticeFrom::lattice_from(val_other)))
124                .collect(),
125        )
126    }
127}
128
129impl<MapSelf, MapOther, K, ValSelf, ValOther> PartialOrd<MapUnion<MapOther>> for MapUnion<MapSelf>
130where
131    MapSelf: Map<K, ValSelf, Key = K, Item = ValSelf> + MapIter + SimpleKeyedRef,
132    MapOther: Map<K, ValOther, Key = K, Item = ValOther> + MapIter + SimpleKeyedRef,
133    ValSelf: PartialOrd<ValOther> + IsBot,
134    ValOther: IsBot,
135{
136    fn partial_cmp(&self, other: &MapUnion<MapOther>) -> Option<Ordering> {
137        let mut self_any_greater = false;
138        let mut other_any_greater = false;
139        let self_keys = self
140            .0
141            .iter()
142            .filter(|(_k, v)| !v.is_bot())
143            .map(|(k, _v)| <MapSelf as SimpleKeyedRef>::into_ref(k));
144        let other_keys = other
145            .0
146            .iter()
147            .filter(|(_k, v)| !v.is_bot())
148            .map(|(k, _v)| <MapOther as SimpleKeyedRef>::into_ref(k));
149        for k in self_keys.chain(other_keys) {
150            match (self.0.get(k), other.0.get(k)) {
151                (Some(self_value), Some(other_value)) => {
152                    match self_value.partial_cmp(&*other_value)? {
153                        Less => {
154                            other_any_greater = true;
155                        }
156                        Greater => {
157                            self_any_greater = true;
158                        }
159                        Equal => {}
160                    }
161                }
162                (Some(_), None) => {
163                    self_any_greater = true;
164                }
165                (None, Some(_)) => {
166                    other_any_greater = true;
167                }
168                (None, None) => unreachable!(),
169            }
170            if self_any_greater && other_any_greater {
171                return None;
172            }
173        }
174        match (self_any_greater, other_any_greater) {
175            (true, false) => Some(Greater),
176            (false, true) => Some(Less),
177            (false, false) => Some(Equal),
178            // We check this one after each loop iteration above.
179            (true, true) => unreachable!(),
180        }
181    }
182}
183impl<MapSelf, MapOther> LatticeOrd<MapUnion<MapOther>> for MapUnion<MapSelf> where
184    Self: PartialOrd<MapUnion<MapOther>>
185{
186}
187
188impl<MapSelf, MapOther, K, ValSelf, ValOther> PartialEq<MapUnion<MapOther>> for MapUnion<MapSelf>
189where
190    MapSelf: Map<K, ValSelf, Key = K, Item = ValSelf> + MapIter + SimpleKeyedRef,
191    MapOther: Map<K, ValOther, Key = K, Item = ValOther> + MapIter + SimpleKeyedRef,
192    ValSelf: PartialEq<ValOther> + IsBot,
193    ValOther: IsBot,
194{
195    fn eq(&self, other: &MapUnion<MapOther>) -> bool {
196        let self_keys = self
197            .0
198            .iter()
199            .filter(|(_k, v)| !v.is_bot())
200            .map(|(k, _v)| <MapSelf as SimpleKeyedRef>::into_ref(k));
201        let other_keys = other
202            .0
203            .iter()
204            .filter(|(_k, v)| !v.is_bot())
205            .map(|(k, _v)| <MapOther as SimpleKeyedRef>::into_ref(k));
206        for k in self_keys.chain(other_keys) {
207            match (self.0.get(k), other.0.get(k)) {
208                (Some(self_value), Some(other_value)) => {
209                    if *self_value != *other_value {
210                        return false;
211                    }
212                }
213                (None, None) => unreachable!(),
214                _ => {
215                    return false;
216                }
217            }
218        }
219
220        true
221    }
222}
223impl<MapSelf> Eq for MapUnion<MapSelf> where Self: PartialEq {}
224
225impl<Map> IsBot for MapUnion<Map>
226where
227    Map: Iter,
228    Map::Item: IsBot,
229{
230    fn is_bot(&self) -> bool {
231        self.0.iter().all(|v| v.is_bot())
232    }
233}
234
235impl<Map> IsTop for MapUnion<Map> {
236    fn is_top(&self) -> bool {
237        false
238    }
239}
240
241#[cfg(feature = "alloc")]
242impl<Map, K, Val> Atomize for MapUnion<Map>
243where
244    Map: 'static
245        + IntoIterator<Item = (K, Val)>
246        + Keyed<Key = K, Item = Val>
247        + Extend<(K, Val)>
248        + for<'a> GetMut<&'a K, Item = Val>,
249    K: 'static + Clone,
250    Val: 'static + Atomize + LatticeFrom<<Val as Atomize>::Atom>,
251{
252    type Atom = MapUnionSingletonMap<K, Val::Atom>;
253
254    // TODO: use impl trait, then remove 'static.
255    type AtomIter = alloc::boxed::Box<dyn Iterator<Item = Self::Atom>>;
256
257    fn atomize(self) -> Self::AtomIter {
258        alloc::boxed::Box::new(self.0.into_iter().flat_map(|(k, val)| {
259            val.atomize()
260                .map(move |v| MapUnionSingletonMap::new_from((k.clone(), v)))
261        }))
262    }
263}
264
265/// [`std::collections::HashMap`]-backed [`MapUnion`] lattice.
266#[cfg(feature = "std")]
267pub type MapUnionHashMap<K, Val> = MapUnion<HashMap<K, Val>>;
268
269/// [`std::collections::BTreeMap`]-backed [`MapUnion`] lattice.
270#[cfg(feature = "alloc")]
271pub type MapUnionBTreeMap<K, Val> = MapUnion<BTreeMap<K, Val>>;
272
273/// [`Vec`](alloc::vec::Vec)-backed [`MapUnion`] lattice.
274#[cfg(feature = "alloc")]
275pub type MapUnionVec<K, Val> = MapUnion<VecMap<K, Val>>;
276
277/// Array-backed [`MapUnion`] lattice.
278pub type MapUnionArrayMap<K, Val, const N: usize> = MapUnion<ArrayMap<K, Val, N>>;
279
280/// [`crate::collections::SingletonMap`]-backed [`MapUnion`] lattice.
281pub type MapUnionSingletonMap<K, Val> = MapUnion<SingletonMap<K, Val>>;
282
283/// [`Option`]-backed [`MapUnion`] lattice.
284pub type MapUnionOptionMap<K, Val> = MapUnion<OptionMap<K, Val>>;
285
286/// Composable bimorphism, wraps an existing morphism by partitioning it per key.
287///
288/// For example, `KeyedBimorphism<..., CartesianProduct<...>>` is a join.
289pub struct KeyedBimorphism<MapOut, Bimorphism> {
290    bimorphism: Bimorphism,
291    _phantom: PhantomData<fn() -> MapOut>,
292}
293impl<MapOut, Bimorphism> KeyedBimorphism<MapOut, Bimorphism> {
294    /// Create a `KeyedBimorphism` using `bimorphism` for handling values.
295    pub fn new(bimorphism: Bimorphism) -> Self {
296        Self {
297            bimorphism,
298            _phantom: PhantomData,
299        }
300    }
301}
302impl<MapA, MapB, MapOut, ValFunc> LatticeBimorphism<MapUnion<MapA>, MapUnion<MapB>>
303    for KeyedBimorphism<MapOut, ValFunc>
304where
305    ValFunc: LatticeBimorphism<MapA::Item, MapB::Item>,
306    MapA: MapIter + SimpleKeyedRef + SimpleCollectionRef,
307    MapB: for<'a> GetKeyValue<&'a MapA::Key, Key = MapA::Key> + SimpleCollectionRef,
308    MapA::Key: Clone + Eq,
309    MapA::Item: Clone,
310    MapB::Item: Clone,
311    MapOut: Default + MapInsert<MapA::Key> + Collection<Item = ValFunc::Output>,
312{
313    type Output = MapUnion<MapOut>;
314
315    fn call(&mut self, lat_a: MapUnion<MapA>, lat_b: MapUnion<MapB>) -> Self::Output {
316        let mut output = MapUnion::<MapOut>::default();
317        for (key, val_a) in lat_a.as_reveal_ref().iter() {
318            let key = <MapA as SimpleKeyedRef>::into_ref(key);
319            let Some((_key, val_b)) = lat_b.as_reveal_ref().get_key_value(key) else {
320                continue;
321            };
322            let val_a = <MapA as SimpleCollectionRef>::into_ref(val_a).clone();
323            let val_b = <MapB as SimpleCollectionRef>::into_ref(val_b).clone();
324
325            let val_out = LatticeBimorphism::call(&mut self.bimorphism, val_a, val_b);
326            <MapOut as MapInsert<_>>::insert(output.as_reveal_mut(), key.clone(), val_out);
327        }
328        output
329    }
330}
331
332#[cfg(test)]
333mod test {
334    use std::collections::HashSet;
335
336    use super::*;
337    use crate::collections::SingletonSet;
338    use crate::set_union::{CartesianProductBimorphism, SetUnionHashSet, SetUnionSingletonSet};
339    use crate::test::{cartesian_power, check_all, check_atomize_each, check_lattice_bimorphism};
340
341    #[test]
342    fn test_map_union() {
343        let mut my_map_a = <MapUnionHashMap<&str, SetUnionHashSet<u64>>>::default();
344        let my_map_b = <MapUnionSingletonMap<&str, SetUnionSingletonSet<u64>>>::new(SingletonMap(
345            "hello",
346            SetUnionSingletonSet::new(SingletonSet(100)),
347        ));
348        let my_map_c =
349            MapUnionSingletonMap::new_from(("hello", SetUnionHashSet::new_from([100, 200])));
350        my_map_a.merge(my_map_b);
351        my_map_a.merge(my_map_c);
352    }
353
354    #[cfg(feature = "alloc")]
355    #[test]
356    fn consistency_atomize() {
357        use alloc::vec;
358        use alloc::vec::Vec;
359
360        let mut test_vec = Vec::new();
361
362        // Size 0.
363        test_vec.push(MapUnionHashMap::default());
364        // Size 1.
365        for key in [0, 1] {
366            for value in [vec![], vec![0], vec![1], vec![0, 1]] {
367                test_vec.push(MapUnionHashMap::new(HashMap::from_iter([(
368                    key,
369                    SetUnionHashSet::new(HashSet::from_iter(value)),
370                )])));
371            }
372        }
373        // Size 2.
374        for [val_a, val_b] in cartesian_power(&[vec![], vec![0], vec![1], vec![0, 1]]) {
375            test_vec.push(MapUnionHashMap::new(HashMap::from_iter([
376                (0, SetUnionHashSet::new(HashSet::from_iter(val_a.clone()))),
377                (1, SetUnionHashSet::new(HashSet::from_iter(val_b.clone()))),
378            ])));
379        }
380
381        check_all(&test_vec);
382        check_atomize_each(&test_vec);
383    }
384
385    /// Check that a key with a value of bottom is the same as an empty map, etc.
386    #[test]
387    fn test_collapes_bot() {
388        let map_empty = <MapUnionHashMap<&str, SetUnionHashSet<u64>>>::default();
389        let map_a_bot = <MapUnionSingletonMap<&str, SetUnionHashSet<u64>>>::new(SingletonMap(
390            "a",
391            Default::default(),
392        ));
393        let map_b_bot = <MapUnionSingletonMap<&str, SetUnionHashSet<u64>>>::new(SingletonMap(
394            "b",
395            Default::default(),
396        ));
397
398        assert_eq!(map_empty, map_a_bot);
399        assert_eq!(map_empty, map_b_bot);
400        assert_eq!(map_a_bot, map_b_bot);
401    }
402
403    #[test]
404    fn test_join_aka_keyed_cartesian_product() {
405        let items_a = &[
406            MapUnionHashMap::new_from([("foo", SetUnionHashSet::new_from(["bar"]))]),
407            MapUnionHashMap::new_from([("foo", SetUnionHashSet::new_from(["baz"]))]),
408            MapUnionHashMap::new_from([("hello", SetUnionHashSet::new_from(["world"]))]),
409        ];
410        let items_b = &[
411            MapUnionHashMap::new_from([("foo", SetUnionHashSet::new_from(["bang"]))]),
412            MapUnionHashMap::new_from([(
413                "hello",
414                SetUnionHashSet::new_from(["goodbye", "farewell"]),
415            )]),
416        ];
417
418        check_lattice_bimorphism(
419            KeyedBimorphism::<HashMap<_, _>, _>::new(
420                CartesianProductBimorphism::<HashSet<_>>::default(),
421            ),
422            items_a,
423            items_a,
424        );
425        check_lattice_bimorphism(
426            KeyedBimorphism::<HashMap<_, _>, _>::new(
427                CartesianProductBimorphism::<HashSet<_>>::default(),
428            ),
429            items_a,
430            items_b,
431        );
432        check_lattice_bimorphism(
433            KeyedBimorphism::<HashMap<_, _>, _>::new(
434                CartesianProductBimorphism::<HashSet<_>>::default(),
435            ),
436            items_b,
437            items_a,
438        );
439        check_lattice_bimorphism(
440            KeyedBimorphism::<HashMap<_, _>, _>::new(
441                CartesianProductBimorphism::<HashSet<_>>::default(),
442            ),
443            items_b,
444            items_b,
445        );
446
447        check_lattice_bimorphism(
448            KeyedBimorphism::<BTreeMap<_, _>, _>::new(
449                CartesianProductBimorphism::<HashSet<_>>::default(),
450            ),
451            items_a,
452            items_a,
453        );
454        check_lattice_bimorphism(
455            KeyedBimorphism::<BTreeMap<_, _>, _>::new(
456                CartesianProductBimorphism::<HashSet<_>>::default(),
457            ),
458            items_a,
459            items_b,
460        );
461        check_lattice_bimorphism(
462            KeyedBimorphism::<BTreeMap<_, _>, _>::new(
463                CartesianProductBimorphism::<HashSet<_>>::default(),
464            ),
465            items_b,
466            items_a,
467        );
468        check_lattice_bimorphism(
469            KeyedBimorphism::<BTreeMap<_, _>, _>::new(
470                CartesianProductBimorphism::<HashSet<_>>::default(),
471            ),
472            items_b,
473            items_b,
474        );
475    }
476}