Skip to main content

lattices/
map_union.rs

1//! Module containing the [`MapUnion`] lattice and aliases for different datastructures.
2
3#[cfg(feature = "alloc")]
4use alloc::collections::BTreeMap;
5use core::cmp::Ordering::{self, *};
6use core::fmt::Debug;
7use core::marker::PhantomData;
8#[cfg(feature = "std")]
9use std::collections::HashMap;
10
11use cc_traits::{Collection, GetKeyValue, Iter, MapInsert, SimpleCollectionRef};
12
13use crate::cc_traits::{GetMut, Keyed, Map, MapIter, SimpleKeyedRef};
14#[cfg(feature = "alloc")]
15use crate::collections::VecMap;
16use crate::collections::{ArrayMap, MapMapValues, OptionMap, SingletonMap};
17use crate::{Atomize, DeepReveal, IsBot, IsTop, LatticeBimorphism, LatticeFrom, LatticeOrd, Merge};
18
19/// Map-union compound lattice.
20///
21/// Each key corresponds to a lattice value instance. Merging map-union lattices is done by
22/// unioning the keys and merging the values of intersecting keys.
23#[repr(transparent)]
24#[derive(Copy, Clone, Debug, Default)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26pub struct MapUnion<Map>(Map);
27impl<Map> MapUnion<Map> {
28    /// Create a new `MapUnion` from a `Map`.
29    pub fn new(val: Map) -> Self {
30        Self(val)
31    }
32
33    /// Create a new `MapUnion` from an `Into<Map>`.
34    pub fn new_from(val: impl Into<Map>) -> Self {
35        Self::new(val.into())
36    }
37
38    /// Reveal the inner value as a shared reference.
39    pub fn as_reveal_ref(&self) -> &Map {
40        &self.0
41    }
42
43    /// Reveal the inner value as an exclusive reference.
44    pub fn as_reveal_mut(&mut self) -> &mut Map {
45        &mut self.0
46    }
47
48    /// Gets the inner by value, consuming self.
49    pub fn into_reveal(self) -> Map {
50        self.0
51    }
52}
53
54impl<Map, Val> DeepReveal for MapUnion<Map>
55where
56    Map: Keyed<Item = Val> + MapMapValues<Val>,
57    Val: DeepReveal,
58{
59    type Revealed = Map::MapValue<Val::Revealed>;
60
61    fn deep_reveal(self) -> Self::Revealed {
62        self.0.map_values(DeepReveal::deep_reveal)
63    }
64}
65
66#[cfg(feature = "alloc")]
67impl<MapSelf, MapOther, K, ValSelf, ValOther> Merge<MapUnion<MapOther>> for MapUnion<MapSelf>
68where
69    MapSelf: Keyed<Key = K, Item = ValSelf>
70        + Extend<(K, ValSelf)>
71        + for<'a> GetMut<&'a K, Item = ValSelf>,
72    MapOther: IntoIterator<Item = (K, ValOther)>,
73    ValSelf: Merge<ValOther> + LatticeFrom<ValOther>,
74    ValOther: IsBot,
75{
76    fn merge(&mut self, other: MapUnion<MapOther>) -> bool {
77        use alloc::vec::Vec;
78
79        let mut changed = false;
80        // This vec collect is needed to prevent simultaneous mut references `self.0.extend` and
81        // `self.0.get_mut`.
82        // TODO(mingwei): This could be fixed with a different structure, maybe some sort of
83        // `Collection` entry API.
84        let iter: Vec<_> = other
85            .0
86            .into_iter()
87            .filter(|(_k_other, val_other)| !val_other.is_bot())
88            .filter_map(|(k_other, val_other)| {
89                match self.0.get_mut(&k_other) {
90                    // Key collision, merge into `self`.
91                    Some(mut val_self) => {
92                        changed |= val_self.merge(val_other);
93                        None
94                    }
95                    // New value, convert for extending.
96                    None => {
97                        changed = true;
98                        Some((k_other, ValSelf::lattice_from(val_other)))
99                    }
100                }
101            })
102            .collect();
103        self.0.extend(iter);
104        changed
105    }
106}
107
108impl<MapSelf, MapOther, K, ValSelf, ValOther> LatticeFrom<MapUnion<MapOther>> for MapUnion<MapSelf>
109where
110    MapSelf: Keyed<Key = K, Item = ValSelf> + FromIterator<(K, ValSelf)>,
111    MapOther: IntoIterator<Item = (K, ValOther)>,
112    ValSelf: LatticeFrom<ValOther>,
113{
114    fn lattice_from(other: MapUnion<MapOther>) -> Self {
115        Self(
116            other
117                .0
118                .into_iter()
119                .map(|(k_other, val_other)| (k_other, LatticeFrom::lattice_from(val_other)))
120                .collect(),
121        )
122    }
123}
124
125impl<MapSelf, MapOther, K, ValSelf, ValOther> PartialOrd<MapUnion<MapOther>> for MapUnion<MapSelf>
126where
127    MapSelf: Map<K, ValSelf, Key = K, Item = ValSelf> + MapIter + SimpleKeyedRef,
128    MapOther: Map<K, ValOther, Key = K, Item = ValOther> + MapIter + SimpleKeyedRef,
129    ValSelf: PartialOrd<ValOther> + IsBot,
130    ValOther: IsBot,
131{
132    fn partial_cmp(&self, other: &MapUnion<MapOther>) -> Option<Ordering> {
133        let mut self_any_greater = false;
134        let mut other_any_greater = false;
135        let self_keys = self
136            .0
137            .iter()
138            .filter(|(_k, v)| !v.is_bot())
139            .map(|(k, _v)| <MapSelf as SimpleKeyedRef>::into_ref(k));
140        let other_keys = other
141            .0
142            .iter()
143            .filter(|(_k, v)| !v.is_bot())
144            .map(|(k, _v)| <MapOther as SimpleKeyedRef>::into_ref(k));
145        for k in self_keys.chain(other_keys) {
146            match (self.0.get(k), other.0.get(k)) {
147                (Some(self_value), Some(other_value)) => {
148                    match self_value.partial_cmp(&*other_value)? {
149                        Less => {
150                            other_any_greater = true;
151                        }
152                        Greater => {
153                            self_any_greater = true;
154                        }
155                        Equal => {}
156                    }
157                }
158                (Some(_), None) => {
159                    self_any_greater = true;
160                }
161                (None, Some(_)) => {
162                    other_any_greater = true;
163                }
164                (None, None) => unreachable!(),
165            }
166            if self_any_greater && other_any_greater {
167                return None;
168            }
169        }
170        match (self_any_greater, other_any_greater) {
171            (true, false) => Some(Greater),
172            (false, true) => Some(Less),
173            (false, false) => Some(Equal),
174            // We check this one after each loop iteration above.
175            (true, true) => unreachable!(),
176        }
177    }
178}
179impl<MapSelf, MapOther> LatticeOrd<MapUnion<MapOther>> for MapUnion<MapSelf> where
180    Self: PartialOrd<MapUnion<MapOther>>
181{
182}
183
184impl<MapSelf, MapOther, K, ValSelf, ValOther> PartialEq<MapUnion<MapOther>> for MapUnion<MapSelf>
185where
186    MapSelf: Map<K, ValSelf, Key = K, Item = ValSelf> + MapIter + SimpleKeyedRef,
187    MapOther: Map<K, ValOther, Key = K, Item = ValOther> + MapIter + SimpleKeyedRef,
188    ValSelf: PartialEq<ValOther> + IsBot,
189    ValOther: IsBot,
190{
191    fn eq(&self, other: &MapUnion<MapOther>) -> bool {
192        let self_keys = self
193            .0
194            .iter()
195            .filter(|(_k, v)| !v.is_bot())
196            .map(|(k, _v)| <MapSelf as SimpleKeyedRef>::into_ref(k));
197        let other_keys = other
198            .0
199            .iter()
200            .filter(|(_k, v)| !v.is_bot())
201            .map(|(k, _v)| <MapOther as SimpleKeyedRef>::into_ref(k));
202        for k in self_keys.chain(other_keys) {
203            match (self.0.get(k), other.0.get(k)) {
204                (Some(self_value), Some(other_value)) => {
205                    if *self_value != *other_value {
206                        return false;
207                    }
208                }
209                (None, None) => unreachable!(),
210                _ => {
211                    return false;
212                }
213            }
214        }
215
216        true
217    }
218}
219impl<MapSelf> Eq for MapUnion<MapSelf> where Self: PartialEq {}
220
221impl<Map> IsBot for MapUnion<Map>
222where
223    Map: Iter,
224    Map::Item: IsBot,
225{
226    fn is_bot(&self) -> bool {
227        self.0.iter().all(|v| v.is_bot())
228    }
229}
230
231impl<Map> IsTop for MapUnion<Map> {
232    fn is_top(&self) -> bool {
233        false
234    }
235}
236
237#[cfg(feature = "alloc")]
238impl<Map, K, Val> Atomize for MapUnion<Map>
239where
240    Map: 'static
241        + IntoIterator<Item = (K, Val)>
242        + Keyed<Key = K, Item = Val>
243        + Extend<(K, Val)>
244        + for<'a> GetMut<&'a K, Item = Val>,
245    K: 'static + Clone,
246    Val: 'static + Atomize + LatticeFrom<<Val as Atomize>::Atom>,
247{
248    type Atom = MapUnionSingletonMap<K, Val::Atom>;
249
250    // TODO: use impl trait, then remove 'static.
251    type AtomIter = alloc::boxed::Box<dyn Iterator<Item = Self::Atom>>;
252
253    fn atomize(self) -> Self::AtomIter {
254        alloc::boxed::Box::new(self.0.into_iter().flat_map(|(k, val)| {
255            val.atomize()
256                .map(move |v| MapUnionSingletonMap::new_from((k.clone(), v)))
257        }))
258    }
259}
260
261/// [`std::collections::HashMap`]-backed [`MapUnion`] lattice.
262#[cfg(feature = "std")]
263pub type MapUnionHashMap<K, Val> = MapUnion<HashMap<K, Val>>;
264
265/// [`std::collections::BTreeMap`]-backed [`MapUnion`] lattice.
266#[cfg(feature = "alloc")]
267pub type MapUnionBTreeMap<K, Val> = MapUnion<BTreeMap<K, Val>>;
268
269/// [`Vec`](alloc::vec::Vec)-backed [`MapUnion`] lattice.
270#[cfg(feature = "alloc")]
271pub type MapUnionVec<K, Val> = MapUnion<VecMap<K, Val>>;
272
273/// Array-backed [`MapUnion`] lattice.
274pub type MapUnionArrayMap<K, Val, const N: usize> = MapUnion<ArrayMap<K, Val, N>>;
275
276/// [`crate::collections::SingletonMap`]-backed [`MapUnion`] lattice.
277pub type MapUnionSingletonMap<K, Val> = MapUnion<SingletonMap<K, Val>>;
278
279/// [`Option`]-backed [`MapUnion`] lattice.
280pub type MapUnionOptionMap<K, Val> = MapUnion<OptionMap<K, Val>>;
281
282/// Composable bimorphism, wraps an existing morphism by partitioning it per key.
283///
284/// For example, `KeyedBimorphism<..., CartesianProduct<...>>` is a join.
285pub struct KeyedBimorphism<MapOut, Bimorphism> {
286    bimorphism: Bimorphism,
287    _phantom: PhantomData<fn() -> MapOut>,
288}
289impl<MapOut, Bimorphism> KeyedBimorphism<MapOut, Bimorphism> {
290    /// Create a `KeyedBimorphism` using `bimorphism` for handling values.
291    pub fn new(bimorphism: Bimorphism) -> Self {
292        Self {
293            bimorphism,
294            _phantom: PhantomData,
295        }
296    }
297}
298impl<MapA, MapB, MapOut, ValFunc> LatticeBimorphism<MapUnion<MapA>, MapUnion<MapB>>
299    for KeyedBimorphism<MapOut, ValFunc>
300where
301    ValFunc: LatticeBimorphism<MapA::Item, MapB::Item>,
302    MapA: MapIter + SimpleKeyedRef + SimpleCollectionRef,
303    MapB: for<'a> GetKeyValue<&'a MapA::Key, Key = MapA::Key> + SimpleCollectionRef,
304    MapA::Key: Clone + Eq,
305    MapA::Item: Clone,
306    MapB::Item: Clone,
307    MapOut: Default + MapInsert<MapA::Key> + Collection<Item = ValFunc::Output>,
308{
309    type Output = MapUnion<MapOut>;
310
311    fn call(&mut self, lat_a: MapUnion<MapA>, lat_b: MapUnion<MapB>) -> Self::Output {
312        let mut output = MapUnion::<MapOut>::default();
313        for (key, val_a) in lat_a.as_reveal_ref().iter() {
314            let key = <MapA as SimpleKeyedRef>::into_ref(key);
315            let Some((_key, val_b)) = lat_b.as_reveal_ref().get_key_value(key) else {
316                continue;
317            };
318            let val_a = <MapA as SimpleCollectionRef>::into_ref(val_a).clone();
319            let val_b = <MapB as SimpleCollectionRef>::into_ref(val_b).clone();
320
321            let val_out = LatticeBimorphism::call(&mut self.bimorphism, val_a, val_b);
322            <MapOut as MapInsert<_>>::insert(output.as_reveal_mut(), key.clone(), val_out);
323        }
324        output
325    }
326}
327
328#[cfg(test)]
329mod test {
330    use std::collections::HashSet;
331
332    use super::*;
333    use crate::collections::SingletonSet;
334    use crate::set_union::{CartesianProductBimorphism, SetUnionHashSet, SetUnionSingletonSet};
335    use crate::test::{cartesian_power, check_all, check_atomize_each, check_lattice_bimorphism};
336
337    #[test]
338    fn test_map_union() {
339        let mut my_map_a = <MapUnionHashMap<&str, SetUnionHashSet<u64>>>::default();
340        let my_map_b = <MapUnionSingletonMap<&str, SetUnionSingletonSet<u64>>>::new(SingletonMap(
341            "hello",
342            SetUnionSingletonSet::new(SingletonSet(100)),
343        ));
344        let my_map_c =
345            MapUnionSingletonMap::new_from(("hello", SetUnionHashSet::new_from([100, 200])));
346        my_map_a.merge(my_map_b);
347        my_map_a.merge(my_map_c);
348    }
349
350    #[cfg(feature = "alloc")]
351    #[test]
352    fn consistency_atomize() {
353        use alloc::vec;
354        use alloc::vec::Vec;
355
356        let mut test_vec = Vec::new();
357
358        // Size 0.
359        test_vec.push(MapUnionHashMap::default());
360        // Size 1.
361        for key in [0, 1] {
362            for value in [vec![], vec![0], vec![1], vec![0, 1]] {
363                test_vec.push(MapUnionHashMap::new(HashMap::from_iter([(
364                    key,
365                    SetUnionHashSet::new(HashSet::from_iter(value)),
366                )])));
367            }
368        }
369        // Size 2.
370        for [val_a, val_b] in cartesian_power(&[vec![], vec![0], vec![1], vec![0, 1]]) {
371            test_vec.push(MapUnionHashMap::new(HashMap::from_iter([
372                (0, SetUnionHashSet::new(HashSet::from_iter(val_a.clone()))),
373                (1, SetUnionHashSet::new(HashSet::from_iter(val_b.clone()))),
374            ])));
375        }
376
377        check_all(&test_vec);
378        check_atomize_each(&test_vec);
379    }
380
381    /// Check that a key with a value of bottom is the same as an empty map, etc.
382    #[test]
383    fn test_collapes_bot() {
384        let map_empty = <MapUnionHashMap<&str, SetUnionHashSet<u64>>>::default();
385        let map_a_bot = <MapUnionSingletonMap<&str, SetUnionHashSet<u64>>>::new(SingletonMap(
386            "a",
387            Default::default(),
388        ));
389        let map_b_bot = <MapUnionSingletonMap<&str, SetUnionHashSet<u64>>>::new(SingletonMap(
390            "b",
391            Default::default(),
392        ));
393
394        assert_eq!(map_empty, map_a_bot);
395        assert_eq!(map_empty, map_b_bot);
396        assert_eq!(map_a_bot, map_b_bot);
397    }
398
399    #[test]
400    fn test_join_aka_keyed_cartesian_product() {
401        let items_a = &[
402            MapUnionHashMap::new_from([("foo", SetUnionHashSet::new_from(["bar"]))]),
403            MapUnionHashMap::new_from([("foo", SetUnionHashSet::new_from(["baz"]))]),
404            MapUnionHashMap::new_from([("hello", SetUnionHashSet::new_from(["world"]))]),
405        ];
406        let items_b = &[
407            MapUnionHashMap::new_from([("foo", SetUnionHashSet::new_from(["bang"]))]),
408            MapUnionHashMap::new_from([(
409                "hello",
410                SetUnionHashSet::new_from(["goodbye", "farewell"]),
411            )]),
412        ];
413
414        check_lattice_bimorphism(
415            KeyedBimorphism::<HashMap<_, _>, _>::new(
416                CartesianProductBimorphism::<HashSet<_>>::default(),
417            ),
418            items_a,
419            items_a,
420        );
421        check_lattice_bimorphism(
422            KeyedBimorphism::<HashMap<_, _>, _>::new(
423                CartesianProductBimorphism::<HashSet<_>>::default(),
424            ),
425            items_a,
426            items_b,
427        );
428        check_lattice_bimorphism(
429            KeyedBimorphism::<HashMap<_, _>, _>::new(
430                CartesianProductBimorphism::<HashSet<_>>::default(),
431            ),
432            items_b,
433            items_a,
434        );
435        check_lattice_bimorphism(
436            KeyedBimorphism::<HashMap<_, _>, _>::new(
437                CartesianProductBimorphism::<HashSet<_>>::default(),
438            ),
439            items_b,
440            items_b,
441        );
442
443        check_lattice_bimorphism(
444            KeyedBimorphism::<BTreeMap<_, _>, _>::new(
445                CartesianProductBimorphism::<HashSet<_>>::default(),
446            ),
447            items_a,
448            items_a,
449        );
450        check_lattice_bimorphism(
451            KeyedBimorphism::<BTreeMap<_, _>, _>::new(
452                CartesianProductBimorphism::<HashSet<_>>::default(),
453            ),
454            items_a,
455            items_b,
456        );
457        check_lattice_bimorphism(
458            KeyedBimorphism::<BTreeMap<_, _>, _>::new(
459                CartesianProductBimorphism::<HashSet<_>>::default(),
460            ),
461            items_b,
462            items_a,
463        );
464        check_lattice_bimorphism(
465            KeyedBimorphism::<BTreeMap<_, _>, _>::new(
466                CartesianProductBimorphism::<HashSet<_>>::default(),
467            ),
468            items_b,
469            items_b,
470        );
471    }
472}