1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
//! Union-find data structure, see [`UnionFind`].

use slotmap::{Key, SecondaryMap};

/// Union-find data structure.
///
/// Used to efficiently track sets of equivalent items.
///
/// <https://en.wikipedia.org/wiki/Disjoint-set_data_structure>
#[derive(Default, Clone)]
pub struct UnionFind<K>
where
    K: Key,
{
    links: SecondaryMap<K, K>,
}
impl<K> UnionFind<K>
where
    K: Key,
{
    /// Creates a new `UnionFind`, same as [`Default::default()`].
    pub fn new() -> Self {
        Self::default()
    }
    /// Creates a new `UnionFind` with the given key capacity pre-allocated.
    pub fn with_capacity(capacity: usize) -> Self {
        Self {
            links: SecondaryMap::with_capacity(capacity),
        }
    }

    /// Combines two items `a` and `b` as equivalent, in the same set.
    pub fn union(&mut self, a: K, b: K) {
        let i = self.find(a);
        let j = self.find(b);
        if i == j {
            return;
        }
        self.links[i] = j;
    }

    /// Finds the "representative" item for `k`. Each set of equivalent items is represented by one
    /// of its member items.
    pub fn find(&mut self, k: K) -> K {
        if let Some(next) = self.links.insert(k, k) {
            if k == next {
                return k;
            }
            self.links[k] = self.find(next);
        }
        self.links[k]
    }

    /// Returns if `a` and `b` are equivalent, i.e. in the same set.
    pub fn same_set(&mut self, a: K, b: K) -> bool {
        self.find(a) == self.find(b)
    }
}

#[cfg(test)]
mod test {
    use slotmap::SlotMap;

    use super::*;

    #[test]
    fn test_basic() {
        let mut sm = SlotMap::new();
        let a = sm.insert(());
        let b = sm.insert(());
        let c = sm.insert(());
        let d = sm.insert(());

        let mut uf = UnionFind::new();
        assert!(!uf.same_set(a, b));
        uf.union(a, b);
        assert!(uf.same_set(a, b));
        uf.union(c, a);
        assert!(uf.same_set(b, c));

        assert!(!uf.same_set(a, d));
        assert!(!uf.same_set(b, d));
        assert!(!uf.same_set(d, c));
    }
}