1use std::boxed::Box;
4use std::cell::Cell;
5use std::cmp::Ordering::{self, *};
6use std::collections::{BTreeMap, HashMap};
7use std::fmt::Debug;
8
9use crate::cc_traits::{Keyed, Map, MapIter, MapMut};
10use crate::collections::{ArrayMap, OptionMap, SingletonMap, VecMap};
11use crate::{Atomize, DeepReveal, IsBot, IsTop, LatticeFrom, LatticeOrd, Max, Merge, Min};
12
13#[repr(transparent)]
31#[derive(Copy, Clone, Debug, Default)]
32#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
33pub struct UnionFind<Map>(Map);
34impl<Map> UnionFind<Map> {
35 pub fn new(val: Map) -> Self {
37 Self(val)
38 }
39
40 pub fn new_from(val: impl Into<Map>) -> Self {
42 Self::new(val.into())
43 }
44
45 pub fn as_reveal_ref(&self) -> &Map {
47 &self.0
48 }
49
50 pub fn as_reveal_mut(&mut self) -> &mut Map {
52 &mut self.0
53 }
54
55 pub fn into_reveal(self) -> Map {
57 self.0
58 }
59}
60
61impl<Map> DeepReveal for UnionFind<Map> {
62 type Revealed = Map;
63
64 fn deep_reveal(self) -> Self::Revealed {
65 self.0
66 }
67}
68
69impl<Map, K> UnionFind<Map>
70where
71 Map: MapMut<K, Cell<K>, Key = K, Item = Cell<K>>,
72 K: Copy + Eq,
73{
74 pub fn union(&mut self, a: K, b: K) -> Min<bool> {
80 let a_root = self.find(a);
81 let b_root = self.find(b);
82 if a_root == b_root {
83 Min::new(false)
84 } else {
85 self.0.insert(b_root, Cell::new(a_root));
86 Min::new(true)
87 }
88 }
89}
90
91impl<MapSelf, K> UnionFind<MapSelf>
92where
93 MapSelf: Map<K, Cell<K>, Key = K, Item = Cell<K>>,
94 K: Copy + Eq,
95{
96 pub fn same(&self, a: K, b: K) -> Max<bool> {
101 Max::new(a == b || self.find(a) == self.find(b))
102 }
103
104 fn find(&self, mut item: K) -> K {
106 let mut root = item;
107 while let Some(parent) = self.0.get(&root) {
108 if parent.get() == root {
110 break;
111 }
112 if parent.get() == item {
114 parent.set(root);
115 break;
116 }
117 root = parent.get();
118 }
119 while item != root {
120 item = self.0.get(&item).unwrap().replace(root);
121 }
122 item
123 }
124}
125
126impl<MapSelf, MapOther, K> Merge<UnionFind<MapOther>> for UnionFind<MapSelf>
127where
128 MapSelf: MapMut<K, Cell<K>, Key = K, Item = Cell<K>>,
129 MapOther: IntoIterator<Item = (K, Cell<K>)>,
130 K: Copy + Eq,
131{
132 fn merge(&mut self, other: UnionFind<MapOther>) -> bool {
133 let mut changed = false;
134 for (item, parent) in other.0 {
135 changed |= self.union(item, parent.get()).into_reveal();
137 }
138 changed
139 }
140}
141
142impl<MapSelf, MapOther, K> LatticeFrom<UnionFind<MapOther>> for UnionFind<MapSelf>
143where
144 MapSelf: Keyed<Key = K, Item = Cell<K>> + FromIterator<(K, Cell<K>)>,
145 MapOther: IntoIterator<Item = (K, Cell<K>)>,
146 K: Copy + Eq,
147{
148 fn lattice_from(other: UnionFind<MapOther>) -> Self {
149 Self(other.0.into_iter().collect())
150 }
151}
152
153impl<MapSelf, MapOther, K> PartialOrd<UnionFind<MapOther>> for UnionFind<MapSelf>
154where
155 MapSelf: MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + MapIter,
156 MapOther: MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + MapIter,
157 K: Copy + Eq,
158{
159 fn partial_cmp(&self, other: &UnionFind<MapOther>) -> Option<Ordering> {
160 let self_any_greater = self
161 .0
162 .iter()
163 .any(|(item, parent)| !other.same(*item, parent.get()).into_reveal());
164 let other_any_greater = other
165 .0
166 .iter()
167 .any(|(item, parent)| !self.same(*item, parent.get()).into_reveal());
168 match (self_any_greater, other_any_greater) {
169 (true, true) => None,
170 (true, false) => Some(Greater),
171 (false, true) => Some(Less),
172 (false, false) => Some(Equal),
173 }
174 }
175}
176impl<MapSelf, MapOther> LatticeOrd<UnionFind<MapOther>> for UnionFind<MapSelf> where
177 Self: PartialOrd<UnionFind<MapOther>>
178{
179}
180
181impl<MapSelf, MapOther, K> PartialEq<UnionFind<MapOther>> for UnionFind<MapSelf>
182where
183 MapSelf: MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + MapIter,
184 MapOther: MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + MapIter,
185 K: Copy + Eq,
186{
187 fn eq(&self, other: &UnionFind<MapOther>) -> bool {
188 !(self
189 .0
190 .iter()
191 .any(|(item, parent)| !other.same(*item, parent.get()).into_reveal())
192 || other
193 .0
194 .iter()
195 .any(|(item, parent)| !self.same(*item, parent.get()).into_reveal()))
196 }
197}
198impl<Map> Eq for UnionFind<Map> where Self: PartialEq {}
199
200impl<Map, K> IsBot for UnionFind<Map>
201where
202 Map: MapIter<Key = K, Item = Cell<K>>,
203 K: Copy + Eq,
204{
205 fn is_bot(&self) -> bool {
206 self.0.iter().all(|(a, b)| *a == b.get())
207 }
208}
209
210impl<Map> IsTop for UnionFind<Map> {
211 fn is_top(&self) -> bool {
212 false
213 }
214}
215
216impl<Map, K> Atomize for UnionFind<Map>
217where
218 Map: 'static + MapMut<K, Cell<K>, Key = K, Item = Cell<K>> + IntoIterator<Item = (K, Cell<K>)>,
219 K: 'static + Copy + Eq,
220{
221 type Atom = UnionFindSingletonMap<K>;
222
223 type AtomIter = Box<dyn Iterator<Item = Self::Atom>>;
225
226 fn atomize(self) -> Self::AtomIter {
227 Box::new(
228 self.0
229 .into_iter()
230 .filter(|(a, b)| *a != b.get())
231 .map(UnionFindSingletonMap::new_from),
232 )
233 }
234}
235
236pub type UnionFindHashMap<K> = UnionFind<HashMap<K, Cell<K>>>;
238
239pub type UnionFindBTreeMap<K> = UnionFind<BTreeMap<K, Cell<K>>>;
241
242pub type UnionFindVec<K> = UnionFind<VecMap<K, Cell<K>>>;
244
245pub type UnionFindArrayMap<K, const N: usize> = UnionFind<ArrayMap<K, Cell<K>, N>>;
247
248pub type UnionFindSingletonMap<K> = UnionFind<SingletonMap<K, Cell<K>>>;
250
251pub type UnionFindOptionMap<K> = UnionFind<OptionMap<K, Cell<K>>>;
253
254#[cfg(test)]
255mod test {
256 use std::println;
257
258 use super::*;
259 use crate::test::{check_all, check_atomize_each};
260
261 #[test]
262 fn test_basic() {
263 let mut my_uf_a = <UnionFindHashMap<char>>::default();
264 let my_uf_b = <UnionFindSingletonMap<char>>::new(SingletonMap('c', Cell::new('a')));
265 let my_uf_c = UnionFindSingletonMap::new_from(('c', Cell::new('b')));
266
267 assert!(!my_uf_a.same('c', 'a').into_reveal());
268 assert!(!my_uf_a.same('c', 'b').into_reveal());
269 assert!(!my_uf_a.same('a', 'b').into_reveal());
270 assert!(!my_uf_a.same('a', 'z').into_reveal());
271 assert_eq!('z', my_uf_a.find('z'));
272
273 my_uf_a.merge(my_uf_b);
274
275 assert!(my_uf_a.same('c', 'a').into_reveal());
276 assert!(!my_uf_a.same('c', 'b').into_reveal());
277 assert!(!my_uf_a.same('a', 'b').into_reveal());
278 assert!(!my_uf_a.same('a', 'z').into_reveal());
279 assert_eq!('z', my_uf_a.find('z'));
280
281 my_uf_a.merge(my_uf_c);
282
283 assert!(my_uf_a.same('c', 'a').into_reveal());
284 assert!(my_uf_a.same('c', 'b').into_reveal());
285 assert!(my_uf_a.same('a', 'b').into_reveal());
286 assert!(!my_uf_a.same('a', 'z').into_reveal());
287 assert_eq!('z', my_uf_a.find('z'));
288 }
289
290 #[test]
292 fn test_malformed() {
293 {
294 let my_uf = <UnionFindBTreeMap<char>>::new_from([
295 ('a', Cell::new('b')),
296 ('b', Cell::new('c')),
297 ('c', Cell::new('a')),
298 ]);
299 println!("{:?}", my_uf);
300 assert!(my_uf.same('a', 'b').into_reveal());
301 println!("{:?}", my_uf);
302 }
303 {
304 let my_uf = <UnionFindBTreeMap<char>>::new_from([
305 ('a', Cell::new('b')),
306 ('b', Cell::new('c')),
307 ('c', Cell::new('d')),
308 ('d', Cell::new('a')),
309 ]);
310 println!("{:?}", my_uf);
311 assert!(my_uf.same('a', 'b').into_reveal());
312 println!("{:?}", my_uf);
313 }
314 }
315
316 #[test]
317 fn consistency_atomize() {
318 let items = &[
319 <UnionFindHashMap<char>>::default(),
320 <UnionFindHashMap<_>>::new_from([('a', Cell::new('a'))]),
321 <UnionFindHashMap<_>>::new_from([('a', Cell::new('a')), ('b', Cell::new('a'))]),
322 <UnionFindHashMap<_>>::new_from([('b', Cell::new('a'))]),
323 <UnionFindHashMap<_>>::new_from([('b', Cell::new('a')), ('c', Cell::new('b'))]),
324 <UnionFindHashMap<_>>::new_from([('b', Cell::new('a')), ('c', Cell::new('b'))]),
325 <UnionFindHashMap<_>>::new_from([('d', Cell::new('b'))]),
326 <UnionFindHashMap<_>>::new_from([
327 ('b', Cell::new('a')),
328 ('c', Cell::new('b')),
329 ('d', Cell::new('a')),
330 ]),
331 <UnionFindHashMap<_>>::new_from([
332 ('b', Cell::new('a')),
333 ('c', Cell::new('b')),
334 ('d', Cell::new('d')),
335 ]),
336 ];
337
338 check_all(items);
339 check_atomize_each(items);
340 }
341}