Skip to main content

dfir_lang/graph/
meta_graph.rs

1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18    DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19    null_write_iterator_fn,
20};
21use super::{
22    CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23    GraphSubgraphId, HANDOFF_NODE_STR, HandoffKind, MODULE_BOUNDARY_NODE_STR, OperatorInstance,
24    PortIndexValue, SINGLETON_SLOT_NODE_STR, Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Diagnostics, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30/// An abstract "meta graph" representation of a DFIR graph.
31///
32/// Can be with or without subgraph partitioning, stratification, and handoff insertion. This is
33/// the meta graph used for generating Rust source code in macros from DFIR sytnax.
34///
35/// This struct has a lot of methods for manipulating the graph, vaguely grouped together in
36/// separate `impl` blocks. You might notice a few particularly specific arbitray-seeming methods
37/// in here--those are just what was needed for the compilation algorithms. If you need another
38/// method then add it.
39#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41    /// Each node type (operator or handoff).
42    nodes: SlotMap<GraphNodeId, GraphNode>,
43
44    /// Instance data corresponding to each operator node.
45    /// This field will be empty after deserialization.
46    #[serde(skip)]
47    operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48    /// Debugging/tracing tag for each operator node.
49    operator_tag: SecondaryMap<GraphNodeId, String>,
50    /// Graph data structure (two-way adjacency list).
51    graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52    /// Input and output port for each edge.
53    ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55    /// Which loop a node belongs to (or none for top-level).
56    node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57    /// Which nodes belong to each loop.
58    loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59    /// For the loop, what is its parent (`None` for top-level).
60    loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61    /// What loops are at the root.
62    root_loops: Vec<GraphLoopId>,
63    /// For the loop, what are its child loops.
64    loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66    /// Which subgraph each node belongs to.
67    node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69    /// Which nodes belong to each subgraph.
70    subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71
72    /// Resolved singletons varnames references, per node.
73    node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
74    /// What variable name each graph node belongs to (if any). For debugging (graph writing) purposes only.
75    node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
76
77    /// Delay type for handoff nodes that represent tick-boundary back-edges.
78    /// Set by `order_subgraphs` for `defer_tick` / `defer_tick_lazy`, either on handoff nodes
79    /// it injects or on existing handoff nodes that it marks as tick-boundary back-edges.
80    handoff_delay_type: SparseSecondaryMap<GraphNodeId, DelayType>,
81}
82
83/// Basic methods.
84impl DfirGraph {
85    /// Create a new empty graph.
86    pub fn new() -> Self {
87        Default::default()
88    }
89}
90
91/// Node methods.
92impl DfirGraph {
93    /// Get a node with its operator instance (if applicable).
94    pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
95        self.nodes.get(node_id).expect("Node not found.")
96    }
97
98    /// Get the `OperatorInstance` for a given node. Node must be an operator and have an
99    /// `OperatorInstance` present, otherwise will return `None`.
100    ///
101    /// Note that no operator instances will be persent after deserialization.
102    pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
103        self.operator_instances.get(node_id)
104    }
105
106    /// Get the debug variable name attached to a graph node.
107    pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
108        self.node_varnames.get(node_id)
109    }
110
111    /// Get subgraph for node.
112    pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
113        self.node_subgraph.get(node_id).copied()
114    }
115
116    /// Degree into a node, i.e. the number of predecessors.
117    pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
118        self.graph.degree_in(node_id)
119    }
120
121    /// Degree out of a node, i.e. the number of successors.
122    pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
123        self.graph.degree_out(node_id)
124    }
125
126    /// Successors, iterator of `(GraphEdgeId, GraphNodeId)` of outgoing edges.
127    pub fn node_successors(
128        &self,
129        src: GraphNodeId,
130    ) -> impl '_
131    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
132    + ExactSizeIterator
133    + FusedIterator
134    + Clone
135    + Debug {
136        self.graph.successors(src)
137    }
138
139    /// Predecessors, iterator of `(GraphEdgeId, GraphNodeId)` of incoming edges.
140    pub fn node_predecessors(
141        &self,
142        dst: GraphNodeId,
143    ) -> impl '_
144    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
145    + ExactSizeIterator
146    + FusedIterator
147    + Clone
148    + Debug {
149        self.graph.predecessors(dst)
150    }
151
152    /// Successor edges, iterator of `GraphEdgeId` of outgoing edges.
153    pub fn node_successor_edges(
154        &self,
155        src: GraphNodeId,
156    ) -> impl '_
157    + DoubleEndedIterator<Item = GraphEdgeId>
158    + ExactSizeIterator
159    + FusedIterator
160    + Clone
161    + Debug {
162        self.graph.successor_edges(src)
163    }
164
165    /// Predecessor edges, iterator of `GraphEdgeId` of incoming edges.
166    pub fn node_predecessor_edges(
167        &self,
168        dst: GraphNodeId,
169    ) -> impl '_
170    + DoubleEndedIterator<Item = GraphEdgeId>
171    + ExactSizeIterator
172    + FusedIterator
173    + Clone
174    + Debug {
175        self.graph.predecessor_edges(dst)
176    }
177
178    /// Successor nodes, iterator of `GraphNodeId`.
179    pub fn node_successor_nodes(
180        &self,
181        src: GraphNodeId,
182    ) -> impl '_
183    + DoubleEndedIterator<Item = GraphNodeId>
184    + ExactSizeIterator
185    + FusedIterator
186    + Clone
187    + Debug {
188        self.graph.successor_vertices(src)
189    }
190
191    /// Predecessor nodes, iterator of `GraphNodeId`.
192    pub fn node_predecessor_nodes(
193        &self,
194        dst: GraphNodeId,
195    ) -> impl '_
196    + DoubleEndedIterator<Item = GraphNodeId>
197    + ExactSizeIterator
198    + FusedIterator
199    + Clone
200    + Debug {
201        self.graph.predecessor_vertices(dst)
202    }
203
204    /// Iterator of node IDs `GraphNodeId`.
205    pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
206        self.nodes.keys()
207    }
208
209    /// Iterator over `(GraphNodeId, &Node)` pairs.
210    pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
211        self.nodes.iter()
212    }
213
214    /// Insert a node, assigning the given varname.
215    pub fn insert_node(
216        &mut self,
217        node: GraphNode,
218        varname_opt: Option<Ident>,
219        loop_opt: Option<GraphLoopId>,
220    ) -> GraphNodeId {
221        let node_id = self.nodes.insert(node);
222        if let Some(varname) = varname_opt {
223            self.node_varnames.insert(node_id, Varname(varname));
224        }
225        if let Some(loop_id) = loop_opt {
226            self.node_loops.insert(node_id, loop_id);
227            self.loop_nodes[loop_id].push(node_id);
228        }
229        node_id
230    }
231
232    /// Insert an operator instance for the given node. Panics if already set.
233    pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
234        assert!(matches!(
235            self.nodes.get(node_id),
236            Some(GraphNode::Operator(_))
237        ));
238        let old_inst = self.operator_instances.insert(node_id, op_inst);
239        assert!(old_inst.is_none());
240    }
241
242    /// Assign all operator instances if not set. Write diagnostic messages/errors into `diagnostics`.
243    pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Diagnostics) {
244        // Handle all nodes in two phases, since the helper methods take total ownership of `&self`.
245        // Possible to do in one phase, but would require accessing fields directly for partial mutable ownership.
246
247        // Collect operator instances, then assign.
248        let mut op_insts = Vec::new();
249        // Collect nodes that should be lowered to handoffs (the `handoff()`/`singleton()` pseudo-operators).
250        let mut handoff_nodes: Vec<(GraphNodeId, HandoffKind, Span)> = Vec::new();
251
252        for (node_id, node) in self.nodes() {
253            let GraphNode::Operator(operator) = node else {
254                continue;
255            };
256            if self.node_op_inst(node_id).is_some() {
257                continue;
258            };
259
260            // Recognize `handoff()`/`singleton()` pseudo-operators and lower to GraphNode::Handoff.
261            let handoff_kind = match &*operator.name_string() {
262                "handoff" => Some(HandoffKind::Vec),
263                "singleton" => Some(HandoffKind::Option),
264                _ => None,
265            };
266            if let Some(kind) = handoff_kind {
267                if !operator.args.is_empty() {
268                    diagnostics.push(Diagnostic::spanned(
269                        operator.path.span(),
270                        Level::Error,
271                        format!("`{}` takes no arguments.", operator.name_string()),
272                    ));
273                }
274                if operator.type_arguments().is_some() {
275                    diagnostics.push(Diagnostic::spanned(
276                        operator.path.span(),
277                        Level::Error,
278                        format!("`{}` takes no generic arguments.", operator.name_string()),
279                    ));
280                }
281                handoff_nodes.push((node_id, kind, operator.path.span()));
282                continue;
283            }
284
285            // Op constraints.
286            let Some(op_constraints) = find_op_op_constraints(operator) else {
287                diagnostics.push(Diagnostic::spanned(
288                    operator.path.span(),
289                    Level::Error,
290                    format!("Unknown operator `{}`", operator.name_string()),
291                ));
292                continue;
293            };
294
295            // Input and output ports.
296            let (input_ports, output_ports) = {
297                let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
298                    .node_predecessors(node_id)
299                    .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
300                    .collect();
301                // Ensure sorted by port index.
302                input_edges.sort();
303                let input_ports: Vec<PortIndexValue> = input_edges
304                    .into_iter()
305                    .map(|(port, _pred)| port)
306                    .cloned()
307                    .collect();
308
309                // Collect output arguments (successors).
310                let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
311                    .node_successors(node_id)
312                    .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
313                    .collect();
314                // Ensure sorted by port index.
315                output_edges.sort();
316                let output_ports: Vec<PortIndexValue> = output_edges
317                    .into_iter()
318                    .map(|(port, _succ)| port)
319                    .cloned()
320                    .collect();
321
322                (input_ports, output_ports)
323            };
324
325            // Generic arguments.
326            let generics = get_operator_generics(diagnostics, operator);
327            // Generic argument errors.
328            {
329                // Span of `generic_args` (if it exists), otherwise span of the operator name.
330                let generics_span = generics
331                    .generic_args
332                    .as_ref()
333                    .map(Spanned::span)
334                    .unwrap_or_else(|| operator.path.span());
335
336                if !op_constraints
337                    .persistence_args
338                    .contains(&generics.persistence_args.len())
339                {
340                    diagnostics.push(Diagnostic::spanned(
341                        generics.persistence_args_span().unwrap_or(generics_span),
342                        Level::Error,
343                        format!(
344                            "`{}` should have {} persistence lifetime arguments, actually has {}.",
345                            op_constraints.name,
346                            op_constraints.persistence_args.human_string(),
347                            generics.persistence_args.len()
348                        ),
349                    ));
350                }
351                if !op_constraints.type_args.contains(&generics.type_args.len()) {
352                    diagnostics.push(Diagnostic::spanned(
353                        generics.type_args_span().unwrap_or(generics_span),
354                        Level::Error,
355                        format!(
356                            "`{}` should have {} generic type arguments, actually has {}.",
357                            op_constraints.name,
358                            op_constraints.type_args.human_string(),
359                            generics.type_args.len()
360                        ),
361                    ));
362                }
363            }
364
365            op_insts.push((
366                node_id,
367                OperatorInstance {
368                    op_constraints,
369                    input_ports,
370                    output_ports,
371                    singletons_referenced: operator.singletons_referenced.clone(),
372                    generics,
373                    arguments_pre: operator.args.clone(),
374                    arguments_raw: operator.args_raw.clone(),
375                },
376            ));
377        }
378
379        for (node_id, op_inst) in op_insts {
380            self.insert_node_op_inst(node_id, op_inst);
381        }
382
383        // Replace pseudo-operator nodes with GraphNode::Handoff.
384        for (node_id, kind, span) in handoff_nodes {
385            self.nodes[node_id] = GraphNode::Handoff {
386                kind,
387                src_span: span,
388                dst_span: span,
389            };
390        }
391    }
392
393    /// Inserts a node between two existing nodes connected by the given `edge_id`.
394    ///
395    /// `edge`: (src, dst, dst_idx)
396    ///
397    /// Before: A (src) ------------> B (dst)
398    /// After:  A (src) -> X (new) -> B (dst)
399    ///
400    /// Returns the ID of X & ID of edge OUT of X.
401    ///
402    /// Note that both the edges will be new and `edge_id` will be removed. Both new edges will
403    /// get the edge type of the original edge.
404    pub fn insert_intermediate_node(
405        &mut self,
406        edge_id: GraphEdgeId,
407        new_node: GraphNode,
408    ) -> (GraphNodeId, GraphEdgeId) {
409        let span = Some(new_node.span());
410
411        // Make corresponding operator instance (if `node` is an operator).
412        let op_inst_opt = 'oc: {
413            let GraphNode::Operator(operator) = &new_node else {
414                break 'oc None;
415            };
416            let Some(op_constraints) = find_op_op_constraints(operator) else {
417                break 'oc None;
418            };
419            let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
420
421            let mut dummy_diagnostics = Diagnostics::new();
422            let generics = get_operator_generics(&mut dummy_diagnostics, operator);
423            assert!(dummy_diagnostics.is_empty());
424
425            Some(OperatorInstance {
426                op_constraints,
427                input_ports: vec![input_port],
428                output_ports: vec![output_port],
429                singletons_referenced: operator.singletons_referenced.clone(),
430                generics,
431                arguments_pre: operator.args.clone(),
432                arguments_raw: operator.args_raw.clone(),
433            })
434        };
435
436        // Insert new `node`.
437        let node_id = self.nodes.insert(new_node);
438        // Insert corresponding `OperatorInstance` if applicable.
439        if let Some(op_inst) = op_inst_opt {
440            self.operator_instances.insert(node_id, op_inst);
441        }
442        // Update edges to insert node within `edge_id`.
443        let (e0, e1) = self
444            .graph
445            .insert_intermediate_vertex(node_id, edge_id)
446            .unwrap();
447
448        // Update corresponding ports.
449        let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
450        self.ports
451            .insert(e0, (src_idx, PortIndexValue::Elided(span)));
452        self.ports
453            .insert(e1, (PortIndexValue::Elided(span), dst_idx));
454
455        (node_id, e1)
456    }
457
458    /// Remove the node `node_id` but preserves and connects the single predecessor and single successor.
459    /// Panics if the node does not have exactly one predecessor and one successor, or is not in the graph.
460    pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
461        assert_eq!(
462            1,
463            self.node_degree_in(node_id),
464            "Removed intermediate node must have one predecessor"
465        );
466        assert_eq!(
467            1,
468            self.node_degree_out(node_id),
469            "Removed intermediate node must have one successor"
470        );
471        assert!(
472            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
473            "Should not remove intermediate node after subgraph partitioning"
474        );
475
476        assert!(self.nodes.remove(node_id).is_some());
477        let (new_edge_id, (pred_edge_id, succ_edge_id)) =
478            self.graph.remove_intermediate_vertex(node_id).unwrap();
479        self.operator_instances.remove(node_id);
480        self.node_varnames.remove(node_id);
481
482        let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
483        let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
484        self.ports.insert(new_edge_id, (src_port, dst_port));
485    }
486
487    /// Helper method: determine the "color" (pull vs push) of a node based on its in and out degree,
488    /// excluding reference edges. If linear (1 in, 1 out), color is `None`, indicating it can be
489    /// either push or pull.
490    ///
491    /// Note that this does NOT consider `DelayType` barriers (which generally implies `Pull`).
492    pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
493        if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
494            return Some(Color::Hoff);
495        }
496
497        // TODO(shadaj): this is a horrible hack
498        if let GraphNode::Operator(op) = self.node(node_id)
499            && (op.name_string() == "resolve_futures_blocking"
500                || op.name_string() == "resolve_futures_blocking_ordered")
501        {
502            return Some(Color::Push);
503        }
504
505        // In-degree, excluding ref-edges.
506        let inn_degree = self.node_predecessor_nodes(node_id).len();
507        // Out-degree excluding ref-edges.
508        let out_degree = self.node_successor_nodes(node_id).len();
509
510        match (inn_degree, out_degree) {
511            (0, 0) => None, // Generally should not happen, "Degenerate subgraph detected".
512            (0, 1) => Some(Color::Pull),
513            (1, 0) => Some(Color::Push),
514            (1, 1) => None, // Linear, can be either push or pull.
515            (_many, 0 | 1) => Some(Color::Pull),
516            (0 | 1, _many) => Some(Color::Push),
517            (_many, _to_many) => Some(Color::Comp),
518        }
519    }
520
521    /// Set the operator tag (for debugging/tracing).
522    pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
523        self.operator_tag.insert(node_id, tag);
524    }
525}
526
527/// Singleton references.
528impl DfirGraph {
529    /// Set the singletons referenced for the `node_id` operator. Each reference corresponds to the
530    /// same index in the [`crate::parse::Operator::singletons_referenced`] vec.
531    pub fn set_node_singleton_references(
532        &mut self,
533        node_id: GraphNodeId,
534        singletons_referenced: Vec<Option<GraphNodeId>>,
535    ) -> Option<Vec<Option<GraphNodeId>>> {
536        self.node_singleton_references
537            .insert(node_id, singletons_referenced)
538    }
539
540    /// Gets the singletons referenced by a node. Returns an empty iterator for non-operators and
541    /// operators that do not reference singletons.
542    pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
543        self.node_singleton_references
544            .get(node_id)
545            .map(std::ops::Deref::deref)
546            .unwrap_or_default()
547    }
548}
549
550/// Module methods.
551impl DfirGraph {
552    /// When modules are imported into a flat graph, they come with an input and output ModuleBoundary node.
553    /// The partitioner doesn't understand these nodes and will panic if it encounters them.
554    /// merge_modules removes them from the graph, stitching the input and ouput sides of the ModuleBondaries based on their ports
555    /// For example:
556    ///     source_iter([]) -> \[myport\]ModuleBoundary(input)\[my_port\] -> map(|x| x) -> ModuleBoundary(output) -> null();
557    /// in the above eaxmple, the \[myport\] port will be used to connect the source_iter with the map that is inside of the module.
558    /// The output module boundary has elided ports, this is also used to match up the input/output across the module boundary.
559    pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
560        let mod_bound_nodes = self
561            .nodes()
562            .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
563            .map(|(nid, _node)| nid)
564            .collect::<Vec<_>>();
565
566        for mod_bound_node in mod_bound_nodes {
567            self.remove_module_boundary(mod_bound_node)?;
568        }
569
570        Ok(())
571    }
572
573    /// see `merge_modules`
574    /// This function removes a singular module boundary from the graph and performs the necessary stitching to fix the graph afterward.
575    /// `merge_modules` calls this function for each module boundary in the graph.
576    fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
577        assert!(
578            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
579            "Should not remove intermediate node after subgraph partitioning"
580        );
581
582        let mut mod_pred_ports = BTreeMap::new();
583        let mut mod_succ_ports = BTreeMap::new();
584
585        for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
586            let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
587            mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
588        }
589
590        for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
591            let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
592            mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
593        }
594
595        if mod_pred_ports.keys().collect::<BTreeSet<_>>()
596            != mod_succ_ports.keys().collect::<BTreeSet<_>>()
597        {
598            // get module boundary node
599            let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
600                panic!();
601            };
602
603            if *input {
604                return Err(Diagnostic {
605                    span: *import_expr,
606                    level: Level::Error,
607                    message: format!(
608                        "The ports into the module did not match. input: {:?}, expected: {:?}",
609                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
610                        mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
611                    ),
612                });
613            } else {
614                return Err(Diagnostic {
615                    span: *import_expr,
616                    level: Level::Error,
617                    message: format!(
618                        "The ports out of the module did not match. output: {:?}, expected: {:?}",
619                        mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
620                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
621                    ),
622                });
623            }
624        }
625
626        for (port, (pred_edge, pred_port)) in mod_pred_ports {
627            let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
628
629            let (src, _) = self.edge(pred_edge);
630            let (_, dst) = self.edge(succ_edge);
631            self.remove_edge(pred_edge);
632            self.remove_edge(succ_edge);
633
634            let new_edge_id = self.graph.insert_edge(src, dst);
635            self.ports.insert(new_edge_id, (pred_port, succ_port));
636        }
637
638        self.graph.remove_vertex(mod_bound_node);
639        self.nodes.remove(mod_bound_node);
640
641        Ok(())
642    }
643}
644
645/// Edge methods.
646impl DfirGraph {
647    /// Get the `src` and `dst` for an edge: `(src GraphNodeId, dst GraphNodeId)`.
648    pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
649        let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
650        (src, dst)
651    }
652
653    /// Get the source and destination ports for an edge: `(src &PortIndexValue, dst &PortIndexValue)`.
654    pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
655        let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
656        (src_port, dst_port)
657    }
658
659    /// Iterator of all edge IDs `GraphEdgeId`.
660    pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
661        self.graph.edge_ids()
662    }
663
664    /// Iterator over all edges: `(GraphEdgeId, (src GraphNodeId, dst GraphNodeId))`.
665    pub fn edges(
666        &self,
667    ) -> impl '_
668    + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
669    + FusedIterator
670    + Clone
671    + Debug {
672        self.graph.edges()
673    }
674
675    /// Insert an edge between nodes thru the given ports.
676    pub fn insert_edge(
677        &mut self,
678        src: GraphNodeId,
679        src_port: PortIndexValue,
680        dst: GraphNodeId,
681        dst_port: PortIndexValue,
682    ) -> GraphEdgeId {
683        let edge_id = self.graph.insert_edge(src, dst);
684        self.ports.insert(edge_id, (src_port, dst_port));
685        edge_id
686    }
687
688    /// Removes an edge and its corresponding ports and edge type info.
689    pub fn remove_edge(&mut self, edge: GraphEdgeId) {
690        let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
691        let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
692    }
693}
694
695/// Subgraph methods.
696impl DfirGraph {
697    /// Nodes belonging to the given subgraph.
698    pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
699        self.subgraph_nodes
700            .get(subgraph_id)
701            .expect("Subgraph not found.")
702    }
703
704    /// Iterator over all subgraph IDs.
705    pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
706        self.subgraph_nodes.keys()
707    }
708
709    /// Iterator over all subgraphs, ID and members: `(GraphSubgraphId, Vec<GraphNodeId>)`.
710    pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
711        self.subgraph_nodes.iter()
712    }
713
714    /// Create a subgraph consisting of `node_ids`. Returns an error if any of the nodes are already in a subgraph.
715    pub fn insert_subgraph(
716        &mut self,
717        node_ids: Vec<GraphNodeId>,
718    ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
719        // Check none are already in subgraphs
720        for &node_id in node_ids.iter() {
721            if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
722                return Err((node_id, old_sg_id));
723            }
724        }
725        let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
726            for &node_id in node_ids.iter() {
727                self.node_subgraph.insert(node_id, sg_id);
728            }
729            node_ids
730        });
731
732        Ok(subgraph_id)
733    }
734
735    /// Removes a node from its subgraph. Returns true if the node was in a subgraph.
736    pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
737        if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
738            self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
739            true
740        } else {
741            false
742        }
743    }
744
745    /// Gets the delay type for a handoff node, if set.
746    pub fn handoff_delay_type(&self, node_id: GraphNodeId) -> Option<DelayType> {
747        self.handoff_delay_type.get(node_id).copied()
748    }
749
750    /// Sets the delay type for a handoff node.
751    pub fn set_handoff_delay_type(&mut self, node_id: GraphNodeId, delay_type: DelayType) {
752        self.handoff_delay_type.insert(node_id, delay_type);
753    }
754
755    /// Helper: finds the first index in `subgraph_nodes` where it transitions from pull to push.
756    fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
757        subgraph_nodes
758            .iter()
759            .position(|&node_id| {
760                self.node_color(node_id)
761                    .is_some_and(|color| Color::Pull != color)
762            })
763            .unwrap_or(subgraph_nodes.len())
764    }
765}
766
767/// Display/output methods.
768impl DfirGraph {
769    /// Helper to generate a deterministic `Ident` for the given node.
770    fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
771        let name = match &self.nodes[node_id] {
772            GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
773            GraphNode::Handoff {
774                kind: HandoffKind::Vec,
775                ..
776            } => format!(
777                "hoff_{:?}_{}",
778                node_id.data(),
779                if is_pred { "recv" } else { "send" }
780            ),
781            GraphNode::Handoff {
782                kind: HandoffKind::Option,
783                ..
784            } => format!(
785                "singleton_{:?}_{}",
786                node_id.data(),
787                if is_pred { "recv" } else { "send" }
788            ),
789            GraphNode::ModuleBoundary { .. } => panic!(),
790        };
791        let span = match (is_pred, &self.nodes[node_id]) {
792            (_, GraphNode::Operator(operator)) => operator.span(),
793            (true, &GraphNode::Handoff { src_span, .. }) => src_span,
794            (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
795            (_, GraphNode::ModuleBoundary { .. }) => panic!(),
796        };
797        Ident::new(&name, span)
798    }
799
800    /// Helper to generate the main buffer `Ident` for a handoff node.
801    fn hoff_buf_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
802        Ident::new(&format!("hoff_{:?}_buf", hoff_id.data()), span)
803    }
804
805    /// Helper to generate the back (double-buffer) `Ident` for a handoff node.
806    fn hoff_back_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
807        Ident::new(&format!("hoff_{:?}_back", hoff_id.data()), span)
808    }
809
810    /// For per-node singleton references. Helper to generate a deterministic `Ident` for the given node.
811    fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
812        Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
813    }
814
815    /// Resolve the singletons via [`Self::node_singleton_references`] for the given `node_id`.
816    /// Returns token streams for each reference:
817    /// - For stateful operators: `&singleton_op_XXX` (borrow the operator's state)
818    /// - For HandoffKind::Option: `&(hoff_XXX_buf.as_ref().unwrap())` (intentionally produce `&&T`
819    ///   so the later `(*expr)` deref yields `&T`) - TODO(mingwei)
820    fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<TokenStream> {
821        self.node_singleton_references(node_id)
822            .iter()
823            .map(|singleton_node_id| {
824                // TODO(mingwei): this `expect` should be caught in error checking
825                let ref_node_id = singleton_node_id
826                    .expect("Expected singleton to be resolved but was not, this is a bug.");
827                if matches!(
828                    self.node(ref_node_id),
829                    GraphNode::Handoff {
830                        kind: HandoffKind::Option,
831                        ..
832                    }
833                ) {
834                    let buf_ident = self.hoff_buf_ident(ref_node_id, span);
835                    // Wrapping in &(...) produces &&T so that postprocess_singletons'
836                    // (*expr) deref gives &T — matching `type O = &'a T`.
837                    // TODO(mingwei): Make postprocess_singletons not deref, remove old singletons (the `else` case below).
838                    quote_spanned! {span=> &(#buf_ident.as_ref().unwrap()) }
839                } else {
840                    let singleton_ident = self.node_as_singleton_ident(ref_node_id, span);
841                    quote_spanned! {span=> &#singleton_ident }
842                }
843            })
844            .collect::<Vec<_>>()
845    }
846
847    /// Returns each subgraph's receive and send handoffs.
848    /// `Map<GraphSubgraphId, (recv handoffs, send handoffs)>`
849    fn helper_collect_subgraph_handoffs(
850        &self,
851    ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
852        // Get data on handoff src and dst subgraphs.
853        let mut subgraph_handoffs: SecondaryMap<
854            GraphSubgraphId,
855            (Vec<GraphNodeId>, Vec<GraphNodeId>),
856        > = self
857            .subgraph_nodes
858            .keys()
859            .map(|k| (k, Default::default()))
860            .collect();
861
862        // For each handoff/singleton node, add it to the `send`/`recv` lists for the corresponding subgraphs.
863        for (hoff_id, hoff) in self.nodes() {
864            if !matches!(hoff, GraphNode::Handoff { .. }) {
865                continue;
866            }
867            // Receivers from the handoff. (Should really only be one).
868            for (_edge, succ_id) in self.node_successors(hoff_id) {
869                let succ_sg = self.node_subgraph(succ_id).unwrap();
870                subgraph_handoffs[succ_sg].0.push(hoff_id);
871            }
872            // Senders into the handoff. (Should really only be one).
873            for (_edge, pred_id) in self.node_predecessors(hoff_id) {
874                let pred_sg = self.node_subgraph(pred_id).unwrap();
875                subgraph_handoffs[pred_sg].1.push(hoff_id);
876            }
877        }
878
879        subgraph_handoffs
880    }
881
882    /// Emit this graph as runnable Rust source code tokens that execute inline.
883    /// Generates a flat `async move |df: &mut Context|` closure where subgraph
884    /// blocks are inlined in topological order, using local `Vec<T>` buffers
885    /// instead of runtime handoffs. Each call to the closure runs one tick.
886    ///
887    /// The generated code block evaluates to a `Dfir` instance wrapping the
888    /// closure. Operator prologues run at construction time on the `Context`
889    /// before it is moved into `Dfir::new`. `Dfir` provides the `Context`
890    /// to the closure on each tick run.
891    ///
892    /// # Errors
893    ///
894    /// Returns all diagnostics as `Err(diagnostics)` if any are errors
895    /// (leaving `&mut diagnostics` empty).
896    pub fn as_code(
897        &self,
898        root: &TokenStream,
899        include_type_guards: bool,
900        prefix: TokenStream,
901        diagnostics: &mut Diagnostics,
902    ) -> Result<TokenStream, Diagnostics> {
903        self.as_code_with_options(root, include_type_guards, true, prefix, diagnostics)
904    }
905
906    /// Like [`Self::as_code`], but with `include_meta` controlling whether
907    /// the runtime meta graph + diagnostics JSON blobs are baked into the
908    /// generated `Dfir::new(...)` call.
909    ///
910    /// The simulator calls Dfir::new() on each iteration, and as a part of that
911    /// it does parsing of the metagraph and diganostics blob. One of them causes spans to get allocated,
912    /// each time a span is allocated, some threadlocal u32 is being incremented, and, on a long simulator run,
913    /// the u32 overflows and panics.
914    pub fn as_code_with_options(
915        &self,
916        root: &TokenStream,
917        include_type_guards: bool,
918        include_meta: bool,
919        prefix: TokenStream,
920        diagnostics: &mut Diagnostics,
921    ) -> Result<TokenStream, Diagnostics> {
922        let df = Ident::new(GRAPH, Span::call_site());
923        let context = Ident::new(CONTEXT, Span::call_site());
924
925        // 1. Generate local buffers for each handoff node (Vec for streams, Option for singletons).
926        let handoff_nodes: Vec<_> = self
927            .nodes
928            .iter()
929            .filter_map(|(node_id, node)| match node {
930                &GraphNode::Handoff {
931                    kind,
932                    src_span,
933                    dst_span,
934                } => Some((node_id, kind, (src_span, dst_span))),
935                GraphNode::Operator(_) => None,
936                GraphNode::ModuleBoundary { .. } => panic!(),
937            })
938            .collect();
939
940        let buffer_code: Vec<TokenStream> = handoff_nodes
941            .iter()
942            .map(|&(node_id, kind, (src_span, dst_span))| {
943                let span = src_span.join(dst_span).unwrap_or(src_span);
944                let buf_ident = self.hoff_buf_ident(node_id, span);
945                match kind {
946                    HandoffKind::Vec => quote_spanned! {span=>
947                        let mut #buf_ident = ::std::vec::Vec::new();
948                    },
949                    HandoffKind::Option => quote_spanned! {span=>
950                        let mut #buf_ident = ::std::option::Option::None;
951                    },
952                }
953            })
954            .collect();
955
956        // For tick-boundary handoffs (`defer_tick` / `defer_tick_lazy`), declare a
957        // second "back" buffer for double-buffering. At the start of each tick, the
958        // main buffer and back buffer are swapped so the consumer reads last tick's
959        // data while the producer writes to a fresh buffer.
960        let back_buffer_code: Vec<TokenStream> = handoff_nodes
961            .iter()
962            .filter(|(node_id, _kind, _)| self.handoff_delay_type(*node_id).is_some())
963            .map(|&(node_id, kind, (src_span, dst_span))| {
964                assert!(
965                    matches!(kind, HandoffKind::Vec),
966                    "bug: only Vec handoffs should have delay types"
967                );
968                let span = src_span.join(dst_span).unwrap_or(src_span);
969                let back_ident = self.hoff_back_ident(node_id, span);
970                quote_spanned! {span=>
971                    let mut #back_ident: Vec<_> = Vec::new();
972                }
973            })
974            .collect();
975
976        // 2. Collect subgraph handoffs (same as as_code).
977        let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
978
979        // 3. Sort subgraphs topologically and collect non-lazy defer_tick buffer idents.
980        //
981        // Handoffs marked with a `DelayType` (Tick/TickLazy) are tick-boundary back-edges.
982        // These are excluded from the topo sort (no ordering constraint). Double-buffering
983        // ensures data written by the producer in tick N is only visible to the consumer
984        // in tick N+1, regardless of execution order.
985        //
986        // While iterating handoffs, we also collect buffer idents for non-lazy tick-boundary
987        // edges (defer_tick). When these buffers are non-empty at end of tick, we set
988        // can_start_tick so that run_available continues ticking.
989        let mut defer_tick_buf_idents: Vec<Ident> = Vec::new();
990        let mut back_edge_hoff_ids: BTreeSet<GraphNodeId> = BTreeSet::new();
991        let all_subgraphs = {
992            // Build predecessor map for subgraphs.
993            let mut sg_preds: SecondaryMap<GraphSubgraphId, Vec<GraphSubgraphId>> =
994                SecondaryMap::<_, Vec<_>>::with_capacity(self.subgraph_nodes.len());
995            for (hoff_id, hoff) in self.nodes() {
996                if !matches!(hoff, GraphNode::Handoff { .. }) {
997                    // Not a handoff; skip.
998                    continue;
999                }
1000                if 0 == self.node_successors(hoff_id).len() {
1001                    // Is a handoff only used by reference, not consumed.
1002                    continue;
1003                }
1004                assert_eq!(1, self.node_successors(hoff_id).len());
1005                assert_eq!(1, self.node_predecessors(hoff_id).len());
1006                let (_edge_id, pred) = self.node_predecessors(hoff_id).next().unwrap();
1007                let (_edge_id, succ) = self.node_successors(hoff_id).next().unwrap();
1008                let pred_sg = self.node_subgraph(pred).unwrap();
1009                let succ_sg = self.node_subgraph(succ).unwrap();
1010                if pred_sg == succ_sg {
1011                    panic!("bug: unexpected subgraph self-handoff cycle");
1012                }
1013                if let Some(delay_type) = self.handoff_delay_type(hoff_id) {
1014                    debug_assert!(matches!(delay_type, DelayType::Tick | DelayType::TickLazy));
1015                    // Tick/back-edge handoff: no ordering constraint. Double-buffering
1016                    // handles the tick deferral regardless of execution order.
1017                    back_edge_hoff_ids.insert(hoff_id);
1018
1019                    // Non-lazy tick-boundary: defer_tick (not defer_tick_lazy).
1020                    if !matches!(delay_type, DelayType::TickLazy) {
1021                        defer_tick_buf_idents.push(self.hoff_buf_ident(hoff_id, hoff.span()));
1022                    }
1023                } else {
1024                    sg_preds.entry(succ_sg).unwrap().or_default().push(pred_sg);
1025                }
1026            }
1027
1028            // Include singleton reference edges: if node A references the
1029            // singleton output of node B, then A's subgraph must run after B's.
1030            for dst_id in self.node_ids() {
1031                for src_ref_id in self
1032                    .node_singleton_references(dst_id)
1033                    .iter()
1034                    .copied()
1035                    .flatten()
1036                {
1037                    // For handoff nodes (no subgraph), use the predecessor's subgraph.
1038                    let src_sg = if let Some(sg) = self.node_subgraph(src_ref_id) {
1039                        sg
1040                    } else {
1041                        let (_edge, pred) = self
1042                            .node_predecessors(src_ref_id)
1043                            .next()
1044                            .expect("handoff must have a predecessor");
1045                        self.node_subgraph(pred).unwrap()
1046                    };
1047                    let dst_sg = self
1048                        .node_subgraph(dst_id)
1049                        .expect("bug: singleton ref consumer must belong to a subgraph");
1050                    if src_sg != dst_sg {
1051                        sg_preds.entry(dst_sg).unwrap().or_default().push(src_sg);
1052                    }
1053
1054                    // Ensure the borrower runs before the pipe consumer
1055                    // (which takes/drains the value).
1056                    // All handoffs should have at most one successor.
1057                    if self.node_subgraph(src_ref_id).is_none() {
1058                        assert!(
1059                            self.node_degree_out(src_ref_id) <= 1,
1060                            "handoff should have at most one successor"
1061                        );
1062                        if let Some((_edge, succ_id)) = self.node_successors(src_ref_id).next()
1063                            && let Some(consumer_sg) = self.node_subgraph(succ_id)
1064                            && consumer_sg != dst_sg
1065                        {
1066                            sg_preds
1067                                .entry(consumer_sg)
1068                                .unwrap()
1069                                .or_default()
1070                                .push(dst_sg);
1071                        }
1072                    }
1073                }
1074            }
1075
1076            let topo_sort = super::graph_algorithms::topo_sort(self.subgraph_ids(), |sg_id| {
1077                sg_preds.get(sg_id).into_iter().flatten().copied()
1078            })
1079            .expect("bug: unexpected cycle between subgraphs within the tick");
1080
1081            topo_sort
1082                .into_iter()
1083                .map(|sg_id| (sg_id, self.subgraph(sg_id)))
1084                .collect::<Vec<_>>()
1085        };
1086
1087        // Generate swap code for tick-boundary (defer_tick / defer_tick_lazy) handoffs.
1088        // At the start of each tick, swap the main buffer and back buffer so the
1089        // consumer reads last tick's data from the back buffer.
1090        let back_edge_swap_code: Vec<TokenStream> = back_edge_hoff_ids
1091            .iter()
1092            .map(|&hoff_id| {
1093                let span = self.nodes[hoff_id].span();
1094                let buf_ident = self.hoff_buf_ident(hoff_id, span);
1095                let back_ident = self.hoff_back_ident(hoff_id, span);
1096                quote_spanned! {span=>
1097                    ::std::mem::swap(&mut #buf_ident, &mut #back_ident);
1098                }
1099            })
1100            .collect();
1101
1102        // Generate drain code for handoffs with no pipe consumer (0 successors).
1103        // These are only accessed via #var references and must be cleared each tick.
1104        let no_consumer_drain_code: Vec<TokenStream> = handoff_nodes
1105            .iter()
1106            .filter(|&&(node_id, _, _)| self.node_degree_out(node_id) == 0)
1107            .map(|&(node_id, kind, (src_span, dst_span))| {
1108                let span = src_span.join(dst_span).unwrap_or(src_span);
1109                let buf_ident = self.hoff_buf_ident(node_id, span);
1110                match kind {
1111                    HandoffKind::Option => quote_spanned! {span=> #buf_ident.take(); },
1112                    HandoffKind::Vec => quote_spanned! {span=> #buf_ident.clear(); },
1113                }
1114            })
1115            .collect();
1116
1117        let mut op_prologue_code = Vec::new();
1118        let mut op_tick_end_code = Vec::new();
1119        let mut subgraph_blocks = Vec::new();
1120        {
1121            for &(subgraph_id, subgraph_nodes) in all_subgraphs.iter() {
1122                let sg_metrics_ffi = subgraph_id.data().as_ffi();
1123                let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
1124
1125                // Generate buffer ident helpers for this subgraph's handoffs.
1126                let recv_port_idents: Vec<Ident> = recv_hoffs
1127                    .iter()
1128                    .map(|&hoff_id| self.node_as_ident(hoff_id, true))
1129                    .collect();
1130                let send_port_idents: Vec<Ident> = send_hoffs
1131                    .iter()
1132                    .map(|&hoff_id| self.node_as_ident(hoff_id, false))
1133                    .collect();
1134
1135                // Map handoff node IDs to buffer idents.
1136                let recv_buf_idents: Vec<Ident> = recv_hoffs
1137                    .iter()
1138                    .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1139                    .collect();
1140                let send_buf_idents: Vec<Ident> = send_hoffs
1141                    .iter()
1142                    .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1143                    .collect();
1144
1145                // Recv port code: drain from buffer into iterator, tracking if non-empty.
1146                // For back-edge (defer_tick) handoffs, drain from the back buffer instead.
1147                // Also update handoff metrics (measured at recv, not send — see graph.rs).
1148                let recv_port_code: Vec<TokenStream> = recv_port_idents
1149                    .iter()
1150                    .zip(recv_buf_idents.iter())
1151                    .zip(recv_hoffs.iter())
1152                    .map(|((port_ident, buf_ident), &hoff_id)| {
1153                        let hoff_ffi = hoff_id.data().as_ffi();
1154                        // Use call_site span for internal identifiers to avoid
1155                        // hygiene issues when invoked through declarative macros
1156                        // (e.g. dfir_expect_warnings!). TODO(#2781): define these once.
1157                        let work_done = Ident::new("__dfir_work_done", Span::call_site());
1158                        let metrics = Ident::new("__dfir_metrics", Span::call_site());
1159
1160                        let GraphNode::Handoff { kind, .. } = self.node(hoff_id) else {
1161                            unreachable!()
1162                        };
1163
1164                        // Compute len and drain expressions based on handoff kind.
1165                        let (len_expr, drain_expr) = match kind {
1166                            HandoffKind::Option => (
1167                                quote! { if #buf_ident.is_some() { 1usize } else { 0usize } },
1168                                quote! { #root::dfir_pipes::pull::iter(#buf_ident.take().into_iter()) },
1169                            ),
1170                            HandoffKind::Vec => {
1171                                let drain_ident = if back_edge_hoff_ids.contains(&hoff_id) {
1172                                    self.hoff_back_ident(hoff_id, buf_ident.span())
1173                                } else {
1174                                    buf_ident.clone()
1175                                };
1176                                (
1177                                    quote! { #drain_ident.len() },
1178                                    quote! { #root::dfir_pipes::pull::iter(#drain_ident.drain(..)) },
1179                                )
1180                            }
1181                        };
1182
1183                        quote_spanned! {port_ident.span()=>
1184                            {
1185                                let hoff_len = #len_expr;
1186                                if hoff_len > 0 {
1187                                    #work_done = true;
1188                                }
1189                                let hoff_metrics = &#metrics.handoffs[
1190                                    #root::slotmap::KeyData::from_ffi(#hoff_ffi).into()
1191                                ];
1192                                hoff_metrics.total_items_count.update(|x| x + hoff_len);
1193                                hoff_metrics.curr_items_count.set(hoff_len);
1194                            }
1195                            let #port_ident = #drain_expr;
1196                        }
1197                    })
1198                    .collect();
1199
1200                // Send port code: push into buffer.
1201                let send_port_code: Vec<TokenStream> = send_port_idents
1202                    .iter()
1203                    .zip(send_buf_idents.iter())
1204                    .zip(send_hoffs.iter())
1205                    .map(|((port_ident, buf_ident), &hoff_id)| {
1206                        let GraphNode::Handoff { kind, .. } = self.node(hoff_id) else {
1207                            unreachable!()
1208                        };
1209                        match kind {
1210                            HandoffKind::Option => {
1211                                // Singleton slot: store exactly one item, panic on duplicate.
1212                                quote_spanned! {port_ident.span()=>
1213                                    let #port_ident = #root::dfir_pipes::push::for_each(|__item| {
1214                                        if #buf_ident.replace(__item).is_some() {
1215                                            panic!("singleton() received more than one item");
1216                                        }
1217                                    });
1218                                }
1219                            }
1220                            HandoffKind::Vec => {
1221                                quote_spanned! {port_ident.span()=>
1222                                    let #port_ident = #root::dfir_pipes::push::vec_push(&mut #buf_ident);
1223                                }
1224                            }
1225                        }
1226                    })
1227                    .collect();
1228
1229                // All nodes in a subgraph should be in the same loop.
1230                let loop_id = self.node_loop(subgraph_nodes[0]);
1231
1232                let mut subgraph_op_iter_code = Vec::new();
1233                let mut subgraph_op_iter_after_code = Vec::new();
1234                {
1235                    let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
1236
1237                    let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
1238                    let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
1239
1240                    for (idx, &node_id) in nodes_iter.enumerate() {
1241                        let node = &self.nodes[node_id];
1242                        assert!(
1243                            matches!(node, GraphNode::Operator(_)),
1244                            "Handoffs are not part of subgraphs."
1245                        );
1246                        let op_inst = &self.operator_instances[node_id];
1247
1248                        let op_span = node.span();
1249                        let op_name = op_inst.op_constraints.name;
1250                        // Use op's span for root. #root is expected to be correct, any errors should span back to the op gen.
1251                        let root = change_spans(root.clone(), op_span);
1252                        let op_constraints = OPERATORS
1253                            .iter()
1254                            .find(|op| op_name == op.name)
1255                            .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
1256
1257                        let ident = self.node_as_ident(node_id, false);
1258
1259                        {
1260                            // TODO clean this up.
1261                            // Collect input arguments (predecessors).
1262                            let mut input_edges = self
1263                                .graph
1264                                .predecessor_edges(node_id)
1265                                .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
1266                                .collect::<Vec<_>>();
1267                            // Ensure sorted by port index.
1268                            input_edges.sort();
1269
1270                            let inputs = input_edges
1271                                .iter()
1272                                .map(|&(_port, edge_id)| {
1273                                    let (pred, _) = self.edge(edge_id);
1274                                    self.node_as_ident(pred, true)
1275                                })
1276                                .collect::<Vec<_>>();
1277
1278                            // Collect output arguments (successors).
1279                            let mut output_edges = self
1280                                .graph
1281                                .successor_edges(node_id)
1282                                .map(|edge_id| (&self.ports[edge_id].0, edge_id))
1283                                .collect::<Vec<_>>();
1284                            // Ensure sorted by port index.
1285                            output_edges.sort();
1286
1287                            let outputs = output_edges
1288                                .iter()
1289                                .map(|&(_port, edge_id)| {
1290                                    let (_, succ) = self.edge(edge_id);
1291                                    self.node_as_ident(succ, false)
1292                                })
1293                                .collect::<Vec<_>>();
1294
1295                            let is_pull = idx < pull_to_push_idx;
1296
1297                            let singleton_output_ident = &if op_constraints.has_singleton_output {
1298                                self.node_as_singleton_ident(node_id, op_span)
1299                            } else {
1300                                // This ident *should* go unused.
1301                                Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
1302                            };
1303
1304                            // There's a bit of dark magic hidden in `Span`s... you'd think it's just a `file:line:column`,
1305                            // but it has one extra bit of info for _name resolution_, used for `Ident`s. `Span::call_site()`
1306                            // has the (unhygienic) resolution we want, an ident is just solely determined by its string name,
1307                            // which is what you'd expect out of unhygienic proc macros like this. Meanwhile, declarative macros
1308                            // use `Span::mixed_site()` which is weird and I don't understand it. It turns out that if you call
1309                            // the dfir syntax proc macro from _within_ a declarative macro then `op_span` will have the
1310                            // bad `Span::mixed_site()` name resolution and cause "Cannot find value `df/context`" errors. So
1311                            // we call `.resolved_at()` to fix resolution back to `Span::call_site()`. -Mingwei
1312                            let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1313                            let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1314
1315                            let singletons_resolved =
1316                                self.helper_resolve_singletons(node_id, op_span);
1317
1318                            let arguments = &process_singletons::postprocess_singletons(
1319                                op_inst.arguments_raw.clone(),
1320                                singletons_resolved,
1321                            );
1322
1323                            let source_tag = 'a: {
1324                                if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1325                                    break 'a tag;
1326                                }
1327
1328                                if proc_macro::is_available() {
1329                                    let op_span = op_span.unwrap();
1330                                    break 'a format!(
1331                                        "loc_{}_{}_{}_{}_{}",
1332                                        crate::pretty_span::make_source_path_relative(
1333                                            &op_span.file()
1334                                        )
1335                                        .display()
1336                                        .to_string()
1337                                        .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1338                                        op_span.start().line(),
1339                                        op_span.start().column(),
1340                                        op_span.end().line(),
1341                                        op_span.end().column(),
1342                                    );
1343                                }
1344
1345                                format!(
1346                                    "loc_nopath_{}_{}_{}_{}",
1347                                    op_span.start().line,
1348                                    op_span.start().column,
1349                                    op_span.end().line,
1350                                    op_span.end().column
1351                                )
1352                            };
1353
1354                            let work_fn = format_ident!(
1355                                "{}__{}__{}",
1356                                ident,
1357                                op_name,
1358                                source_tag,
1359                                span = op_span
1360                            );
1361                            let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1362
1363                            let context_args = WriteContextArgs {
1364                                root: &root,
1365                                df_ident: df_local,
1366                                context,
1367                                subgraph_id,
1368                                node_id,
1369                                loop_id,
1370                                op_span,
1371                                op_tag: self.operator_tag.get(node_id).cloned(),
1372                                work_fn: &work_fn,
1373                                work_fn_async: &work_fn_async,
1374                                ident: &ident,
1375                                is_pull,
1376                                inputs: &inputs,
1377                                outputs: &outputs,
1378                                singleton_output_ident,
1379                                op_name,
1380                                op_inst,
1381                                arguments,
1382                            };
1383
1384                            let write_result =
1385                                (op_constraints.write_fn)(&context_args, diagnostics);
1386                            let OperatorWriteOutput {
1387                                write_prologue,
1388                                write_iterator,
1389                                write_iterator_after,
1390                                write_tick_end,
1391                            } = write_result.unwrap_or_else(|()| {
1392                                assert!(
1393                                    diagnostics.has_error(),
1394                                    "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1395                                    op_name,
1396                                );
1397                                OperatorWriteOutput {
1398                                    write_iterator: null_write_iterator_fn(&context_args),
1399                                    ..Default::default()
1400                                }
1401                            });
1402
1403                            op_prologue_code.push(syn::parse_quote! {
1404                                #[allow(non_snake_case)]
1405                                #[inline(always)]
1406                                fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1407                                    thunk()
1408                                }
1409
1410                                #[allow(non_snake_case)]
1411                                #[inline(always)]
1412                                async fn #work_fn_async<T>(
1413                                    thunk: impl ::std::future::Future<Output = T>,
1414                                ) -> T {
1415                                    thunk.await
1416                                }
1417                            });
1418                            op_prologue_code.push(write_prologue);
1419                            op_tick_end_code.push(write_tick_end);
1420                            subgraph_op_iter_code.push(write_iterator);
1421
1422                            if include_type_guards {
1423                                let type_guard = if is_pull {
1424                                    quote_spanned! {op_span=>
1425                                        let #ident = {
1426                                            #[allow(non_snake_case)]
1427                                            #[inline(always)]
1428                                            pub fn #work_fn<Item, Input>(input: Input)
1429                                                -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = (), CanPend = Input::CanPend, CanEnd = Input::CanEnd>
1430                                            where
1431                                                Input: #root::dfir_pipes::pull::Pull<Item = Item, Meta = ()>,
1432                                            {
1433                                                #root::pin_project_lite::pin_project! {
1434                                                    #[repr(transparent)]
1435                                                    struct Pull<Item, Input: #root::dfir_pipes::pull::Pull<Item = Item>> {
1436                                                        #[pin]
1437                                                        inner: Input
1438                                                    }
1439                                                }
1440
1441                                                impl<Item, Input> #root::dfir_pipes::pull::Pull for Pull<Item, Input>
1442                                                where
1443                                                    Input: #root::dfir_pipes::pull::Pull<Item = Item>,
1444                                                {
1445                                                    type Ctx<'ctx> = Input::Ctx<'ctx>;
1446
1447                                                    type Item = Item;
1448                                                    type Meta = Input::Meta;
1449                                                    type CanPend = Input::CanPend;
1450                                                    type CanEnd = Input::CanEnd;
1451
1452                                                    #[inline(always)]
1453                                                    fn pull(
1454                                                        self: ::std::pin::Pin<&mut Self>,
1455                                                        ctx: &mut Self::Ctx<'_>,
1456                                                    ) -> #root::dfir_pipes::pull::PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
1457                                                        #root::dfir_pipes::pull::Pull::pull(self.project().inner, ctx)
1458                                                    }
1459
1460                                                    #[inline(always)]
1461                                                    fn size_hint(&self) -> (usize, Option<usize>) {
1462                                                        #root::dfir_pipes::pull::Pull::size_hint(&self.inner)
1463                                                    }
1464                                                }
1465
1466                                                Pull {
1467                                                    inner: input
1468                                                }
1469                                            }
1470                                            #work_fn::<_, _>( #ident )
1471                                        };
1472                                    }
1473                                } else {
1474                                    quote_spanned! {op_span=>
1475                                        let #ident = {
1476                                            #[allow(non_snake_case)]
1477                                            #[inline(always)]
1478                                            pub fn #work_fn<Item, Psh>(psh: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
1479                                            where
1480                                                Psh: #root::dfir_pipes::push::Push<Item, ()>
1481                                            {
1482                                                #root::pin_project_lite::pin_project! {
1483                                                    #[repr(transparent)]
1484                                                    struct PushGuard<Psh> {
1485                                                        #[pin]
1486                                                        inner: Psh,
1487                                                    }
1488                                                }
1489
1490                                                impl<Item, Psh> #root::dfir_pipes::push::Push<Item, ()> for PushGuard<Psh>
1491                                                where
1492                                                    Psh: #root::dfir_pipes::push::Push<Item, ()>,
1493                                                {
1494                                                    type Ctx<'ctx> = Psh::Ctx<'ctx>;
1495
1496                                                    type CanPend = Psh::CanPend;
1497
1498                                                    #[inline(always)]
1499                                                    fn poll_ready(
1500                                                        self: ::std::pin::Pin<&mut Self>,
1501                                                        ctx: &mut Self::Ctx<'_>,
1502                                                    ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1503                                                        #root::dfir_pipes::push::Push::poll_ready(self.project().inner, ctx)
1504                                                    }
1505
1506                                                    #[inline(always)]
1507                                                    fn start_send(
1508                                                        self: ::std::pin::Pin<&mut Self>,
1509                                                        item: Item,
1510                                                        meta: (),
1511                                                    ) {
1512                                                        #root::dfir_pipes::push::Push::start_send(self.project().inner, item, meta)
1513                                                    }
1514
1515                                                    #[inline(always)]
1516                                                    fn poll_finalize(
1517                                                        self: ::std::pin::Pin<&mut Self>,
1518                                                        ctx: &mut Self::Ctx<'_>,
1519                                                    ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1520                                                        #root::dfir_pipes::push::Push::poll_finalize(self.project().inner, ctx)
1521                                                    }
1522
1523                                                    #[inline(always)]
1524                                                    fn size_hint(
1525                                                        self: ::std::pin::Pin<&mut Self>,
1526                                                        hint: (usize, Option<usize>),
1527                                                    ) {
1528                                                        #root::dfir_pipes::push::Push::size_hint(self.project().inner, hint)
1529                                                    }
1530                                                }
1531
1532                                                PushGuard {
1533                                                    inner: psh
1534                                                }
1535                                            }
1536                                            #work_fn( #ident )
1537                                        };
1538                                    }
1539                                };
1540                                subgraph_op_iter_code.push(type_guard);
1541                            }
1542                            subgraph_op_iter_after_code.push(write_iterator_after);
1543                        }
1544                    }
1545
1546                    {
1547                        // Determine pull and push halves of the `Pivot`.
1548                        let pull_ident = if 0 < pull_to_push_idx {
1549                            self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1550                        } else {
1551                            // Entire subgraph is push (with a single recv/pull handoff input).
1552                            recv_port_idents[0].clone()
1553                        };
1554
1555                        #[rustfmt::skip]
1556                        let push_ident = if let Some(&node_id) =
1557                            subgraph_nodes.get(pull_to_push_idx)
1558                        {
1559                            self.node_as_ident(node_id, false)
1560                        } else if 1 == send_port_idents.len() {
1561                            // Entire subgraph is pull (with a single send/push handoff output).
1562                            send_port_idents[0].clone()
1563                        } else {
1564                            diagnostics.push(Diagnostic::spanned(
1565                                pull_ident.span(),
1566                                Level::Error,
1567                                "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1568                            ));
1569                            continue;
1570                        };
1571
1572                        // Pivot span is combination of pull and push spans (or if not possible, just take the push).
1573                        let pivot_span = pull_ident
1574                            .span()
1575                            .join(push_ident.span())
1576                            .unwrap_or_else(|| push_ident.span());
1577                        let pivot_fn_ident = Ident::new(
1578                            &format!("pivot_run_sg_{:?}", subgraph_id.data()),
1579                            pivot_span,
1580                        );
1581                        let root = change_spans(root.clone(), pivot_span);
1582                        subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1583                            #[inline(always)]
1584                            fn #pivot_fn_ident<Pul, Psh, Item>(pull: Pul, push: Psh)
1585                                -> impl ::std::future::Future<Output = ()>
1586                            where
1587                                Pul: #root::dfir_pipes::pull::Pull<Item = Item>,
1588                                Psh: #root::dfir_pipes::push::Push<Item, Pul::Meta>,
1589                            {
1590                                #root::dfir_pipes::pull::Pull::send_push(pull, push)
1591                            }
1592                            (#pivot_fn_ident)(#pull_ident, #push_ident).await;
1593                        });
1594                    }
1595                };
1596
1597                // Each subgraph block is an async block so it can be individually instrumented.
1598                // Note: this ident is for the subgraph future, not a runtime SubgraphId binding
1599                // (unlike the scheduled path's `sg_ident`).
1600                let sg_fut_ident = subgraph_id.as_ident(Span::call_site());
1601
1602                // Generate send-side curr_items_count updates (after subgraph runs).
1603                let send_metrics_code: Vec<TokenStream> = send_hoffs
1604                    .iter()
1605                    .zip(send_buf_idents.iter())
1606                    .map(|(&hoff_id, buf_ident)| {
1607                        let hoff_ffi = hoff_id.data().as_ffi();
1608                        let GraphNode::Handoff { kind, .. } = self.node(hoff_id) else {
1609                            unreachable!()
1610                        };
1611                        let len_expr = match kind {
1612                            HandoffKind::Option => {
1613                                quote! { if #buf_ident.is_some() { 1 } else { 0 } }
1614                            }
1615                            HandoffKind::Vec => {
1616                                quote! { #buf_ident.len() }
1617                            }
1618                        };
1619                        quote! {
1620                            __dfir_metrics.handoffs[
1621                                #root::slotmap::KeyData::from_ffi(#hoff_ffi).into()
1622                            ].curr_items_count.set(#len_expr);
1623                        }
1624                    })
1625                    .collect();
1626
1627                subgraph_blocks.push(quote! {
1628                    let #sg_fut_ident = async {
1629                        let #context = &#df;
1630                        #( #recv_port_code )*
1631                        #( #send_port_code )*
1632                        #( #subgraph_op_iter_code )*
1633                        #( #subgraph_op_iter_after_code )*
1634                    };
1635                    {
1636                        let sg_metrics = &__dfir_metrics.subgraphs[
1637                            #root::slotmap::KeyData::from_ffi(#sg_metrics_ffi).into()
1638                        ];
1639                        #root::scheduled::metrics::InstrumentSubgraph::new(
1640                            #sg_fut_ident, sg_metrics
1641                        ).await;
1642                        sg_metrics.total_run_count.update(|x| x + 1);
1643                    }
1644                    #( #send_metrics_code )*
1645                });
1646
1647                // Collect per-subgraph prologues into the main prologue lists.
1648                // (They are already pushed above in the operator loop.)
1649            }
1650        }
1651
1652        if diagnostics.has_error() {
1653            return Err(std::mem::take(diagnostics));
1654        }
1655        let _ = diagnostics; // Ensure no more diagnostics may be added after checking for errors.
1656
1657        let (meta_graph_arg, diagnostics_arg) = if include_meta {
1658            let meta_graph_json = serde_json::to_string(&self).unwrap();
1659            let meta_graph_json = Literal::string(&meta_graph_json);
1660
1661            let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1662            let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1663            let diagnostics_json = Literal::string(&diagnostics_json);
1664
1665            (
1666                quote! { Some(#meta_graph_json) },
1667                quote! { Some(#diagnostics_json) },
1668            )
1669        } else {
1670            (quote! { None }, quote! { None })
1671        };
1672
1673        // Generate metrics initialization: one entry per handoff and per subgraph.
1674        let metrics_init_code = {
1675            let handoff_inits = handoff_nodes.iter().map(|&(node_id, _, _)| {
1676                let ffi = node_id.data().as_ffi();
1677                quote! {
1678                    dfir_metrics.handoffs.insert(
1679                        #root::slotmap::KeyData::from_ffi(#ffi).into(),
1680                        ::std::default::Default::default(),
1681                    );
1682                }
1683            });
1684            let subgraph_inits = all_subgraphs.iter().map(|&(sg_id, _)| {
1685                let ffi = sg_id.data().as_ffi();
1686                quote! {
1687                    dfir_metrics.subgraphs.insert(
1688                        #root::slotmap::KeyData::from_ffi(#ffi).into(),
1689                        ::std::default::Default::default(),
1690                    );
1691                }
1692            });
1693            handoff_inits.chain(subgraph_inits).collect::<Vec<_>>()
1694        };
1695
1696        // Prologues and buffer declarations persist across ticks (outside the closure).
1697        // Subgraph blocks run each tick (inside the closure).
1698        Ok(quote! {
1699            {
1700                #prefix
1701
1702                use #root::{var_expr, var_args};
1703
1704                let __dfir_wake_state = ::std::sync::Arc::new(
1705                    #root::scheduled::context::WakeState::default()
1706                );
1707
1708                let __dfir_metrics = {
1709                    let mut dfir_metrics = #root::scheduled::metrics::DfirMetrics::default();
1710                    #( #metrics_init_code )*
1711                    ::std::rc::Rc::new(dfir_metrics)
1712                };
1713
1714                #[allow(unused_mut)]
1715                let mut #df = #root::scheduled::context::Context::new(
1716                    ::std::clone::Clone::clone(&__dfir_wake_state),
1717                    __dfir_metrics,
1718                );
1719
1720                #( #buffer_code )*
1721                #( #back_buffer_code )*
1722                #( #op_prologue_code )*
1723
1724                // Pre-set to true so the first tick always returns true
1725                // (matching Dfir pre-scheduling behavior). Subsequent ticks
1726                // start false (from take()) and are set true by recv port code
1727                // if any handoff buffer has data.
1728                let mut __dfir_work_done = true;
1729                #[allow(unused_qualifications, unused_mut, unused_variables, clippy::await_holding_refcell_ref, clippy::deref_addrof)]
1730                let __dfir_inline_tick = async move |#df: &mut #root::scheduled::context::Context| {
1731                    let __dfir_metrics = #df.metrics();
1732                    // Double-buffer swap for defer_tick handoffs: move last tick's
1733                    // producer output into the back buffer for the consumer to drain.
1734                    #( #back_edge_swap_code )*
1735                    #( #subgraph_blocks )*
1736
1737                    // For non-lazy defer_tick: if any deferred buffer has data,
1738                    // signal that another tick should run.
1739                    if false #( || !#defer_tick_buf_idents.is_empty() )* {
1740                        #df.schedule_subgraph(true);
1741                    }
1742
1743                    // End-of-tick state reset (e.g. 'tick persistence).
1744                    #( #op_tick_end_code )*
1745
1746                    // Drain handoff buffers that have no pipe consumer (e.g. singleton
1747                    // used only via #var reference). Without this, the value would
1748                    // persist across ticks and cause panics on the next write.
1749                    #( #no_consumer_drain_code )*
1750
1751                    #df.__end_tick();
1752                    ::std::mem::take(&mut __dfir_work_done)
1753                };
1754                #root::scheduled::context::Dfir::new(
1755                    __dfir_inline_tick,
1756                    #df,
1757                    #meta_graph_arg,
1758                    #diagnostics_arg,
1759                )
1760            }
1761        })
1762    }
1763
1764    /// Color mode (pull vs. push, handoff vs. comp) for nodes. Some nodes can be push *OR* pull;
1765    /// those nodes will not be set in the returned map.
1766    pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1767        let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1768            .node_ids()
1769            .filter_map(|node_id| {
1770                let op_color = self.node_color(node_id)?;
1771                Some((node_id, op_color))
1772            })
1773            .collect();
1774
1775        // Fill in rest via subgraphs.
1776        for sg_nodes in self.subgraph_nodes.values() {
1777            let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1778
1779            for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1780                let is_pull = idx < pull_to_push_idx;
1781                node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1782            }
1783        }
1784
1785        node_color_map
1786    }
1787
1788    /// Writes this graph as mermaid into a string.
1789    pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1790        let mut output = String::new();
1791        self.write_mermaid(&mut output, write_config).unwrap();
1792        output
1793    }
1794
1795    /// Writes this graph as mermaid into the given `Write`.
1796    pub fn write_mermaid(
1797        &self,
1798        output: impl std::fmt::Write,
1799        write_config: &WriteConfig,
1800    ) -> std::fmt::Result {
1801        let mut graph_write = Mermaid::new(output);
1802        self.write_graph(&mut graph_write, write_config)
1803    }
1804
1805    /// Writes this graph as DOT (graphviz) into a string.
1806    pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1807        let mut output = String::new();
1808        let mut graph_write = Dot::new(&mut output);
1809        self.write_graph(&mut graph_write, write_config).unwrap();
1810        output
1811    }
1812
1813    /// Writes this graph as DOT (graphviz) into the given `Write`.
1814    pub fn write_dot(
1815        &self,
1816        output: impl std::fmt::Write,
1817        write_config: &WriteConfig,
1818    ) -> std::fmt::Result {
1819        let mut graph_write = Dot::new(output);
1820        self.write_graph(&mut graph_write, write_config)
1821    }
1822
1823    /// Write out this graph using the given `GraphWrite`. E.g. `Mermaid` or `Dot.
1824    pub(crate) fn write_graph<W>(
1825        &self,
1826        mut graph_write: W,
1827        write_config: &WriteConfig,
1828    ) -> Result<(), W::Err>
1829    where
1830        W: GraphWrite,
1831    {
1832        fn helper_edge_label(
1833            src_port: &PortIndexValue,
1834            dst_port: &PortIndexValue,
1835        ) -> Option<String> {
1836            let src_label = match src_port {
1837                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1838                PortIndexValue::Int(index) => Some(index.value.to_string()),
1839                _ => None,
1840            };
1841            let dst_label = match dst_port {
1842                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1843                PortIndexValue::Int(index) => Some(index.value.to_string()),
1844                _ => None,
1845            };
1846            let label = match (src_label, dst_label) {
1847                (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1848                (Some(l1), None) => Some(l1),
1849                (None, Some(l2)) => Some(l2),
1850                (None, None) => None,
1851            };
1852            label
1853        }
1854
1855        // Make node color map one time.
1856        let node_color_map = self.node_color_map();
1857
1858        // Write prologue.
1859        graph_write.write_prologue()?;
1860
1861        // Define nodes.
1862        let mut skipped_handoffs = BTreeSet::new();
1863        let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1864        for (node_id, node) in self.nodes() {
1865            if matches!(node, GraphNode::Handoff { .. }) {
1866                if write_config.no_handoffs {
1867                    skipped_handoffs.insert(node_id);
1868                    continue;
1869                } else {
1870                    let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1871                    let pred_sg = self.node_subgraph(pred_node);
1872                    let succ_node = self.node_successor_nodes(node_id).next();
1873                    let succ_sg = succ_node.and_then(|n| self.node_subgraph(n));
1874                    if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1875                        && pred_sg == succ_sg
1876                    {
1877                        subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1878                    }
1879                }
1880            }
1881            graph_write.write_node_definition(
1882                node_id,
1883                &if write_config.op_short_text {
1884                    node.to_name_string()
1885                } else if write_config.op_text_no_imports {
1886                    // Remove any lines that start with "use" (imports)
1887                    let full_text = node.to_pretty_string();
1888                    let mut output = String::new();
1889                    for sentence in full_text.split('\n') {
1890                        if sentence.trim().starts_with("use") {
1891                            continue;
1892                        }
1893                        output.push('\n');
1894                        output.push_str(sentence);
1895                    }
1896                    output.into()
1897                } else {
1898                    node.to_pretty_string()
1899                },
1900                if write_config.no_pull_push {
1901                    None
1902                } else {
1903                    node_color_map.get(node_id).copied()
1904                },
1905            )?;
1906        }
1907
1908        // Write edges.
1909        for (edge_id, (src_id, mut dst_id)) in self.edges() {
1910            // Handling for if `write_config.no_handoffs` true.
1911            if skipped_handoffs.contains(&src_id) {
1912                continue;
1913            }
1914
1915            let (src_port, mut dst_port) = self.edge_ports(edge_id);
1916            if skipped_handoffs.contains(&dst_id) {
1917                let mut handoff_succs = self.node_successors(dst_id);
1918                assert_eq!(1, handoff_succs.len());
1919                let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1920                dst_id = succ_node;
1921                dst_port = self.edge_ports(succ_edge).1;
1922            }
1923
1924            let label = helper_edge_label(src_port, dst_port);
1925            let delay_type = self
1926                .node_op_inst(dst_id)
1927                .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1928            graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1929        }
1930
1931        // Write reference edges.
1932        if !write_config.no_references {
1933            for dst_id in self.node_ids() {
1934                for src_ref_id in self
1935                    .node_singleton_references(dst_id)
1936                    .iter()
1937                    .copied()
1938                    .flatten()
1939                {
1940                    let delay_type = Some(DelayType::Stratum);
1941                    let label = None;
1942                    graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1943                }
1944            }
1945        }
1946
1947        // The following code is a little bit tricky. Generally, the graph has the hierarchy:
1948        // `loop -> subgraph -> varname -> node`. However, each of these can be disabled via the `write_config`. To
1949        // handle both the enabled and disabled case, this code is structured as a series of nested loops. If the layer
1950        // is disabled, then the HashMap<Option<KEY>, Vec<VALUE>> will only have a single key (`None`) with a
1951        // corresponding `Vec` value containing everything. This way no special handling is needed for the next layer.
1952
1953        // Loop -> Subgraphs
1954        let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1955            let loop_id = if write_config.no_loops {
1956                None
1957            } else {
1958                self.subgraph_loop(sg_id)
1959            };
1960            (loop_id, sg_id)
1961        });
1962        let loop_subgraphs = into_group_map(loop_subgraphs);
1963        for (loop_id, subgraph_ids) in loop_subgraphs {
1964            if let Some(loop_id) = loop_id {
1965                graph_write.write_loop_start(loop_id)?;
1966            }
1967
1968            // Subgraph -> Varnames.
1969            let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
1970                self.subgraph(sg_id).iter().copied().map(move |node_id| {
1971                    let opt_sg_id = if write_config.no_subgraphs {
1972                        None
1973                    } else {
1974                        Some(sg_id)
1975                    };
1976                    (opt_sg_id, (self.node_varname(node_id), node_id))
1977                })
1978            });
1979            let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
1980            for (sg_id, varnames) in subgraph_varnames_nodes {
1981                if let Some(sg_id) = sg_id {
1982                    graph_write.write_subgraph_start(sg_id)?;
1983                }
1984
1985                // Varnames -> Nodes.
1986                let varname_nodes = varnames.into_iter().map(|(varname, node)| {
1987                    let varname = if write_config.no_varnames {
1988                        None
1989                    } else {
1990                        varname
1991                    };
1992                    (varname, node)
1993                });
1994                let varname_nodes = into_group_map(varname_nodes);
1995                for (varname, node_ids) in varname_nodes {
1996                    if let Some(varname) = varname {
1997                        graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
1998                    }
1999
2000                    // Write all nodes.
2001                    for node_id in node_ids {
2002                        graph_write.write_node(node_id)?;
2003                    }
2004
2005                    if varname.is_some() {
2006                        graph_write.write_varname_end()?;
2007                    }
2008                }
2009
2010                if sg_id.is_some() {
2011                    graph_write.write_subgraph_end()?;
2012                }
2013            }
2014
2015            if loop_id.is_some() {
2016                graph_write.write_loop_end()?;
2017            }
2018        }
2019
2020        // Write epilogue.
2021        graph_write.write_epilogue()?;
2022
2023        Ok(())
2024    }
2025
2026    /// Convert back into surface syntax.
2027    pub fn surface_syntax_string(&self) -> String {
2028        let mut string = String::new();
2029        self.write_surface_syntax(&mut string).unwrap();
2030        string
2031    }
2032
2033    /// Convert back into surface syntax.
2034    pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
2035        for (key, node) in self.nodes.iter() {
2036            match node {
2037                GraphNode::Operator(op) => {
2038                    writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
2039                }
2040                GraphNode::Handoff {
2041                    kind: HandoffKind::Vec,
2042                    ..
2043                } => {
2044                    writeln!(write, "{:?} = handoff();", key.data())?;
2045                }
2046                GraphNode::Handoff {
2047                    kind: HandoffKind::Option,
2048                    ..
2049                } => {
2050                    writeln!(write, "{:?} = singleton();", key.data())?;
2051                }
2052                GraphNode::ModuleBoundary { .. } => panic!(),
2053            }
2054        }
2055        writeln!(write)?;
2056        for (_e, (src_key, dst_key)) in self.graph.edges() {
2057            writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
2058        }
2059        Ok(())
2060    }
2061
2062    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
2063    pub fn mermaid_string_flat(&self) -> String {
2064        let mut string = String::new();
2065        self.write_mermaid_flat(&mut string).unwrap();
2066        string
2067    }
2068
2069    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
2070    pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
2071        writeln!(write, "flowchart TB")?;
2072        for (key, node) in self.nodes.iter() {
2073            match node {
2074                GraphNode::Operator(operator) => writeln!(
2075                    write,
2076                    "    %% {span}\n    {id:?}[\"{row_col} <tt>{code}</tt>\"]",
2077                    span = PrettySpan(node.span()),
2078                    id = key.data(),
2079                    row_col = PrettyRowCol(node.span()),
2080                    code = operator
2081                        .to_token_stream()
2082                        .to_string()
2083                        .replace('&', "&amp;")
2084                        .replace('<', "&lt;")
2085                        .replace('>', "&gt;")
2086                        .replace('"', "&quot;")
2087                        .replace('\n', "<br>"),
2088                ),
2089                GraphNode::Handoff {
2090                    kind: HandoffKind::Vec,
2091                    ..
2092                } => {
2093                    writeln!(write, r#"    {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
2094                }
2095                GraphNode::Handoff {
2096                    kind: HandoffKind::Option,
2097                    ..
2098                } => {
2099                    writeln!(
2100                        write,
2101                        r#"    {:?}{{"{}"}}"#,
2102                        key.data(),
2103                        SINGLETON_SLOT_NODE_STR
2104                    )
2105                }
2106                GraphNode::ModuleBoundary { .. } => {
2107                    writeln!(
2108                        write,
2109                        r#"    {:?}{{"{}"}}"#,
2110                        key.data(),
2111                        MODULE_BOUNDARY_NODE_STR
2112                    )
2113                }
2114            }?;
2115        }
2116        writeln!(write)?;
2117        for (_e, (src_key, dst_key)) in self.graph.edges() {
2118            writeln!(write, "    {:?}-->{:?}", src_key.data(), dst_key.data())?;
2119        }
2120        Ok(())
2121    }
2122}
2123
2124/// Loops
2125impl DfirGraph {
2126    /// Iterator over all loop IDs.
2127    pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
2128        self.loop_nodes.keys()
2129    }
2130
2131    /// Iterator over all loops, ID and members: `(GraphLoopId, Vec<GraphNodeId>)`.
2132    pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
2133        self.loop_nodes.iter()
2134    }
2135
2136    /// Create a new loop context, with the given parent loop (or `None`).
2137    pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
2138        let loop_id = self.loop_nodes.insert(Vec::new());
2139        self.loop_children.insert(loop_id, Vec::new());
2140        if let Some(parent_loop) = parent_loop {
2141            self.loop_parent.insert(loop_id, parent_loop);
2142            self.loop_children
2143                .get_mut(parent_loop)
2144                .unwrap()
2145                .push(loop_id);
2146        } else {
2147            self.root_loops.push(loop_id);
2148        }
2149        loop_id
2150    }
2151
2152    /// Get a node's loop context (or `None` for root).
2153    pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
2154        self.node_loops.get(node_id).copied()
2155    }
2156
2157    /// Get a subgraph's loop context (or `None` for root).
2158    pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
2159        let &node_id = self.subgraph(subgraph_id).first().unwrap();
2160        let out = self.node_loop(node_id);
2161        debug_assert!(
2162            self.subgraph(subgraph_id)
2163                .iter()
2164                .all(|&node_id| self.node_loop(node_id) == out),
2165            "Subgraph nodes should all have the same loop context."
2166        );
2167        out
2168    }
2169
2170    /// Get a loop context's parent loop context (or `None` for root).
2171    pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
2172        self.loop_parent.get(loop_id).copied()
2173    }
2174
2175    /// Get a loop context's child loops.
2176    pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
2177        self.loop_children.get(loop_id).unwrap()
2178    }
2179}
2180
2181/// Configuration for writing graphs.
2182#[derive(Clone, Debug, Default)]
2183#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
2184pub struct WriteConfig {
2185    /// Subgraphs will not be rendered if set.
2186    #[cfg_attr(feature = "clap-derive", arg(long))]
2187    pub no_subgraphs: bool,
2188    /// Variable names will not be rendered if set.
2189    #[cfg_attr(feature = "clap-derive", arg(long))]
2190    pub no_varnames: bool,
2191    /// Will not render pull/push shapes if set.
2192    #[cfg_attr(feature = "clap-derive", arg(long))]
2193    pub no_pull_push: bool,
2194    /// Will not render handoffs if set.
2195    #[cfg_attr(feature = "clap-derive", arg(long))]
2196    pub no_handoffs: bool,
2197    /// Will not render singleton references if set.
2198    #[cfg_attr(feature = "clap-derive", arg(long))]
2199    pub no_references: bool,
2200    /// Will not render loops if set.
2201    #[cfg_attr(feature = "clap-derive", arg(long))]
2202    pub no_loops: bool,
2203
2204    /// Op text will only be their name instead of the whole source.
2205    #[cfg_attr(feature = "clap-derive", arg(long))]
2206    pub op_short_text: bool,
2207    /// Op text will exclude any line that starts with "use".
2208    #[cfg_attr(feature = "clap-derive", arg(long))]
2209    pub op_text_no_imports: bool,
2210}
2211
2212/// Enum for choosing between mermaid and dot graph writing.
2213#[derive(Copy, Clone, Debug)]
2214#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
2215pub enum WriteGraphType {
2216    /// Mermaid graphs.
2217    Mermaid,
2218    /// Dot (Graphviz) graphs.
2219    Dot,
2220}
2221
2222/// [`itertools::Itertools::into_group_map`], but for `BTreeMap`.
2223fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
2224where
2225    K: Ord,
2226{
2227    let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
2228    for (k, v) in iter {
2229        out.entry(k).or_default().push(v);
2230    }
2231    out
2232}