lattices/
union_find.rs

1//! Module containing the [`UnionFind`] lattice and aliases for different datastructures.
2
3use std::cell::Cell;
4use std::cmp::Ordering::{self, *};
5use std::collections::{BTreeMap, HashMap};
6use std::fmt::Debug;
7
8use crate::cc_traits::{Keyed, Map, MapIter, MapMut};
9use crate::collections::{ArrayMap, OptionMap, SingletonMap, VecMap};
10use crate::{Atomize, DeepReveal, IsBot, IsTop, LatticeFrom, LatticeOrd, Max, Merge, Min};
11
12// TODO(mingwei): handling malformed trees - parents must be Ord smaller than children.
13
14/// Union-find lattice.
15///
16/// Each value of `K` in the map represents an item in a set. When two lattices instances are
17/// merged, any sets with common elements will be unioned together.
18///
19/// [`Self::union(a, b)`] unions two sets together, which is equivalent to merging in a
20/// `UnionFindSingletonMap` atom of `(a, b)` (or `(b, a)`).
21///
22/// Any union-find consisting only of singleton sets is bottom.
23///
24/// ## Hasse diagram of partitions of a set of size four:
25///
26/// <a href="https://en.wikipedia.org/wiki/File:Set_partitions_4;_Hasse;_circles.svg">
27///     <img src="https://upload.wikimedia.org/wikipedia/commons/3/32/Set_partitions_4%3B_Hasse%3B_circles.svg" width="500" />
28/// </a>
29#[repr(transparent)]
30#[derive(Copy, Clone, Debug, Default)]
31#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
32pub struct UnionFind<Map>(Map);
33impl<Map> UnionFind<Map> {
34    /// Create a new `UnionFind` from a `Map`.
35    pub fn new(val: Map) -> Self {
36        Self(val)
37    }
38
39    /// Create a new `UnionFind` from an `Into<Map>`.
40    pub fn new_from(val: impl Into<Map>) -> Self {
41        Self::new(val.into())
42    }
43
44    /// Reveal the inner value as a shared reference.
45    pub fn as_reveal_ref(&self) -> &Map {
46        &self.0
47    }
48
49    /// Reveal the inner value as an exclusive reference.
50    pub fn as_reveal_mut(&mut self) -> &mut Map {
51        &mut self.0
52    }
53
54    /// Gets the inner by value, consuming self.
55    pub fn into_reveal(self) -> Map {
56        self.0
57    }
58}
59
60impl<Map> DeepReveal for UnionFind<Map> {
61    type Revealed = Map;
62
63    fn deep_reveal(self) -> Self::Revealed {
64        self.0
65    }
66}
67
68impl<Map, K> UnionFind<Map>
69where
70    Map: MapMut<K, Cell<K>, Key = K, Item = Cell<K>>,
71    K: Copy + Eq,
72{
73    /// Union the sets containg `a` and `b`.
74    ///
75    /// Returns true if the sets changed, false if `a` and `b` were already in the same set. Once
76    /// this returns false it will always return false for the same `a` and `b`, therefore it
77    /// returns a `Min<bool>` lattice.
78    pub fn union(&mut self, a: K, b: K) -> Min<bool> {
79        let a_root = self.find(a);
80        let b_root = self.find(b);
81        if a_root == b_root {
82            Min::new(false)
83        } else {
84            self.0.insert(b_root, Cell::new(a_root));
85            Min::new(true)
86        }
87    }
88}
89
90impl<MapSelf, K> UnionFind<MapSelf>
91where
92    MapSelf: Map<K, Cell<K>, Key = K, Item = Cell<K>>,
93    K: Copy + Eq,
94{
95    /// Returns if `a` and `b` are in the same set.
96    ///
97    /// This method is monotonic: once this returns true it will always return true for the same
98    /// `a` and `b`, therefore it returns a `Max<bool>` lattice.
99    pub fn same(&self, a: K, b: K) -> Max<bool> {
100        Max::new(a == b || self.find(a) == self.find(b))
101    }
102
103    /// Finds the representative root node for `item`.
104    fn find(&self, mut item: K) -> K {
105        let mut root = item;
106        while let Some(parent) = self.0.get(&root) {
107            // If root is the representative.
108            if parent.get() == root {
109                break;
110            }
111            // Loop detected, close the end.
112            if parent.get() == item {
113                parent.set(root);
114                break;
115            }
116            root = parent.get();
117        }
118        while item != root {
119            item = self.0.get(&item).unwrap().replace(root);
120        }
121        item
122    }
123}
124
125impl<MapSelf, MapOther, K> Merge<UnionFind<MapOther>> for UnionFind<MapSelf>
126where
127    MapSelf: MapMut<K, Cell<K>, Key = K, Item = Cell<K>>,
128    MapOther: IntoIterator<Item = (K, Cell<K>)>,
129    K: Copy + Eq,
130{
131    fn merge(&mut self, other: UnionFind<MapOther>) -> bool {
132        let mut changed = false;
133        for (item, parent) in other.0 {
134            // Do not short circuit.
135            changed |= self.union(item, parent.get()).into_reveal();
136        }
137        changed
138    }
139}
140
141impl<MapSelf, MapOther, K> LatticeFrom<UnionFind<MapOther>> for UnionFind<MapSelf>
142where
143    MapSelf: Keyed<Key = K, Item = Cell<K>> + FromIterator<(K, Cell<K>)>,
144    MapOther: IntoIterator<Item = (K, Cell<K>)>,
145    K: Copy + Eq,
146{
147    fn lattice_from(other: UnionFind<MapOther>) -> Self {
148        Self(other.0.into_iter().collect())
149    }
150}
151
152impl<MapSelf, MapOther, K> PartialOrd<UnionFind<MapOther>> for UnionFind<MapSelf>
153where
154    MapSelf: MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + MapIter,
155    MapOther: MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + MapIter,
156    K: Copy + Eq,
157{
158    fn partial_cmp(&self, other: &UnionFind<MapOther>) -> Option<Ordering> {
159        let self_any_greater = self
160            .0
161            .iter()
162            .any(|(item, parent)| !other.same(*item, parent.get()).into_reveal());
163        let other_any_greater = other
164            .0
165            .iter()
166            .any(|(item, parent)| !self.same(*item, parent.get()).into_reveal());
167        match (self_any_greater, other_any_greater) {
168            (true, true) => None,
169            (true, false) => Some(Greater),
170            (false, true) => Some(Less),
171            (false, false) => Some(Equal),
172        }
173    }
174}
175impl<MapSelf, MapOther> LatticeOrd<UnionFind<MapOther>> for UnionFind<MapSelf> where
176    Self: PartialOrd<UnionFind<MapOther>>
177{
178}
179
180impl<MapSelf, MapOther, K> PartialEq<UnionFind<MapOther>> for UnionFind<MapSelf>
181where
182    MapSelf: MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + MapIter,
183    MapOther: MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + MapIter,
184    K: Copy + Eq,
185{
186    fn eq(&self, other: &UnionFind<MapOther>) -> bool {
187        !(self
188            .0
189            .iter()
190            .any(|(item, parent)| !other.same(*item, parent.get()).into_reveal())
191            || other
192                .0
193                .iter()
194                .any(|(item, parent)| !self.same(*item, parent.get()).into_reveal()))
195    }
196}
197impl<Map> Eq for UnionFind<Map> where Self: PartialEq {}
198
199impl<Map, K> IsBot for UnionFind<Map>
200where
201    Map: MapIter<Key = K, Item = Cell<K>>,
202    K: Copy + Eq,
203{
204    fn is_bot(&self) -> bool {
205        self.0.iter().all(|(a, b)| *a == b.get())
206    }
207}
208
209impl<Map> IsTop for UnionFind<Map> {
210    fn is_top(&self) -> bool {
211        false
212    }
213}
214
215impl<Map, K> Atomize for UnionFind<Map>
216where
217    Map: 'static + MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + IntoIterator<Item = (K, Cell<K>)>,
218    K: 'static + Copy + Eq,
219{
220    type Atom = UnionFindSingletonMap<K>;
221
222    // TODO: use impl trait, then remove 'static.
223    type AtomIter = Box<dyn Iterator<Item = Self::Atom>>;
224
225    fn atomize(self) -> Self::AtomIter {
226        Box::new(
227            self.0
228                .into_iter()
229                .filter(|(a, b)| *a != b.get())
230                .map(UnionFindSingletonMap::new_from),
231        )
232    }
233}
234
235/// [`std::collections::HashMap`]-backed [`UnionFind`] lattice.
236pub type UnionFindHashMap<K> = UnionFind<HashMap<K, Cell<K>>>;
237
238/// [`std::collections::BTreeMap`]-backed [`UnionFind`] lattice.
239pub type UnionFindBTreeMap<K> = UnionFind<BTreeMap<K, Cell<K>>>;
240
241/// [`Vec`]-backed [`UnionFind`] lattice.
242pub type UnionFindVec<K> = UnionFind<VecMap<K, Cell<K>>>;
243
244/// Array-backed [`UnionFind`] lattice.
245pub type UnionFindArrayMap<K, const N: usize> = UnionFind<ArrayMap<K, Cell<K>, N>>;
246
247/// [`crate::collections::SingletonMap`]-backed [`UnionFind`] lattice.
248pub type UnionFindSingletonMap<K> = UnionFind<SingletonMap<K, Cell<K>>>;
249
250/// [`Option`]-backed [`UnionFind`] lattice.
251pub type UnionFindOptionMap<K> = UnionFind<OptionMap<K, Cell<K>>>;
252
253#[cfg(test)]
254mod test {
255    use super::*;
256    use crate::test::{check_all, check_atomize_each};
257
258    #[test]
259    fn test_basic() {
260        let mut my_uf_a = <UnionFindHashMap<char>>::default();
261        let my_uf_b = <UnionFindSingletonMap<char>>::new(SingletonMap('c', Cell::new('a')));
262        let my_uf_c = UnionFindSingletonMap::new_from(('c', Cell::new('b')));
263
264        assert!(!my_uf_a.same('c', 'a').into_reveal());
265        assert!(!my_uf_a.same('c', 'b').into_reveal());
266        assert!(!my_uf_a.same('a', 'b').into_reveal());
267        assert!(!my_uf_a.same('a', 'z').into_reveal());
268        assert_eq!('z', my_uf_a.find('z'));
269
270        my_uf_a.merge(my_uf_b);
271
272        assert!(my_uf_a.same('c', 'a').into_reveal());
273        assert!(!my_uf_a.same('c', 'b').into_reveal());
274        assert!(!my_uf_a.same('a', 'b').into_reveal());
275        assert!(!my_uf_a.same('a', 'z').into_reveal());
276        assert_eq!('z', my_uf_a.find('z'));
277
278        my_uf_a.merge(my_uf_c);
279
280        assert!(my_uf_a.same('c', 'a').into_reveal());
281        assert!(my_uf_a.same('c', 'b').into_reveal());
282        assert!(my_uf_a.same('a', 'b').into_reveal());
283        assert!(!my_uf_a.same('a', 'z').into_reveal());
284        assert_eq!('z', my_uf_a.find('z'));
285    }
286
287    // Make sure loops are considered as one group and don't hang.
288    #[test]
289    fn test_malformed() {
290        {
291            let my_uf = <UnionFindBTreeMap<char>>::new_from([
292                ('a', Cell::new('b')),
293                ('b', Cell::new('c')),
294                ('c', Cell::new('a')),
295            ]);
296            println!("{:?}", my_uf);
297            assert!(my_uf.same('a', 'b').into_reveal());
298            println!("{:?}", my_uf);
299        }
300        {
301            let my_uf = <UnionFindBTreeMap<char>>::new_from([
302                ('a', Cell::new('b')),
303                ('b', Cell::new('c')),
304                ('c', Cell::new('d')),
305                ('d', Cell::new('a')),
306            ]);
307            println!("{:?}", my_uf);
308            assert!(my_uf.same('a', 'b').into_reveal());
309            println!("{:?}", my_uf);
310        }
311    }
312
313    #[test]
314    fn consistency_atomize() {
315        let items = &[
316            <UnionFindHashMap<char>>::default(),
317            <UnionFindHashMap<_>>::new_from([('a', Cell::new('a'))]),
318            <UnionFindHashMap<_>>::new_from([('a', Cell::new('a')), ('b', Cell::new('a'))]),
319            <UnionFindHashMap<_>>::new_from([('b', Cell::new('a'))]),
320            <UnionFindHashMap<_>>::new_from([('b', Cell::new('a')), ('c', Cell::new('b'))]),
321            <UnionFindHashMap<_>>::new_from([('b', Cell::new('a')), ('c', Cell::new('b'))]),
322            <UnionFindHashMap<_>>::new_from([('d', Cell::new('b'))]),
323            <UnionFindHashMap<_>>::new_from([
324                ('b', Cell::new('a')),
325                ('c', Cell::new('b')),
326                ('d', Cell::new('a')),
327            ]),
328            <UnionFindHashMap<_>>::new_from([
329                ('b', Cell::new('a')),
330                ('c', Cell::new('b')),
331                ('d', Cell::new('d')),
332            ]),
333        ];
334
335        check_all(items);
336        check_atomize_each(items);
337    }
338}