dfir_lang/graph/
flat_to_partitioned.rs

1//! Subgraph partioning algorithm
2
3use std::collections::{BTreeMap, BTreeSet};
4
5use slotmap::{SecondaryMap, SparseSecondaryMap};
6use syn::parse_quote;
7
8use super::meta_graph::DfirGraph;
9use super::ops::{DelayType, FloType, find_node_op_constraints};
10use super::{Color, GraphEdgeId, GraphNode, GraphNodeId, GraphSubgraphId, graph_algorithms};
11use crate::diagnostic::{Diagnostic, Level};
12use crate::union_find::UnionFind;
13
14/// Helper struct for tracking barrier crossers, see [`find_barrier_crossers`].
15struct BarrierCrossers {
16    /// Edge barrier crossers, including what type.
17    pub edge_barrier_crossers: SecondaryMap<GraphEdgeId, DelayType>,
18    /// Singleton reference barrier crossers, considered to be [`DelayType::Stratum`].
19    pub singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)>,
20}
21impl BarrierCrossers {
22    /// Iterate pairs of nodes that are across a barrier. Excludes `DelayType::NextIteration` pairs.
23    fn iter_node_pairs<'a>(
24        &'a self,
25        partitioned_graph: &'a DfirGraph,
26    ) -> impl 'a + Iterator<Item = ((GraphNodeId, GraphNodeId), DelayType)> {
27        let edge_pairs_iter = self
28            .edge_barrier_crossers
29            .iter()
30            .map(|(edge_id, &delay_type)| {
31                let src_dst = partitioned_graph.edge(edge_id);
32                (src_dst, delay_type)
33            });
34        let singleton_pairs_iter = self
35            .singleton_barrier_crossers
36            .iter()
37            .map(|&src_dst| (src_dst, DelayType::Stratum));
38        edge_pairs_iter.chain(singleton_pairs_iter)
39    }
40
41    /// Insert/replace edge.
42    fn replace_edge(&mut self, old_edge_id: GraphEdgeId, new_edge_id: GraphEdgeId) {
43        if let Some(delay_type) = self.edge_barrier_crossers.remove(old_edge_id) {
44            self.edge_barrier_crossers.insert(new_edge_id, delay_type);
45        }
46    }
47}
48
49/// Find all the barrier crossers.
50fn find_barrier_crossers(partitioned_graph: &DfirGraph) -> BarrierCrossers {
51    let edge_barrier_crossers = partitioned_graph
52        .edges()
53        .filter(|&(_, (_src, dst))| {
54            // Ignore barriers within `loop {` blocks.
55            partitioned_graph.node_loop(dst).is_none()
56        })
57        .filter_map(|(edge_id, (_src, dst))| {
58            let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
59            let op_constraints = partitioned_graph.node_op_inst(dst)?.op_constraints;
60            let input_barrier = (op_constraints.input_delaytype_fn)(dst_port)?;
61            Some((edge_id, input_barrier))
62        })
63        .collect();
64    let singleton_barrier_crossers = partitioned_graph
65        .node_ids()
66        .flat_map(|dst| {
67            partitioned_graph
68                .node_singleton_references(dst)
69                .iter()
70                .flatten()
71                .map(move |&src_ref| (src_ref, dst))
72        })
73        .collect();
74    BarrierCrossers {
75        edge_barrier_crossers,
76        singleton_barrier_crossers,
77    }
78}
79
80fn find_subgraph_unionfind(
81    partitioned_graph: &DfirGraph,
82    barrier_crossers: &BarrierCrossers,
83) -> (UnionFind<GraphNodeId>, BTreeSet<GraphEdgeId>) {
84    // Modality (color) of nodes, push or pull.
85    // TODO(mingwei)? This does NOT consider `DelayType` barriers (which generally imply `Pull`),
86    // which makes it inconsistant with the final output in `as_code()`. But this doesn't create
87    // any bugs since we exclude `DelayType` edges from joining subgraphs anyway.
88    let mut node_color = partitioned_graph
89        .node_ids()
90        .filter_map(|node_id| {
91            let op_color = partitioned_graph.node_color(node_id)?;
92            Some((node_id, op_color))
93        })
94        .collect::<SparseSecondaryMap<_, _>>();
95
96    let mut subgraph_unionfind: UnionFind<GraphNodeId> =
97        UnionFind::with_capacity(partitioned_graph.nodes().len());
98
99    // Will contain all edges which are handoffs. Starts out with all edges and
100    // we remove from this set as we combine nodes into subgraphs.
101    let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
102    // Would sort edges here for priority (for now, no sort/priority).
103
104    // Each edge gets looked at in order. However we may not know if a linear
105    // chain of operators is PUSH vs PULL until we look at the ends. A fancier
106    // algorithm would know to handle linear chains from the outside inward.
107    // But instead we just run through the edges in a loop until no more
108    // progress is made. Could have some sort of O(N^2) pathological worst
109    // case.
110    let mut progress = true;
111    while progress {
112        progress = false;
113        // TODO(mingwei): Could this iterate `handoff_edges` instead? (Modulo ownership). Then no case (1) below.
114        for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
115            // Ignore (1) already added edges as well as (2) new self-cycles. (Unless reference edge).
116            if subgraph_unionfind.same_set(src, dst) {
117                // Note that the _edge_ `edge_id` might not be in the subgraph even when both `src` and `dst` are. This prevents case 2.
118                // Handoffs will be inserted later for this self-loop.
119                continue;
120            }
121
122            // Do not connect stratum crossers (next edges).
123            if barrier_crossers
124                .iter_node_pairs(partitioned_graph)
125                .any(|((x_src, x_dst), _)| {
126                    (subgraph_unionfind.same_set(x_src, src)
127                        && subgraph_unionfind.same_set(x_dst, dst))
128                        || (subgraph_unionfind.same_set(x_src, dst)
129                            && subgraph_unionfind.same_set(x_dst, src))
130                })
131            {
132                continue;
133            }
134
135            // Do not connect across loop contexts.
136            if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
137                continue;
138            }
139            // Do not connect `next_iteration()`.
140            if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
141                Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
142            }) {
143                continue;
144            }
145
146            if can_connect_colorize(&mut node_color, src, dst) {
147                // At this point we have selected this edge and its src & dst to be
148                // within a single subgraph.
149                subgraph_unionfind.union(src, dst);
150                assert!(handoff_edges.remove(&edge_id));
151                progress = true;
152            }
153        }
154    }
155
156    (subgraph_unionfind, handoff_edges)
157}
158
159/// Builds the datastructures for checking which subgraph each node belongs to
160/// after handoffs have already been inserted to partition subgraphs.
161/// This list of nodes in each subgraph are returned in topological sort order.
162fn make_subgraph_collect(
163    partitioned_graph: &DfirGraph,
164    mut subgraph_unionfind: UnionFind<GraphNodeId>,
165) -> SecondaryMap<GraphNodeId, Vec<GraphNodeId>> {
166    // We want the nodes of each subgraph to be listed in topo-sort order.
167    // We could do this on each subgraph, or we could do it all at once on the
168    // whole node graph by ignoring handoffs, which is what we do here:
169    let topo_sort = graph_algorithms::topo_sort(
170        partitioned_graph
171            .nodes()
172            .filter(|&(_, node)| !matches!(node, GraphNode::Handoff { .. }))
173            .map(|(node_id, _)| node_id),
174        |v| {
175            partitioned_graph
176                .node_predecessor_nodes(v)
177                .filter(|&pred_id| {
178                    let pred = partitioned_graph.node(pred_id);
179                    !matches!(pred, GraphNode::Handoff { .. })
180                })
181        },
182    )
183    .expect("Subgraphs are in-out trees.");
184
185    let mut grouped_nodes: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = Default::default();
186    for node_id in topo_sort {
187        let repr_node = subgraph_unionfind.find(node_id);
188        if !grouped_nodes.contains_key(repr_node) {
189            grouped_nodes.insert(repr_node, Default::default());
190        }
191        grouped_nodes[repr_node].push(node_id);
192    }
193    grouped_nodes
194}
195
196/// Find subgraph and insert handoffs.
197/// Modifies barrier_crossers so that the edge OUT of an inserted handoff has
198/// the DelayType data.
199fn make_subgraphs(partitioned_graph: &mut DfirGraph, barrier_crossers: &mut BarrierCrossers) {
200    // Algorithm:
201    // 1. Each node begins as its own subgraph.
202    // 2. Collect edges. (Future optimization: sort so edges which should not be split across a handoff come first).
203    // 3. For each edge, try to join `(to, from)` into the same subgraph.
204
205    // TODO(mingwei):
206    // self.partitioned_graph.assert_valid();
207
208    let (subgraph_unionfind, handoff_edges) =
209        find_subgraph_unionfind(partitioned_graph, barrier_crossers);
210
211    // Insert handoffs between subgraphs (or on subgraph self-loop edges)
212    for edge_id in handoff_edges {
213        let (src_id, dst_id) = partitioned_graph.edge(edge_id);
214
215        // Already has a handoff, no need to insert one.
216        let src_node = partitioned_graph.node(src_id);
217        let dst_node = partitioned_graph.node(dst_id);
218        if matches!(src_node, GraphNode::Handoff { .. })
219            || matches!(dst_node, GraphNode::Handoff { .. })
220        {
221            continue;
222        }
223
224        let hoff = GraphNode::Handoff {
225            src_span: src_node.span(),
226            dst_span: dst_node.span(),
227        };
228        let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
229
230        // Update barrier_crossers for inserted node.
231        barrier_crossers.replace_edge(edge_id, out_edge_id);
232    }
233
234    // Determine node's subgraph and subgraph's nodes.
235    // This list of nodes in each subgraph are to be in topological sort order.
236    // Eventually returned directly in the [`DfirGraph`].
237    let grouped_nodes = make_subgraph_collect(partitioned_graph, subgraph_unionfind);
238    for (_repr_node, member_nodes) in grouped_nodes {
239        partitioned_graph.insert_subgraph(member_nodes).unwrap();
240    }
241}
242
243/// Set `src` or `dst` color if `None` based on the other (if possible):
244/// `None` indicates an op could be pull or push i.e. unary-in & unary-out.
245/// So in that case we color `src` or `dst` based on its newfound neighbor (the other one).
246///
247/// Returns if `src` and `dst` can be in the same subgraph.
248fn can_connect_colorize(
249    node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
250    src: GraphNodeId,
251    dst: GraphNodeId,
252) -> bool {
253    // Pull -> Pull
254    // Push -> Push
255    // Pull -> [Computation] -> Push
256    // Push -> [Handoff] -> Pull
257    let can_connect = match (node_color.get(src), node_color.get(dst)) {
258        // Linear chain, can't connect because it may cause future conflicts.
259        // But if it doesn't in the _future_ we can connect it (once either/both ends are determined).
260        (None, None) => false,
261
262        // Infer left side.
263        (None, Some(Color::Pull | Color::Comp)) => {
264            node_color.insert(src, Color::Pull);
265            true
266        }
267        (None, Some(Color::Push | Color::Hoff)) => {
268            node_color.insert(src, Color::Push);
269            true
270        }
271
272        // Infer right side.
273        (Some(Color::Pull | Color::Hoff), None) => {
274            node_color.insert(dst, Color::Pull);
275            true
276        }
277        (Some(Color::Comp | Color::Push), None) => {
278            node_color.insert(dst, Color::Push);
279            true
280        }
281
282        // Both sides already specified.
283        (Some(Color::Pull), Some(Color::Pull)) => true,
284        (Some(Color::Pull), Some(Color::Comp)) => true,
285        (Some(Color::Pull), Some(Color::Push)) => true,
286
287        (Some(Color::Comp), Some(Color::Pull)) => false,
288        (Some(Color::Comp), Some(Color::Comp)) => false,
289        (Some(Color::Comp), Some(Color::Push)) => true,
290
291        (Some(Color::Push), Some(Color::Pull)) => false,
292        (Some(Color::Push), Some(Color::Comp)) => false,
293        (Some(Color::Push), Some(Color::Push)) => true,
294
295        // Handoffs are not part of subgraphs.
296        (Some(Color::Hoff), Some(_)) => false,
297        (Some(_), Some(Color::Hoff)) => false,
298    };
299    can_connect
300}
301
302/// Stratification is surprisingly tricky. Basically it is topological sort, but with some nuance.
303///
304/// Returns an error if there is a cycle thru negation.
305fn find_subgraph_strata(
306    partitioned_graph: &mut DfirGraph,
307    barrier_crossers: &BarrierCrossers,
308) -> Result<(), Diagnostic> {
309    // Determine subgraphs's stratum number.
310    // Find SCCs ignoring `defer_tick()` (`DelayType::Tick`) edges, then do TopoSort on the
311    // resulting DAG.
312    // Cycles thru cross-stratum negative edges (both `DelayType::Tick` and `DelayType::Stratum`)
313    // are an error.
314
315    // Generate a subgraph graph. I.e. each node is a subgraph.
316    // Edges are connections between subgraphs, ignoring tick-crossers.
317    // TODO: use DiMulGraph here?
318    #[derive(Default)]
319    struct SubgraphGraph {
320        preds: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>>,
321        succs: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>>,
322    }
323    impl SubgraphGraph {
324        fn insert_edge(&mut self, src: GraphSubgraphId, dst: GraphSubgraphId) {
325            self.preds.entry(dst).or_default().push(src);
326            self.succs.entry(src).or_default().push(dst);
327        }
328    }
329    let mut subgraph_graph = SubgraphGraph::default();
330
331    // Negative (next stratum) connections between subgraphs. (Ignore `defer_tick()` connections).
332    let mut subgraph_stratum_barriers: BTreeSet<(GraphSubgraphId, GraphSubgraphId)> =
333        Default::default();
334
335    // Iterate handoffs between subgraphs, to build a subgraph meta-graph.
336    for (node_id, node) in partitioned_graph.nodes() {
337        if matches!(node, GraphNode::Handoff { .. }) {
338            assert_eq!(1, partitioned_graph.node_successors(node_id).count());
339            let (succ_edge, succ) = partitioned_graph.node_successors(node_id).next().unwrap();
340
341            // TODO(mingwei): Should we look at the singleton references too?
342            let succ_edge_delaytype = barrier_crossers
343                .edge_barrier_crossers
344                .get(succ_edge)
345                .copied();
346            // Ignore tick edges.
347            if let Some(DelayType::Tick | DelayType::TickLazy) = succ_edge_delaytype {
348                continue;
349            }
350
351            assert_eq!(1, partitioned_graph.node_predecessors(node_id).count());
352            let (_edge_id, pred) = partitioned_graph.node_predecessors(node_id).next().unwrap();
353
354            let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
355            let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
356
357            subgraph_graph.insert_edge(pred_sg, succ_sg);
358
359            if Some(DelayType::Stratum) == succ_edge_delaytype {
360                subgraph_stratum_barriers.insert((pred_sg, succ_sg));
361            }
362        }
363    }
364    // Include reference edges as well.
365    // TODO(mingwei): deduplicate graph building code.
366    for &(pred, succ) in barrier_crossers.singleton_barrier_crossers.iter() {
367        assert_ne!(pred, succ, "TODO(mingwei)");
368        let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
369        let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
370        assert_ne!(pred_sg, succ_sg);
371        subgraph_graph.insert_edge(pred_sg, succ_sg);
372        subgraph_stratum_barriers.insert((pred_sg, succ_sg));
373    }
374
375    // Topological sort (of strongly connected components) is how we find the (nondecreasing)
376    // order of strata.
377    let topo_sort_order = graph_algorithms::topo_sort_scc(
378        || partitioned_graph.subgraph_ids(),
379        |v| subgraph_graph.preds.get(&v).into_iter().flatten().cloned(),
380        |u| subgraph_graph.succs.get(&u).into_iter().flatten().cloned(),
381    );
382
383    // Each subgraph's stratum number is the same as it's predecessors.
384    //
385    // Unless:
386    // - At the top level: there is a negative edge (e.g. `fold()`), then we increment.
387    // - Entering or exiting a loop.
388    for sg_id in topo_sort_order {
389        let curr_loop = partitioned_graph.subgraph_loop(sg_id);
390
391        let stratum = subgraph_graph
392            .preds
393            .get(&sg_id)
394            .into_iter()
395            .flatten()
396            .filter_map(|&pred_sg_id| {
397                partitioned_graph
398                    .subgraph_stratum(pred_sg_id)
399                    .map(|stratum| {
400                        let pred_loop = partitioned_graph.subgraph_loop(pred_sg_id);
401                        if curr_loop != pred_loop {
402                            // Entering or exiting a loop.
403                            stratum + 1
404                        } else if curr_loop.is_none()
405                            && subgraph_stratum_barriers.contains(&(pred_sg_id, sg_id))
406                        {
407                            // Top level && negative edge.
408                            stratum + 1
409                        } else {
410                            stratum
411                        }
412                    })
413            })
414            .max()
415            .unwrap_or(0);
416        partitioned_graph.set_subgraph_stratum(sg_id, stratum);
417    }
418
419    // Re-introduce the `defer_tick()` edges, ensuring they actually go to the next tick.
420    let extra_stratum = partitioned_graph.max_stratum().unwrap_or(0) + 1; // Used for `defer_tick()` delayer subgraphs.
421    for (edge_id, &delay_type) in barrier_crossers.edge_barrier_crossers.iter() {
422        let (hoff, dst) = partitioned_graph.edge(edge_id);
423        // Ignore barriers within `loop {` blocks.
424        if partitioned_graph.node_loop(dst).is_some() {
425            continue;
426        }
427        let (_hoff_port, dst_port) = partitioned_graph.edge_ports(edge_id);
428
429        assert_eq!(1, partitioned_graph.node_predecessors(hoff).count());
430        let src = partitioned_graph
431            .node_predecessor_nodes(hoff)
432            .next()
433            .unwrap();
434
435        let src_sg = partitioned_graph.node_subgraph(src).unwrap();
436        let dst_sg = partitioned_graph.node_subgraph(dst).unwrap();
437        let src_stratum = partitioned_graph.subgraph_stratum(src_sg);
438        let dst_stratum = partitioned_graph.subgraph_stratum(dst_sg);
439        let dst_span = partitioned_graph.node(dst).span();
440        match delay_type {
441            DelayType::Tick | DelayType::TickLazy => {
442                let is_lazy = matches!(delay_type, DelayType::TickLazy);
443                // If tick edge goes foreward in stratum, need to buffer.
444                // (TODO(mingwei): could use a different kind of handoff.)
445                // Or if lazy, need to create extra subgraph to mark as lazy.
446                if src_stratum <= dst_stratum || is_lazy {
447                    // We inject a new subgraph between the src/dst which runs as the last stratum
448                    // of the tick and therefore delays the data until the next tick.
449
450                    // Before: A (src) -> H -> B (dst)
451                    // Then add intermediate identity:
452                    let (new_node_id, new_edge_id) = partitioned_graph.insert_intermediate_node(
453                        edge_id,
454                        // TODO(mingwei): Proper span w/ `parse_quote_spanned!`?
455                        GraphNode::Operator(parse_quote! { identity() }),
456                    );
457                    // Intermediate: A (src) -> H -> ID -> B (dst)
458                    let hoff = GraphNode::Handoff {
459                        // Span to the node that has the input stratum barrier.
460                        src_span: dst_span,
461                        dst_span,
462                    };
463                    let (_hoff_node_id, _hoff_edge_id) =
464                        partitioned_graph.insert_intermediate_node(new_edge_id, hoff);
465                    // After: A (src) -> H -> ID -> H' -> B (dst)
466
467                    // Set stratum number for new intermediate:
468                    // Create subgraph.
469                    let new_subgraph_id = partitioned_graph
470                        .insert_subgraph(vec![new_node_id])
471                        .unwrap();
472
473                    // Assign stratum.
474                    partitioned_graph.set_subgraph_stratum(new_subgraph_id, extra_stratum);
475
476                    // Assign laziness.
477                    partitioned_graph.set_subgraph_laziness(new_subgraph_id, is_lazy);
478                }
479            }
480            DelayType::Stratum => {
481                // Any negative edges which go onto the same or previous stratum are bad.
482                // Indicates an unbroken negative cycle.
483                // TODO(mingwei): This check is insufficient: https://github.com/hydro-project/hydro/issues/1115#issuecomment-2018385033
484                if dst_stratum <= src_stratum {
485                    return Err(Diagnostic::spanned(
486                        dst_port.span(),
487                        Level::Error,
488                        "Negative edge creates a negative cycle which must be broken with a `defer_tick()` operator.",
489                    ));
490                }
491            }
492            DelayType::MonotoneAccum => {
493                // cycles are actually fine
494                continue;
495            }
496        }
497    }
498    Ok(())
499}
500
501/// Put `is_external_input: true` operators in separate stratum 0 subgraphs if they are not in stratum 0.
502/// By ripping them out of their subgraph/stratum if they're not already in statum 0.
503fn separate_external_inputs(partitioned_graph: &mut DfirGraph) {
504    let external_input_nodes: Vec<_> = partitioned_graph
505        .nodes()
506        // Ensure node is an operator (not a handoff), get constraints spec.
507        .filter_map(|(node_id, node)| {
508            find_node_op_constraints(node).map(|op_constraints| (node_id, op_constraints))
509        })
510        // Ensure current `node_id` is an external input.
511        .filter(|(_node_id, op_constraints)| op_constraints.is_external_input)
512        // Collect just `node_id`s.
513        .map(|(node_id, _op_constraints)| node_id)
514        // Ignore if operator node is already stratum 0.
515        .filter(|&node_id| {
516            0 != partitioned_graph
517                .subgraph_stratum(partitioned_graph.node_subgraph(node_id).unwrap())
518                .unwrap()
519        })
520        .collect();
521
522    for node_id in external_input_nodes {
523        // Remove node from old subgraph.
524        assert!(
525            partitioned_graph.remove_from_subgraph(node_id),
526            "Cannot move input node that is not in a subgraph, this is a bug."
527        );
528        // Create new subgraph in stratum 0 for this source.
529        let new_sg_id = partitioned_graph.insert_subgraph(vec![node_id]).unwrap();
530        partitioned_graph.set_subgraph_stratum(new_sg_id, 0);
531
532        // Insert handoff.
533        for edge_id in partitioned_graph
534            .node_successor_edges(node_id)
535            .collect::<Vec<_>>()
536        {
537            let span = partitioned_graph.node(node_id).span();
538            let hoff = GraphNode::Handoff {
539                src_span: span,
540                dst_span: span,
541            };
542            partitioned_graph.insert_intermediate_node(edge_id, hoff);
543        }
544    }
545}
546
547/// Main method for this module. Partions a flat [`DfirGraph`] into one with subgraphs.
548///
549/// Returns an error if a negative cycle exists in the graph. Negative cycles prevent partioning.
550pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
551    // Pre-find barrier crossers (input edges with a `DelayType`).
552    let mut barrier_crossers = find_barrier_crossers(&flat_graph);
553    let mut partitioned_graph = flat_graph;
554
555    // Partition into subgraphs.
556    make_subgraphs(&mut partitioned_graph, &mut barrier_crossers);
557
558    // Find strata for subgraphs (early returns with error if negative cycle found).
559    find_subgraph_strata(&mut partitioned_graph, &barrier_crossers)?;
560
561    // Ensure all external inputs are in stratum 0.
562    separate_external_inputs(&mut partitioned_graph);
563
564    Ok(partitioned_graph)
565}