1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet, VecDeque};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, TokenStreamExt, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18 DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19 null_write_iterator_fn,
20};
21use super::{
22 CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23 GraphSubgraphId, HANDOFF_NODE_STR, MODULE_BOUNDARY_NODE_STR, OperatorInstance, PortIndexValue,
24 Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41 nodes: SlotMap<GraphNodeId, GraphNode>,
43
44 #[serde(skip)]
47 operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48 operator_tag: SecondaryMap<GraphNodeId, String>,
50 graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52 ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55 node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57 loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59 loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61 root_loops: Vec<GraphLoopId>,
63 loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66 node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69 subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71 subgraph_stratum: SecondaryMap<GraphSubgraphId, usize>,
73
74 node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
76 node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
78
79 subgraph_laziness: SecondaryMap<GraphSubgraphId, bool>,
83}
84
85impl DfirGraph {
87 pub fn new() -> Self {
89 Default::default()
90 }
91}
92
93impl DfirGraph {
95 pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
97 self.nodes.get(node_id).expect("Node not found.")
98 }
99
100 pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
105 self.operator_instances.get(node_id)
106 }
107
108 pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
110 self.node_varnames.get(node_id)
111 }
112
113 pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
115 self.node_subgraph.get(node_id).copied()
116 }
117
118 pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
120 self.graph.degree_in(node_id)
121 }
122
123 pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
125 self.graph.degree_out(node_id)
126 }
127
128 pub fn node_successors(
130 &self,
131 src: GraphNodeId,
132 ) -> impl '_
133 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
134 + ExactSizeIterator
135 + FusedIterator
136 + Clone
137 + Debug {
138 self.graph.successors(src)
139 }
140
141 pub fn node_predecessors(
143 &self,
144 dst: GraphNodeId,
145 ) -> impl '_
146 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
147 + ExactSizeIterator
148 + FusedIterator
149 + Clone
150 + Debug {
151 self.graph.predecessors(dst)
152 }
153
154 pub fn node_successor_edges(
156 &self,
157 src: GraphNodeId,
158 ) -> impl '_
159 + DoubleEndedIterator<Item = GraphEdgeId>
160 + ExactSizeIterator
161 + FusedIterator
162 + Clone
163 + Debug {
164 self.graph.successor_edges(src)
165 }
166
167 pub fn node_predecessor_edges(
169 &self,
170 dst: GraphNodeId,
171 ) -> impl '_
172 + DoubleEndedIterator<Item = GraphEdgeId>
173 + ExactSizeIterator
174 + FusedIterator
175 + Clone
176 + Debug {
177 self.graph.predecessor_edges(dst)
178 }
179
180 pub fn node_successor_nodes(
182 &self,
183 src: GraphNodeId,
184 ) -> impl '_
185 + DoubleEndedIterator<Item = GraphNodeId>
186 + ExactSizeIterator
187 + FusedIterator
188 + Clone
189 + Debug {
190 self.graph.successor_vertices(src)
191 }
192
193 pub fn node_predecessor_nodes(
195 &self,
196 dst: GraphNodeId,
197 ) -> impl '_
198 + DoubleEndedIterator<Item = GraphNodeId>
199 + ExactSizeIterator
200 + FusedIterator
201 + Clone
202 + Debug {
203 self.graph.predecessor_vertices(dst)
204 }
205
206 pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
208 self.nodes.keys()
209 }
210
211 pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
213 self.nodes.iter()
214 }
215
216 pub fn insert_node(
218 &mut self,
219 node: GraphNode,
220 varname_opt: Option<Ident>,
221 loop_opt: Option<GraphLoopId>,
222 ) -> GraphNodeId {
223 let node_id = self.nodes.insert(node);
224 if let Some(varname) = varname_opt {
225 self.node_varnames.insert(node_id, Varname(varname));
226 }
227 if let Some(loop_id) = loop_opt {
228 self.node_loops.insert(node_id, loop_id);
229 self.loop_nodes[loop_id].push(node_id);
230 }
231 node_id
232 }
233
234 pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
236 assert!(matches!(
237 self.nodes.get(node_id),
238 Some(GraphNode::Operator(_))
239 ));
240 let old_inst = self.operator_instances.insert(node_id, op_inst);
241 assert!(old_inst.is_none());
242 }
243
244 pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Vec<Diagnostic>) {
246 let mut op_insts = Vec::new();
247 for (node_id, node) in self.nodes() {
248 let GraphNode::Operator(operator) = node else {
249 continue;
250 };
251 if self.node_op_inst(node_id).is_some() {
252 continue;
253 };
254
255 let Some(op_constraints) = find_op_op_constraints(operator) else {
257 diagnostics.push(Diagnostic::spanned(
258 operator.path.span(),
259 Level::Error,
260 format!("Unknown operator `{}`", operator.name_string()),
261 ));
262 continue;
263 };
264
265 let (input_ports, output_ports) = {
267 let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
268 .node_predecessors(node_id)
269 .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
270 .collect();
271 input_edges.sort();
273 let input_ports: Vec<PortIndexValue> = input_edges
274 .into_iter()
275 .map(|(port, _pred)| port)
276 .cloned()
277 .collect();
278
279 let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
281 .node_successors(node_id)
282 .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
283 .collect();
284 output_edges.sort();
286 let output_ports: Vec<PortIndexValue> = output_edges
287 .into_iter()
288 .map(|(port, _succ)| port)
289 .cloned()
290 .collect();
291
292 (input_ports, output_ports)
293 };
294
295 let generics = get_operator_generics(diagnostics, operator);
297 {
299 let generics_span = generics
301 .generic_args
302 .as_ref()
303 .map(Spanned::span)
304 .unwrap_or_else(|| operator.path.span());
305
306 if !op_constraints
307 .persistence_args
308 .contains(&generics.persistence_args.len())
309 {
310 diagnostics.push(Diagnostic::spanned(
311 generics_span,
312 Level::Error,
313 format!(
314 "`{}` should have {} persistence lifetime arguments, actually has {}.",
315 op_constraints.name,
316 op_constraints.persistence_args.human_string(),
317 generics.persistence_args.len()
318 ),
319 ));
320 }
321 if !op_constraints.type_args.contains(&generics.type_args.len()) {
322 diagnostics.push(Diagnostic::spanned(
323 generics_span,
324 Level::Error,
325 format!(
326 "`{}` should have {} generic type arguments, actually has {}.",
327 op_constraints.name,
328 op_constraints.type_args.human_string(),
329 generics.type_args.len()
330 ),
331 ));
332 }
333 }
334
335 op_insts.push((
336 node_id,
337 OperatorInstance {
338 op_constraints,
339 input_ports,
340 output_ports,
341 singletons_referenced: operator.singletons_referenced.clone(),
342 generics,
343 arguments_pre: operator.args.clone(),
344 arguments_raw: operator.args_raw.clone(),
345 },
346 ));
347 }
348
349 for (node_id, op_inst) in op_insts {
350 self.insert_node_op_inst(node_id, op_inst);
351 }
352 }
353
354 pub fn insert_intermediate_node(
366 &mut self,
367 edge_id: GraphEdgeId,
368 new_node: GraphNode,
369 ) -> (GraphNodeId, GraphEdgeId) {
370 let span = Some(new_node.span());
371
372 let op_inst_opt = 'oc: {
374 let GraphNode::Operator(operator) = &new_node else {
375 break 'oc None;
376 };
377 let Some(op_constraints) = find_op_op_constraints(operator) else {
378 break 'oc None;
379 };
380 let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
381 let generics = get_operator_generics(
382 &mut Vec::new(), operator,
384 );
385 Some(OperatorInstance {
386 op_constraints,
387 input_ports: vec![input_port],
388 output_ports: vec![output_port],
389 singletons_referenced: operator.singletons_referenced.clone(),
390 generics,
391 arguments_pre: operator.args.clone(),
392 arguments_raw: operator.args_raw.clone(),
393 })
394 };
395
396 let node_id = self.nodes.insert(new_node);
398 if let Some(op_inst) = op_inst_opt {
400 self.operator_instances.insert(node_id, op_inst);
401 }
402 let (e0, e1) = self
404 .graph
405 .insert_intermediate_vertex(node_id, edge_id)
406 .unwrap();
407
408 let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
410 self.ports
411 .insert(e0, (src_idx, PortIndexValue::Elided(span)));
412 self.ports
413 .insert(e1, (PortIndexValue::Elided(span), dst_idx));
414
415 (node_id, e1)
416 }
417
418 pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
421 assert_eq!(
422 1,
423 self.node_degree_in(node_id),
424 "Removed intermediate node must have one predecessor"
425 );
426 assert_eq!(
427 1,
428 self.node_degree_out(node_id),
429 "Removed intermediate node must have one successor"
430 );
431 assert!(
432 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
433 "Should not remove intermediate node after subgraph partitioning"
434 );
435
436 assert!(self.nodes.remove(node_id).is_some());
437 let (new_edge_id, (pred_edge_id, succ_edge_id)) =
438 self.graph.remove_intermediate_vertex(node_id).unwrap();
439 self.operator_instances.remove(node_id);
440 self.node_varnames.remove(node_id);
441
442 let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
443 let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
444 self.ports.insert(new_edge_id, (src_port, dst_port));
445 }
446
447 pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
453 if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
454 return Some(Color::Hoff);
455 }
456
457 if let GraphNode::Operator(op) = self.node(node_id)
459 && (op.name_string() == "resolve_futures_blocking"
460 || op.name_string() == "resolve_futures_blocking_ordered")
461 {
462 return Some(Color::Push);
463 }
464
465 let inn_degree = self.node_predecessor_nodes(node_id).count();
467 let out_degree = self.node_successor_nodes(node_id).count();
469
470 match (inn_degree, out_degree) {
471 (0, 0) => None, (0, 1) => Some(Color::Pull),
473 (1, 0) => Some(Color::Push),
474 (1, 1) => None, (_many, 0 | 1) => Some(Color::Pull),
476 (0 | 1, _many) => Some(Color::Push),
477 (_many, _to_many) => Some(Color::Comp),
478 }
479 }
480
481 pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
483 self.operator_tag.insert(node_id, tag.to_owned());
484 }
485}
486
487impl DfirGraph {
489 pub fn set_node_singleton_references(
492 &mut self,
493 node_id: GraphNodeId,
494 singletons_referenced: Vec<Option<GraphNodeId>>,
495 ) -> Option<Vec<Option<GraphNodeId>>> {
496 self.node_singleton_references
497 .insert(node_id, singletons_referenced)
498 }
499
500 pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
503 self.node_singleton_references
504 .get(node_id)
505 .map(std::ops::Deref::deref)
506 .unwrap_or_default()
507 }
508}
509
510impl DfirGraph {
512 pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
520 let mod_bound_nodes = self
521 .nodes()
522 .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
523 .map(|(nid, _node)| nid)
524 .collect::<Vec<_>>();
525
526 for mod_bound_node in mod_bound_nodes {
527 self.remove_module_boundary(mod_bound_node)?;
528 }
529
530 Ok(())
531 }
532
533 fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
537 assert!(
538 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
539 "Should not remove intermediate node after subgraph partitioning"
540 );
541
542 let mut mod_pred_ports = BTreeMap::new();
543 let mut mod_succ_ports = BTreeMap::new();
544
545 for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
546 let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
547 mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
548 }
549
550 for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
551 let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
552 mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
553 }
554
555 if mod_pred_ports.keys().collect::<BTreeSet<_>>()
556 != mod_succ_ports.keys().collect::<BTreeSet<_>>()
557 {
558 let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
560 panic!();
561 };
562
563 if *input {
564 return Err(Diagnostic {
565 span: *import_expr,
566 level: Level::Error,
567 message: format!(
568 "The ports into the module did not match. input: {:?}, expected: {:?}",
569 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
570 mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
571 ),
572 });
573 } else {
574 return Err(Diagnostic {
575 span: *import_expr,
576 level: Level::Error,
577 message: format!(
578 "The ports out of the module did not match. output: {:?}, expected: {:?}",
579 mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
580 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
581 ),
582 });
583 }
584 }
585
586 for (port, (pred_edge, pred_port)) in mod_pred_ports {
587 let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
588
589 let (src, _) = self.edge(pred_edge);
590 let (_, dst) = self.edge(succ_edge);
591 self.remove_edge(pred_edge);
592 self.remove_edge(succ_edge);
593
594 let new_edge_id = self.graph.insert_edge(src, dst);
595 self.ports.insert(new_edge_id, (pred_port, succ_port));
596 }
597
598 self.graph.remove_vertex(mod_bound_node);
599 self.nodes.remove(mod_bound_node);
600
601 Ok(())
602 }
603}
604
605impl DfirGraph {
607 pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
609 let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
610 (src, dst)
611 }
612
613 pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
615 let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
616 (src_port, dst_port)
617 }
618
619 pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
621 self.graph.edge_ids()
622 }
623
624 pub fn edges(
626 &self,
627 ) -> impl '_
628 + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
629 + FusedIterator
630 + Clone
631 + Debug {
632 self.graph.edges()
633 }
634
635 pub fn insert_edge(
637 &mut self,
638 src: GraphNodeId,
639 src_port: PortIndexValue,
640 dst: GraphNodeId,
641 dst_port: PortIndexValue,
642 ) -> GraphEdgeId {
643 let edge_id = self.graph.insert_edge(src, dst);
644 self.ports.insert(edge_id, (src_port, dst_port));
645 edge_id
646 }
647
648 pub fn remove_edge(&mut self, edge: GraphEdgeId) {
650 let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
651 let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
652 }
653}
654
655impl DfirGraph {
657 pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
659 self.subgraph_nodes
660 .get(subgraph_id)
661 .expect("Subgraph not found.")
662 }
663
664 pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
666 self.subgraph_nodes.keys()
667 }
668
669 pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
671 self.subgraph_nodes.iter()
672 }
673
674 pub fn insert_subgraph(
676 &mut self,
677 node_ids: Vec<GraphNodeId>,
678 ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
679 for &node_id in node_ids.iter() {
681 if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
682 return Err((node_id, old_sg_id));
683 }
684 }
685 let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
686 for &node_id in node_ids.iter() {
687 self.node_subgraph.insert(node_id, sg_id);
688 }
689 node_ids
690 });
691
692 Ok(subgraph_id)
693 }
694
695 pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
697 if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
698 self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
699 true
700 } else {
701 false
702 }
703 }
704
705 pub fn subgraph_stratum(&self, sg_id: GraphSubgraphId) -> Option<usize> {
707 self.subgraph_stratum.get(sg_id).copied()
708 }
709
710 pub fn set_subgraph_stratum(
712 &mut self,
713 sg_id: GraphSubgraphId,
714 stratum: usize,
715 ) -> Option<usize> {
716 self.subgraph_stratum.insert(sg_id, stratum)
717 }
718
719 fn subgraph_laziness(&self, sg_id: GraphSubgraphId) -> bool {
721 self.subgraph_laziness.get(sg_id).copied().unwrap_or(false)
722 }
723
724 pub fn set_subgraph_laziness(&mut self, sg_id: GraphSubgraphId, lazy: bool) -> bool {
726 self.subgraph_laziness.insert(sg_id, lazy).unwrap_or(false)
727 }
728
729 pub fn max_stratum(&self) -> Option<usize> {
731 self.subgraph_stratum.values().copied().max()
732 }
733
734 fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
736 subgraph_nodes
737 .iter()
738 .position(|&node_id| {
739 self.node_color(node_id)
740 .is_some_and(|color| Color::Pull != color)
741 })
742 .unwrap_or(subgraph_nodes.len())
743 }
744}
745
746impl DfirGraph {
748 fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
750 let name = match &self.nodes[node_id] {
751 GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
752 GraphNode::Handoff { .. } => format!(
753 "hoff_{:?}_{}",
754 node_id.data(),
755 if is_pred { "recv" } else { "send" }
756 ),
757 GraphNode::ModuleBoundary { .. } => panic!(),
758 };
759 let span = match (is_pred, &self.nodes[node_id]) {
760 (_, GraphNode::Operator(operator)) => operator.span(),
761 (true, &GraphNode::Handoff { src_span, .. }) => src_span,
762 (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
763 (_, GraphNode::ModuleBoundary { .. }) => panic!(),
764 };
765 Ident::new(&name, span)
766 }
767
768 fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
770 Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
771 }
772
773 fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<Ident> {
775 self.node_singleton_references(node_id)
776 .iter()
777 .map(|singleton_node_id| {
778 self.node_as_singleton_ident(
780 singleton_node_id
781 .expect("Expected singleton to be resolved but was not, this is a bug."),
782 span,
783 )
784 })
785 .collect::<Vec<_>>()
786 }
787
788 fn helper_collect_subgraph_handoffs(
791 &self,
792 ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
793 let mut subgraph_handoffs: SecondaryMap<
795 GraphSubgraphId,
796 (Vec<GraphNodeId>, Vec<GraphNodeId>),
797 > = self
798 .subgraph_nodes
799 .keys()
800 .map(|k| (k, Default::default()))
801 .collect();
802
803 for (hoff_id, node) in self.nodes() {
805 if !matches!(node, GraphNode::Handoff { .. }) {
806 continue;
807 }
808 for (_edge, succ_id) in self.node_successors(hoff_id) {
810 let succ_sg = self.node_subgraph(succ_id).unwrap();
811 subgraph_handoffs[succ_sg].0.push(hoff_id);
812 }
813 for (_edge, pred_id) in self.node_predecessors(hoff_id) {
815 let pred_sg = self.node_subgraph(pred_id).unwrap();
816 subgraph_handoffs[pred_sg].1.push(hoff_id);
817 }
818 }
819
820 subgraph_handoffs
821 }
822
823 fn codegen_nested_loops(&self, df: &Ident) -> TokenStream {
825 let mut out = TokenStream::new();
827 let mut queue = VecDeque::from_iter(self.root_loops.iter().copied());
828 while let Some(loop_id) = queue.pop_front() {
829 let parent_opt = self
830 .loop_parent(loop_id)
831 .map(|loop_id| loop_id.as_ident(Span::call_site()))
832 .map(|ident| quote! { Some(#ident) })
833 .unwrap_or_else(|| quote! { None });
834 let loop_name = loop_id.as_ident(Span::call_site());
835 out.append_all(quote! {
836 let #loop_name = #df.add_loop(#parent_opt);
837 });
838 queue.extend(self.loop_children.get(loop_id).into_iter().flatten());
839 }
840 out
841 }
842
843 pub fn as_code(
845 &self,
846 root: &TokenStream,
847 include_type_guards: bool,
848 prefix: TokenStream,
849 diagnostics: &mut Vec<Diagnostic>,
850 ) -> TokenStream {
851 let df = Ident::new(GRAPH, Span::call_site());
852 let context = Ident::new(CONTEXT, Span::call_site());
853
854 let handoff_code = self
856 .nodes
857 .iter()
858 .filter_map(|(node_id, node)| match node {
859 GraphNode::Operator(_) => None,
860 &GraphNode::Handoff { src_span, dst_span } => Some((node_id, (src_span, dst_span))),
861 GraphNode::ModuleBoundary { .. } => panic!(),
862 })
863 .map(|(node_id, (src_span, dst_span))| {
864 let ident_send = Ident::new(&format!("hoff_{:?}_send", node_id.data()), dst_span);
865 let ident_recv = Ident::new(&format!("hoff_{:?}_recv", node_id.data()), src_span);
866 let span = src_span.join(dst_span).unwrap_or(src_span);
867 let mut hoff_name = Literal::string(&format!("handoff {:?}", node_id));
868 hoff_name.set_span(span);
869 let hoff_type = quote_spanned! (span=> #root::scheduled::handoff::VecHandoff<_>);
870 quote_spanned! {span=>
871 let (#ident_send, #ident_recv) =
872 #df.make_edge::<_, #hoff_type>(#hoff_name);
873 }
874 });
875
876 let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
877
878 let (subgraphs_without_preds, subgraphs_with_preds) = self
880 .subgraph_nodes
881 .iter()
882 .partition::<Vec<_>, _>(|(_, nodes)| {
883 nodes
884 .iter()
885 .any(|&node_id| self.node_degree_in(node_id) == 0)
886 });
887
888 let mut op_prologue_code = Vec::new();
889 let mut op_prologue_after_code = Vec::new();
890 let mut subgraphs = Vec::new();
891 {
892 for &(subgraph_id, subgraph_nodes) in subgraphs_without_preds
893 .iter()
894 .chain(subgraphs_with_preds.iter())
895 {
896 let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
897 let recv_ports: Vec<Ident> = recv_hoffs
898 .iter()
899 .map(|&hoff_id| self.node_as_ident(hoff_id, true))
900 .collect();
901 let send_ports: Vec<Ident> = send_hoffs
902 .iter()
903 .map(|&hoff_id| self.node_as_ident(hoff_id, false))
904 .collect();
905
906 let recv_port_code = recv_ports.iter().map(|ident| {
907 quote_spanned! {ident.span()=>
908 let mut #ident = #ident.borrow_mut_swap();
909 let #ident = #root::futures::stream::iter(#ident.drain(..));
910 }
911 });
912 let send_port_code = send_ports.iter().map(|ident| {
913 quote_spanned! {ident.span()=>
914 let #ident = #root::sinktools::for_each(|v| {
915 #ident.give(Some(v));
916 });
917 }
918 });
919
920 let loop_id = self
921 .node_loop(subgraph_nodes[0]);
923
924 let mut subgraph_op_iter_code = Vec::new();
925 let mut subgraph_op_iter_after_code = Vec::new();
926 {
927 let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
928
929 let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
930 let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
931
932 for (idx, &node_id) in nodes_iter.enumerate() {
933 let node = &self.nodes[node_id];
934 assert!(
935 matches!(node, GraphNode::Operator(_)),
936 "Handoffs are not part of subgraphs."
937 );
938 let op_inst = &self.operator_instances[node_id];
939
940 let op_span = node.span();
941 let op_name = op_inst.op_constraints.name;
942 let root = change_spans(root.clone(), op_span);
944 let op_constraints = OPERATORS
946 .iter()
947 .find(|op| op_name == op.name)
948 .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
949
950 let ident = self.node_as_ident(node_id, false);
951
952 {
953 let mut input_edges = self
956 .graph
957 .predecessor_edges(node_id)
958 .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
959 .collect::<Vec<_>>();
960 input_edges.sort();
962
963 let inputs = input_edges
964 .iter()
965 .map(|&(_port, edge_id)| {
966 let (pred, _) = self.edge(edge_id);
967 self.node_as_ident(pred, true)
968 })
969 .collect::<Vec<_>>();
970
971 let mut output_edges = self
973 .graph
974 .successor_edges(node_id)
975 .map(|edge_id| (&self.ports[edge_id].0, edge_id))
976 .collect::<Vec<_>>();
977 output_edges.sort();
979
980 let outputs = output_edges
981 .iter()
982 .map(|&(_port, edge_id)| {
983 let (_, succ) = self.edge(edge_id);
984 self.node_as_ident(succ, false)
985 })
986 .collect::<Vec<_>>();
987
988 let is_pull = idx < pull_to_push_idx;
989
990 let singleton_output_ident = &if op_constraints.has_singleton_output {
991 self.node_as_singleton_ident(node_id, op_span)
992 } else {
993 Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
995 };
996
997 let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1006 let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1007
1008 let singletons_resolved =
1009 self.helper_resolve_singletons(node_id, op_span);
1010 let arguments = &process_singletons::postprocess_singletons(
1011 op_inst.arguments_raw.clone(),
1012 singletons_resolved.clone(),
1013 context,
1014 );
1015 let arguments_handles =
1016 &process_singletons::postprocess_singletons_handles(
1017 op_inst.arguments_raw.clone(),
1018 singletons_resolved.clone(),
1019 );
1020
1021 let source_tag = 'a: {
1022 if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1023 break 'a tag;
1024 }
1025
1026 #[cfg(nightly)]
1027 if proc_macro::is_available() {
1028 let op_span = op_span.unwrap();
1029 break 'a format!(
1030 "loc_{}_{}_{}_{}_{}",
1031 crate::pretty_span::make_source_path_relative(
1032 &op_span.file()
1033 )
1034 .display()
1035 .to_string()
1036 .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1037 op_span.start().line(),
1038 op_span.start().column(),
1039 op_span.end().line(),
1040 op_span.end().column(),
1041 );
1042 }
1043
1044 format!(
1045 "loc_nopath_{}_{}_{}_{}",
1046 op_span.start().line,
1047 op_span.start().column,
1048 op_span.end().line,
1049 op_span.end().column
1050 )
1051 };
1052
1053 let work_fn = format_ident!(
1054 "{}__{}__{}",
1055 ident,
1056 op_name,
1057 source_tag,
1058 span = op_span
1059 );
1060 let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1061
1062 let context_args = WriteContextArgs {
1063 root: &root,
1064 df_ident: df_local,
1065 context,
1066 subgraph_id,
1067 node_id,
1068 loop_id,
1069 op_span,
1070 op_tag: self.operator_tag.get(node_id).cloned(),
1071 work_fn: &work_fn,
1072 work_fn_async: &work_fn_async,
1073 ident: &ident,
1074 is_pull,
1075 inputs: &inputs,
1076 outputs: &outputs,
1077 singleton_output_ident,
1078 op_name,
1079 op_inst,
1080 arguments,
1081 arguments_handles,
1082 };
1083
1084 let write_result =
1085 (op_constraints.write_fn)(&context_args, diagnostics);
1086 let OperatorWriteOutput {
1087 write_prologue,
1088 write_prologue_after,
1089 write_iterator,
1090 write_iterator_after,
1091 } = write_result.unwrap_or_else(|()| {
1092 assert!(
1093 diagnostics.iter().any(Diagnostic::is_error),
1094 "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1095 op_name,
1096 );
1097 OperatorWriteOutput { write_iterator: null_write_iterator_fn(&context_args), ..Default::default() }
1098 });
1099
1100 op_prologue_code.push(syn::parse_quote! {
1101 #[allow(non_snake_case)]
1102 #[inline(always)]
1103 fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1104 thunk()
1105 }
1106
1107 #[allow(non_snake_case)]
1108 #[inline(always)]
1109 async fn #work_fn_async<T>(thunk: impl ::std::future::Future<Output = T>) -> T {
1110 thunk.await
1111 }
1112 });
1113 op_prologue_code.push(write_prologue);
1114 op_prologue_after_code.push(write_prologue_after);
1115 subgraph_op_iter_code.push(write_iterator);
1116
1117 if include_type_guards {
1118 let type_guard = if is_pull {
1119 quote_spanned! {op_span=>
1120 let #ident = {
1121 #[allow(non_snake_case)]
1122 #[inline(always)]
1123 pub fn #work_fn<Item, Input: #root::futures::stream::Stream<Item = Item>>(input: Input) -> impl #root::futures::stream::Stream<Item = Item> {
1124 #root::pin_project_lite::pin_project! {
1125 #[repr(transparent)]
1126 struct Pull<Item, Input: #root::futures::stream::Stream<Item = Item>> {
1127 #[pin]
1128 inner: Input
1129 }
1130 }
1131
1132 impl<Item, Input> #root::futures::stream::Stream for Pull<Item, Input>
1133 where
1134 Input: #root::futures::stream::Stream<Item = Item>,
1135 {
1136 type Item = Item;
1137
1138 #[inline(always)]
1139 fn poll_next(
1140 self: ::std::pin::Pin<&mut Self>,
1141 cx: &mut ::std::task::Context<'_>,
1142 ) -> ::std::task::Poll<::std::option::Option<Self::Item>> {
1143 #root::futures::stream::Stream::poll_next(self.project().inner, cx)
1144 }
1145
1146 #[inline(always)]
1147 fn size_hint(&self) -> (usize, Option<usize>) {
1148 #root::futures::stream::Stream::size_hint(&self.inner)
1149 }
1150 }
1151
1152 Pull {
1153 inner: input
1154 }
1155 }
1156 #work_fn( #ident )
1157 };
1158 }
1159 } else {
1160 quote_spanned! {op_span=>
1161 let #ident = {
1162 #[allow(non_snake_case)]
1163 #[inline(always)]
1164 pub fn #work_fn<Item, Si>(si: Si) -> impl #root::futures::sink::Sink<Item, Error = #root::Never>
1165 where
1166 Si: #root::futures::sink::Sink<Item, Error = #root::Never>
1167 {
1168 #root::pin_project_lite::pin_project! {
1169 #[repr(transparent)]
1170 struct Push<Si> {
1171 #[pin]
1172 si: Si,
1173 }
1174 }
1175
1176 impl<Item, Si> #root::futures::sink::Sink<Item> for Push<Si>
1177 where
1178 Si: #root::futures::sink::Sink<Item>,
1179 {
1180 type Error = Si::Error;
1181
1182 fn poll_ready(
1183 self: ::std::pin::Pin<&mut Self>,
1184 cx: &mut ::std::task::Context<'_>,
1185 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1186 self.project().si.poll_ready(cx)
1187 }
1188
1189 fn start_send(
1190 self: ::std::pin::Pin<&mut Self>,
1191 item: Item,
1192 ) -> ::std::result::Result<(), Self::Error> {
1193 self.project().si.start_send(item)
1194 }
1195
1196 fn poll_flush(
1197 self: ::std::pin::Pin<&mut Self>,
1198 cx: &mut ::std::task::Context<'_>,
1199 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1200 self.project().si.poll_flush(cx)
1201 }
1202
1203 fn poll_close(
1204 self: ::std::pin::Pin<&mut Self>,
1205 cx: &mut ::std::task::Context<'_>,
1206 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1207 self.project().si.poll_close(cx)
1208 }
1209 }
1210
1211 Push {
1212 si
1213 }
1214 }
1215 #work_fn( #ident )
1216 };
1217 }
1218 };
1219 subgraph_op_iter_code.push(type_guard);
1220 }
1221 subgraph_op_iter_after_code.push(write_iterator_after);
1222 }
1223 }
1224
1225 {
1226 let pull_ident = if 0 < pull_to_push_idx {
1228 self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1229 } else {
1230 recv_ports[0].clone()
1232 };
1233
1234 #[rustfmt::skip]
1235 let push_ident = if let Some(&node_id) =
1236 subgraph_nodes.get(pull_to_push_idx)
1237 {
1238 self.node_as_ident(node_id, false)
1239 } else if 1 == send_ports.len() {
1240 send_ports[0].clone()
1242 } else {
1243 diagnostics.push(Diagnostic::spanned(
1244 pull_ident.span(),
1245 Level::Error,
1246 "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1247 ));
1248 continue;
1249 };
1250
1251 let pivot_span = pull_ident
1253 .span()
1254 .join(push_ident.span())
1255 .unwrap_or_else(|| push_ident.span());
1256 let pivot_fn_ident =
1257 Ident::new(&format!("pivot_run_sg_{:?}", subgraph_id.0), pivot_span);
1258 let root = change_spans(root.clone(), pivot_span);
1259 subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1260 #[inline(always)]
1261 fn #pivot_fn_ident<Pull, Push, Item>(pull: Pull, push: Push)
1262 -> impl ::std::future::Future<Output = ::std::result::Result<(), #root::Never>>
1263 where
1264 Pull: #root::futures::stream::Stream<Item = Item>,
1265 Push: #root::futures::sink::Sink<Item, Error = #root::Never>,
1266 {
1267 #root::sinktools::send_stream(pull, push)
1268 }
1269 (#pivot_fn_ident)(#pull_ident, #push_ident).await.unwrap();
1270 });
1271 }
1272 };
1273
1274 let subgraph_name = Literal::string(&format!("Subgraph {:?}", subgraph_id));
1275 let stratum = Literal::usize_unsuffixed(
1276 self.subgraph_stratum.get(subgraph_id).cloned().unwrap_or(0),
1277 );
1278 let laziness = self.subgraph_laziness(subgraph_id);
1279
1280 let loop_id_opt = loop_id
1282 .map(|loop_id| loop_id.as_ident(Span::call_site()))
1283 .map(|ident| quote! { Some(#ident) })
1284 .unwrap_or_else(|| quote! { None });
1285
1286 let sg_ident = subgraph_id.as_ident(Span::call_site());
1287
1288 subgraphs.push(quote! {
1289 let #sg_ident = #df.add_subgraph_full(
1290 #subgraph_name,
1291 #stratum,
1292 var_expr!( #( #recv_ports ),* ),
1293 var_expr!( #( #send_ports ),* ),
1294 #laziness,
1295 #loop_id_opt,
1296 async move |#context, var_args!( #( #recv_ports ),* ), var_args!( #( #send_ports ),* )| {
1297 #( #recv_port_code )*
1298 #( #send_port_code )*
1299 #( #subgraph_op_iter_code )*
1300 #( #subgraph_op_iter_after_code )*
1301 },
1302 );
1303 });
1304 }
1305 }
1306
1307 let loop_code = self.codegen_nested_loops(&df);
1308
1309 let code = quote! {
1314 #( #handoff_code )*
1315 #loop_code
1316 #( #op_prologue_code )*
1317 #( #subgraphs )*
1318 #( #op_prologue_after_code )*
1319 };
1320
1321 let meta_graph_json = serde_json::to_string(&self).unwrap();
1322 let meta_graph_json = Literal::string(&meta_graph_json);
1323
1324 let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1325 let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1326 let diagnostics_json = Literal::string(&diagnostics_json);
1327
1328 quote! {
1329 {
1330 #[allow(unused_qualifications, clippy::await_holding_refcell_ref)]
1331 {
1332 #prefix
1333
1334 use #root::{var_expr, var_args};
1335
1336 let mut #df = #root::scheduled::graph::Dfir::new();
1337 #df.__assign_meta_graph(#meta_graph_json);
1338 #df.__assign_diagnostics(#diagnostics_json);
1339
1340 #code
1341
1342 #df
1343 }
1344 }
1345 }
1346 }
1347
1348 pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1351 let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1352 .node_ids()
1353 .filter_map(|node_id| {
1354 let op_color = self.node_color(node_id)?;
1355 Some((node_id, op_color))
1356 })
1357 .collect();
1358
1359 for sg_nodes in self.subgraph_nodes.values() {
1361 let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1362
1363 for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1364 let is_pull = idx < pull_to_push_idx;
1365 node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1366 }
1367 }
1368
1369 node_color_map
1370 }
1371
1372 pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1374 let mut output = String::new();
1375 self.write_mermaid(&mut output, write_config).unwrap();
1376 output
1377 }
1378
1379 pub fn write_mermaid(
1381 &self,
1382 output: impl std::fmt::Write,
1383 write_config: &WriteConfig,
1384 ) -> std::fmt::Result {
1385 let mut graph_write = Mermaid::new(output);
1386 self.write_graph(&mut graph_write, write_config)
1387 }
1388
1389 pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1391 let mut output = String::new();
1392 let mut graph_write = Dot::new(&mut output);
1393 self.write_graph(&mut graph_write, write_config).unwrap();
1394 output
1395 }
1396
1397 pub fn write_dot(
1399 &self,
1400 output: impl std::fmt::Write,
1401 write_config: &WriteConfig,
1402 ) -> std::fmt::Result {
1403 let mut graph_write = Dot::new(output);
1404 self.write_graph(&mut graph_write, write_config)
1405 }
1406
1407 pub(crate) fn write_graph<W>(
1409 &self,
1410 mut graph_write: W,
1411 write_config: &WriteConfig,
1412 ) -> Result<(), W::Err>
1413 where
1414 W: GraphWrite,
1415 {
1416 fn helper_edge_label(
1417 src_port: &PortIndexValue,
1418 dst_port: &PortIndexValue,
1419 ) -> Option<String> {
1420 let src_label = match src_port {
1421 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1422 PortIndexValue::Int(index) => Some(index.value.to_string()),
1423 _ => None,
1424 };
1425 let dst_label = match dst_port {
1426 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1427 PortIndexValue::Int(index) => Some(index.value.to_string()),
1428 _ => None,
1429 };
1430 let label = match (src_label, dst_label) {
1431 (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1432 (Some(l1), None) => Some(l1),
1433 (None, Some(l2)) => Some(l2),
1434 (None, None) => None,
1435 };
1436 label
1437 }
1438
1439 let node_color_map = self.node_color_map();
1441
1442 graph_write.write_prologue()?;
1444
1445 let mut skipped_handoffs = BTreeSet::new();
1447 let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1448 for (node_id, node) in self.nodes() {
1449 if matches!(node, GraphNode::Handoff { .. }) {
1450 if write_config.no_handoffs {
1451 skipped_handoffs.insert(node_id);
1452 continue;
1453 } else {
1454 let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1455 let pred_sg = self.node_subgraph(pred_node);
1456 let succ_node = self.node_successor_nodes(node_id).next().unwrap();
1457 let succ_sg = self.node_subgraph(succ_node);
1458 if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1459 && pred_sg == succ_sg
1460 {
1461 subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1462 }
1463 }
1464 }
1465 graph_write.write_node_definition(
1466 node_id,
1467 &if write_config.op_short_text {
1468 node.to_name_string()
1469 } else if write_config.op_text_no_imports {
1470 let full_text = node.to_pretty_string();
1472 let mut output = String::new();
1473 for sentence in full_text.split('\n') {
1474 if sentence.trim().starts_with("use") {
1475 continue;
1476 }
1477 output.push('\n');
1478 output.push_str(sentence);
1479 }
1480 output.into()
1481 } else {
1482 node.to_pretty_string()
1483 },
1484 if write_config.no_pull_push {
1485 None
1486 } else {
1487 node_color_map.get(node_id).copied()
1488 },
1489 )?;
1490 }
1491
1492 for (edge_id, (src_id, mut dst_id)) in self.edges() {
1494 if skipped_handoffs.contains(&src_id) {
1496 continue;
1497 }
1498
1499 let (src_port, mut dst_port) = self.edge_ports(edge_id);
1500 if skipped_handoffs.contains(&dst_id) {
1501 let mut handoff_succs = self.node_successors(dst_id);
1502 assert_eq!(1, handoff_succs.len());
1503 let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1504 dst_id = succ_node;
1505 dst_port = self.edge_ports(succ_edge).1;
1506 }
1507
1508 let label = helper_edge_label(src_port, dst_port);
1509 let delay_type = self
1510 .node_op_inst(dst_id)
1511 .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1512 graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1513 }
1514
1515 if !write_config.no_references {
1517 for dst_id in self.node_ids() {
1518 for src_ref_id in self
1519 .node_singleton_references(dst_id)
1520 .iter()
1521 .copied()
1522 .flatten()
1523 {
1524 let delay_type = Some(DelayType::Stratum);
1525 let label = None;
1526 graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1527 }
1528 }
1529 }
1530
1531 let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1542 let loop_id = if write_config.no_loops {
1543 None
1544 } else {
1545 self.subgraph_loop(sg_id)
1546 };
1547 (loop_id, sg_id)
1548 });
1549 let loop_subgraphs = into_group_map(loop_subgraphs);
1550 for (loop_id, subgraph_ids) in loop_subgraphs {
1551 if let Some(loop_id) = loop_id {
1552 graph_write.write_loop_start(loop_id)?;
1553 }
1554
1555 let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
1557 self.subgraph(sg_id).iter().copied().map(move |node_id| {
1558 let opt_sg_id = if write_config.no_subgraphs {
1559 None
1560 } else {
1561 Some(sg_id)
1562 };
1563 (opt_sg_id, (self.node_varname(node_id), node_id))
1564 })
1565 });
1566 let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
1567 for (sg_id, varnames) in subgraph_varnames_nodes {
1568 if let Some(sg_id) = sg_id {
1569 let stratum = self.subgraph_stratum(sg_id).unwrap();
1570 graph_write.write_subgraph_start(sg_id, stratum)?;
1571 }
1572
1573 let varname_nodes = varnames.into_iter().map(|(varname, node)| {
1575 let varname = if write_config.no_varnames {
1576 None
1577 } else {
1578 varname
1579 };
1580 (varname, node)
1581 });
1582 let varname_nodes = into_group_map(varname_nodes);
1583 for (varname, node_ids) in varname_nodes {
1584 if let Some(varname) = varname {
1585 graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
1586 }
1587
1588 for node_id in node_ids {
1590 graph_write.write_node(node_id)?;
1591 }
1592
1593 if varname.is_some() {
1594 graph_write.write_varname_end()?;
1595 }
1596 }
1597
1598 if sg_id.is_some() {
1599 graph_write.write_subgraph_end()?;
1600 }
1601 }
1602
1603 if loop_id.is_some() {
1604 graph_write.write_loop_end()?;
1605 }
1606 }
1607
1608 graph_write.write_epilogue()?;
1610
1611 Ok(())
1612 }
1613
1614 pub fn surface_syntax_string(&self) -> String {
1616 let mut string = String::new();
1617 self.write_surface_syntax(&mut string).unwrap();
1618 string
1619 }
1620
1621 pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1623 for (key, node) in self.nodes.iter() {
1624 match node {
1625 GraphNode::Operator(op) => {
1626 writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
1627 }
1628 GraphNode::Handoff { .. } => {
1629 writeln!(write, "// {:?} = <handoff>;", key.data())?;
1630 }
1631 GraphNode::ModuleBoundary { .. } => panic!(),
1632 }
1633 }
1634 writeln!(write)?;
1635 for (_e, (src_key, dst_key)) in self.graph.edges() {
1636 writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
1637 }
1638 Ok(())
1639 }
1640
1641 pub fn mermaid_string_flat(&self) -> String {
1643 let mut string = String::new();
1644 self.write_mermaid_flat(&mut string).unwrap();
1645 string
1646 }
1647
1648 pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1650 writeln!(write, "flowchart TB")?;
1651 for (key, node) in self.nodes.iter() {
1652 match node {
1653 GraphNode::Operator(operator) => writeln!(
1654 write,
1655 " %% {span}\n {id:?}[\"{row_col} <tt>{code}</tt>\"]",
1656 span = PrettySpan(node.span()),
1657 id = key.data(),
1658 row_col = PrettyRowCol(node.span()),
1659 code = operator
1660 .to_token_stream()
1661 .to_string()
1662 .replace('&', "&")
1663 .replace('<', "<")
1664 .replace('>', ">")
1665 .replace('"', """)
1666 .replace('\n', "<br>"),
1667 ),
1668 GraphNode::Handoff { .. } => {
1669 writeln!(write, r#" {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
1670 }
1671 GraphNode::ModuleBoundary { .. } => {
1672 writeln!(
1673 write,
1674 r#" {:?}{{"{}"}}"#,
1675 key.data(),
1676 MODULE_BOUNDARY_NODE_STR
1677 )
1678 }
1679 }?;
1680 }
1681 writeln!(write)?;
1682 for (_e, (src_key, dst_key)) in self.graph.edges() {
1683 writeln!(write, " {:?}-->{:?}", src_key.data(), dst_key.data())?;
1684 }
1685 Ok(())
1686 }
1687}
1688
1689impl DfirGraph {
1691 pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
1693 self.loop_nodes.keys()
1694 }
1695
1696 pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
1698 self.loop_nodes.iter()
1699 }
1700
1701 pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
1703 let loop_id = self.loop_nodes.insert(Vec::new());
1704 self.loop_children.insert(loop_id, Vec::new());
1705 if let Some(parent_loop) = parent_loop {
1706 self.loop_parent.insert(loop_id, parent_loop);
1707 self.loop_children
1708 .get_mut(parent_loop)
1709 .unwrap()
1710 .push(loop_id);
1711 } else {
1712 self.root_loops.push(loop_id);
1713 }
1714 loop_id
1715 }
1716
1717 pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
1719 self.node_loops.get(node_id).copied()
1720 }
1721
1722 pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
1724 let &node_id = self.subgraph(subgraph_id).first().unwrap();
1725 let out = self.node_loop(node_id);
1726 debug_assert!(
1727 self.subgraph(subgraph_id)
1728 .iter()
1729 .all(|&node_id| self.node_loop(node_id) == out),
1730 "Subgraph nodes should all have the same loop context."
1731 );
1732 out
1733 }
1734
1735 pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
1737 self.loop_parent.get(loop_id).copied()
1738 }
1739
1740 pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
1742 self.loop_children.get(loop_id).unwrap()
1743 }
1744}
1745
1746#[derive(Clone, Debug, Default)]
1748#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
1749pub struct WriteConfig {
1750 #[cfg_attr(feature = "clap-derive", arg(long))]
1752 pub no_subgraphs: bool,
1753 #[cfg_attr(feature = "clap-derive", arg(long))]
1755 pub no_varnames: bool,
1756 #[cfg_attr(feature = "clap-derive", arg(long))]
1758 pub no_pull_push: bool,
1759 #[cfg_attr(feature = "clap-derive", arg(long))]
1761 pub no_handoffs: bool,
1762 #[cfg_attr(feature = "clap-derive", arg(long))]
1764 pub no_references: bool,
1765 #[cfg_attr(feature = "clap-derive", arg(long))]
1767 pub no_loops: bool,
1768
1769 #[cfg_attr(feature = "clap-derive", arg(long))]
1771 pub op_short_text: bool,
1772 #[cfg_attr(feature = "clap-derive", arg(long))]
1774 pub op_text_no_imports: bool,
1775}
1776
1777#[derive(Copy, Clone, Debug)]
1779#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
1780pub enum WriteGraphType {
1781 Mermaid,
1783 Dot,
1785}
1786
1787fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
1789where
1790 K: Ord,
1791{
1792 let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
1793 for (k, v) in iter {
1794 out.entry(k).or_default().push(v);
1795 }
1796 out
1797}