1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet, VecDeque};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, TokenStreamExt, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18 DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19 null_write_iterator_fn,
20};
21use super::{
22 CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23 GraphSubgraphId, HANDOFF_NODE_STR, MODULE_BOUNDARY_NODE_STR, OperatorInstance, PortIndexValue,
24 Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41 nodes: SlotMap<GraphNodeId, GraphNode>,
43
44 #[serde(skip)]
47 operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48 operator_tag: SecondaryMap<GraphNodeId, String>,
50 graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52 ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55 node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57 loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59 loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61 root_loops: Vec<GraphLoopId>,
63 loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66 node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69 subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71 subgraph_stratum: SecondaryMap<GraphSubgraphId, usize>,
73
74 node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
76 node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
78
79 subgraph_laziness: SecondaryMap<GraphSubgraphId, bool>,
83}
84
85impl DfirGraph {
87 pub fn new() -> Self {
89 Default::default()
90 }
91}
92
93impl DfirGraph {
95 pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
97 self.nodes.get(node_id).expect("Node not found.")
98 }
99
100 pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
105 self.operator_instances.get(node_id)
106 }
107
108 pub fn node_varname(&self, node_id: GraphNodeId) -> Option<Ident> {
110 self.node_varnames.get(node_id).map(|x| x.0.clone())
111 }
112
113 pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
115 self.node_subgraph.get(node_id).copied()
116 }
117
118 pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
120 self.graph.degree_in(node_id)
121 }
122
123 pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
125 self.graph.degree_out(node_id)
126 }
127
128 pub fn node_successors(
130 &self,
131 src: GraphNodeId,
132 ) -> impl '_
133 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
134 + ExactSizeIterator
135 + FusedIterator
136 + Clone
137 + Debug {
138 self.graph.successors(src)
139 }
140
141 pub fn node_predecessors(
143 &self,
144 dst: GraphNodeId,
145 ) -> impl '_
146 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
147 + ExactSizeIterator
148 + FusedIterator
149 + Clone
150 + Debug {
151 self.graph.predecessors(dst)
152 }
153
154 pub fn node_successor_edges(
156 &self,
157 src: GraphNodeId,
158 ) -> impl '_
159 + DoubleEndedIterator<Item = GraphEdgeId>
160 + ExactSizeIterator
161 + FusedIterator
162 + Clone
163 + Debug {
164 self.graph.successor_edges(src)
165 }
166
167 pub fn node_predecessor_edges(
169 &self,
170 dst: GraphNodeId,
171 ) -> impl '_
172 + DoubleEndedIterator<Item = GraphEdgeId>
173 + ExactSizeIterator
174 + FusedIterator
175 + Clone
176 + Debug {
177 self.graph.predecessor_edges(dst)
178 }
179
180 pub fn node_successor_nodes(
182 &self,
183 src: GraphNodeId,
184 ) -> impl '_
185 + DoubleEndedIterator<Item = GraphNodeId>
186 + ExactSizeIterator
187 + FusedIterator
188 + Clone
189 + Debug {
190 self.graph.successor_vertices(src)
191 }
192
193 pub fn node_predecessor_nodes(
195 &self,
196 dst: GraphNodeId,
197 ) -> impl '_
198 + DoubleEndedIterator<Item = GraphNodeId>
199 + ExactSizeIterator
200 + FusedIterator
201 + Clone
202 + Debug {
203 self.graph.predecessor_vertices(dst)
204 }
205
206 pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
208 self.nodes.keys()
209 }
210
211 pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
213 self.nodes.iter()
214 }
215
216 pub fn insert_node(
218 &mut self,
219 node: GraphNode,
220 varname_opt: Option<Ident>,
221 loop_opt: Option<GraphLoopId>,
222 ) -> GraphNodeId {
223 let node_id = self.nodes.insert(node);
224 if let Some(varname) = varname_opt {
225 self.node_varnames.insert(node_id, Varname(varname));
226 }
227 if let Some(loop_id) = loop_opt {
228 self.node_loops.insert(node_id, loop_id);
229 self.loop_nodes[loop_id].push(node_id);
230 }
231 node_id
232 }
233
234 pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
236 assert!(matches!(
237 self.nodes.get(node_id),
238 Some(GraphNode::Operator(_))
239 ));
240 let old_inst = self.operator_instances.insert(node_id, op_inst);
241 assert!(old_inst.is_none());
242 }
243
244 pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Vec<Diagnostic>) {
246 let mut op_insts = Vec::new();
247 for (node_id, node) in self.nodes() {
248 let GraphNode::Operator(operator) = node else {
249 continue;
250 };
251 if self.node_op_inst(node_id).is_some() {
252 continue;
253 };
254
255 let Some(op_constraints) = find_op_op_constraints(operator) else {
257 diagnostics.push(Diagnostic::spanned(
258 operator.path.span(),
259 Level::Error,
260 format!("Unknown operator `{}`", operator.name_string()),
261 ));
262 continue;
263 };
264
265 let (input_ports, output_ports) = {
267 let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
268 .node_predecessors(node_id)
269 .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
270 .collect();
271 input_edges.sort();
273 let input_ports: Vec<PortIndexValue> = input_edges
274 .into_iter()
275 .map(|(port, _pred)| port)
276 .cloned()
277 .collect();
278
279 let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
281 .node_successors(node_id)
282 .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
283 .collect();
284 output_edges.sort();
286 let output_ports: Vec<PortIndexValue> = output_edges
287 .into_iter()
288 .map(|(port, _succ)| port)
289 .cloned()
290 .collect();
291
292 (input_ports, output_ports)
293 };
294
295 let generics = get_operator_generics(diagnostics, operator);
297 {
299 let generics_span = generics
301 .generic_args
302 .as_ref()
303 .map(Spanned::span)
304 .unwrap_or_else(|| operator.path.span());
305
306 if !op_constraints
307 .persistence_args
308 .contains(&generics.persistence_args.len())
309 {
310 diagnostics.push(Diagnostic::spanned(
311 generics_span,
312 Level::Error,
313 format!(
314 "`{}` should have {} persistence lifetime arguments, actually has {}.",
315 op_constraints.name,
316 op_constraints.persistence_args.human_string(),
317 generics.persistence_args.len()
318 ),
319 ));
320 }
321 if !op_constraints.type_args.contains(&generics.type_args.len()) {
322 diagnostics.push(Diagnostic::spanned(
323 generics_span,
324 Level::Error,
325 format!(
326 "`{}` should have {} generic type arguments, actually has {}.",
327 op_constraints.name,
328 op_constraints.type_args.human_string(),
329 generics.type_args.len()
330 ),
331 ));
332 }
333 }
334
335 op_insts.push((
336 node_id,
337 OperatorInstance {
338 op_constraints,
339 input_ports,
340 output_ports,
341 singletons_referenced: operator.singletons_referenced.clone(),
342 generics,
343 arguments_pre: operator.args.clone(),
344 arguments_raw: operator.args_raw.clone(),
345 },
346 ));
347 }
348
349 for (node_id, op_inst) in op_insts {
350 self.insert_node_op_inst(node_id, op_inst);
351 }
352 }
353
354 pub fn insert_intermediate_node(
366 &mut self,
367 edge_id: GraphEdgeId,
368 new_node: GraphNode,
369 ) -> (GraphNodeId, GraphEdgeId) {
370 let span = Some(new_node.span());
371
372 let op_inst_opt = 'oc: {
374 let GraphNode::Operator(operator) = &new_node else {
375 break 'oc None;
376 };
377 let Some(op_constraints) = find_op_op_constraints(operator) else {
378 break 'oc None;
379 };
380 let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
381 let generics = get_operator_generics(
382 &mut Vec::new(), operator,
384 );
385 Some(OperatorInstance {
386 op_constraints,
387 input_ports: vec![input_port],
388 output_ports: vec![output_port],
389 singletons_referenced: operator.singletons_referenced.clone(),
390 generics,
391 arguments_pre: operator.args.clone(),
392 arguments_raw: operator.args_raw.clone(),
393 })
394 };
395
396 let node_id = self.nodes.insert(new_node);
398 if let Some(op_inst) = op_inst_opt {
400 self.operator_instances.insert(node_id, op_inst);
401 }
402 let (e0, e1) = self
404 .graph
405 .insert_intermediate_vertex(node_id, edge_id)
406 .unwrap();
407
408 let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
410 self.ports
411 .insert(e0, (src_idx, PortIndexValue::Elided(span)));
412 self.ports
413 .insert(e1, (PortIndexValue::Elided(span), dst_idx));
414
415 (node_id, e1)
416 }
417
418 pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
421 assert_eq!(
422 1,
423 self.node_degree_in(node_id),
424 "Removed intermediate node must have one predecessor"
425 );
426 assert_eq!(
427 1,
428 self.node_degree_out(node_id),
429 "Removed intermediate node must have one successor"
430 );
431 assert!(
432 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
433 "Should not remove intermediate node after subgraph partitioning"
434 );
435
436 assert!(self.nodes.remove(node_id).is_some());
437 let (new_edge_id, (pred_edge_id, succ_edge_id)) =
438 self.graph.remove_intermediate_vertex(node_id).unwrap();
439 self.operator_instances.remove(node_id);
440 self.node_varnames.remove(node_id);
441
442 let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
443 let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
444 self.ports.insert(new_edge_id, (src_port, dst_port));
445 }
446
447 pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
453 if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
454 return Some(Color::Hoff);
455 }
456 let inn_degree = self.node_predecessor_nodes(node_id).count();
458 let out_degree = self.node_successor_nodes(node_id).count();
460
461 match (inn_degree, out_degree) {
462 (0, 0) => None, (0, 1) => Some(Color::Pull),
464 (1, 0) => Some(Color::Push),
465 (1, 1) => None, (_many, 0 | 1) => Some(Color::Pull),
467 (0 | 1, _many) => Some(Color::Push),
468 (_many, _to_many) => Some(Color::Comp),
469 }
470 }
471
472 pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
474 self.operator_tag.insert(node_id, tag.to_owned());
475 }
476}
477
478impl DfirGraph {
480 pub fn set_node_singleton_references(
483 &mut self,
484 node_id: GraphNodeId,
485 singletons_referenced: Vec<Option<GraphNodeId>>,
486 ) -> Option<Vec<Option<GraphNodeId>>> {
487 self.node_singleton_references
488 .insert(node_id, singletons_referenced)
489 }
490
491 pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
494 self.node_singleton_references
495 .get(node_id)
496 .map(std::ops::Deref::deref)
497 .unwrap_or_default()
498 }
499}
500
501impl DfirGraph {
503 pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
511 let mod_bound_nodes = self
512 .nodes()
513 .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
514 .map(|(nid, _node)| nid)
515 .collect::<Vec<_>>();
516
517 for mod_bound_node in mod_bound_nodes {
518 self.remove_module_boundary(mod_bound_node)?;
519 }
520
521 Ok(())
522 }
523
524 fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
528 assert!(
529 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
530 "Should not remove intermediate node after subgraph partitioning"
531 );
532
533 let mut mod_pred_ports = BTreeMap::new();
534 let mut mod_succ_ports = BTreeMap::new();
535
536 for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
537 let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
538 mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
539 }
540
541 for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
542 let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
543 mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
544 }
545
546 if mod_pred_ports.keys().collect::<BTreeSet<_>>()
547 != mod_succ_ports.keys().collect::<BTreeSet<_>>()
548 {
549 let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
551 panic!();
552 };
553
554 if *input {
555 return Err(Diagnostic {
556 span: *import_expr,
557 level: Level::Error,
558 message: format!(
559 "The ports into the module did not match. input: {:?}, expected: {:?}",
560 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
561 mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
562 ),
563 });
564 } else {
565 return Err(Diagnostic {
566 span: *import_expr,
567 level: Level::Error,
568 message: format!(
569 "The ports out of the module did not match. output: {:?}, expected: {:?}",
570 mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
571 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
572 ),
573 });
574 }
575 }
576
577 for (port, (pred_edge, pred_port)) in mod_pred_ports {
578 let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
579
580 let (src, _) = self.edge(pred_edge);
581 let (_, dst) = self.edge(succ_edge);
582 self.remove_edge(pred_edge);
583 self.remove_edge(succ_edge);
584
585 let new_edge_id = self.graph.insert_edge(src, dst);
586 self.ports.insert(new_edge_id, (pred_port, succ_port));
587 }
588
589 self.graph.remove_vertex(mod_bound_node);
590 self.nodes.remove(mod_bound_node);
591
592 Ok(())
593 }
594}
595
596impl DfirGraph {
598 pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
600 let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
601 (src, dst)
602 }
603
604 pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
606 let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
607 (src_port, dst_port)
608 }
609
610 pub fn edge_ids(&self) -> slotmap::basic::Keys<GraphEdgeId, (GraphNodeId, GraphNodeId)> {
612 self.graph.edge_ids()
613 }
614
615 pub fn edges(
617 &self,
618 ) -> impl '_
619 + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
620 + FusedIterator
621 + Clone
622 + Debug {
623 self.graph.edges()
624 }
625
626 pub fn insert_edge(
628 &mut self,
629 src: GraphNodeId,
630 src_port: PortIndexValue,
631 dst: GraphNodeId,
632 dst_port: PortIndexValue,
633 ) -> GraphEdgeId {
634 let edge_id = self.graph.insert_edge(src, dst);
635 self.ports.insert(edge_id, (src_port, dst_port));
636 edge_id
637 }
638
639 pub fn remove_edge(&mut self, edge: GraphEdgeId) {
641 let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
642 let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
643 }
644}
645
646impl DfirGraph {
648 pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
650 self.subgraph_nodes
651 .get(subgraph_id)
652 .expect("Subgraph not found.")
653 }
654
655 pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
657 self.subgraph_nodes.keys()
658 }
659
660 pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
662 self.subgraph_nodes.iter()
663 }
664
665 pub fn insert_subgraph(
667 &mut self,
668 node_ids: Vec<GraphNodeId>,
669 ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
670 for &node_id in node_ids.iter() {
672 if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
673 return Err((node_id, old_sg_id));
674 }
675 }
676 let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
677 for &node_id in node_ids.iter() {
678 self.node_subgraph.insert(node_id, sg_id);
679 }
680 node_ids
681 });
682
683 Ok(subgraph_id)
684 }
685
686 pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
688 if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
689 self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
690 true
691 } else {
692 false
693 }
694 }
695
696 pub fn subgraph_stratum(&self, sg_id: GraphSubgraphId) -> Option<usize> {
698 self.subgraph_stratum.get(sg_id).copied()
699 }
700
701 pub fn set_subgraph_stratum(
703 &mut self,
704 sg_id: GraphSubgraphId,
705 stratum: usize,
706 ) -> Option<usize> {
707 self.subgraph_stratum.insert(sg_id, stratum)
708 }
709
710 fn subgraph_laziness(&self, sg_id: GraphSubgraphId) -> bool {
712 self.subgraph_laziness.get(sg_id).copied().unwrap_or(false)
713 }
714
715 pub fn set_subgraph_laziness(&mut self, sg_id: GraphSubgraphId, lazy: bool) -> bool {
717 self.subgraph_laziness.insert(sg_id, lazy).unwrap_or(false)
718 }
719
720 pub fn max_stratum(&self) -> Option<usize> {
722 self.subgraph_stratum.values().copied().max()
723 }
724
725 fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
727 subgraph_nodes
728 .iter()
729 .position(|&node_id| {
730 self.node_color(node_id)
731 .is_some_and(|color| Color::Pull != color)
732 })
733 .unwrap_or(subgraph_nodes.len())
734 }
735}
736
737impl DfirGraph {
739 fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
741 let name = match &self.nodes[node_id] {
742 GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
743 GraphNode::Handoff { .. } => format!(
744 "hoff_{:?}_{}",
745 node_id.data(),
746 if is_pred { "recv" } else { "send" }
747 ),
748 GraphNode::ModuleBoundary { .. } => panic!(),
749 };
750 let span = match (is_pred, &self.nodes[node_id]) {
751 (_, GraphNode::Operator(operator)) => operator.span(),
752 (true, &GraphNode::Handoff { src_span, .. }) => src_span,
753 (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
754 (_, GraphNode::ModuleBoundary { .. }) => panic!(),
755 };
756 Ident::new(&name, span)
757 }
758
759 fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
761 Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
762 }
763
764 fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<Ident> {
766 self.node_singleton_references(node_id)
767 .iter()
768 .map(|singleton_node_id| {
769 self.node_as_singleton_ident(
771 singleton_node_id
772 .expect("Expected singleton to be resolved but was not, this is a bug."),
773 span,
774 )
775 })
776 .collect::<Vec<_>>()
777 }
778
779 fn helper_collect_subgraph_handoffs(
782 &self,
783 ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
784 let mut subgraph_handoffs: SecondaryMap<
786 GraphSubgraphId,
787 (Vec<GraphNodeId>, Vec<GraphNodeId>),
788 > = self
789 .subgraph_nodes
790 .keys()
791 .map(|k| (k, Default::default()))
792 .collect();
793
794 for (hoff_id, node) in self.nodes() {
796 if !matches!(node, GraphNode::Handoff { .. }) {
797 continue;
798 }
799 for (_edge, succ_id) in self.node_successors(hoff_id) {
801 let succ_sg = self.node_subgraph(succ_id).unwrap();
802 subgraph_handoffs[succ_sg].0.push(hoff_id);
803 }
804 for (_edge, pred_id) in self.node_predecessors(hoff_id) {
806 let pred_sg = self.node_subgraph(pred_id).unwrap();
807 subgraph_handoffs[pred_sg].1.push(hoff_id);
808 }
809 }
810
811 subgraph_handoffs
812 }
813
814 fn loop_as_ident(loop_id: GraphLoopId) -> Ident {
816 Ident::new(&format!("loop_{:?}", loop_id.data()), Span::call_site())
817 }
818
819 fn codegen_nested_loops(&self, df: &Ident) -> TokenStream {
821 let mut out = TokenStream::new();
823 let mut queue = VecDeque::from_iter(self.root_loops.iter().copied());
824 while let Some(loop_id) = queue.pop_front() {
825 let parent_opt = self
826 .loop_parent(loop_id)
827 .map(Self::loop_as_ident)
828 .map(|ident| quote! { Some(#ident) })
829 .unwrap_or_else(|| quote! { None });
830 let loop_name = Self::loop_as_ident(loop_id);
831 out.append_all(quote! {
832 let #loop_name = #df.add_loop(#parent_opt);
833 });
834 queue.extend(self.loop_children.get(loop_id).into_iter().flatten());
835 }
836 out
837 }
838
839 pub fn as_code(
841 &self,
842 root: &TokenStream,
843 include_type_guards: bool,
844 prefix: TokenStream,
845 diagnostics: &mut Vec<Diagnostic>,
846 ) -> TokenStream {
847 let df = Ident::new(GRAPH, Span::call_site());
848 let context = Ident::new(CONTEXT, Span::call_site());
849
850 let handoff_code = self
852 .nodes
853 .iter()
854 .filter_map(|(node_id, node)| match node {
855 GraphNode::Operator(_) => None,
856 &GraphNode::Handoff { src_span, dst_span } => Some((node_id, (src_span, dst_span))),
857 GraphNode::ModuleBoundary { .. } => panic!(),
858 })
859 .map(|(node_id, (src_span, dst_span))| {
860 let ident_send = Ident::new(&format!("hoff_{:?}_send", node_id.data()), dst_span);
861 let ident_recv = Ident::new(&format!("hoff_{:?}_recv", node_id.data()), src_span);
862 let span = src_span.join(dst_span).unwrap_or(src_span);
863 let mut hoff_name = Literal::string(&format!("handoff {:?}", node_id));
864 hoff_name.set_span(span);
865 let hoff_type = quote_spanned! (span=> #root::scheduled::handoff::VecHandoff<_>);
866 quote_spanned! {span=>
867 let (#ident_send, #ident_recv) =
868 #df.make_edge::<_, #hoff_type>(#hoff_name);
869 }
870 });
871
872 let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
873
874 let (subgraphs_without_preds, subgraphs_with_preds) = self
876 .subgraph_nodes
877 .iter()
878 .partition::<Vec<_>, _>(|(_, nodes)| {
879 nodes
880 .iter()
881 .any(|&node_id| self.node_degree_in(node_id) == 0)
882 });
883
884 let mut op_prologue_code = Vec::new();
885 let mut subgraphs = Vec::new();
886 {
887 for &(subgraph_id, subgraph_nodes) in subgraphs_without_preds
888 .iter()
889 .chain(subgraphs_with_preds.iter())
890 {
891 let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
892 let recv_ports: Vec<Ident> = recv_hoffs
893 .iter()
894 .map(|&hoff_id| self.node_as_ident(hoff_id, true))
895 .collect();
896 let send_ports: Vec<Ident> = send_hoffs
897 .iter()
898 .map(|&hoff_id| self.node_as_ident(hoff_id, false))
899 .collect();
900
901 let recv_port_code = recv_ports.iter().map(|ident| {
902 quote! {
903 let mut #ident = #ident.borrow_mut_swap();
904 let #ident = #ident.drain(..);
905 }
906 });
907 let send_port_code = send_ports.iter().map(|ident| {
908 quote! {
909 let #ident = #root::pusherator::for_each::ForEach::new(|v| {
910 #ident.give(Some(v));
911 });
912 }
913 });
914
915 let loop_id = self
916 .node_loop(subgraph_nodes[0]);
918
919 let mut subgraph_op_iter_code = Vec::new();
920 let mut subgraph_op_iter_after_code = Vec::new();
921 {
922 let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
923
924 let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
925 let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
926
927 for (idx, &node_id) in nodes_iter.enumerate() {
928 let node = &self.nodes[node_id];
929 assert!(
930 matches!(node, GraphNode::Operator(_)),
931 "Handoffs are not part of subgraphs."
932 );
933 let op_inst = &self.operator_instances[node_id];
934
935 let op_span = node.span();
936 let op_name = op_inst.op_constraints.name;
937 let root = change_spans(root.clone(), op_span);
939 let op_constraints = OPERATORS
941 .iter()
942 .find(|op| op_name == op.name)
943 .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
944
945 let ident = self.node_as_ident(node_id, false);
946
947 {
948 let mut input_edges = self
951 .graph
952 .predecessor_edges(node_id)
953 .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
954 .collect::<Vec<_>>();
955 input_edges.sort();
957
958 let inputs = input_edges
959 .iter()
960 .map(|&(_port, edge_id)| {
961 let (pred, _) = self.edge(edge_id);
962 self.node_as_ident(pred, true)
963 })
964 .collect::<Vec<_>>();
965
966 let mut output_edges = self
968 .graph
969 .successor_edges(node_id)
970 .map(|edge_id| (&self.ports[edge_id].0, edge_id))
971 .collect::<Vec<_>>();
972 output_edges.sort();
974
975 let outputs = output_edges
976 .iter()
977 .map(|&(_port, edge_id)| {
978 let (_, succ) = self.edge(edge_id);
979 self.node_as_ident(succ, false)
980 })
981 .collect::<Vec<_>>();
982
983 let is_pull = idx < pull_to_push_idx;
984
985 let singleton_output_ident = &if op_constraints.has_singleton_output {
986 self.node_as_singleton_ident(node_id, op_span)
987 } else {
988 Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
990 };
991
992 let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1001 let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1002
1003 let singletons_resolved =
1004 self.helper_resolve_singletons(node_id, op_span);
1005 let arguments = &process_singletons::postprocess_singletons(
1006 op_inst.arguments_raw.clone(),
1007 singletons_resolved.clone(),
1008 context,
1009 );
1010 let arguments_handles =
1011 &process_singletons::postprocess_singletons_handles(
1012 op_inst.arguments_raw.clone(),
1013 singletons_resolved.clone(),
1014 );
1015
1016 let source_tag = 'a: {
1017 if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1018 break 'a tag;
1019 }
1020
1021 #[cfg(nightly)]
1022 if proc_macro::is_available() {
1023 let op_span = op_span.unwrap();
1024 break 'a format!(
1025 "loc_{}_{}_{}_{}_{}",
1026 op_span
1027 .source_file()
1028 .path()
1029 .display()
1030 .to_string()
1031 .replace(|x: char| !x.is_alphanumeric(), "_"),
1032 op_span.start().line(),
1033 op_span.start().column(),
1034 op_span.end().line(),
1035 op_span.end().column(),
1036 );
1037 }
1038
1039 format!(
1040 "loc_nopath_{}_{}_{}_{}",
1041 op_span.start().line,
1042 op_span.start().column,
1043 op_span.end().line,
1044 op_span.end().column
1045 )
1046 };
1047
1048 let fn_ident = format_ident!(
1049 "{}__{}__{}",
1050 ident,
1051 op_name,
1052 source_tag,
1053 span = op_span
1054 );
1055
1056 let context_args = WriteContextArgs {
1057 root: &root,
1058 df_ident: df_local,
1059 context,
1060 subgraph_id,
1061 node_id,
1062 loop_id,
1063 op_span,
1064 op_tag: self.operator_tag.get(node_id).cloned(),
1065 work_fn: &fn_ident,
1066 ident: &ident,
1067 is_pull,
1068 inputs: &inputs,
1069 outputs: &outputs,
1070 singleton_output_ident,
1071 op_name,
1072 op_inst,
1073 arguments,
1074 arguments_handles,
1075 };
1076
1077 let write_result =
1078 (op_constraints.write_fn)(&context_args, diagnostics);
1079 let OperatorWriteOutput {
1080 write_prologue,
1081 write_iterator,
1082 write_iterator_after,
1083 } = write_result.unwrap_or_else(|()| {
1084 assert!(
1085 diagnostics.iter().any(Diagnostic::is_error),
1086 "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1087 op_name,
1088 );
1089 OperatorWriteOutput { write_iterator: null_write_iterator_fn(&context_args), ..Default::default() }
1090 });
1091
1092 op_prologue_code.push(syn::parse_quote! {
1093 #[allow(non_snake_case)]
1094 #[inline(always)]
1095 fn #fn_ident<T>(thunk: impl FnOnce() -> T) -> T {
1096 thunk()
1097 }
1098 });
1099 op_prologue_code.push(write_prologue);
1100
1101 subgraph_op_iter_code.push(write_iterator);
1102
1103 if include_type_guards {
1104 let type_guard = if is_pull {
1105 quote_spanned! {op_span=>
1106 let #ident = {
1107 #[allow(non_snake_case)]
1108 #[inline(always)]
1109 pub fn #fn_ident<Item, Input: ::std::iter::Iterator<Item = Item>>(input: Input) -> impl ::std::iter::Iterator<Item = Item> {
1110 #[repr(transparent)]
1111 struct Pull<Item, Input: ::std::iter::Iterator<Item = Item>> {
1112 inner: Input
1113 }
1114
1115 impl<Item, Input: ::std::iter::Iterator<Item = Item>> Iterator for Pull<Item, Input> {
1116 type Item = Item;
1117
1118 #[inline(always)]
1119 fn next(&mut self) -> Option<Self::Item> {
1120 self.inner.next()
1121 }
1122
1123 #[inline(always)]
1124 fn size_hint(&self) -> (usize, Option<usize>) {
1125 self.inner.size_hint()
1126 }
1127 }
1128
1129 Pull {
1130 inner: input
1131 }
1132 }
1133 #fn_ident( #ident )
1134 };
1135 }
1136 } else {
1137 quote_spanned! {op_span=>
1138 let #ident = {
1139 #[allow(non_snake_case)]
1140 #[inline(always)]
1141 pub fn #fn_ident<Item, Input: #root::pusherator::Pusherator<Item = Item>>(input: Input) -> impl #root::pusherator::Pusherator<Item = Item> {
1142 #[repr(transparent)]
1143 struct Push<Item, Input: #root::pusherator::Pusherator<Item = Item>> {
1144 inner: Input
1145 }
1146
1147 impl<Item, Input: #root::pusherator::Pusherator<Item = Item>> #root::pusherator::Pusherator for Push<Item, Input> {
1148 type Item = Item;
1149
1150 #[inline(always)]
1151 fn give(&mut self, item: Self::Item) {
1152 self.inner.give(item)
1153 }
1154 }
1155
1156 Push {
1157 inner: input
1158 }
1159 }
1160 #fn_ident( #ident )
1161 };
1162 }
1163 };
1164 subgraph_op_iter_code.push(type_guard);
1165 }
1166 subgraph_op_iter_after_code.push(write_iterator_after);
1167 }
1168 }
1169
1170 {
1171 let pull_ident = if 0 < pull_to_push_idx {
1173 self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1174 } else {
1175 recv_ports[0].clone()
1177 };
1178
1179 #[rustfmt::skip]
1180 let push_ident = if let Some(&node_id) =
1181 subgraph_nodes.get(pull_to_push_idx)
1182 {
1183 self.node_as_ident(node_id, false)
1184 } else if 1 == send_ports.len() {
1185 send_ports[0].clone()
1187 } else {
1188 diagnostics.push(Diagnostic::spanned(
1189 pull_ident.span(),
1190 Level::Error,
1191 "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1192 ));
1193 continue;
1194 };
1195
1196 let pivot_span = pull_ident
1198 .span()
1199 .join(push_ident.span())
1200 .unwrap_or_else(|| push_ident.span());
1201 let pivot_fn_ident =
1202 Ident::new(&format!("pivot_run_sg_{:?}", subgraph_id.0), pivot_span);
1203 subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1204 #[inline(always)]
1205 fn #pivot_fn_ident<Pull: ::std::iter::Iterator<Item = Item>, Push: #root::pusherator::Pusherator<Item = Item>, Item>(pull: Pull, push: Push) {
1206 #root::pusherator::pivot::Pivot::new(pull, push).run();
1207 }
1208 #pivot_fn_ident(#pull_ident, #push_ident);
1209 });
1210 }
1211 };
1212
1213 let subgraph_name = Literal::string(&format!("Subgraph {:?}", subgraph_id));
1214 let stratum = Literal::usize_unsuffixed(
1215 self.subgraph_stratum.get(subgraph_id).cloned().unwrap_or(0),
1216 );
1217 let laziness = self.subgraph_laziness(subgraph_id);
1218
1219 let loop_id_opt = loop_id
1221 .map(Self::loop_as_ident)
1222 .map(|ident| quote! { Some(#ident) })
1223 .unwrap_or_else(|| quote! { None });
1224
1225 subgraphs.push(quote! {
1226 #df.add_subgraph_full(
1227 #subgraph_name,
1228 #stratum,
1229 var_expr!( #( #recv_ports ),* ),
1230 var_expr!( #( #send_ports ),* ),
1231 #laziness,
1232 #loop_id_opt,
1233 move |#context, var_args!( #( #recv_ports ),* ), var_args!( #( #send_ports ),* )| {
1234 #( #recv_port_code )*
1235 #( #send_port_code )*
1236 #( #subgraph_op_iter_code )*
1237 #( #subgraph_op_iter_after_code )*
1238 },
1239 );
1240 });
1241 }
1242 }
1243
1244 let loop_code = self.codegen_nested_loops(&df);
1245
1246 let code = quote! {
1251 #( #handoff_code )*
1252 #loop_code
1253 #( #op_prologue_code )*
1254 #( #subgraphs )*
1255 };
1256
1257 let meta_graph_json = serde_json::to_string(&self).unwrap();
1258 let meta_graph_json = Literal::string(&meta_graph_json);
1259
1260 let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1261 let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1262 let diagnostics_json = Literal::string(&diagnostics_json);
1263
1264 quote! {
1265 {
1266 #[allow(unused_qualifications)]
1267 {
1268 #prefix
1269
1270 use #root::{var_expr, var_args};
1271
1272 let mut #df = #root::scheduled::graph::Dfir::new();
1273 #df.__assign_meta_graph(#meta_graph_json);
1274 #df.__assign_diagnostics(#diagnostics_json);
1275
1276 #code
1277
1278 #df
1279 }
1280 }
1281 }
1282 }
1283
1284 pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1287 let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1288 .node_ids()
1289 .filter_map(|node_id| {
1290 let op_color = self.node_color(node_id)?;
1291 Some((node_id, op_color))
1292 })
1293 .collect();
1294
1295 for sg_nodes in self.subgraph_nodes.values() {
1297 let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1298
1299 for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1300 let is_pull = idx < pull_to_push_idx;
1301 node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1302 }
1303 }
1304
1305 node_color_map
1306 }
1307
1308 pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1310 let mut output = String::new();
1311 self.write_mermaid(&mut output, write_config).unwrap();
1312 output
1313 }
1314
1315 pub fn write_mermaid(
1317 &self,
1318 output: impl std::fmt::Write,
1319 write_config: &WriteConfig,
1320 ) -> std::fmt::Result {
1321 let mut graph_write = Mermaid::new(output);
1322 self.write_graph(&mut graph_write, write_config)
1323 }
1324
1325 pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1327 let mut output = String::new();
1328 let mut graph_write = Dot::new(&mut output);
1329 self.write_graph(&mut graph_write, write_config).unwrap();
1330 output
1331 }
1332
1333 pub fn write_dot(
1335 &self,
1336 output: impl std::fmt::Write,
1337 write_config: &WriteConfig,
1338 ) -> std::fmt::Result {
1339 let mut graph_write = Dot::new(output);
1340 self.write_graph(&mut graph_write, write_config)
1341 }
1342
1343 pub(crate) fn write_graph<W>(
1345 &self,
1346 mut graph_write: W,
1347 write_config: &WriteConfig,
1348 ) -> Result<(), W::Err>
1349 where
1350 W: GraphWrite,
1351 {
1352 fn helper_edge_label(
1353 src_port: &PortIndexValue,
1354 dst_port: &PortIndexValue,
1355 ) -> Option<String> {
1356 let src_label = match src_port {
1357 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1358 PortIndexValue::Int(index) => Some(index.value.to_string()),
1359 _ => None,
1360 };
1361 let dst_label = match dst_port {
1362 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1363 PortIndexValue::Int(index) => Some(index.value.to_string()),
1364 _ => None,
1365 };
1366 let label = match (src_label, dst_label) {
1367 (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1368 (Some(l1), None) => Some(l1),
1369 (None, Some(l2)) => Some(l2),
1370 (None, None) => None,
1371 };
1372 label
1373 }
1374
1375 let node_color_map = self.node_color_map();
1377
1378 let mut sg_varname_nodes =
1380 <SparseSecondaryMap<GraphSubgraphId, BTreeMap<Varname, BTreeSet<GraphNodeId>>>>::new();
1381 let mut varname_nodes = <BTreeMap<Varname, BTreeSet<GraphNodeId>>>::new();
1382 if !write_config.no_varnames {
1383 for (node_id, varname) in self.node_varnames.iter() {
1384 let varname_map = if !write_config.no_subgraphs {
1386 let Some(sg_id) = self.node_subgraph(node_id) else {
1387 continue;
1388 };
1389 sg_varname_nodes.entry(sg_id).unwrap().or_default()
1390 } else {
1391 &mut varname_nodes
1392 };
1393 varname_map
1394 .entry(varname.clone())
1395 .or_default()
1396 .insert(node_id);
1397 }
1398 }
1399
1400 graph_write.write_prologue()?;
1402
1403 let mut skipped_handoffs = BTreeSet::new();
1405 let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1406 for (node_id, node) in self.nodes() {
1407 if matches!(node, GraphNode::Handoff { .. }) {
1408 if write_config.no_handoffs {
1409 skipped_handoffs.insert(node_id);
1410 continue;
1411 } else {
1412 let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1413 let pred_sg = self.node_subgraph(pred_node);
1414 let succ_node = self.node_successor_nodes(node_id).next().unwrap();
1415 let succ_sg = self.node_subgraph(succ_node);
1416 if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg) {
1417 if pred_sg == succ_sg {
1418 subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1419 }
1420 }
1421 }
1422 }
1423 graph_write.write_node(
1424 node_id,
1425 &if write_config.op_short_text {
1426 node.to_name_string()
1427 } else if write_config.op_text_no_imports {
1428 let full_text = node.to_pretty_string();
1430 let mut output = String::new();
1431 for sentence in full_text.split('\n') {
1432 if sentence.trim().starts_with("use") {
1433 continue;
1434 }
1435 output.push('\n');
1436 output.push_str(sentence);
1437 }
1438 output.into()
1439 } else {
1440 node.to_pretty_string()
1441 },
1442 if write_config.no_pull_push {
1443 None
1444 } else {
1445 node_color_map.get(node_id).copied()
1446 },
1447 )?;
1448 }
1449
1450 for (edge_id, (src_id, mut dst_id)) in self.edges() {
1452 if skipped_handoffs.contains(&src_id) {
1454 continue;
1455 }
1456
1457 let (src_port, mut dst_port) = self.edge_ports(edge_id);
1458 if skipped_handoffs.contains(&dst_id) {
1459 let mut handoff_succs = self.node_successors(dst_id);
1460 assert_eq!(1, handoff_succs.len());
1461 let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1462 dst_id = succ_node;
1463 dst_port = self.edge_ports(succ_edge).1;
1464 }
1465
1466 let label = helper_edge_label(src_port, dst_port);
1467 let delay_type = self
1468 .node_op_inst(dst_id)
1469 .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1470 graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1471 }
1472
1473 if !write_config.no_references {
1475 for dst_id in self.node_ids() {
1476 for src_ref_id in self
1477 .node_singleton_references(dst_id)
1478 .iter()
1479 .copied()
1480 .flatten()
1481 {
1482 let delay_type = Some(DelayType::Stratum);
1483 let label = None;
1484 graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1485 }
1486 }
1487 }
1488
1489 if !write_config.no_subgraphs {
1491 for (subgraph_id, subgraph_node_ids) in self.subgraph_nodes.iter() {
1492 let handoff_node_ids = subgraph_handoffs.get(&subgraph_id).into_iter().flatten();
1493 let subgraph_node_ids = subgraph_node_ids.iter();
1494 let all_node_ids = handoff_node_ids.chain(subgraph_node_ids).copied();
1495
1496 let stratum = self.subgraph_stratum.get(subgraph_id);
1497 graph_write.write_subgraph_start(subgraph_id, *stratum.unwrap(), all_node_ids)?;
1498 if !write_config.no_varnames {
1500 for (varname, varname_node_ids) in
1501 sg_varname_nodes.remove(subgraph_id).into_iter().flatten()
1502 {
1503 assert!(!varname_node_ids.is_empty());
1504 graph_write.write_varname(
1505 &varname.0.to_string(),
1506 varname_node_ids.into_iter(),
1507 Some(subgraph_id),
1508 )?;
1509 }
1510 }
1511 graph_write.write_subgraph_end()?;
1512 }
1513 } else if !write_config.no_varnames {
1514 for (varname, varname_node_ids) in varname_nodes {
1515 graph_write.write_varname(
1516 &varname.0.to_string(),
1517 varname_node_ids.into_iter(),
1518 None,
1519 )?;
1520 }
1521 }
1522
1523 graph_write.write_epilogue()?;
1525
1526 Ok(())
1527 }
1528
1529 pub fn surface_syntax_string(&self) -> String {
1531 let mut string = String::new();
1532 self.write_surface_syntax(&mut string).unwrap();
1533 string
1534 }
1535
1536 pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1538 for (key, node) in self.nodes.iter() {
1539 match node {
1540 GraphNode::Operator(op) => {
1541 writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
1542 }
1543 GraphNode::Handoff { .. } => {
1544 writeln!(write, "// {:?} = <handoff>;", key.data())?;
1545 }
1546 GraphNode::ModuleBoundary { .. } => panic!(),
1547 }
1548 }
1549 writeln!(write)?;
1550 for (_e, (src_key, dst_key)) in self.graph.edges() {
1551 writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
1552 }
1553 Ok(())
1554 }
1555
1556 pub fn mermaid_string_flat(&self) -> String {
1558 let mut string = String::new();
1559 self.write_mermaid_flat(&mut string).unwrap();
1560 string
1561 }
1562
1563 pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1565 writeln!(write, "flowchart TB")?;
1566 for (key, node) in self.nodes.iter() {
1567 match node {
1568 GraphNode::Operator(operator) => writeln!(
1569 write,
1570 " %% {span}\n {id:?}[\"{row_col} <tt>{code}</tt>\"]",
1571 span = PrettySpan(node.span()),
1572 id = key.data(),
1573 row_col = PrettyRowCol(node.span()),
1574 code = operator
1575 .to_token_stream()
1576 .to_string()
1577 .replace('&', "&")
1578 .replace('<', "<")
1579 .replace('>', ">")
1580 .replace('"', """)
1581 .replace('\n', "<br>"),
1582 ),
1583 GraphNode::Handoff { .. } => {
1584 writeln!(write, r#" {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
1585 }
1586 GraphNode::ModuleBoundary { .. } => {
1587 writeln!(
1588 write,
1589 r#" {:?}{{"{}"}}"#,
1590 key.data(),
1591 MODULE_BOUNDARY_NODE_STR
1592 )
1593 }
1594 }?;
1595 }
1596 writeln!(write)?;
1597 for (_e, (src_key, dst_key)) in self.graph.edges() {
1598 writeln!(write, " {:?}-->{:?}", src_key.data(), dst_key.data())?;
1599 }
1600 Ok(())
1601 }
1602}
1603
1604impl DfirGraph {
1606 pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
1608 self.loop_nodes.keys()
1609 }
1610
1611 pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
1613 self.loop_nodes.iter()
1614 }
1615
1616 pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
1618 let loop_id = self.loop_nodes.insert(Vec::new());
1619 self.loop_children.insert(loop_id, Vec::new());
1620 if let Some(parent_loop) = parent_loop {
1621 self.loop_parent.insert(loop_id, parent_loop);
1622 self.loop_children
1623 .get_mut(parent_loop)
1624 .unwrap()
1625 .push(loop_id);
1626 } else {
1627 self.root_loops.push(loop_id);
1628 }
1629 loop_id
1630 }
1631
1632 pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
1634 self.node_loops.get(node_id).copied()
1635 }
1636
1637 pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
1639 let &node_id = self.subgraph(subgraph_id).first().unwrap();
1640 let out = self.node_loop(node_id);
1641 debug_assert!(
1642 self.subgraph(subgraph_id)
1643 .iter()
1644 .all(|&node_id| self.node_loop(node_id) == out),
1645 "Subgraph nodes should all have the same loop context."
1646 );
1647 out
1648 }
1649
1650 pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
1652 self.loop_parent.get(loop_id).copied()
1653 }
1654
1655 pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
1657 self.loop_children.get(loop_id).unwrap()
1658 }
1659}
1660
1661#[derive(Clone, Debug, Default)]
1663#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
1664pub struct WriteConfig {
1665 #[cfg_attr(feature = "clap-derive", arg(long))]
1667 pub no_subgraphs: bool,
1668 #[cfg_attr(feature = "clap-derive", arg(long))]
1670 pub no_varnames: bool,
1671 #[cfg_attr(feature = "clap-derive", arg(long))]
1673 pub no_pull_push: bool,
1674 #[cfg_attr(feature = "clap-derive", arg(long))]
1676 pub no_handoffs: bool,
1677 #[cfg_attr(feature = "clap-derive", arg(long))]
1679 pub no_references: bool,
1680
1681 #[cfg_attr(feature = "clap-derive", arg(long))]
1683 pub op_short_text: bool,
1684 #[cfg_attr(feature = "clap-derive", arg(long))]
1686 pub op_text_no_imports: bool,
1687}
1688
1689#[derive(Copy, Clone, Debug)]
1691#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
1692pub enum WriteGraphType {
1693 Mermaid,
1695 Dot,
1697}