1use std::collections::{BTreeMap, BTreeSet};
4
5use proc_macro2::Span;
6use slotmap::{SecondaryMap, SparseSecondaryMap};
7use syn::parse_quote;
8
9use super::meta_graph::DfirGraph;
10use super::ops::{DelayType, FloType, find_node_op_constraints};
11use super::{Color, GraphEdgeId, GraphNode, GraphNodeId, GraphSubgraphId, graph_algorithms};
12use crate::diagnostic::{Diagnostic, Level};
13use crate::union_find::UnionFind;
14
15struct BarrierCrossers {
17 pub edge_barrier_crossers: SecondaryMap<GraphEdgeId, DelayType>,
19 pub singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)>,
21}
22impl BarrierCrossers {
23 fn iter_node_pairs<'a>(
25 &'a self,
26 partitioned_graph: &'a DfirGraph,
27 ) -> impl 'a + Iterator<Item = ((GraphNodeId, GraphNodeId), DelayType)> {
28 let edge_pairs_iter = self
29 .edge_barrier_crossers
30 .iter()
31 .map(|(edge_id, &delay_type)| {
32 let src_dst = partitioned_graph.edge(edge_id);
33 (src_dst, delay_type)
34 });
35 let singleton_pairs_iter = self
36 .singleton_barrier_crossers
37 .iter()
38 .map(|&src_dst| (src_dst, DelayType::Stratum));
39 edge_pairs_iter.chain(singleton_pairs_iter)
40 }
41
42 fn replace_edge(&mut self, old_edge_id: GraphEdgeId, new_edge_id: GraphEdgeId) {
44 if let Some(delay_type) = self.edge_barrier_crossers.remove(old_edge_id) {
45 self.edge_barrier_crossers.insert(new_edge_id, delay_type);
46 }
47 }
48}
49
50fn find_barrier_crossers(partitioned_graph: &DfirGraph) -> BarrierCrossers {
52 let edge_barrier_crossers = partitioned_graph
53 .edges()
54 .filter(|&(_, (_src, dst))| {
55 partitioned_graph.node_loop(dst).is_none()
57 })
58 .filter_map(|(edge_id, (_src, dst))| {
59 let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
60 let op_constraints = partitioned_graph.node_op_inst(dst)?.op_constraints;
61 let input_barrier = (op_constraints.input_delaytype_fn)(dst_port)?;
62 Some((edge_id, input_barrier))
63 })
64 .collect();
65 let singleton_barrier_crossers = partitioned_graph
66 .node_ids()
67 .flat_map(|dst| {
68 partitioned_graph
69 .node_singleton_references(dst)
70 .iter()
71 .flatten()
72 .map(move |&src_ref| (src_ref, dst))
73 })
74 .collect();
75 BarrierCrossers {
76 edge_barrier_crossers,
77 singleton_barrier_crossers,
78 }
79}
80
81fn find_subgraph_unionfind(
82 partitioned_graph: &DfirGraph,
83 barrier_crossers: &BarrierCrossers,
84) -> (UnionFind<GraphNodeId>, BTreeSet<GraphEdgeId>) {
85 let mut node_color = partitioned_graph
90 .node_ids()
91 .filter_map(|node_id| {
92 let op_color = partitioned_graph.node_color(node_id)?;
93 Some((node_id, op_color))
94 })
95 .collect::<SparseSecondaryMap<_, _>>();
96
97 let mut subgraph_unionfind: UnionFind<GraphNodeId> =
98 UnionFind::with_capacity(partitioned_graph.nodes().len());
99
100 let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
103 let mut progress = true;
112 while progress {
113 progress = false;
114 for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
116 if subgraph_unionfind.same_set(src, dst) {
118 continue;
121 }
122
123 if barrier_crossers
125 .iter_node_pairs(partitioned_graph)
126 .any(|((x_src, x_dst), _)| {
127 (subgraph_unionfind.same_set(x_src, src)
128 && subgraph_unionfind.same_set(x_dst, dst))
129 || (subgraph_unionfind.same_set(x_src, dst)
130 && subgraph_unionfind.same_set(x_dst, src))
131 })
132 {
133 continue;
134 }
135
136 if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
138 continue;
139 }
140 if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
142 Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
143 }) {
144 continue;
145 }
146
147 if can_connect_colorize(&mut node_color, src, dst) {
148 subgraph_unionfind.union(src, dst);
151 assert!(handoff_edges.remove(&edge_id));
152 progress = true;
153 }
154 }
155 }
156
157 (subgraph_unionfind, handoff_edges)
158}
159
160fn make_subgraph_collect(
164 partitioned_graph: &DfirGraph,
165 mut subgraph_unionfind: UnionFind<GraphNodeId>,
166) -> SecondaryMap<GraphNodeId, Vec<GraphNodeId>> {
167 let topo_sort = graph_algorithms::topo_sort(
171 partitioned_graph
172 .nodes()
173 .filter(|&(_, node)| !matches!(node, GraphNode::Handoff { .. }))
174 .map(|(node_id, _)| node_id),
175 |v| {
176 partitioned_graph
177 .node_predecessor_nodes(v)
178 .filter(|&pred_id| {
179 let pred = partitioned_graph.node(pred_id);
180 !matches!(pred, GraphNode::Handoff { .. })
181 })
182 },
183 )
184 .expect("Subgraphs are in-out trees.");
185
186 let mut grouped_nodes: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = Default::default();
187 for node_id in topo_sort {
188 let repr_node = subgraph_unionfind.find(node_id);
189 if !grouped_nodes.contains_key(repr_node) {
190 grouped_nodes.insert(repr_node, Default::default());
191 }
192 grouped_nodes[repr_node].push(node_id);
193 }
194 grouped_nodes
195}
196
197fn make_subgraphs(partitioned_graph: &mut DfirGraph, barrier_crossers: &mut BarrierCrossers) {
201 let (subgraph_unionfind, handoff_edges) =
210 find_subgraph_unionfind(partitioned_graph, barrier_crossers);
211
212 for edge_id in handoff_edges {
214 let (src_id, dst_id) = partitioned_graph.edge(edge_id);
215
216 let src_node = partitioned_graph.node(src_id);
218 let dst_node = partitioned_graph.node(dst_id);
219 if matches!(src_node, GraphNode::Handoff { .. })
220 || matches!(dst_node, GraphNode::Handoff { .. })
221 {
222 continue;
223 }
224
225 let hoff = GraphNode::Handoff {
226 src_span: src_node.span(),
227 dst_span: dst_node.span(),
228 };
229 let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
230
231 barrier_crossers.replace_edge(edge_id, out_edge_id);
233 }
234
235 let grouped_nodes = make_subgraph_collect(partitioned_graph, subgraph_unionfind);
239 for (_repr_node, member_nodes) in grouped_nodes {
240 partitioned_graph.insert_subgraph(member_nodes).unwrap();
241 }
242}
243
244fn can_connect_colorize(
250 node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
251 src: GraphNodeId,
252 dst: GraphNodeId,
253) -> bool {
254 let can_connect = match (node_color.get(src), node_color.get(dst)) {
259 (None, None) => false,
262
263 (None, Some(Color::Pull | Color::Comp)) => {
265 node_color.insert(src, Color::Pull);
266 true
267 }
268 (None, Some(Color::Push | Color::Hoff)) => {
269 node_color.insert(src, Color::Push);
270 true
271 }
272
273 (Some(Color::Pull | Color::Hoff), None) => {
275 node_color.insert(dst, Color::Pull);
276 true
277 }
278 (Some(Color::Comp | Color::Push), None) => {
279 node_color.insert(dst, Color::Push);
280 true
281 }
282
283 (Some(Color::Pull), Some(Color::Pull)) => true,
285 (Some(Color::Pull), Some(Color::Comp)) => true,
286 (Some(Color::Pull), Some(Color::Push)) => true,
287
288 (Some(Color::Comp), Some(Color::Pull)) => false,
289 (Some(Color::Comp), Some(Color::Comp)) => false,
290 (Some(Color::Comp), Some(Color::Push)) => true,
291
292 (Some(Color::Push), Some(Color::Pull)) => false,
293 (Some(Color::Push), Some(Color::Comp)) => false,
294 (Some(Color::Push), Some(Color::Push)) => true,
295
296 (Some(Color::Hoff), Some(_)) => false,
298 (Some(_), Some(Color::Hoff)) => false,
299 };
300 can_connect
301}
302
303fn find_subgraph_strata(
307 partitioned_graph: &mut DfirGraph,
308 barrier_crossers: &BarrierCrossers,
309) -> Result<(), Diagnostic> {
310 #[derive(Default)]
320 struct SubgraphGraph {
321 preds: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>>,
322 succs: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>>,
323 }
324 impl SubgraphGraph {
325 fn insert_edge(&mut self, src: GraphSubgraphId, dst: GraphSubgraphId) {
326 self.preds.entry(dst).or_default().push(src);
327 self.succs.entry(src).or_default().push(dst);
328 }
329 }
330 let mut subgraph_graph = SubgraphGraph::default();
331
332 let mut subgraph_stratum_barriers: BTreeSet<(GraphSubgraphId, GraphSubgraphId)> =
334 Default::default();
335
336 for (node_id, node) in partitioned_graph.nodes() {
338 if matches!(node, GraphNode::Handoff { .. }) {
339 assert_eq!(1, partitioned_graph.node_successors(node_id).count());
340 let (succ_edge, succ) = partitioned_graph.node_successors(node_id).next().unwrap();
341
342 let succ_edge_delaytype = barrier_crossers
344 .edge_barrier_crossers
345 .get(succ_edge)
346 .copied();
347 if let Some(DelayType::Tick | DelayType::TickLazy) = succ_edge_delaytype {
349 continue;
350 }
351
352 assert_eq!(1, partitioned_graph.node_predecessors(node_id).count());
353 let (_edge_id, pred) = partitioned_graph.node_predecessors(node_id).next().unwrap();
354
355 let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
356 let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
357
358 subgraph_graph.insert_edge(pred_sg, succ_sg);
359
360 if Some(DelayType::Stratum) == succ_edge_delaytype {
361 subgraph_stratum_barriers.insert((pred_sg, succ_sg));
362 }
363 }
364 }
365 for &(pred, succ) in barrier_crossers.singleton_barrier_crossers.iter() {
368 assert_ne!(pred, succ, "TODO(mingwei)");
369 let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
370 let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
371 assert_ne!(pred_sg, succ_sg);
372 subgraph_graph.insert_edge(pred_sg, succ_sg);
373 subgraph_stratum_barriers.insert((pred_sg, succ_sg));
374 }
375
376 let topo_sort_order = graph_algorithms::topo_sort_scc(
379 || partitioned_graph.subgraph_ids(),
380 |v| subgraph_graph.preds.get(&v).into_iter().flatten().cloned(),
381 |u| subgraph_graph.succs.get(&u).into_iter().flatten().cloned(),
382 );
383
384 for sg_id in topo_sort_order {
390 let curr_loop = partitioned_graph.subgraph_loop(sg_id);
391
392 let stratum = subgraph_graph
393 .preds
394 .get(&sg_id)
395 .into_iter()
396 .flatten()
397 .filter_map(|&pred_sg_id| {
398 partitioned_graph
399 .subgraph_stratum(pred_sg_id)
400 .map(|stratum| {
401 let pred_loop = partitioned_graph.subgraph_loop(pred_sg_id);
402 if curr_loop != pred_loop {
403 stratum + 1
405 } else if curr_loop.is_none()
406 && subgraph_stratum_barriers.contains(&(pred_sg_id, sg_id))
407 {
408 stratum + 1
410 } else {
411 stratum
412 }
413 })
414 })
415 .max()
416 .unwrap_or(0);
417 partitioned_graph.set_subgraph_stratum(sg_id, stratum);
418 }
419
420 let extra_stratum = partitioned_graph.max_stratum().unwrap_or(0) + 1; for (edge_id, &delay_type) in barrier_crossers.edge_barrier_crossers.iter() {
423 let (hoff, dst) = partitioned_graph.edge(edge_id);
424 if partitioned_graph.node_loop(dst).is_some() {
426 continue;
427 }
428 let (_hoff_port, dst_port) = partitioned_graph.edge_ports(edge_id);
429
430 assert_eq!(1, partitioned_graph.node_predecessors(hoff).count());
431 let src = partitioned_graph
432 .node_predecessor_nodes(hoff)
433 .next()
434 .unwrap();
435
436 let src_sg = partitioned_graph.node_subgraph(src).unwrap();
437 let dst_sg = partitioned_graph.node_subgraph(dst).unwrap();
438 let src_stratum = partitioned_graph.subgraph_stratum(src_sg);
439 let dst_stratum = partitioned_graph.subgraph_stratum(dst_sg);
440 match delay_type {
441 DelayType::Tick | DelayType::TickLazy => {
442 let is_lazy = matches!(delay_type, DelayType::TickLazy);
443 if src_stratum <= dst_stratum || is_lazy {
447 let (new_node_id, new_edge_id) = partitioned_graph.insert_intermediate_node(
453 edge_id,
454 GraphNode::Operator(parse_quote! { identity() }),
456 );
457 let hoff = GraphNode::Handoff {
459 src_span: Span::call_site(), dst_span: Span::call_site(),
461 };
462 let (_hoff_node_id, _hoff_edge_id) =
463 partitioned_graph.insert_intermediate_node(new_edge_id, hoff);
464 let new_subgraph_id = partitioned_graph
469 .insert_subgraph(vec![new_node_id])
470 .unwrap();
471
472 partitioned_graph.set_subgraph_stratum(new_subgraph_id, extra_stratum);
474
475 partitioned_graph.set_subgraph_laziness(new_subgraph_id, is_lazy);
477 }
478 }
479 DelayType::Stratum => {
480 if dst_stratum <= src_stratum {
484 return Err(Diagnostic::spanned(
485 dst_port.span(),
486 Level::Error,
487 "Negative edge creates a negative cycle which must be broken with a `defer_tick()` operator.",
488 ));
489 }
490 }
491 DelayType::MonotoneAccum => {
492 continue;
494 }
495 }
496 }
497 Ok(())
498}
499
500fn separate_external_inputs(partitioned_graph: &mut DfirGraph) {
503 let external_input_nodes: Vec<_> = partitioned_graph
504 .nodes()
505 .filter_map(|(node_id, node)| {
507 find_node_op_constraints(node).map(|op_constraints| (node_id, op_constraints))
508 })
509 .filter(|(_node_id, op_constraints)| op_constraints.is_external_input)
511 .map(|(node_id, _op_constraints)| node_id)
513 .filter(|&node_id| {
515 0 != partitioned_graph
516 .subgraph_stratum(partitioned_graph.node_subgraph(node_id).unwrap())
517 .unwrap()
518 })
519 .collect();
520
521 for node_id in external_input_nodes {
522 assert!(
524 partitioned_graph.remove_from_subgraph(node_id),
525 "Cannot move input node that is not in a subgraph, this is a bug."
526 );
527 let new_sg_id = partitioned_graph.insert_subgraph(vec![node_id]).unwrap();
529 partitioned_graph.set_subgraph_stratum(new_sg_id, 0);
530
531 for edge_id in partitioned_graph
533 .node_successor_edges(node_id)
534 .collect::<Vec<_>>()
535 {
536 let span = partitioned_graph.node(node_id).span();
537 let hoff = GraphNode::Handoff {
538 src_span: span,
539 dst_span: span,
540 };
541 partitioned_graph.insert_intermediate_node(edge_id, hoff);
542 }
543 }
544}
545
546pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
550 let mut barrier_crossers = find_barrier_crossers(&flat_graph);
552 let mut partitioned_graph = flat_graph;
553
554 make_subgraphs(&mut partitioned_graph, &mut barrier_crossers);
556
557 find_subgraph_strata(&mut partitioned_graph, &barrier_crossers)?;
559
560 separate_external_inputs(&mut partitioned_graph);
562
563 Ok(partitioned_graph)
564}