dfir_lang/
union_find.rs

1//! Union-find data structure, see [`UnionFind`].
2
3use slotmap::{Key, SecondaryMap};
4
5/// Union-find data structure.
6///
7/// Used to efficiently track sets of equivalent items.
8///
9/// <https://en.wikipedia.org/wiki/Disjoint-set_data_structure>
10#[derive(Default, Clone)]
11pub struct UnionFind<K>
12where
13    K: Key,
14{
15    links: SecondaryMap<K, K>,
16}
17impl<K> UnionFind<K>
18where
19    K: Key,
20{
21    /// Creates a new `UnionFind`, same as [`Default::default()`].
22    pub fn new() -> Self {
23        Self::default()
24    }
25    /// Creates a new `UnionFind` with the given key capacity pre-allocated.
26    pub fn with_capacity(capacity: usize) -> Self {
27        Self {
28            links: SecondaryMap::with_capacity(capacity),
29        }
30    }
31
32    /// Combines two items `a` and `b` as equivalent, in the same set.
33    pub fn union(&mut self, a: K, b: K) {
34        let i = self.find(a);
35        let j = self.find(b);
36        if i == j {
37            return;
38        }
39        self.links[i] = j;
40    }
41
42    /// Finds the "representative" item for `k`. Each set of equivalent items is represented by one
43    /// of its member items.
44    pub fn find(&mut self, k: K) -> K {
45        if let Some(next) = self.links.insert(k, k) {
46            if k == next {
47                return k;
48            }
49            self.links[k] = self.find(next);
50        }
51        self.links[k]
52    }
53
54    /// Returns if `a` and `b` are equivalent, i.e. in the same set.
55    pub fn same_set(&mut self, a: K, b: K) -> bool {
56        self.find(a) == self.find(b)
57    }
58}
59
60#[cfg(test)]
61mod test {
62    use slotmap::SlotMap;
63
64    use super::*;
65
66    #[test]
67    fn test_basic() {
68        let mut sm = SlotMap::new();
69        let a = sm.insert(());
70        let b = sm.insert(());
71        let c = sm.insert(());
72        let d = sm.insert(());
73
74        let mut uf = UnionFind::new();
75        assert!(!uf.same_set(a, b));
76        uf.union(a, b);
77        assert!(uf.same_set(a, b));
78        uf.union(c, a);
79        assert!(uf.same_set(b, c));
80
81        assert!(!uf.same_set(a, d));
82        assert!(!uf.same_set(b, d));
83        assert!(!uf.same_set(d, c));
84    }
85}