Skip to main content

lattices/
union_find.rs

1//! Module containing the [`UnionFind`] lattice and aliases for different datastructures.
2
3use std::boxed::Box;
4use std::cell::Cell;
5use std::cmp::Ordering::{self, *};
6use std::collections::{BTreeMap, HashMap};
7use std::fmt::Debug;
8
9use crate::cc_traits::{Keyed, Map, MapIter, MapMut};
10use crate::collections::{ArrayMap, OptionMap, SingletonMap, VecMap};
11use crate::{Atomize, DeepReveal, IsBot, IsTop, LatticeFrom, LatticeOrd, Max, Merge, Min};
12
13// TODO(mingwei): handling malformed trees - parents must be Ord smaller than children.
14
15/// Union-find lattice.
16///
17/// Each value of `K` in the map represents an item in a set. When two lattices instances are
18/// merged, any sets with common elements will be unioned together.
19///
20/// [`Self::union(a, b)`] unions two sets together, which is equivalent to merging in a
21/// `UnionFindSingletonMap` atom of `(a, b)` (or `(b, a)`).
22///
23/// Any union-find consisting only of singleton sets is bottom.
24///
25/// ## Hasse diagram of partitions of a set of size four:
26///
27/// <a href="https://en.wikipedia.org/wiki/File:Set_partitions_4;_Hasse;_circles.svg">
28///     <img src="https://upload.wikimedia.org/wikipedia/commons/3/32/Set_partitions_4%3B_Hasse%3B_circles.svg" width="500" />
29/// </a>
30#[repr(transparent)]
31#[derive(Copy, Clone, Debug, Default)]
32#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
33pub struct UnionFind<Map>(Map);
34impl<Map> UnionFind<Map> {
35    /// Create a new `UnionFind` from a `Map`.
36    pub fn new(val: Map) -> Self {
37        Self(val)
38    }
39
40    /// Create a new `UnionFind` from an `Into<Map>`.
41    pub fn new_from(val: impl Into<Map>) -> Self {
42        Self::new(val.into())
43    }
44
45    /// Reveal the inner value as a shared reference.
46    pub fn as_reveal_ref(&self) -> &Map {
47        &self.0
48    }
49
50    /// Reveal the inner value as an exclusive reference.
51    pub fn as_reveal_mut(&mut self) -> &mut Map {
52        &mut self.0
53    }
54
55    /// Gets the inner by value, consuming self.
56    pub fn into_reveal(self) -> Map {
57        self.0
58    }
59}
60
61impl<Map> DeepReveal for UnionFind<Map> {
62    type Revealed = Map;
63
64    fn deep_reveal(self) -> Self::Revealed {
65        self.0
66    }
67}
68
69impl<Map, K> UnionFind<Map>
70where
71    Map: MapMut<K, Cell<K>, Key = K, Item = Cell<K>>,
72    K: Copy + Eq,
73{
74    /// Union the sets containg `a` and `b`.
75    ///
76    /// Returns true if the sets changed, false if `a` and `b` were already in the same set. Once
77    /// this returns false it will always return false for the same `a` and `b`, therefore it
78    /// returns a `Min<bool>` lattice.
79    pub fn union(&mut self, a: K, b: K) -> Min<bool> {
80        let a_root = self.find(a);
81        let b_root = self.find(b);
82        if a_root == b_root {
83            Min::new(false)
84        } else {
85            self.0.insert(b_root, Cell::new(a_root));
86            Min::new(true)
87        }
88    }
89}
90
91impl<MapSelf, K> UnionFind<MapSelf>
92where
93    MapSelf: Map<K, Cell<K>, Key = K, Item = Cell<K>>,
94    K: Copy + Eq,
95{
96    /// Returns if `a` and `b` are in the same set.
97    ///
98    /// This method is monotonic: once this returns true it will always return true for the same
99    /// `a` and `b`, therefore it returns a `Max<bool>` lattice.
100    pub fn same(&self, a: K, b: K) -> Max<bool> {
101        Max::new(a == b || self.find(a) == self.find(b))
102    }
103
104    /// Finds the representative root node for `item`.
105    fn find(&self, mut item: K) -> K {
106        let mut root = item;
107        while let Some(parent) = self.0.get(&root) {
108            // If root is the representative.
109            if parent.get() == root {
110                break;
111            }
112            // Loop detected, close the end.
113            if parent.get() == item {
114                parent.set(root);
115                break;
116            }
117            root = parent.get();
118        }
119        while item != root {
120            item = self.0.get(&item).unwrap().replace(root);
121        }
122        item
123    }
124}
125
126impl<MapSelf, MapOther, K> Merge<UnionFind<MapOther>> for UnionFind<MapSelf>
127where
128    MapSelf: MapMut<K, Cell<K>, Key = K, Item = Cell<K>>,
129    MapOther: IntoIterator<Item = (K, Cell<K>)>,
130    K: Copy + Eq,
131{
132    fn merge(&mut self, other: UnionFind<MapOther>) -> bool {
133        let mut changed = false;
134        for (item, parent) in other.0 {
135            // Do not short circuit.
136            changed |= self.union(item, parent.get()).into_reveal();
137        }
138        changed
139    }
140}
141
142impl<MapSelf, MapOther, K> LatticeFrom<UnionFind<MapOther>> for UnionFind<MapSelf>
143where
144    MapSelf: Keyed<Key = K, Item = Cell<K>> + FromIterator<(K, Cell<K>)>,
145    MapOther: IntoIterator<Item = (K, Cell<K>)>,
146    K: Copy + Eq,
147{
148    fn lattice_from(other: UnionFind<MapOther>) -> Self {
149        Self(other.0.into_iter().collect())
150    }
151}
152
153impl<MapSelf, MapOther, K> PartialOrd<UnionFind<MapOther>> for UnionFind<MapSelf>
154where
155    MapSelf: MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + MapIter,
156    MapOther: MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + MapIter,
157    K: Copy + Eq,
158{
159    fn partial_cmp(&self, other: &UnionFind<MapOther>) -> Option<Ordering> {
160        let self_any_greater = self
161            .0
162            .iter()
163            .any(|(item, parent)| !other.same(*item, parent.get()).into_reveal());
164        let other_any_greater = other
165            .0
166            .iter()
167            .any(|(item, parent)| !self.same(*item, parent.get()).into_reveal());
168        match (self_any_greater, other_any_greater) {
169            (true, true) => None,
170            (true, false) => Some(Greater),
171            (false, true) => Some(Less),
172            (false, false) => Some(Equal),
173        }
174    }
175}
176impl<MapSelf, MapOther> LatticeOrd<UnionFind<MapOther>> for UnionFind<MapSelf> where
177    Self: PartialOrd<UnionFind<MapOther>>
178{
179}
180
181impl<MapSelf, MapOther, K> PartialEq<UnionFind<MapOther>> for UnionFind<MapSelf>
182where
183    MapSelf: MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + MapIter,
184    MapOther: MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + MapIter,
185    K: Copy + Eq,
186{
187    fn eq(&self, other: &UnionFind<MapOther>) -> bool {
188        !(self
189            .0
190            .iter()
191            .any(|(item, parent)| !other.same(*item, parent.get()).into_reveal())
192            || other
193                .0
194                .iter()
195                .any(|(item, parent)| !self.same(*item, parent.get()).into_reveal()))
196    }
197}
198impl<Map> Eq for UnionFind<Map> where Self: PartialEq {}
199
200impl<Map, K> IsBot for UnionFind<Map>
201where
202    Map: MapIter<Key = K, Item = Cell<K>>,
203    K: Copy + Eq,
204{
205    fn is_bot(&self) -> bool {
206        self.0.iter().all(|(a, b)| *a == b.get())
207    }
208}
209
210impl<Map> IsTop for UnionFind<Map> {
211    fn is_top(&self) -> bool {
212        false
213    }
214}
215
216impl<Map, K> Atomize for UnionFind<Map>
217where
218    Map: 'static + MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + IntoIterator<Item = (K, Cell<K>)>,
219    K: 'static + Copy + Eq,
220{
221    type Atom = UnionFindSingletonMap<K>;
222
223    // TODO: use impl trait, then remove 'static.
224    type AtomIter = Box<dyn Iterator<Item = Self::Atom>>;
225
226    fn atomize(self) -> Self::AtomIter {
227        Box::new(
228            self.0
229                .into_iter()
230                .filter(|(a, b)| *a != b.get())
231                .map(UnionFindSingletonMap::new_from),
232        )
233    }
234}
235
236/// [`std::collections::HashMap`]-backed [`UnionFind`] lattice.
237pub type UnionFindHashMap<K> = UnionFind<HashMap<K, Cell<K>>>;
238
239/// [`std::collections::BTreeMap`]-backed [`UnionFind`] lattice.
240pub type UnionFindBTreeMap<K> = UnionFind<BTreeMap<K, Cell<K>>>;
241
242/// [`Vec`](alloc::vec::Vec)-backed [`UnionFind`] lattice.
243pub type UnionFindVec<K> = UnionFind<VecMap<K, Cell<K>>>;
244
245/// Array-backed [`UnionFind`] lattice.
246pub type UnionFindArrayMap<K, const N: usize> = UnionFind<ArrayMap<K, Cell<K>, N>>;
247
248/// [`crate::collections::SingletonMap`]-backed [`UnionFind`] lattice.
249pub type UnionFindSingletonMap<K> = UnionFind<SingletonMap<K, Cell<K>>>;
250
251/// [`Option`]-backed [`UnionFind`] lattice.
252pub type UnionFindOptionMap<K> = UnionFind<OptionMap<K, Cell<K>>>;
253
254#[cfg(test)]
255mod test {
256    use std::println;
257
258    use super::*;
259    use crate::test::{check_all, check_atomize_each};
260
261    #[test]
262    fn test_basic() {
263        let mut my_uf_a = <UnionFindHashMap<char>>::default();
264        let my_uf_b = <UnionFindSingletonMap<char>>::new(SingletonMap('c', Cell::new('a')));
265        let my_uf_c = UnionFindSingletonMap::new_from(('c', Cell::new('b')));
266
267        assert!(!my_uf_a.same('c', 'a').into_reveal());
268        assert!(!my_uf_a.same('c', 'b').into_reveal());
269        assert!(!my_uf_a.same('a', 'b').into_reveal());
270        assert!(!my_uf_a.same('a', 'z').into_reveal());
271        assert_eq!('z', my_uf_a.find('z'));
272
273        my_uf_a.merge(my_uf_b);
274
275        assert!(my_uf_a.same('c', 'a').into_reveal());
276        assert!(!my_uf_a.same('c', 'b').into_reveal());
277        assert!(!my_uf_a.same('a', 'b').into_reveal());
278        assert!(!my_uf_a.same('a', 'z').into_reveal());
279        assert_eq!('z', my_uf_a.find('z'));
280
281        my_uf_a.merge(my_uf_c);
282
283        assert!(my_uf_a.same('c', 'a').into_reveal());
284        assert!(my_uf_a.same('c', 'b').into_reveal());
285        assert!(my_uf_a.same('a', 'b').into_reveal());
286        assert!(!my_uf_a.same('a', 'z').into_reveal());
287        assert_eq!('z', my_uf_a.find('z'));
288    }
289
290    // Make sure loops are considered as one group and don't hang.
291    #[test]
292    fn test_malformed() {
293        {
294            let my_uf = <UnionFindBTreeMap<char>>::new_from([
295                ('a', Cell::new('b')),
296                ('b', Cell::new('c')),
297                ('c', Cell::new('a')),
298            ]);
299            println!("{:?}", my_uf);
300            assert!(my_uf.same('a', 'b').into_reveal());
301            println!("{:?}", my_uf);
302        }
303        {
304            let my_uf = <UnionFindBTreeMap<char>>::new_from([
305                ('a', Cell::new('b')),
306                ('b', Cell::new('c')),
307                ('c', Cell::new('d')),
308                ('d', Cell::new('a')),
309            ]);
310            println!("{:?}", my_uf);
311            assert!(my_uf.same('a', 'b').into_reveal());
312            println!("{:?}", my_uf);
313        }
314    }
315
316    #[test]
317    fn consistency_atomize() {
318        let items = &[
319            <UnionFindHashMap<char>>::default(),
320            <UnionFindHashMap<_>>::new_from([('a', Cell::new('a'))]),
321            <UnionFindHashMap<_>>::new_from([('a', Cell::new('a')), ('b', Cell::new('a'))]),
322            <UnionFindHashMap<_>>::new_from([('b', Cell::new('a'))]),
323            <UnionFindHashMap<_>>::new_from([('b', Cell::new('a')), ('c', Cell::new('b'))]),
324            <UnionFindHashMap<_>>::new_from([('b', Cell::new('a')), ('c', Cell::new('b'))]),
325            <UnionFindHashMap<_>>::new_from([('d', Cell::new('b'))]),
326            <UnionFindHashMap<_>>::new_from([
327                ('b', Cell::new('a')),
328                ('c', Cell::new('b')),
329                ('d', Cell::new('a')),
330            ]),
331            <UnionFindHashMap<_>>::new_from([
332                ('b', Cell::new('a')),
333                ('c', Cell::new('b')),
334                ('d', Cell::new('d')),
335            ]),
336        ];
337
338        check_all(items);
339        check_atomize_each(items);
340    }
341}