dfir_lang/graph/
graph_algorithms.rs

1//! General graph algorithm utility functions
2
3use std::collections::btree_map::Entry;
4use std::collections::{BTreeMap, BTreeSet};
5
6/// Computes the topological sort of the nodes of a possibly cyclic graph by ordering strongly
7/// connected components together.
8pub fn topo_sort_scc<Id, NodesFn, NodeIds, PredsFn, SuccsFn, PredsIter, SuccsIter>(
9    mut nodes_fn: NodesFn,
10    mut preds_fn: PredsFn,
11    succs_fn: SuccsFn,
12) -> Vec<Id>
13where
14    Id: Copy + Eq + Ord,
15    NodesFn: FnMut() -> NodeIds,
16    NodeIds: IntoIterator<Item = Id>,
17    PredsFn: FnMut(Id) -> PredsIter,
18    SuccsFn: FnMut(Id) -> SuccsIter,
19    PredsIter: IntoIterator<Item = Id>,
20    SuccsIter: IntoIterator<Item = Id>,
21{
22    let scc = scc_kosaraju((nodes_fn)(), &mut preds_fn, succs_fn);
23    let topo_sort_order = {
24        // Condensed each SCC into a single node for toposort.
25        let mut condensed_preds: BTreeMap<Id, Vec<Id>> = Default::default();
26        for v in (nodes_fn)() {
27            let v = scc[&v];
28            condensed_preds.entry(v).or_default().extend(
29                (preds_fn)(v)
30                    .into_iter()
31                    .map(|u| scc[&u])
32                    .filter(|&u| v != u),
33            );
34        }
35
36        topo_sort((nodes_fn)(), |v| {
37            condensed_preds.get(&v).into_iter().flatten().cloned()
38        })
39        .map_err(drop)
40        .expect("No cycles after SCC condensing.")
41    };
42    topo_sort_order
43}
44
45/// Topologically sorts a set of nodes. Returns a list where the order of `Id`s will agree with
46/// the order of any path through the graph.
47///
48/// This succeeds if the input is a directed acyclic graph (DAG).
49///
50/// If the input has a cycle, an `Err` will be returned containing the cycle. Each node in the
51/// cycle will be listed exactly once.
52///
53/// <https://en.wikipedia.org/wiki/Topological_sorting>
54pub fn topo_sort<Id, NodeIds, PredsFn, PredsIter>(
55    node_ids: NodeIds,
56    mut preds_fn: PredsFn,
57) -> Result<Vec<Id>, Vec<Id>>
58where
59    Id: Copy + Eq + Ord,
60    NodeIds: IntoIterator<Item = Id>,
61    PredsFn: FnMut(Id) -> PredsIter,
62    PredsIter: IntoIterator<Item = Id>,
63{
64    let (mut marked, mut order) = Default::default();
65
66    fn pred_dfs_postorder<Id, PredsFn, PredsIter>(
67        node_id: Id,
68        preds_fn: &mut PredsFn,
69        marked: &mut BTreeMap<Id, bool>, // `false` => temporary, `true` => permanent.
70        order: &mut Vec<Id>,
71    ) -> Result<(), ()>
72    where
73        Id: Copy + Eq + Ord,
74        PredsFn: FnMut(Id) -> PredsIter,
75        PredsIter: IntoIterator<Item = Id>,
76    {
77        match marked.get(&node_id) {
78            Some(_permanent @ true) => Ok(()),
79            Some(_temporary @ false) => {
80                // Cycle found!
81                order.clear();
82                order.push(node_id);
83                Err(())
84            }
85            None => {
86                marked.insert(node_id, false);
87                for next_pred in (preds_fn)(node_id) {
88                    pred_dfs_postorder(next_pred, preds_fn, marked, order).map_err(|()| {
89                        if order.len() == 1 || order.first().unwrap() != order.last().unwrap() {
90                            order.push(node_id);
91                        }
92                    })?;
93                }
94                order.push(node_id);
95                marked.insert(node_id, true);
96                Ok(())
97            }
98        }
99    }
100
101    for node_id in node_ids {
102        if pred_dfs_postorder(node_id, &mut preds_fn, &mut marked, &mut order).is_err() {
103            // Cycle found.
104            let end = order.last().unwrap();
105            let beg = order.iter().position(|n| n == end).unwrap();
106            order.drain(0..=beg);
107            return Err(order);
108        }
109    }
110
111    Ok(order)
112}
113
114/// Finds the strongly connected components in the graph. A strongly connected component is a
115/// subset of nodes that are all reachable by each other.
116///
117/// <https://en.wikipedia.org/wiki/Strongly_connected_component>
118///
119/// Each component is represented by a specific member node. The returned `BTreeMap` maps each node
120/// ID to the node ID of its "representative." Nodes with the same "representative" node are in the
121/// same strongly connected component.
122///
123/// This function uses [Kosaraju's algorithm](https://en.wikipedia.org/wiki/Kosaraju%27s_algorithm).
124pub fn scc_kosaraju<Id, NodeIds, PredsFn, SuccsFn, PredsIter, SuccsIter>(
125    nodes: NodeIds,
126    mut preds_fn: PredsFn,
127    mut succs_fn: SuccsFn,
128) -> BTreeMap<Id, Id>
129where
130    Id: Copy + Eq + Ord,
131    NodeIds: IntoIterator<Item = Id>,
132    PredsFn: FnMut(Id) -> PredsIter,
133    SuccsFn: FnMut(Id) -> SuccsIter,
134    PredsIter: IntoIterator<Item = Id>,
135    SuccsIter: IntoIterator<Item = Id>,
136{
137    // https://en.wikipedia.org/wiki/Kosaraju%27s_algorithm
138    fn visit<Id, SuccsFn, SuccsIter>(
139        succs_fn: &mut SuccsFn,
140        u: Id,
141        seen: &mut BTreeSet<Id>,
142        stack: &mut Vec<Id>,
143    ) where
144        Id: Copy + Eq + Ord,
145        SuccsFn: FnMut(Id) -> SuccsIter,
146        SuccsIter: IntoIterator<Item = Id>,
147    {
148        if seen.insert(u) {
149            for v in (succs_fn)(u) {
150                visit(succs_fn, v, seen, stack);
151            }
152            stack.push(u);
153        }
154    }
155    let (mut seen, mut stack) = Default::default();
156    for sg in nodes {
157        visit(&mut succs_fn, sg, &mut seen, &mut stack);
158    }
159    let _ = seen;
160
161    fn assign<Id, PredsFn, PredsIter>(
162        preds_fn: &mut PredsFn,
163        v: Id,
164        root: Id,
165        components: &mut BTreeMap<Id, Id>,
166    ) where
167        Id: Copy + Eq + Ord,
168        PredsFn: FnMut(Id) -> PredsIter,
169        PredsIter: IntoIterator<Item = Id>,
170    {
171        if let Entry::Vacant(vacant_entry) = components.entry(v) {
172            vacant_entry.insert(root);
173            for u in (preds_fn)(v) {
174                assign(preds_fn, u, root, components);
175            }
176        }
177    }
178
179    let mut components = Default::default();
180    for sg in stack.into_iter().rev() {
181        assign(&mut preds_fn, sg, sg, &mut components);
182    }
183    components
184}
185
186#[cfg(test)]
187mod test {
188    use itertools::Itertools;
189
190    use super::*;
191
192    #[test]
193    pub fn test_toposort() {
194        let edges = [
195            (5, 11),
196            (11, 2),
197            (11, 9),
198            (11, 10),
199            (7, 11),
200            (7, 8),
201            (8, 9),
202            (3, 8),
203            (3, 10),
204        ];
205
206        // https://commons.wikimedia.org/wiki/File:Directed_acyclic_graph_2.svg
207        let sort = topo_sort([2, 3, 5, 7, 8, 9, 10, 11], |v| {
208            edges
209                .iter()
210                .filter(move |&&(_, dst)| v == dst)
211                .map(|&(src, _)| src)
212        });
213        assert!(
214            sort.is_ok(),
215            "Did not expect cycle: {:?}",
216            sort.unwrap_err()
217        );
218
219        let sort = sort.unwrap();
220        println!("{:?}", sort);
221
222        let position: BTreeMap<_, _> = sort.iter().enumerate().map(|(i, &x)| (x, i)).collect();
223        for (src, dst) in edges.iter() {
224            assert!(position[src] < position[dst]);
225        }
226    }
227
228    #[test]
229    pub fn test_toposort_cycle() {
230        // https://commons.wikimedia.org/wiki/File:Directed_graph,_cyclic.svg
231        //          ┌────►C──────┐
232        //          │            │
233        //          │            ▼
234        // A───────►B            E ─────►F
235        //          ▲            │
236        //          │            │
237        //          └─────D◄─────┘
238        let edges = [
239            ('A', 'B'),
240            ('B', 'C'),
241            ('C', 'E'),
242            ('D', 'B'),
243            ('E', 'F'),
244            ('E', 'D'),
245        ];
246        let ids = edges
247            .iter()
248            .flat_map(|&(a, b)| [a, b])
249            .collect::<BTreeSet<_>>();
250        let cycle_rotations = BTreeSet::from_iter([
251            ['B', 'C', 'E', 'D'],
252            ['C', 'E', 'D', 'B'],
253            ['E', 'D', 'B', 'C'],
254            ['D', 'B', 'C', 'E'],
255        ]);
256
257        let permutations = ids.iter().copied().permutations(ids.len());
258        for permutation in permutations {
259            let result = topo_sort(permutation.iter().copied(), |v| {
260                edges
261                    .iter()
262                    .filter(move |&&(_, dst)| v == dst)
263                    .map(|&(src, _)| src)
264            });
265            assert!(result.is_err());
266            let cycle = result.unwrap_err();
267            assert!(
268                cycle_rotations.contains(&*cycle),
269                "cycle: {:?}, vertex order: {:?}",
270                cycle,
271                permutation
272            );
273        }
274    }
275
276    #[test]
277    pub fn test_scc_kosaraju() {
278        // https://commons.wikimedia.org/wiki/File:Scc-1.svg
279        let edges = [
280            ('a', 'b'),
281            ('b', 'c'),
282            ('b', 'f'),
283            ('b', 'e'),
284            ('c', 'd'),
285            ('c', 'g'),
286            ('d', 'c'),
287            ('d', 'h'),
288            ('e', 'a'),
289            ('e', 'f'),
290            ('f', 'g'),
291            ('g', 'f'),
292            ('h', 'd'),
293            ('h', 'g'),
294        ];
295
296        let scc = scc_kosaraju(
297            'a'..='g',
298            |v| {
299                edges
300                    .iter()
301                    .filter(move |&&(_, dst)| v == dst)
302                    .map(|&(src, _)| src)
303            },
304            |u| {
305                edges
306                    .iter()
307                    .filter(move |&&(src, _)| u == src)
308                    .map(|&(_, dst)| dst)
309            },
310        );
311
312        assert_ne!(scc[&'a'], scc[&'c']);
313        assert_ne!(scc[&'a'], scc[&'f']);
314        assert_ne!(scc[&'c'], scc[&'f']);
315
316        assert_eq!(scc[&'a'], scc[&'b']);
317        assert_eq!(scc[&'a'], scc[&'e']);
318
319        assert_eq!(scc[&'c'], scc[&'d']);
320        assert_eq!(scc[&'c'], scc[&'h']);
321
322        assert_eq!(scc[&'f'], scc[&'g']);
323    }
324}