dfir_lang/graph/
meta_graph.rs

1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet, VecDeque};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, TokenStreamExt, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18    DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19    null_write_iterator_fn,
20};
21use super::{
22    CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23    GraphSubgraphId, HANDOFF_NODE_STR, MODULE_BOUNDARY_NODE_STR, OperatorInstance, PortIndexValue,
24    Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30/// An abstract "meta graph" representation of a DFIR graph.
31///
32/// Can be with or without subgraph partitioning, stratification, and handoff insertion. This is
33/// the meta graph used for generating Rust source code in macros from DFIR sytnax.
34///
35/// This struct has a lot of methods for manipulating the graph, vaguely grouped together in
36/// separate `impl` blocks. You might notice a few particularly specific arbitray-seeming methods
37/// in here--those are just what was needed for the compilation algorithms. If you need another
38/// method then add it.
39#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41    /// Each node type (operator or handoff).
42    nodes: SlotMap<GraphNodeId, GraphNode>,
43
44    /// Instance data corresponding to each operator node.
45    /// This field will be empty after deserialization.
46    #[serde(skip)]
47    operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48    /// Debugging/tracing tag for each operator node.
49    operator_tag: SecondaryMap<GraphNodeId, String>,
50    /// Graph data structure (two-way adjacency list).
51    graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52    /// Input and output port for each edge.
53    ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55    /// Which loop a node belongs to (or none for top-level).
56    node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57    /// Which nodes belong to each loop.
58    loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59    /// For the loop, what is its parent (`None` for top-level).
60    loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61    /// What loops are at the root.
62    root_loops: Vec<GraphLoopId>,
63    /// For the loop, what are its child loops.
64    loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66    /// Which subgraph each node belongs to.
67    node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69    /// Which nodes belong to each subgraph.
70    subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71    /// Which stratum each subgraph belongs to.
72    subgraph_stratum: SecondaryMap<GraphSubgraphId, usize>,
73
74    /// Resolved singletons varnames references, per node.
75    node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
76    /// What variable name each graph node belongs to (if any). For debugging (graph writing) purposes only.
77    node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
78
79    /// If this subgraph is 'lazy' then when it sends data to a lower stratum it does not cause a new tick to start
80    /// This is to support lazy defers
81    /// If the value does not exist for a given subgraph id then the subgraph is not lazy.
82    subgraph_laziness: SecondaryMap<GraphSubgraphId, bool>,
83}
84
85/// Basic methods.
86impl DfirGraph {
87    /// Create a new empty graph.
88    pub fn new() -> Self {
89        Default::default()
90    }
91}
92
93/// Node methods.
94impl DfirGraph {
95    /// Get a node with its operator instance (if applicable).
96    pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
97        self.nodes.get(node_id).expect("Node not found.")
98    }
99
100    /// Get the `OperatorInstance` for a given node. Node must be an operator and have an
101    /// `OperatorInstance` present, otherwise will return `None`.
102    ///
103    /// Note that no operator instances will be persent after deserialization.
104    pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
105        self.operator_instances.get(node_id)
106    }
107
108    /// Get the debug variable name attached to a graph node.
109    pub fn node_varname(&self, node_id: GraphNodeId) -> Option<Ident> {
110        self.node_varnames.get(node_id).map(|x| x.0.clone())
111    }
112
113    /// Get subgraph for node.
114    pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
115        self.node_subgraph.get(node_id).copied()
116    }
117
118    /// Degree into a node, i.e. the number of predecessors.
119    pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
120        self.graph.degree_in(node_id)
121    }
122
123    /// Degree out of a node, i.e. the number of successors.
124    pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
125        self.graph.degree_out(node_id)
126    }
127
128    /// Successors, iterator of `(GraphEdgeId, GraphNodeId)` of outgoing edges.
129    pub fn node_successors(
130        &self,
131        src: GraphNodeId,
132    ) -> impl '_
133    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
134    + ExactSizeIterator
135    + FusedIterator
136    + Clone
137    + Debug {
138        self.graph.successors(src)
139    }
140
141    /// Predecessors, iterator of `(GraphEdgeId, GraphNodeId)` of incoming edges.
142    pub fn node_predecessors(
143        &self,
144        dst: GraphNodeId,
145    ) -> impl '_
146    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
147    + ExactSizeIterator
148    + FusedIterator
149    + Clone
150    + Debug {
151        self.graph.predecessors(dst)
152    }
153
154    /// Successor edges, iterator of `GraphEdgeId` of outgoing edges.
155    pub fn node_successor_edges(
156        &self,
157        src: GraphNodeId,
158    ) -> impl '_
159    + DoubleEndedIterator<Item = GraphEdgeId>
160    + ExactSizeIterator
161    + FusedIterator
162    + Clone
163    + Debug {
164        self.graph.successor_edges(src)
165    }
166
167    /// Predecessor edges, iterator of `GraphEdgeId` of incoming edges.
168    pub fn node_predecessor_edges(
169        &self,
170        dst: GraphNodeId,
171    ) -> impl '_
172    + DoubleEndedIterator<Item = GraphEdgeId>
173    + ExactSizeIterator
174    + FusedIterator
175    + Clone
176    + Debug {
177        self.graph.predecessor_edges(dst)
178    }
179
180    /// Successor nodes, iterator of `GraphNodeId`.
181    pub fn node_successor_nodes(
182        &self,
183        src: GraphNodeId,
184    ) -> impl '_
185    + DoubleEndedIterator<Item = GraphNodeId>
186    + ExactSizeIterator
187    + FusedIterator
188    + Clone
189    + Debug {
190        self.graph.successor_vertices(src)
191    }
192
193    /// Predecessor nodes, iterator of `GraphNodeId`.
194    pub fn node_predecessor_nodes(
195        &self,
196        dst: GraphNodeId,
197    ) -> impl '_
198    + DoubleEndedIterator<Item = GraphNodeId>
199    + ExactSizeIterator
200    + FusedIterator
201    + Clone
202    + Debug {
203        self.graph.predecessor_vertices(dst)
204    }
205
206    /// Iterator of node IDs `GraphNodeId`.
207    pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
208        self.nodes.keys()
209    }
210
211    /// Iterator over `(GraphNodeId, &Node)` pairs.
212    pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
213        self.nodes.iter()
214    }
215
216    /// Insert a node, assigning the given varname.
217    pub fn insert_node(
218        &mut self,
219        node: GraphNode,
220        varname_opt: Option<Ident>,
221        loop_opt: Option<GraphLoopId>,
222    ) -> GraphNodeId {
223        let node_id = self.nodes.insert(node);
224        if let Some(varname) = varname_opt {
225            self.node_varnames.insert(node_id, Varname(varname));
226        }
227        if let Some(loop_id) = loop_opt {
228            self.node_loops.insert(node_id, loop_id);
229            self.loop_nodes[loop_id].push(node_id);
230        }
231        node_id
232    }
233
234    /// Insert an operator instance for the given node. Panics if already set.
235    pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
236        assert!(matches!(
237            self.nodes.get(node_id),
238            Some(GraphNode::Operator(_))
239        ));
240        let old_inst = self.operator_instances.insert(node_id, op_inst);
241        assert!(old_inst.is_none());
242    }
243
244    /// Assign all operator instances if not set. Write diagnostic messages/errors into `diagnostics`.
245    pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Vec<Diagnostic>) {
246        let mut op_insts = Vec::new();
247        for (node_id, node) in self.nodes() {
248            let GraphNode::Operator(operator) = node else {
249                continue;
250            };
251            if self.node_op_inst(node_id).is_some() {
252                continue;
253            };
254
255            // Op constraints.
256            let Some(op_constraints) = find_op_op_constraints(operator) else {
257                diagnostics.push(Diagnostic::spanned(
258                    operator.path.span(),
259                    Level::Error,
260                    format!("Unknown operator `{}`", operator.name_string()),
261                ));
262                continue;
263            };
264
265            // Input and output ports.
266            let (input_ports, output_ports) = {
267                let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
268                    .node_predecessors(node_id)
269                    .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
270                    .collect();
271                // Ensure sorted by port index.
272                input_edges.sort();
273                let input_ports: Vec<PortIndexValue> = input_edges
274                    .into_iter()
275                    .map(|(port, _pred)| port)
276                    .cloned()
277                    .collect();
278
279                // Collect output arguments (successors).
280                let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
281                    .node_successors(node_id)
282                    .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
283                    .collect();
284                // Ensure sorted by port index.
285                output_edges.sort();
286                let output_ports: Vec<PortIndexValue> = output_edges
287                    .into_iter()
288                    .map(|(port, _succ)| port)
289                    .cloned()
290                    .collect();
291
292                (input_ports, output_ports)
293            };
294
295            // Generic arguments.
296            let generics = get_operator_generics(diagnostics, operator);
297            // Generic argument errors.
298            {
299                // Span of `generic_args` (if it exists), otherwise span of the operator name.
300                let generics_span = generics
301                    .generic_args
302                    .as_ref()
303                    .map(Spanned::span)
304                    .unwrap_or_else(|| operator.path.span());
305
306                if !op_constraints
307                    .persistence_args
308                    .contains(&generics.persistence_args.len())
309                {
310                    diagnostics.push(Diagnostic::spanned(
311                        generics_span,
312                        Level::Error,
313                        format!(
314                            "`{}` should have {} persistence lifetime arguments, actually has {}.",
315                            op_constraints.name,
316                            op_constraints.persistence_args.human_string(),
317                            generics.persistence_args.len()
318                        ),
319                    ));
320                }
321                if !op_constraints.type_args.contains(&generics.type_args.len()) {
322                    diagnostics.push(Diagnostic::spanned(
323                        generics_span,
324                        Level::Error,
325                        format!(
326                            "`{}` should have {} generic type arguments, actually has {}.",
327                            op_constraints.name,
328                            op_constraints.type_args.human_string(),
329                            generics.type_args.len()
330                        ),
331                    ));
332                }
333            }
334
335            op_insts.push((
336                node_id,
337                OperatorInstance {
338                    op_constraints,
339                    input_ports,
340                    output_ports,
341                    singletons_referenced: operator.singletons_referenced.clone(),
342                    generics,
343                    arguments_pre: operator.args.clone(),
344                    arguments_raw: operator.args_raw.clone(),
345                },
346            ));
347        }
348
349        for (node_id, op_inst) in op_insts {
350            self.insert_node_op_inst(node_id, op_inst);
351        }
352    }
353
354    /// Inserts a node between two existing nodes connected by the given `edge_id`.
355    ///
356    /// `edge`: (src, dst, dst_idx)
357    ///
358    /// Before: A (src) ------------> B (dst)
359    /// After:  A (src) -> X (new) -> B (dst)
360    ///
361    /// Returns the ID of X & ID of edge OUT of X.
362    ///
363    /// Note that both the edges will be new and `edge_id` will be removed. Both new edges will
364    /// get the edge type of the original edge.
365    pub fn insert_intermediate_node(
366        &mut self,
367        edge_id: GraphEdgeId,
368        new_node: GraphNode,
369    ) -> (GraphNodeId, GraphEdgeId) {
370        let span = Some(new_node.span());
371
372        // Make corresponding operator instance (if `node` is an operator).
373        let op_inst_opt = 'oc: {
374            let GraphNode::Operator(operator) = &new_node else {
375                break 'oc None;
376            };
377            let Some(op_constraints) = find_op_op_constraints(operator) else {
378                break 'oc None;
379            };
380            let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
381            let generics = get_operator_generics(
382                &mut Vec::new(), // TODO(mingwei) diagnostics
383                operator,
384            );
385            Some(OperatorInstance {
386                op_constraints,
387                input_ports: vec![input_port],
388                output_ports: vec![output_port],
389                singletons_referenced: operator.singletons_referenced.clone(),
390                generics,
391                arguments_pre: operator.args.clone(),
392                arguments_raw: operator.args_raw.clone(),
393            })
394        };
395
396        // Insert new `node`.
397        let node_id = self.nodes.insert(new_node);
398        // Insert corresponding `OperatorInstance` if applicable.
399        if let Some(op_inst) = op_inst_opt {
400            self.operator_instances.insert(node_id, op_inst);
401        }
402        // Update edges to insert node within `edge_id`.
403        let (e0, e1) = self
404            .graph
405            .insert_intermediate_vertex(node_id, edge_id)
406            .unwrap();
407
408        // Update corresponding ports.
409        let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
410        self.ports
411            .insert(e0, (src_idx, PortIndexValue::Elided(span)));
412        self.ports
413            .insert(e1, (PortIndexValue::Elided(span), dst_idx));
414
415        (node_id, e1)
416    }
417
418    /// Remove the node `node_id` but preserves and connects the single predecessor and single successor.
419    /// Panics if the node does not have exactly one predecessor and one successor, or is not in the graph.
420    pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
421        assert_eq!(
422            1,
423            self.node_degree_in(node_id),
424            "Removed intermediate node must have one predecessor"
425        );
426        assert_eq!(
427            1,
428            self.node_degree_out(node_id),
429            "Removed intermediate node must have one successor"
430        );
431        assert!(
432            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
433            "Should not remove intermediate node after subgraph partitioning"
434        );
435
436        assert!(self.nodes.remove(node_id).is_some());
437        let (new_edge_id, (pred_edge_id, succ_edge_id)) =
438            self.graph.remove_intermediate_vertex(node_id).unwrap();
439        self.operator_instances.remove(node_id);
440        self.node_varnames.remove(node_id);
441
442        let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
443        let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
444        self.ports.insert(new_edge_id, (src_port, dst_port));
445    }
446
447    /// Helper method: determine the "color" (pull vs push) of a node based on its in and out degree,
448    /// excluding reference edges. If linear (1 in, 1 out), color is `None`, indicating it can be
449    /// either push or pull.
450    ///
451    /// Note that this does NOT consider `DelayType` barriers (which generally implies `Pull`).
452    pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
453        if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
454            return Some(Color::Hoff);
455        }
456        // In-degree, excluding ref-edges.
457        let inn_degree = self.node_predecessor_nodes(node_id).count();
458        // Out-degree excluding ref-edges.
459        let out_degree = self.node_successor_nodes(node_id).count();
460
461        match (inn_degree, out_degree) {
462            (0, 0) => None, // Generally should not happen, "Degenerate subgraph detected".
463            (0, 1) => Some(Color::Pull),
464            (1, 0) => Some(Color::Push),
465            (1, 1) => None, // Linear, can be either push or pull.
466            (_many, 0 | 1) => Some(Color::Pull),
467            (0 | 1, _many) => Some(Color::Push),
468            (_many, _to_many) => Some(Color::Comp),
469        }
470    }
471
472    /// Set the operator tag (for debugging/tracing).
473    pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
474        self.operator_tag.insert(node_id, tag.to_owned());
475    }
476}
477
478/// Singleton references.
479impl DfirGraph {
480    /// Set the singletons referenced for the `node_id` operator. Each reference corresponds to the
481    /// same index in the [`crate::parse::Operator::singletons_referenced`] vec.
482    pub fn set_node_singleton_references(
483        &mut self,
484        node_id: GraphNodeId,
485        singletons_referenced: Vec<Option<GraphNodeId>>,
486    ) -> Option<Vec<Option<GraphNodeId>>> {
487        self.node_singleton_references
488            .insert(node_id, singletons_referenced)
489    }
490
491    /// Gets the singletons referenced by a node. Returns an empty iterator for non-operators and
492    /// operators that do not reference singletons.
493    pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
494        self.node_singleton_references
495            .get(node_id)
496            .map(std::ops::Deref::deref)
497            .unwrap_or_default()
498    }
499}
500
501/// Module methods.
502impl DfirGraph {
503    /// When modules are imported into a flat graph, they come with an input and output ModuleBoundary node.
504    /// The partitioner doesn't understand these nodes and will panic if it encounters them.
505    /// merge_modules removes them from the graph, stitching the input and ouput sides of the ModuleBondaries based on their ports
506    /// For example:
507    ///     source_iter([]) -> \[myport\]ModuleBoundary(input)\[my_port\] -> map(|x| x) -> ModuleBoundary(output) -> null();
508    /// in the above eaxmple, the \[myport\] port will be used to connect the source_iter with the map that is inside of the module.
509    /// The output module boundary has elided ports, this is also used to match up the input/output across the module boundary.
510    pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
511        let mod_bound_nodes = self
512            .nodes()
513            .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
514            .map(|(nid, _node)| nid)
515            .collect::<Vec<_>>();
516
517        for mod_bound_node in mod_bound_nodes {
518            self.remove_module_boundary(mod_bound_node)?;
519        }
520
521        Ok(())
522    }
523
524    /// see `merge_modules`
525    /// This function removes a singular module boundary from the graph and performs the necessary stitching to fix the graph afterward.
526    /// `merge_modules` calls this function for each module boundary in the graph.
527    fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
528        assert!(
529            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
530            "Should not remove intermediate node after subgraph partitioning"
531        );
532
533        let mut mod_pred_ports = BTreeMap::new();
534        let mut mod_succ_ports = BTreeMap::new();
535
536        for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
537            let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
538            mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
539        }
540
541        for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
542            let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
543            mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
544        }
545
546        if mod_pred_ports.keys().collect::<BTreeSet<_>>()
547            != mod_succ_ports.keys().collect::<BTreeSet<_>>()
548        {
549            // get module boundary node
550            let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
551                panic!();
552            };
553
554            if *input {
555                return Err(Diagnostic {
556                    span: *import_expr,
557                    level: Level::Error,
558                    message: format!(
559                        "The ports into the module did not match. input: {:?}, expected: {:?}",
560                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
561                        mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
562                    ),
563                });
564            } else {
565                return Err(Diagnostic {
566                    span: *import_expr,
567                    level: Level::Error,
568                    message: format!(
569                        "The ports out of the module did not match. output: {:?}, expected: {:?}",
570                        mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
571                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
572                    ),
573                });
574            }
575        }
576
577        for (port, (pred_edge, pred_port)) in mod_pred_ports {
578            let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
579
580            let (src, _) = self.edge(pred_edge);
581            let (_, dst) = self.edge(succ_edge);
582            self.remove_edge(pred_edge);
583            self.remove_edge(succ_edge);
584
585            let new_edge_id = self.graph.insert_edge(src, dst);
586            self.ports.insert(new_edge_id, (pred_port, succ_port));
587        }
588
589        self.graph.remove_vertex(mod_bound_node);
590        self.nodes.remove(mod_bound_node);
591
592        Ok(())
593    }
594}
595
596/// Edge methods.
597impl DfirGraph {
598    /// Get the `src` and `dst` for an edge: `(src GraphNodeId, dst GraphNodeId)`.
599    pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
600        let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
601        (src, dst)
602    }
603
604    /// Get the source and destination ports for an edge: `(src &PortIndexValue, dst &PortIndexValue)`.
605    pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
606        let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
607        (src_port, dst_port)
608    }
609
610    /// Iterator of all edge IDs `GraphEdgeId`.
611    pub fn edge_ids(&self) -> slotmap::basic::Keys<GraphEdgeId, (GraphNodeId, GraphNodeId)> {
612        self.graph.edge_ids()
613    }
614
615    /// Iterator over all edges: `(GraphEdgeId, (src GraphNodeId, dst GraphNodeId))`.
616    pub fn edges(
617        &self,
618    ) -> impl '_
619    + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
620    + FusedIterator
621    + Clone
622    + Debug {
623        self.graph.edges()
624    }
625
626    /// Insert an edge between nodes thru the given ports.
627    pub fn insert_edge(
628        &mut self,
629        src: GraphNodeId,
630        src_port: PortIndexValue,
631        dst: GraphNodeId,
632        dst_port: PortIndexValue,
633    ) -> GraphEdgeId {
634        let edge_id = self.graph.insert_edge(src, dst);
635        self.ports.insert(edge_id, (src_port, dst_port));
636        edge_id
637    }
638
639    /// Removes an edge and its corresponding ports and edge type info.
640    pub fn remove_edge(&mut self, edge: GraphEdgeId) {
641        let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
642        let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
643    }
644}
645
646/// Subgraph methods.
647impl DfirGraph {
648    /// Nodes belonging to the given subgraph.
649    pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
650        self.subgraph_nodes
651            .get(subgraph_id)
652            .expect("Subgraph not found.")
653    }
654
655    /// Iterator over all subgraph IDs.
656    pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
657        self.subgraph_nodes.keys()
658    }
659
660    /// Iterator over all subgraphs, ID and members: `(GraphSubgraphId, Vec<GraphNodeId>)`.
661    pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
662        self.subgraph_nodes.iter()
663    }
664
665    /// Create a subgraph consisting of `node_ids`. Returns an error if any of the nodes are already in a subgraph.
666    pub fn insert_subgraph(
667        &mut self,
668        node_ids: Vec<GraphNodeId>,
669    ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
670        // Check none are already in subgraphs
671        for &node_id in node_ids.iter() {
672            if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
673                return Err((node_id, old_sg_id));
674            }
675        }
676        let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
677            for &node_id in node_ids.iter() {
678                self.node_subgraph.insert(node_id, sg_id);
679            }
680            node_ids
681        });
682
683        Ok(subgraph_id)
684    }
685
686    /// Removes a node from its subgraph. Returns true if the node was in a subgraph.
687    pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
688        if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
689            self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
690            true
691        } else {
692            false
693        }
694    }
695
696    /// Gets the stratum number of the subgraph.
697    pub fn subgraph_stratum(&self, sg_id: GraphSubgraphId) -> Option<usize> {
698        self.subgraph_stratum.get(sg_id).copied()
699    }
700
701    /// Set subgraph's stratum number, returning the old value if exists.
702    pub fn set_subgraph_stratum(
703        &mut self,
704        sg_id: GraphSubgraphId,
705        stratum: usize,
706    ) -> Option<usize> {
707        self.subgraph_stratum.insert(sg_id, stratum)
708    }
709
710    /// Gets whether the subgraph is lazy or not
711    fn subgraph_laziness(&self, sg_id: GraphSubgraphId) -> bool {
712        self.subgraph_laziness.get(sg_id).copied().unwrap_or(false)
713    }
714
715    /// Set subgraph's laziness, returning the old value.
716    pub fn set_subgraph_laziness(&mut self, sg_id: GraphSubgraphId, lazy: bool) -> bool {
717        self.subgraph_laziness.insert(sg_id, lazy).unwrap_or(false)
718    }
719
720    /// Returns the the stratum number of the largest (latest) stratum (inclusive).
721    pub fn max_stratum(&self) -> Option<usize> {
722        self.subgraph_stratum.values().copied().max()
723    }
724
725    /// Helper: finds the first index in `subgraph_nodes` where it transitions from pull to push.
726    fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
727        subgraph_nodes
728            .iter()
729            .position(|&node_id| {
730                self.node_color(node_id)
731                    .is_some_and(|color| Color::Pull != color)
732            })
733            .unwrap_or(subgraph_nodes.len())
734    }
735}
736
737/// Display/output methods.
738impl DfirGraph {
739    /// Helper to generate a deterministic `Ident` for the given node.
740    fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
741        let name = match &self.nodes[node_id] {
742            GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
743            GraphNode::Handoff { .. } => format!(
744                "hoff_{:?}_{}",
745                node_id.data(),
746                if is_pred { "recv" } else { "send" }
747            ),
748            GraphNode::ModuleBoundary { .. } => panic!(),
749        };
750        let span = match (is_pred, &self.nodes[node_id]) {
751            (_, GraphNode::Operator(operator)) => operator.span(),
752            (true, &GraphNode::Handoff { src_span, .. }) => src_span,
753            (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
754            (_, GraphNode::ModuleBoundary { .. }) => panic!(),
755        };
756        Ident::new(&name, span)
757    }
758
759    /// For per-node singleton references. Helper to generate a deterministic `Ident` for the given node.
760    fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
761        Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
762    }
763
764    /// Resolve the singletons via [`Self::node_singleton_references`] for the given `node_id`.
765    fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<Ident> {
766        self.node_singleton_references(node_id)
767            .iter()
768            .map(|singleton_node_id| {
769                // TODO(mingwei): this `expect` should be caught in error checking
770                self.node_as_singleton_ident(
771                    singleton_node_id
772                        .expect("Expected singleton to be resolved but was not, this is a bug."),
773                    span,
774                )
775            })
776            .collect::<Vec<_>>()
777    }
778
779    /// Returns each subgraph's receive and send handoffs.
780    /// `Map<GraphSubgraphId, (recv handoffs, send handoffs)>`
781    fn helper_collect_subgraph_handoffs(
782        &self,
783    ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
784        // Get data on handoff src and dst subgraphs.
785        let mut subgraph_handoffs: SecondaryMap<
786            GraphSubgraphId,
787            (Vec<GraphNodeId>, Vec<GraphNodeId>),
788        > = self
789            .subgraph_nodes
790            .keys()
791            .map(|k| (k, Default::default()))
792            .collect();
793
794        // For each handoff node, add it to the `send`/`recv` lists for the corresponding subgraphs.
795        for (hoff_id, node) in self.nodes() {
796            if !matches!(node, GraphNode::Handoff { .. }) {
797                continue;
798            }
799            // Receivers from the handoff. (Should really only be one).
800            for (_edge, succ_id) in self.node_successors(hoff_id) {
801                let succ_sg = self.node_subgraph(succ_id).unwrap();
802                subgraph_handoffs[succ_sg].0.push(hoff_id);
803            }
804            // Senders into the handoff. (Should really only be one).
805            for (_edge, pred_id) in self.node_predecessors(hoff_id) {
806                let pred_sg = self.node_subgraph(pred_id).unwrap();
807                subgraph_handoffs[pred_sg].1.push(hoff_id);
808            }
809        }
810
811        subgraph_handoffs
812    }
813
814    /// Generate a deterministic `Ident` for the given loop ID.
815    fn loop_as_ident(loop_id: GraphLoopId) -> Ident {
816        Ident::new(&format!("loop_{:?}", loop_id.data()), Span::call_site())
817    }
818
819    /// Code for adding all nested loops.
820    fn codegen_nested_loops(&self, df: &Ident) -> TokenStream {
821        // Breadth-first iteration from outermost (root) loops to deepest nested loops.
822        let mut out = TokenStream::new();
823        let mut queue = VecDeque::from_iter(self.root_loops.iter().copied());
824        while let Some(loop_id) = queue.pop_front() {
825            let parent_opt = self
826                .loop_parent(loop_id)
827                .map(Self::loop_as_ident)
828                .map(|ident| quote! { Some(#ident) })
829                .unwrap_or_else(|| quote! { None });
830            let loop_name = Self::loop_as_ident(loop_id);
831            out.append_all(quote! {
832                let #loop_name = #df.add_loop(#parent_opt);
833            });
834            queue.extend(self.loop_children.get(loop_id).into_iter().flatten());
835        }
836        out
837    }
838
839    /// Emit this graph as runnable Rust source code tokens.
840    pub fn as_code(
841        &self,
842        root: &TokenStream,
843        include_type_guards: bool,
844        prefix: TokenStream,
845        diagnostics: &mut Vec<Diagnostic>,
846    ) -> TokenStream {
847        let df = Ident::new(GRAPH, Span::call_site());
848        let context = Ident::new(CONTEXT, Span::call_site());
849
850        // Code for adding handoffs.
851        let handoff_code = self
852            .nodes
853            .iter()
854            .filter_map(|(node_id, node)| match node {
855                GraphNode::Operator(_) => None,
856                &GraphNode::Handoff { src_span, dst_span } => Some((node_id, (src_span, dst_span))),
857                GraphNode::ModuleBoundary { .. } => panic!(),
858            })
859            .map(|(node_id, (src_span, dst_span))| {
860                let ident_send = Ident::new(&format!("hoff_{:?}_send", node_id.data()), dst_span);
861                let ident_recv = Ident::new(&format!("hoff_{:?}_recv", node_id.data()), src_span);
862                let span = src_span.join(dst_span).unwrap_or(src_span);
863                let mut hoff_name = Literal::string(&format!("handoff {:?}", node_id));
864                hoff_name.set_span(span);
865                let hoff_type = quote_spanned! (span=> #root::scheduled::handoff::VecHandoff<_>);
866                quote_spanned! {span=>
867                    let (#ident_send, #ident_recv) =
868                        #df.make_edge::<_, #hoff_type>(#hoff_name);
869                }
870            });
871
872        let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
873
874        // we first generate the subgraphs that have no inputs to guide type inference
875        let (subgraphs_without_preds, subgraphs_with_preds) = self
876            .subgraph_nodes
877            .iter()
878            .partition::<Vec<_>, _>(|(_, nodes)| {
879                nodes
880                    .iter()
881                    .any(|&node_id| self.node_degree_in(node_id) == 0)
882            });
883
884        let mut op_prologue_code = Vec::new();
885        let mut subgraphs = Vec::new();
886        {
887            for &(subgraph_id, subgraph_nodes) in subgraphs_without_preds
888                .iter()
889                .chain(subgraphs_with_preds.iter())
890            {
891                let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
892                let recv_ports: Vec<Ident> = recv_hoffs
893                    .iter()
894                    .map(|&hoff_id| self.node_as_ident(hoff_id, true))
895                    .collect();
896                let send_ports: Vec<Ident> = send_hoffs
897                    .iter()
898                    .map(|&hoff_id| self.node_as_ident(hoff_id, false))
899                    .collect();
900
901                let recv_port_code = recv_ports.iter().map(|ident| {
902                    quote! {
903                        let mut #ident = #ident.borrow_mut_swap();
904                        let #ident = #ident.drain(..);
905                    }
906                });
907                let send_port_code = send_ports.iter().map(|ident| {
908                    quote! {
909                        let #ident = #root::pusherator::for_each::ForEach::new(|v| {
910                            #ident.give(Some(v));
911                        });
912                    }
913                });
914
915                let loop_id = self
916                    // All nodes in a subgraph should be in the same loop.
917                    .node_loop(subgraph_nodes[0]);
918
919                let mut subgraph_op_iter_code = Vec::new();
920                let mut subgraph_op_iter_after_code = Vec::new();
921                {
922                    let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
923
924                    let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
925                    let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
926
927                    for (idx, &node_id) in nodes_iter.enumerate() {
928                        let node = &self.nodes[node_id];
929                        assert!(
930                            matches!(node, GraphNode::Operator(_)),
931                            "Handoffs are not part of subgraphs."
932                        );
933                        let op_inst = &self.operator_instances[node_id];
934
935                        let op_span = node.span();
936                        let op_name = op_inst.op_constraints.name;
937                        // Use op's span for root. #root is expected to be correct, any errors should span back to the op gen.
938                        let root = change_spans(root.clone(), op_span);
939                        // TODO(mingwei): Just use `op_inst.op_constraints`?
940                        let op_constraints = OPERATORS
941                            .iter()
942                            .find(|op| op_name == op.name)
943                            .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
944
945                        let ident = self.node_as_ident(node_id, false);
946
947                        {
948                            // TODO clean this up.
949                            // Collect input arguments (predecessors).
950                            let mut input_edges = self
951                                .graph
952                                .predecessor_edges(node_id)
953                                .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
954                                .collect::<Vec<_>>();
955                            // Ensure sorted by port index.
956                            input_edges.sort();
957
958                            let inputs = input_edges
959                                .iter()
960                                .map(|&(_port, edge_id)| {
961                                    let (pred, _) = self.edge(edge_id);
962                                    self.node_as_ident(pred, true)
963                                })
964                                .collect::<Vec<_>>();
965
966                            // Collect output arguments (successors).
967                            let mut output_edges = self
968                                .graph
969                                .successor_edges(node_id)
970                                .map(|edge_id| (&self.ports[edge_id].0, edge_id))
971                                .collect::<Vec<_>>();
972                            // Ensure sorted by port index.
973                            output_edges.sort();
974
975                            let outputs = output_edges
976                                .iter()
977                                .map(|&(_port, edge_id)| {
978                                    let (_, succ) = self.edge(edge_id);
979                                    self.node_as_ident(succ, false)
980                                })
981                                .collect::<Vec<_>>();
982
983                            let is_pull = idx < pull_to_push_idx;
984
985                            let singleton_output_ident = &if op_constraints.has_singleton_output {
986                                self.node_as_singleton_ident(node_id, op_span)
987                            } else {
988                                // This ident *should* go unused.
989                                Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
990                            };
991
992                            // There's a bit of dark magic hidden in `Span`s... you'd think it's just a `file:line:column`,
993                            // but it has one extra bit of info for _name resolution_, used for `Ident`s. `Span::call_site()`
994                            // has the (unhygienic) resolution we want, an ident is just solely determined by its string name,
995                            // which is what you'd expect out of unhygienic proc macros like this. Meanwhile, declarative macros
996                            // use `Span::mixed_site()` which is weird and I don't understand it. It turns out that if you call
997                            // the dfir syntax proc macro from _within_ a declarative macro then `op_span` will have the
998                            // bad `Span::mixed_site()` name resolution and cause "Cannot find value `df/context`" errors. So
999                            // we call `.resolved_at()` to fix resolution back to `Span::call_site()`. -Mingwei
1000                            let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1001                            let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1002
1003                            let singletons_resolved =
1004                                self.helper_resolve_singletons(node_id, op_span);
1005                            let arguments = &process_singletons::postprocess_singletons(
1006                                op_inst.arguments_raw.clone(),
1007                                singletons_resolved.clone(),
1008                                context,
1009                            );
1010                            let arguments_handles =
1011                                &process_singletons::postprocess_singletons_handles(
1012                                    op_inst.arguments_raw.clone(),
1013                                    singletons_resolved.clone(),
1014                                );
1015
1016                            let source_tag = 'a: {
1017                                if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1018                                    break 'a tag;
1019                                }
1020
1021                                #[cfg(nightly)]
1022                                if proc_macro::is_available() {
1023                                    let op_span = op_span.unwrap();
1024                                    break 'a format!(
1025                                        "loc_{}_{}_{}_{}_{}",
1026                                        op_span
1027                                            .source_file()
1028                                            .path()
1029                                            .display()
1030                                            .to_string()
1031                                            .replace(|x: char| !x.is_alphanumeric(), "_"),
1032                                        op_span.start().line(),
1033                                        op_span.start().column(),
1034                                        op_span.end().line(),
1035                                        op_span.end().column(),
1036                                    );
1037                                }
1038
1039                                format!(
1040                                    "loc_nopath_{}_{}_{}_{}",
1041                                    op_span.start().line,
1042                                    op_span.start().column,
1043                                    op_span.end().line,
1044                                    op_span.end().column
1045                                )
1046                            };
1047
1048                            let fn_ident = format_ident!(
1049                                "{}__{}__{}",
1050                                ident,
1051                                op_name,
1052                                source_tag,
1053                                span = op_span
1054                            );
1055
1056                            let context_args = WriteContextArgs {
1057                                root: &root,
1058                                df_ident: df_local,
1059                                context,
1060                                subgraph_id,
1061                                node_id,
1062                                loop_id,
1063                                op_span,
1064                                op_tag: self.operator_tag.get(node_id).cloned(),
1065                                work_fn: &fn_ident,
1066                                ident: &ident,
1067                                is_pull,
1068                                inputs: &inputs,
1069                                outputs: &outputs,
1070                                singleton_output_ident,
1071                                op_name,
1072                                op_inst,
1073                                arguments,
1074                                arguments_handles,
1075                            };
1076
1077                            let write_result =
1078                                (op_constraints.write_fn)(&context_args, diagnostics);
1079                            let OperatorWriteOutput {
1080                                write_prologue,
1081                                write_iterator,
1082                                write_iterator_after,
1083                            } = write_result.unwrap_or_else(|()| {
1084                                assert!(
1085                                    diagnostics.iter().any(Diagnostic::is_error),
1086                                    "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1087                                    op_name,
1088                                );
1089                                OperatorWriteOutput { write_iterator: null_write_iterator_fn(&context_args), ..Default::default() }
1090                            });
1091
1092                            op_prologue_code.push(syn::parse_quote! {
1093                                #[allow(non_snake_case)]
1094                                #[inline(always)]
1095                                fn #fn_ident<T>(thunk: impl FnOnce() -> T) -> T {
1096                                    thunk()
1097                                }
1098                            });
1099                            op_prologue_code.push(write_prologue);
1100
1101                            subgraph_op_iter_code.push(write_iterator);
1102
1103                            if include_type_guards {
1104                                let type_guard = if is_pull {
1105                                    quote_spanned! {op_span=>
1106                                        let #ident = {
1107                                            #[allow(non_snake_case)]
1108                                            #[inline(always)]
1109                                            pub fn #fn_ident<Item, Input: ::std::iter::Iterator<Item = Item>>(input: Input) -> impl ::std::iter::Iterator<Item = Item> {
1110                                                #[repr(transparent)]
1111                                                struct Pull<Item, Input: ::std::iter::Iterator<Item = Item>> {
1112                                                    inner: Input
1113                                                }
1114
1115                                                impl<Item, Input: ::std::iter::Iterator<Item = Item>> Iterator for Pull<Item, Input> {
1116                                                    type Item = Item;
1117
1118                                                    #[inline(always)]
1119                                                    fn next(&mut self) -> Option<Self::Item> {
1120                                                        self.inner.next()
1121                                                    }
1122
1123                                                    #[inline(always)]
1124                                                    fn size_hint(&self) -> (usize, Option<usize>) {
1125                                                        self.inner.size_hint()
1126                                                    }
1127                                                }
1128
1129                                                Pull {
1130                                                    inner: input
1131                                                }
1132                                            }
1133                                            #fn_ident( #ident )
1134                                        };
1135                                    }
1136                                } else {
1137                                    quote_spanned! {op_span=>
1138                                        let #ident = {
1139                                            #[allow(non_snake_case)]
1140                                            #[inline(always)]
1141                                            pub fn #fn_ident<Item, Input: #root::pusherator::Pusherator<Item = Item>>(input: Input) -> impl #root::pusherator::Pusherator<Item = Item> {
1142                                                #[repr(transparent)]
1143                                                struct Push<Item, Input: #root::pusherator::Pusherator<Item = Item>> {
1144                                                    inner: Input
1145                                                }
1146
1147                                                impl<Item, Input: #root::pusherator::Pusherator<Item = Item>> #root::pusherator::Pusherator for Push<Item, Input> {
1148                                                    type Item = Item;
1149
1150                                                    #[inline(always)]
1151                                                    fn give(&mut self, item: Self::Item) {
1152                                                        self.inner.give(item)
1153                                                    }
1154                                                }
1155
1156                                                Push {
1157                                                    inner: input
1158                                                }
1159                                            }
1160                                            #fn_ident( #ident )
1161                                        };
1162                                    }
1163                                };
1164                                subgraph_op_iter_code.push(type_guard);
1165                            }
1166                            subgraph_op_iter_after_code.push(write_iterator_after);
1167                        }
1168                    }
1169
1170                    {
1171                        // Determine pull and push halves of the `Pivot`.
1172                        let pull_ident = if 0 < pull_to_push_idx {
1173                            self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1174                        } else {
1175                            // Entire subgraph is push (with a single recv/pull handoff input).
1176                            recv_ports[0].clone()
1177                        };
1178
1179                        #[rustfmt::skip]
1180                        let push_ident = if let Some(&node_id) =
1181                            subgraph_nodes.get(pull_to_push_idx)
1182                        {
1183                            self.node_as_ident(node_id, false)
1184                        } else if 1 == send_ports.len() {
1185                            // Entire subgraph is pull (with a single send/push handoff output).
1186                            send_ports[0].clone()
1187                        } else {
1188                            diagnostics.push(Diagnostic::spanned(
1189                                pull_ident.span(),
1190                                Level::Error,
1191                                "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1192                            ));
1193                            continue;
1194                        };
1195
1196                        // Pivot span is combination of pull and push spans (or if not possible, just take the push).
1197                        let pivot_span = pull_ident
1198                            .span()
1199                            .join(push_ident.span())
1200                            .unwrap_or_else(|| push_ident.span());
1201                        let pivot_fn_ident =
1202                            Ident::new(&format!("pivot_run_sg_{:?}", subgraph_id.0), pivot_span);
1203                        subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1204                            #[inline(always)]
1205                            fn #pivot_fn_ident<Pull: ::std::iter::Iterator<Item = Item>, Push: #root::pusherator::Pusherator<Item = Item>, Item>(pull: Pull, push: Push) {
1206                                #root::pusherator::pivot::Pivot::new(pull, push).run();
1207                            }
1208                            #pivot_fn_ident(#pull_ident, #push_ident);
1209                        });
1210                    }
1211                };
1212
1213                let subgraph_name = Literal::string(&format!("Subgraph {:?}", subgraph_id));
1214                let stratum = Literal::usize_unsuffixed(
1215                    self.subgraph_stratum.get(subgraph_id).cloned().unwrap_or(0),
1216                );
1217                let laziness = self.subgraph_laziness(subgraph_id);
1218
1219                // Codegen: the loop that this subgraph is in `Some(<loop_id>)`, or `None` if not in a loop.
1220                let loop_id_opt = loop_id
1221                    .map(Self::loop_as_ident)
1222                    .map(|ident| quote! { Some(#ident) })
1223                    .unwrap_or_else(|| quote! { None });
1224
1225                subgraphs.push(quote! {
1226                    #df.add_subgraph_full(
1227                        #subgraph_name,
1228                        #stratum,
1229                        var_expr!( #( #recv_ports ),* ),
1230                        var_expr!( #( #send_ports ),* ),
1231                        #laziness,
1232                        #loop_id_opt,
1233                        move |#context, var_args!( #( #recv_ports ),* ), var_args!( #( #send_ports ),* )| {
1234                            #( #recv_port_code )*
1235                            #( #send_port_code )*
1236                            #( #subgraph_op_iter_code )*
1237                            #( #subgraph_op_iter_after_code )*
1238                        },
1239                    );
1240                });
1241            }
1242        }
1243
1244        let loop_code = self.codegen_nested_loops(&df);
1245
1246        // These two are quoted separately here because iterators are lazily evaluated, so this
1247        // forces them to do their work. This work includes populating some data, namely
1248        // `diagonstics`, which we need to determine if it compilation was actually successful.
1249        // -Mingwei
1250        let code = quote! {
1251            #( #handoff_code )*
1252            #loop_code
1253            #( #op_prologue_code )*
1254            #( #subgraphs )*
1255        };
1256
1257        let meta_graph_json = serde_json::to_string(&self).unwrap();
1258        let meta_graph_json = Literal::string(&meta_graph_json);
1259
1260        let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1261        let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1262        let diagnostics_json = Literal::string(&diagnostics_json);
1263
1264        quote! {
1265            {
1266                #[allow(unused_qualifications)]
1267                {
1268                    #prefix
1269
1270                    use #root::{var_expr, var_args};
1271
1272                    let mut #df = #root::scheduled::graph::Dfir::new();
1273                    #df.__assign_meta_graph(#meta_graph_json);
1274                    #df.__assign_diagnostics(#diagnostics_json);
1275
1276                    #code
1277
1278                    #df
1279                }
1280            }
1281        }
1282    }
1283
1284    /// Color mode (pull vs. push, handoff vs. comp) for nodes. Some nodes can be push *OR* pull;
1285    /// those nodes will not be set in the returned map.
1286    pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1287        let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1288            .node_ids()
1289            .filter_map(|node_id| {
1290                let op_color = self.node_color(node_id)?;
1291                Some((node_id, op_color))
1292            })
1293            .collect();
1294
1295        // Fill in rest via subgraphs.
1296        for sg_nodes in self.subgraph_nodes.values() {
1297            let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1298
1299            for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1300                let is_pull = idx < pull_to_push_idx;
1301                node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1302            }
1303        }
1304
1305        node_color_map
1306    }
1307
1308    /// Writes this graph as mermaid into a string.
1309    pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1310        let mut output = String::new();
1311        self.write_mermaid(&mut output, write_config).unwrap();
1312        output
1313    }
1314
1315    /// Writes this graph as mermaid into the given `Write`.
1316    pub fn write_mermaid(
1317        &self,
1318        output: impl std::fmt::Write,
1319        write_config: &WriteConfig,
1320    ) -> std::fmt::Result {
1321        let mut graph_write = Mermaid::new(output);
1322        self.write_graph(&mut graph_write, write_config)
1323    }
1324
1325    /// Writes this graph as DOT (graphviz) into a string.
1326    pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1327        let mut output = String::new();
1328        let mut graph_write = Dot::new(&mut output);
1329        self.write_graph(&mut graph_write, write_config).unwrap();
1330        output
1331    }
1332
1333    /// Writes this graph as DOT (graphviz) into the given `Write`.
1334    pub fn write_dot(
1335        &self,
1336        output: impl std::fmt::Write,
1337        write_config: &WriteConfig,
1338    ) -> std::fmt::Result {
1339        let mut graph_write = Dot::new(output);
1340        self.write_graph(&mut graph_write, write_config)
1341    }
1342
1343    /// Write out this graph using the given `GraphWrite`. E.g. `Mermaid` or `Dot.
1344    pub(crate) fn write_graph<W>(
1345        &self,
1346        mut graph_write: W,
1347        write_config: &WriteConfig,
1348    ) -> Result<(), W::Err>
1349    where
1350        W: GraphWrite,
1351    {
1352        fn helper_edge_label(
1353            src_port: &PortIndexValue,
1354            dst_port: &PortIndexValue,
1355        ) -> Option<String> {
1356            let src_label = match src_port {
1357                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1358                PortIndexValue::Int(index) => Some(index.value.to_string()),
1359                _ => None,
1360            };
1361            let dst_label = match dst_port {
1362                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1363                PortIndexValue::Int(index) => Some(index.value.to_string()),
1364                _ => None,
1365            };
1366            let label = match (src_label, dst_label) {
1367                (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1368                (Some(l1), None) => Some(l1),
1369                (None, Some(l2)) => Some(l2),
1370                (None, None) => None,
1371            };
1372            label
1373        }
1374
1375        // Make node color map one time.
1376        let node_color_map = self.node_color_map();
1377
1378        // Collect varnames.
1379        let mut sg_varname_nodes =
1380            <SparseSecondaryMap<GraphSubgraphId, BTreeMap<Varname, BTreeSet<GraphNodeId>>>>::new();
1381        let mut varname_nodes = <BTreeMap<Varname, BTreeSet<GraphNodeId>>>::new();
1382        if !write_config.no_varnames {
1383            for (node_id, varname) in self.node_varnames.iter() {
1384                // Only collect if needed.
1385                let varname_map = if !write_config.no_subgraphs {
1386                    let Some(sg_id) = self.node_subgraph(node_id) else {
1387                        continue;
1388                    };
1389                    sg_varname_nodes.entry(sg_id).unwrap().or_default()
1390                } else {
1391                    &mut varname_nodes
1392                };
1393                varname_map
1394                    .entry(varname.clone())
1395                    .or_default()
1396                    .insert(node_id);
1397            }
1398        }
1399
1400        // Write prologue.
1401        graph_write.write_prologue()?;
1402
1403        // Write nodes.
1404        let mut skipped_handoffs = BTreeSet::new();
1405        let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1406        for (node_id, node) in self.nodes() {
1407            if matches!(node, GraphNode::Handoff { .. }) {
1408                if write_config.no_handoffs {
1409                    skipped_handoffs.insert(node_id);
1410                    continue;
1411                } else {
1412                    let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1413                    let pred_sg = self.node_subgraph(pred_node);
1414                    let succ_node = self.node_successor_nodes(node_id).next().unwrap();
1415                    let succ_sg = self.node_subgraph(succ_node);
1416                    if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg) {
1417                        if pred_sg == succ_sg {
1418                            subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1419                        }
1420                    }
1421                }
1422            }
1423            graph_write.write_node(
1424                node_id,
1425                &if write_config.op_short_text {
1426                    node.to_name_string()
1427                } else if write_config.op_text_no_imports {
1428                    // Remove any lines that start with "use" (imports)
1429                    let full_text = node.to_pretty_string();
1430                    let mut output = String::new();
1431                    for sentence in full_text.split('\n') {
1432                        if sentence.trim().starts_with("use") {
1433                            continue;
1434                        }
1435                        output.push('\n');
1436                        output.push_str(sentence);
1437                    }
1438                    output.into()
1439                } else {
1440                    node.to_pretty_string()
1441                },
1442                if write_config.no_pull_push {
1443                    None
1444                } else {
1445                    node_color_map.get(node_id).copied()
1446                },
1447            )?;
1448        }
1449
1450        // Write edges.
1451        for (edge_id, (src_id, mut dst_id)) in self.edges() {
1452            // Handling for if `write_config.no_handoffs` true.
1453            if skipped_handoffs.contains(&src_id) {
1454                continue;
1455            }
1456
1457            let (src_port, mut dst_port) = self.edge_ports(edge_id);
1458            if skipped_handoffs.contains(&dst_id) {
1459                let mut handoff_succs = self.node_successors(dst_id);
1460                assert_eq!(1, handoff_succs.len());
1461                let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1462                dst_id = succ_node;
1463                dst_port = self.edge_ports(succ_edge).1;
1464            }
1465
1466            let label = helper_edge_label(src_port, dst_port);
1467            let delay_type = self
1468                .node_op_inst(dst_id)
1469                .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1470            graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1471        }
1472
1473        // Write reference edges.
1474        if !write_config.no_references {
1475            for dst_id in self.node_ids() {
1476                for src_ref_id in self
1477                    .node_singleton_references(dst_id)
1478                    .iter()
1479                    .copied()
1480                    .flatten()
1481                {
1482                    let delay_type = Some(DelayType::Stratum);
1483                    let label = None;
1484                    graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1485                }
1486            }
1487        }
1488
1489        // Write subgraphs.
1490        if !write_config.no_subgraphs {
1491            for (subgraph_id, subgraph_node_ids) in self.subgraph_nodes.iter() {
1492                let handoff_node_ids = subgraph_handoffs.get(&subgraph_id).into_iter().flatten();
1493                let subgraph_node_ids = subgraph_node_ids.iter();
1494                let all_node_ids = handoff_node_ids.chain(subgraph_node_ids).copied();
1495
1496                let stratum = self.subgraph_stratum.get(subgraph_id);
1497                graph_write.write_subgraph_start(subgraph_id, *stratum.unwrap(), all_node_ids)?;
1498                // Write out any variable names within the subgraph.
1499                if !write_config.no_varnames {
1500                    for (varname, varname_node_ids) in
1501                        sg_varname_nodes.remove(subgraph_id).into_iter().flatten()
1502                    {
1503                        assert!(!varname_node_ids.is_empty());
1504                        graph_write.write_varname(
1505                            &varname.0.to_string(),
1506                            varname_node_ids.into_iter(),
1507                            Some(subgraph_id),
1508                        )?;
1509                    }
1510                }
1511                graph_write.write_subgraph_end()?;
1512            }
1513        } else if !write_config.no_varnames {
1514            for (varname, varname_node_ids) in varname_nodes {
1515                graph_write.write_varname(
1516                    &varname.0.to_string(),
1517                    varname_node_ids.into_iter(),
1518                    None,
1519                )?;
1520            }
1521        }
1522
1523        // Write epilogue.
1524        graph_write.write_epilogue()?;
1525
1526        Ok(())
1527    }
1528
1529    /// Convert back into surface syntax.
1530    pub fn surface_syntax_string(&self) -> String {
1531        let mut string = String::new();
1532        self.write_surface_syntax(&mut string).unwrap();
1533        string
1534    }
1535
1536    /// Convert back into surface syntax.
1537    pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1538        for (key, node) in self.nodes.iter() {
1539            match node {
1540                GraphNode::Operator(op) => {
1541                    writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
1542                }
1543                GraphNode::Handoff { .. } => {
1544                    writeln!(write, "// {:?} = <handoff>;", key.data())?;
1545                }
1546                GraphNode::ModuleBoundary { .. } => panic!(),
1547            }
1548        }
1549        writeln!(write)?;
1550        for (_e, (src_key, dst_key)) in self.graph.edges() {
1551            writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
1552        }
1553        Ok(())
1554    }
1555
1556    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
1557    pub fn mermaid_string_flat(&self) -> String {
1558        let mut string = String::new();
1559        self.write_mermaid_flat(&mut string).unwrap();
1560        string
1561    }
1562
1563    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
1564    pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1565        writeln!(write, "flowchart TB")?;
1566        for (key, node) in self.nodes.iter() {
1567            match node {
1568                GraphNode::Operator(operator) => writeln!(
1569                    write,
1570                    "    %% {span}\n    {id:?}[\"{row_col} <tt>{code}</tt>\"]",
1571                    span = PrettySpan(node.span()),
1572                    id = key.data(),
1573                    row_col = PrettyRowCol(node.span()),
1574                    code = operator
1575                        .to_token_stream()
1576                        .to_string()
1577                        .replace('&', "&amp;")
1578                        .replace('<', "&lt;")
1579                        .replace('>', "&gt;")
1580                        .replace('"', "&quot;")
1581                        .replace('\n', "<br>"),
1582                ),
1583                GraphNode::Handoff { .. } => {
1584                    writeln!(write, r#"    {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
1585                }
1586                GraphNode::ModuleBoundary { .. } => {
1587                    writeln!(
1588                        write,
1589                        r#"    {:?}{{"{}"}}"#,
1590                        key.data(),
1591                        MODULE_BOUNDARY_NODE_STR
1592                    )
1593                }
1594            }?;
1595        }
1596        writeln!(write)?;
1597        for (_e, (src_key, dst_key)) in self.graph.edges() {
1598            writeln!(write, "    {:?}-->{:?}", src_key.data(), dst_key.data())?;
1599        }
1600        Ok(())
1601    }
1602}
1603
1604/// Loops
1605impl DfirGraph {
1606    /// Iterator over all loop IDs.
1607    pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
1608        self.loop_nodes.keys()
1609    }
1610
1611    /// Iterator over all loops, ID and members: `(GraphLoopId, Vec<GraphNodeId>)`.
1612    pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
1613        self.loop_nodes.iter()
1614    }
1615
1616    /// Create a new loop context, with the given parent loop (or `None`).
1617    pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
1618        let loop_id = self.loop_nodes.insert(Vec::new());
1619        self.loop_children.insert(loop_id, Vec::new());
1620        if let Some(parent_loop) = parent_loop {
1621            self.loop_parent.insert(loop_id, parent_loop);
1622            self.loop_children
1623                .get_mut(parent_loop)
1624                .unwrap()
1625                .push(loop_id);
1626        } else {
1627            self.root_loops.push(loop_id);
1628        }
1629        loop_id
1630    }
1631
1632    /// Get a node's loop context (or `None` for root).
1633    pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
1634        self.node_loops.get(node_id).copied()
1635    }
1636
1637    /// Get a subgraph's loop context (or `None` for root).
1638    pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
1639        let &node_id = self.subgraph(subgraph_id).first().unwrap();
1640        let out = self.node_loop(node_id);
1641        debug_assert!(
1642            self.subgraph(subgraph_id)
1643                .iter()
1644                .all(|&node_id| self.node_loop(node_id) == out),
1645            "Subgraph nodes should all have the same loop context."
1646        );
1647        out
1648    }
1649
1650    /// Get a loop context's parent loop context (or `None` for root).
1651    pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
1652        self.loop_parent.get(loop_id).copied()
1653    }
1654
1655    /// Get a loop context's child loops.
1656    pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
1657        self.loop_children.get(loop_id).unwrap()
1658    }
1659}
1660
1661/// Configuration for writing graphs.
1662#[derive(Clone, Debug, Default)]
1663#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
1664pub struct WriteConfig {
1665    /// Subgraphs will not be rendered if set.
1666    #[cfg_attr(feature = "clap-derive", arg(long))]
1667    pub no_subgraphs: bool,
1668    /// Variable names will not be rendered if set.
1669    #[cfg_attr(feature = "clap-derive", arg(long))]
1670    pub no_varnames: bool,
1671    /// Will not render pull/push shapes if set.
1672    #[cfg_attr(feature = "clap-derive", arg(long))]
1673    pub no_pull_push: bool,
1674    /// Will not render handoffs if set.
1675    #[cfg_attr(feature = "clap-derive", arg(long))]
1676    pub no_handoffs: bool,
1677    /// Will not render singleton references if set.
1678    #[cfg_attr(feature = "clap-derive", arg(long))]
1679    pub no_references: bool,
1680
1681    /// Op text will only be their name instead of the whole source.
1682    #[cfg_attr(feature = "clap-derive", arg(long))]
1683    pub op_short_text: bool,
1684    /// Op text will exclude any line that starts with "use".
1685    #[cfg_attr(feature = "clap-derive", arg(long))]
1686    pub op_text_no_imports: bool,
1687}
1688
1689/// Enum for choosing between mermaid and dot graph writing.
1690#[derive(Copy, Clone, Debug)]
1691#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
1692pub enum WriteGraphType {
1693    /// Mermaid graphs.
1694    Mermaid,
1695    /// Dot (Graphviz) graphs.
1696    Dot,
1697}