1use std::collections::btree_map::Entry;
4use std::collections::{BTreeMap, BTreeSet};
5
6pub fn topo_sort_scc<Id, NodesFn, NodeIds, PredsFn, SuccsFn, PredsIter, SuccsIter>(
9 mut nodes_fn: NodesFn,
10 mut preds_fn: PredsFn,
11 succs_fn: SuccsFn,
12) -> Vec<Id>
13where
14 Id: Copy + Eq + Ord,
15 NodesFn: FnMut() -> NodeIds,
16 NodeIds: IntoIterator<Item = Id>,
17 PredsFn: FnMut(Id) -> PredsIter,
18 SuccsFn: FnMut(Id) -> SuccsIter,
19 PredsIter: IntoIterator<Item = Id>,
20 SuccsIter: IntoIterator<Item = Id>,
21{
22 let scc = scc_kosaraju((nodes_fn)(), &mut preds_fn, succs_fn);
23 let topo_sort_order = {
24 let mut condensed_preds: BTreeMap<Id, Vec<Id>> = Default::default();
26 for v in (nodes_fn)() {
27 let v = scc[&v];
28 condensed_preds.entry(v).or_default().extend(
29 (preds_fn)(v)
30 .into_iter()
31 .map(|u| scc[&u])
32 .filter(|&u| v != u),
33 );
34 }
35
36 topo_sort((nodes_fn)(), |v| {
37 condensed_preds.get(&v).into_iter().flatten().cloned()
38 })
39 .map_err(drop)
40 .expect("No cycles after SCC condensing.")
41 };
42 topo_sort_order
43}
44
45pub fn topo_sort<Id, NodeIds, PredsFn, PredsIter>(
55 node_ids: NodeIds,
56 mut preds_fn: PredsFn,
57) -> Result<Vec<Id>, Vec<Id>>
58where
59 Id: Copy + Eq + Ord,
60 NodeIds: IntoIterator<Item = Id>,
61 PredsFn: FnMut(Id) -> PredsIter,
62 PredsIter: IntoIterator<Item = Id>,
63{
64 let (mut marked, mut order) = Default::default();
65
66 fn pred_dfs_postorder<Id, PredsFn, PredsIter>(
67 node_id: Id,
68 preds_fn: &mut PredsFn,
69 marked: &mut BTreeMap<Id, bool>, order: &mut Vec<Id>,
71 ) -> Result<(), ()>
72 where
73 Id: Copy + Eq + Ord,
74 PredsFn: FnMut(Id) -> PredsIter,
75 PredsIter: IntoIterator<Item = Id>,
76 {
77 match marked.get(&node_id) {
78 Some(_permanent @ true) => Ok(()),
79 Some(_temporary @ false) => {
80 order.clear();
82 order.push(node_id);
83 Err(())
84 }
85 None => {
86 marked.insert(node_id, false);
87 for next_pred in (preds_fn)(node_id) {
88 pred_dfs_postorder(next_pred, preds_fn, marked, order).map_err(|()| {
89 if order.len() == 1 || order.first().unwrap() != order.last().unwrap() {
90 order.push(node_id);
91 }
92 })?;
93 }
94 order.push(node_id);
95 marked.insert(node_id, true);
96 Ok(())
97 }
98 }
99 }
100
101 for node_id in node_ids {
102 if pred_dfs_postorder(node_id, &mut preds_fn, &mut marked, &mut order).is_err() {
103 let end = order.last().unwrap();
105 let beg = order.iter().position(|n| n == end).unwrap();
106 order.drain(0..=beg);
107 return Err(order);
108 }
109 }
110
111 Ok(order)
112}
113
114pub fn scc_kosaraju<Id, NodeIds, PredsFn, SuccsFn, PredsIter, SuccsIter>(
125 nodes: NodeIds,
126 mut preds_fn: PredsFn,
127 mut succs_fn: SuccsFn,
128) -> BTreeMap<Id, Id>
129where
130 Id: Copy + Eq + Ord,
131 NodeIds: IntoIterator<Item = Id>,
132 PredsFn: FnMut(Id) -> PredsIter,
133 SuccsFn: FnMut(Id) -> SuccsIter,
134 PredsIter: IntoIterator<Item = Id>,
135 SuccsIter: IntoIterator<Item = Id>,
136{
137 fn visit<Id, SuccsFn, SuccsIter>(
139 succs_fn: &mut SuccsFn,
140 u: Id,
141 seen: &mut BTreeSet<Id>,
142 stack: &mut Vec<Id>,
143 ) where
144 Id: Copy + Eq + Ord,
145 SuccsFn: FnMut(Id) -> SuccsIter,
146 SuccsIter: IntoIterator<Item = Id>,
147 {
148 if seen.insert(u) {
149 for v in (succs_fn)(u) {
150 visit(succs_fn, v, seen, stack);
151 }
152 stack.push(u);
153 }
154 }
155 let (mut seen, mut stack) = Default::default();
156 for sg in nodes {
157 visit(&mut succs_fn, sg, &mut seen, &mut stack);
158 }
159 let _ = seen;
160
161 fn assign<Id, PredsFn, PredsIter>(
162 preds_fn: &mut PredsFn,
163 v: Id,
164 root: Id,
165 components: &mut BTreeMap<Id, Id>,
166 ) where
167 Id: Copy + Eq + Ord,
168 PredsFn: FnMut(Id) -> PredsIter,
169 PredsIter: IntoIterator<Item = Id>,
170 {
171 if let Entry::Vacant(vacant_entry) = components.entry(v) {
172 vacant_entry.insert(root);
173 for u in (preds_fn)(v) {
174 assign(preds_fn, u, root, components);
175 }
176 }
177 }
178
179 let mut components = Default::default();
180 for sg in stack.into_iter().rev() {
181 assign(&mut preds_fn, sg, sg, &mut components);
182 }
183 components
184}
185
186#[cfg(test)]
187mod test {
188 use itertools::Itertools;
189
190 use super::*;
191
192 #[test]
193 pub fn test_toposort() {
194 let edges = [
195 (5, 11),
196 (11, 2),
197 (11, 9),
198 (11, 10),
199 (7, 11),
200 (7, 8),
201 (8, 9),
202 (3, 8),
203 (3, 10),
204 ];
205
206 let sort = topo_sort([2, 3, 5, 7, 8, 9, 10, 11], |v| {
208 edges
209 .iter()
210 .filter(move |&&(_, dst)| v == dst)
211 .map(|&(src, _)| src)
212 });
213 assert!(
214 sort.is_ok(),
215 "Did not expect cycle: {:?}",
216 sort.unwrap_err()
217 );
218
219 let sort = sort.unwrap();
220 println!("{:?}", sort);
221
222 let position: BTreeMap<_, _> = sort.iter().enumerate().map(|(i, &x)| (x, i)).collect();
223 for (src, dst) in edges.iter() {
224 assert!(position[src] < position[dst]);
225 }
226 }
227
228 #[test]
229 pub fn test_toposort_cycle() {
230 let edges = [
239 ('A', 'B'),
240 ('B', 'C'),
241 ('C', 'E'),
242 ('D', 'B'),
243 ('E', 'F'),
244 ('E', 'D'),
245 ];
246 let ids = edges
247 .iter()
248 .flat_map(|&(a, b)| [a, b])
249 .collect::<BTreeSet<_>>();
250 let cycle_rotations = BTreeSet::from_iter([
251 ['B', 'C', 'E', 'D'],
252 ['C', 'E', 'D', 'B'],
253 ['E', 'D', 'B', 'C'],
254 ['D', 'B', 'C', 'E'],
255 ]);
256
257 let permutations = ids.iter().copied().permutations(ids.len());
258 for permutation in permutations {
259 let result = topo_sort(permutation.iter().copied(), |v| {
260 edges
261 .iter()
262 .filter(move |&&(_, dst)| v == dst)
263 .map(|&(src, _)| src)
264 });
265 assert!(result.is_err());
266 let cycle = result.unwrap_err();
267 assert!(
268 cycle_rotations.contains(&*cycle),
269 "cycle: {:?}, vertex order: {:?}",
270 cycle,
271 permutation
272 );
273 }
274 }
275
276 #[test]
277 pub fn test_scc_kosaraju() {
278 let edges = [
280 ('a', 'b'),
281 ('b', 'c'),
282 ('b', 'f'),
283 ('b', 'e'),
284 ('c', 'd'),
285 ('c', 'g'),
286 ('d', 'c'),
287 ('d', 'h'),
288 ('e', 'a'),
289 ('e', 'f'),
290 ('f', 'g'),
291 ('g', 'f'),
292 ('h', 'd'),
293 ('h', 'g'),
294 ];
295
296 let scc = scc_kosaraju(
297 'a'..='g',
298 |v| {
299 edges
300 .iter()
301 .filter(move |&&(_, dst)| v == dst)
302 .map(|&(src, _)| src)
303 },
304 |u| {
305 edges
306 .iter()
307 .filter(move |&&(src, _)| u == src)
308 .map(|&(_, dst)| dst)
309 },
310 );
311
312 assert_ne!(scc[&'a'], scc[&'c']);
313 assert_ne!(scc[&'a'], scc[&'f']);
314 assert_ne!(scc[&'c'], scc[&'f']);
315
316 assert_eq!(scc[&'a'], scc[&'b']);
317 assert_eq!(scc[&'a'], scc[&'e']);
318
319 assert_eq!(scc[&'c'], scc[&'d']);
320 assert_eq!(scc[&'c'], scc[&'h']);
321
322 assert_eq!(scc[&'f'], scc[&'g']);
323 }
324}