1use std::collections::{BTreeMap, BTreeSet};
4
5use slotmap::{SecondaryMap, SparseSecondaryMap};
6use syn::parse_quote;
7
8use super::meta_graph::DfirGraph;
9use super::ops::{DelayType, FloType, find_node_op_constraints};
10use super::{Color, GraphEdgeId, GraphNode, GraphNodeId, GraphSubgraphId, graph_algorithms};
11use crate::diagnostic::{Diagnostic, Level};
12use crate::union_find::UnionFind;
13
14struct BarrierCrossers {
16 pub edge_barrier_crossers: SecondaryMap<GraphEdgeId, DelayType>,
18 pub singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)>,
20}
21impl BarrierCrossers {
22 fn iter_node_pairs<'a>(
24 &'a self,
25 partitioned_graph: &'a DfirGraph,
26 ) -> impl 'a + Iterator<Item = ((GraphNodeId, GraphNodeId), DelayType)> {
27 let edge_pairs_iter = self
28 .edge_barrier_crossers
29 .iter()
30 .map(|(edge_id, &delay_type)| {
31 let src_dst = partitioned_graph.edge(edge_id);
32 (src_dst, delay_type)
33 });
34 let singleton_pairs_iter = self
35 .singleton_barrier_crossers
36 .iter()
37 .map(|&src_dst| (src_dst, DelayType::Stratum));
38 edge_pairs_iter.chain(singleton_pairs_iter)
39 }
40
41 fn replace_edge(&mut self, old_edge_id: GraphEdgeId, new_edge_id: GraphEdgeId) {
43 if let Some(delay_type) = self.edge_barrier_crossers.remove(old_edge_id) {
44 self.edge_barrier_crossers.insert(new_edge_id, delay_type);
45 }
46 }
47}
48
49fn find_barrier_crossers(partitioned_graph: &DfirGraph) -> BarrierCrossers {
51 let edge_barrier_crossers = partitioned_graph
52 .edges()
53 .filter(|&(_, (_src, dst))| {
54 partitioned_graph.node_loop(dst).is_none()
56 })
57 .filter_map(|(edge_id, (_src, dst))| {
58 let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
59 let op_constraints = partitioned_graph.node_op_inst(dst)?.op_constraints;
60 let input_barrier = (op_constraints.input_delaytype_fn)(dst_port)?;
61 Some((edge_id, input_barrier))
62 })
63 .collect();
64 let singleton_barrier_crossers = partitioned_graph
65 .node_ids()
66 .flat_map(|dst| {
67 partitioned_graph
68 .node_singleton_references(dst)
69 .iter()
70 .flatten()
71 .map(move |&src_ref| (src_ref, dst))
72 })
73 .collect();
74 BarrierCrossers {
75 edge_barrier_crossers,
76 singleton_barrier_crossers,
77 }
78}
79
80fn find_subgraph_unionfind(
81 partitioned_graph: &DfirGraph,
82 barrier_crossers: &BarrierCrossers,
83) -> (UnionFind<GraphNodeId>, BTreeSet<GraphEdgeId>) {
84 let mut node_color = partitioned_graph
89 .node_ids()
90 .filter_map(|node_id| {
91 let op_color = partitioned_graph.node_color(node_id)?;
92 Some((node_id, op_color))
93 })
94 .collect::<SparseSecondaryMap<_, _>>();
95
96 let mut subgraph_unionfind: UnionFind<GraphNodeId> =
97 UnionFind::with_capacity(partitioned_graph.nodes().len());
98
99 let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
102 let mut progress = true;
111 while progress {
112 progress = false;
113 for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
115 if subgraph_unionfind.same_set(src, dst) {
117 continue;
120 }
121
122 if barrier_crossers
124 .iter_node_pairs(partitioned_graph)
125 .any(|((x_src, x_dst), _)| {
126 (subgraph_unionfind.same_set(x_src, src)
127 && subgraph_unionfind.same_set(x_dst, dst))
128 || (subgraph_unionfind.same_set(x_src, dst)
129 && subgraph_unionfind.same_set(x_dst, src))
130 })
131 {
132 continue;
133 }
134
135 if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
137 continue;
138 }
139 if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
141 Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
142 }) {
143 continue;
144 }
145
146 if can_connect_colorize(&mut node_color, src, dst) {
147 subgraph_unionfind.union(src, dst);
150 assert!(handoff_edges.remove(&edge_id));
151 progress = true;
152 }
153 }
154 }
155
156 (subgraph_unionfind, handoff_edges)
157}
158
159fn make_subgraph_collect(
163 partitioned_graph: &DfirGraph,
164 mut subgraph_unionfind: UnionFind<GraphNodeId>,
165) -> SecondaryMap<GraphNodeId, Vec<GraphNodeId>> {
166 let topo_sort = graph_algorithms::topo_sort(
170 partitioned_graph
171 .nodes()
172 .filter(|&(_, node)| !matches!(node, GraphNode::Handoff { .. }))
173 .map(|(node_id, _)| node_id),
174 |v| {
175 partitioned_graph
176 .node_predecessor_nodes(v)
177 .filter(|&pred_id| {
178 let pred = partitioned_graph.node(pred_id);
179 !matches!(pred, GraphNode::Handoff { .. })
180 })
181 },
182 )
183 .expect("Subgraphs are in-out trees.");
184
185 let mut grouped_nodes: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = Default::default();
186 for node_id in topo_sort {
187 let repr_node = subgraph_unionfind.find(node_id);
188 if !grouped_nodes.contains_key(repr_node) {
189 grouped_nodes.insert(repr_node, Default::default());
190 }
191 grouped_nodes[repr_node].push(node_id);
192 }
193 grouped_nodes
194}
195
196fn make_subgraphs(partitioned_graph: &mut DfirGraph, barrier_crossers: &mut BarrierCrossers) {
200 let (subgraph_unionfind, handoff_edges) =
209 find_subgraph_unionfind(partitioned_graph, barrier_crossers);
210
211 for edge_id in handoff_edges {
213 let (src_id, dst_id) = partitioned_graph.edge(edge_id);
214
215 let src_node = partitioned_graph.node(src_id);
217 let dst_node = partitioned_graph.node(dst_id);
218 if matches!(src_node, GraphNode::Handoff { .. })
219 || matches!(dst_node, GraphNode::Handoff { .. })
220 {
221 continue;
222 }
223
224 let hoff = GraphNode::Handoff {
225 src_span: src_node.span(),
226 dst_span: dst_node.span(),
227 };
228 let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
229
230 barrier_crossers.replace_edge(edge_id, out_edge_id);
232 }
233
234 let grouped_nodes = make_subgraph_collect(partitioned_graph, subgraph_unionfind);
238 for (_repr_node, member_nodes) in grouped_nodes {
239 partitioned_graph.insert_subgraph(member_nodes).unwrap();
240 }
241}
242
243fn can_connect_colorize(
249 node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
250 src: GraphNodeId,
251 dst: GraphNodeId,
252) -> bool {
253 let can_connect = match (node_color.get(src), node_color.get(dst)) {
258 (None, None) => false,
261
262 (None, Some(Color::Pull | Color::Comp)) => {
264 node_color.insert(src, Color::Pull);
265 true
266 }
267 (None, Some(Color::Push | Color::Hoff)) => {
268 node_color.insert(src, Color::Push);
269 true
270 }
271
272 (Some(Color::Pull | Color::Hoff), None) => {
274 node_color.insert(dst, Color::Pull);
275 true
276 }
277 (Some(Color::Comp | Color::Push), None) => {
278 node_color.insert(dst, Color::Push);
279 true
280 }
281
282 (Some(Color::Pull), Some(Color::Pull)) => true,
284 (Some(Color::Pull), Some(Color::Comp)) => true,
285 (Some(Color::Pull), Some(Color::Push)) => true,
286
287 (Some(Color::Comp), Some(Color::Pull)) => false,
288 (Some(Color::Comp), Some(Color::Comp)) => false,
289 (Some(Color::Comp), Some(Color::Push)) => true,
290
291 (Some(Color::Push), Some(Color::Pull)) => false,
292 (Some(Color::Push), Some(Color::Comp)) => false,
293 (Some(Color::Push), Some(Color::Push)) => true,
294
295 (Some(Color::Hoff), Some(_)) => false,
297 (Some(_), Some(Color::Hoff)) => false,
298 };
299 can_connect
300}
301
302fn find_subgraph_strata(
306 partitioned_graph: &mut DfirGraph,
307 barrier_crossers: &BarrierCrossers,
308) -> Result<(), Diagnostic> {
309 #[derive(Default)]
319 struct SubgraphGraph {
320 preds: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>>,
321 succs: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>>,
322 }
323 impl SubgraphGraph {
324 fn insert_edge(&mut self, src: GraphSubgraphId, dst: GraphSubgraphId) {
325 self.preds.entry(dst).or_default().push(src);
326 self.succs.entry(src).or_default().push(dst);
327 }
328 }
329 let mut subgraph_graph = SubgraphGraph::default();
330
331 let mut subgraph_stratum_barriers: BTreeSet<(GraphSubgraphId, GraphSubgraphId)> =
333 Default::default();
334
335 for (node_id, node) in partitioned_graph.nodes() {
337 if matches!(node, GraphNode::Handoff { .. }) {
338 assert_eq!(1, partitioned_graph.node_successors(node_id).count());
339 let (succ_edge, succ) = partitioned_graph.node_successors(node_id).next().unwrap();
340
341 let succ_edge_delaytype = barrier_crossers
343 .edge_barrier_crossers
344 .get(succ_edge)
345 .copied();
346 if let Some(DelayType::Tick | DelayType::TickLazy) = succ_edge_delaytype {
348 continue;
349 }
350
351 assert_eq!(1, partitioned_graph.node_predecessors(node_id).count());
352 let (_edge_id, pred) = partitioned_graph.node_predecessors(node_id).next().unwrap();
353
354 let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
355 let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
356
357 subgraph_graph.insert_edge(pred_sg, succ_sg);
358
359 if Some(DelayType::Stratum) == succ_edge_delaytype {
360 subgraph_stratum_barriers.insert((pred_sg, succ_sg));
361 }
362 }
363 }
364 for &(pred, succ) in barrier_crossers.singleton_barrier_crossers.iter() {
367 assert_ne!(pred, succ, "TODO(mingwei)");
368 let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
369 let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
370 assert_ne!(pred_sg, succ_sg);
371 subgraph_graph.insert_edge(pred_sg, succ_sg);
372 subgraph_stratum_barriers.insert((pred_sg, succ_sg));
373 }
374
375 let topo_sort_order = graph_algorithms::topo_sort_scc(
378 || partitioned_graph.subgraph_ids(),
379 |v| subgraph_graph.preds.get(&v).into_iter().flatten().cloned(),
380 |u| subgraph_graph.succs.get(&u).into_iter().flatten().cloned(),
381 );
382
383 for sg_id in topo_sort_order {
389 let curr_loop = partitioned_graph.subgraph_loop(sg_id);
390
391 let stratum = subgraph_graph
392 .preds
393 .get(&sg_id)
394 .into_iter()
395 .flatten()
396 .filter_map(|&pred_sg_id| {
397 partitioned_graph
398 .subgraph_stratum(pred_sg_id)
399 .map(|stratum| {
400 let pred_loop = partitioned_graph.subgraph_loop(pred_sg_id);
401 if curr_loop != pred_loop {
402 stratum + 1
404 } else if curr_loop.is_none()
405 && subgraph_stratum_barriers.contains(&(pred_sg_id, sg_id))
406 {
407 stratum + 1
409 } else {
410 stratum
411 }
412 })
413 })
414 .max()
415 .unwrap_or(0);
416 partitioned_graph.set_subgraph_stratum(sg_id, stratum);
417 }
418
419 let extra_stratum = partitioned_graph.max_stratum().unwrap_or(0) + 1; for (edge_id, &delay_type) in barrier_crossers.edge_barrier_crossers.iter() {
422 let (hoff, dst) = partitioned_graph.edge(edge_id);
423 if partitioned_graph.node_loop(dst).is_some() {
425 continue;
426 }
427 let (_hoff_port, dst_port) = partitioned_graph.edge_ports(edge_id);
428
429 assert_eq!(1, partitioned_graph.node_predecessors(hoff).count());
430 let src = partitioned_graph
431 .node_predecessor_nodes(hoff)
432 .next()
433 .unwrap();
434
435 let src_sg = partitioned_graph.node_subgraph(src).unwrap();
436 let dst_sg = partitioned_graph.node_subgraph(dst).unwrap();
437 let src_stratum = partitioned_graph.subgraph_stratum(src_sg);
438 let dst_stratum = partitioned_graph.subgraph_stratum(dst_sg);
439 let dst_span = partitioned_graph.node(dst).span();
440 match delay_type {
441 DelayType::Tick | DelayType::TickLazy => {
442 let is_lazy = matches!(delay_type, DelayType::TickLazy);
443 if src_stratum <= dst_stratum || is_lazy {
447 let (new_node_id, new_edge_id) = partitioned_graph.insert_intermediate_node(
453 edge_id,
454 GraphNode::Operator(parse_quote! { identity() }),
456 );
457 let hoff = GraphNode::Handoff {
459 src_span: dst_span,
461 dst_span,
462 };
463 let (_hoff_node_id, _hoff_edge_id) =
464 partitioned_graph.insert_intermediate_node(new_edge_id, hoff);
465 let new_subgraph_id = partitioned_graph
470 .insert_subgraph(vec![new_node_id])
471 .unwrap();
472
473 partitioned_graph.set_subgraph_stratum(new_subgraph_id, extra_stratum);
475
476 partitioned_graph.set_subgraph_laziness(new_subgraph_id, is_lazy);
478 }
479 }
480 DelayType::Stratum => {
481 if dst_stratum <= src_stratum {
485 return Err(Diagnostic::spanned(
486 dst_port.span(),
487 Level::Error,
488 "Negative edge creates a negative cycle which must be broken with a `defer_tick()` operator.",
489 ));
490 }
491 }
492 DelayType::MonotoneAccum => {
493 continue;
495 }
496 }
497 }
498 Ok(())
499}
500
501fn separate_external_inputs(partitioned_graph: &mut DfirGraph) {
504 let external_input_nodes: Vec<_> = partitioned_graph
505 .nodes()
506 .filter_map(|(node_id, node)| {
508 find_node_op_constraints(node).map(|op_constraints| (node_id, op_constraints))
509 })
510 .filter(|(_node_id, op_constraints)| op_constraints.is_external_input)
512 .map(|(node_id, _op_constraints)| node_id)
514 .filter(|&node_id| {
516 0 != partitioned_graph
517 .subgraph_stratum(partitioned_graph.node_subgraph(node_id).unwrap())
518 .unwrap()
519 })
520 .collect();
521
522 for node_id in external_input_nodes {
523 assert!(
525 partitioned_graph.remove_from_subgraph(node_id),
526 "Cannot move input node that is not in a subgraph, this is a bug."
527 );
528 let new_sg_id = partitioned_graph.insert_subgraph(vec![node_id]).unwrap();
530 partitioned_graph.set_subgraph_stratum(new_sg_id, 0);
531
532 for edge_id in partitioned_graph
534 .node_successor_edges(node_id)
535 .collect::<Vec<_>>()
536 {
537 let span = partitioned_graph.node(node_id).span();
538 let hoff = GraphNode::Handoff {
539 src_span: span,
540 dst_span: span,
541 };
542 partitioned_graph.insert_intermediate_node(edge_id, hoff);
543 }
544 }
545}
546
547pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
551 let mut barrier_crossers = find_barrier_crossers(&flat_graph);
553 let mut partitioned_graph = flat_graph;
554
555 make_subgraphs(&mut partitioned_graph, &mut barrier_crossers);
557
558 find_subgraph_strata(&mut partitioned_graph, &barrier_crossers)?;
560
561 separate_external_inputs(&mut partitioned_graph);
563
564 Ok(partitioned_graph)
565}