dfir_lang/graph/
meta_graph.rs

1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet, VecDeque};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, TokenStreamExt, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18    DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19    null_write_iterator_fn,
20};
21use super::{
22    CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23    GraphSubgraphId, HANDOFF_NODE_STR, MODULE_BOUNDARY_NODE_STR, OperatorInstance, PortIndexValue,
24    Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30/// An abstract "meta graph" representation of a DFIR graph.
31///
32/// Can be with or without subgraph partitioning, stratification, and handoff insertion. This is
33/// the meta graph used for generating Rust source code in macros from DFIR sytnax.
34///
35/// This struct has a lot of methods for manipulating the graph, vaguely grouped together in
36/// separate `impl` blocks. You might notice a few particularly specific arbitray-seeming methods
37/// in here--those are just what was needed for the compilation algorithms. If you need another
38/// method then add it.
39#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41    /// Each node type (operator or handoff).
42    nodes: SlotMap<GraphNodeId, GraphNode>,
43
44    /// Instance data corresponding to each operator node.
45    /// This field will be empty after deserialization.
46    #[serde(skip)]
47    operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48    /// Debugging/tracing tag for each operator node.
49    operator_tag: SecondaryMap<GraphNodeId, String>,
50    /// Graph data structure (two-way adjacency list).
51    graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52    /// Input and output port for each edge.
53    ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55    /// Which loop a node belongs to (or none for top-level).
56    node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57    /// Which nodes belong to each loop.
58    loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59    /// For the loop, what is its parent (`None` for top-level).
60    loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61    /// What loops are at the root.
62    root_loops: Vec<GraphLoopId>,
63    /// For the loop, what are its child loops.
64    loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66    /// Which subgraph each node belongs to.
67    node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69    /// Which nodes belong to each subgraph.
70    subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71    /// Which stratum each subgraph belongs to.
72    subgraph_stratum: SecondaryMap<GraphSubgraphId, usize>,
73
74    /// Resolved singletons varnames references, per node.
75    node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
76    /// What variable name each graph node belongs to (if any). For debugging (graph writing) purposes only.
77    node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
78
79    /// If this subgraph is 'lazy' then when it sends data to a lower stratum it does not cause a new tick to start
80    /// This is to support lazy defers
81    /// If the value does not exist for a given subgraph id then the subgraph is not lazy.
82    subgraph_laziness: SecondaryMap<GraphSubgraphId, bool>,
83}
84
85/// Basic methods.
86impl DfirGraph {
87    /// Create a new empty graph.
88    pub fn new() -> Self {
89        Default::default()
90    }
91}
92
93/// Node methods.
94impl DfirGraph {
95    /// Get a node with its operator instance (if applicable).
96    pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
97        self.nodes.get(node_id).expect("Node not found.")
98    }
99
100    /// Get the `OperatorInstance` for a given node. Node must be an operator and have an
101    /// `OperatorInstance` present, otherwise will return `None`.
102    ///
103    /// Note that no operator instances will be persent after deserialization.
104    pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
105        self.operator_instances.get(node_id)
106    }
107
108    /// Get the debug variable name attached to a graph node.
109    pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
110        self.node_varnames.get(node_id)
111    }
112
113    /// Get subgraph for node.
114    pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
115        self.node_subgraph.get(node_id).copied()
116    }
117
118    /// Degree into a node, i.e. the number of predecessors.
119    pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
120        self.graph.degree_in(node_id)
121    }
122
123    /// Degree out of a node, i.e. the number of successors.
124    pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
125        self.graph.degree_out(node_id)
126    }
127
128    /// Successors, iterator of `(GraphEdgeId, GraphNodeId)` of outgoing edges.
129    pub fn node_successors(
130        &self,
131        src: GraphNodeId,
132    ) -> impl '_
133    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
134    + ExactSizeIterator
135    + FusedIterator
136    + Clone
137    + Debug {
138        self.graph.successors(src)
139    }
140
141    /// Predecessors, iterator of `(GraphEdgeId, GraphNodeId)` of incoming edges.
142    pub fn node_predecessors(
143        &self,
144        dst: GraphNodeId,
145    ) -> impl '_
146    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
147    + ExactSizeIterator
148    + FusedIterator
149    + Clone
150    + Debug {
151        self.graph.predecessors(dst)
152    }
153
154    /// Successor edges, iterator of `GraphEdgeId` of outgoing edges.
155    pub fn node_successor_edges(
156        &self,
157        src: GraphNodeId,
158    ) -> impl '_
159    + DoubleEndedIterator<Item = GraphEdgeId>
160    + ExactSizeIterator
161    + FusedIterator
162    + Clone
163    + Debug {
164        self.graph.successor_edges(src)
165    }
166
167    /// Predecessor edges, iterator of `GraphEdgeId` of incoming edges.
168    pub fn node_predecessor_edges(
169        &self,
170        dst: GraphNodeId,
171    ) -> impl '_
172    + DoubleEndedIterator<Item = GraphEdgeId>
173    + ExactSizeIterator
174    + FusedIterator
175    + Clone
176    + Debug {
177        self.graph.predecessor_edges(dst)
178    }
179
180    /// Successor nodes, iterator of `GraphNodeId`.
181    pub fn node_successor_nodes(
182        &self,
183        src: GraphNodeId,
184    ) -> impl '_
185    + DoubleEndedIterator<Item = GraphNodeId>
186    + ExactSizeIterator
187    + FusedIterator
188    + Clone
189    + Debug {
190        self.graph.successor_vertices(src)
191    }
192
193    /// Predecessor nodes, iterator of `GraphNodeId`.
194    pub fn node_predecessor_nodes(
195        &self,
196        dst: GraphNodeId,
197    ) -> impl '_
198    + DoubleEndedIterator<Item = GraphNodeId>
199    + ExactSizeIterator
200    + FusedIterator
201    + Clone
202    + Debug {
203        self.graph.predecessor_vertices(dst)
204    }
205
206    /// Iterator of node IDs `GraphNodeId`.
207    pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
208        self.nodes.keys()
209    }
210
211    /// Iterator over `(GraphNodeId, &Node)` pairs.
212    pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
213        self.nodes.iter()
214    }
215
216    /// Insert a node, assigning the given varname.
217    pub fn insert_node(
218        &mut self,
219        node: GraphNode,
220        varname_opt: Option<Ident>,
221        loop_opt: Option<GraphLoopId>,
222    ) -> GraphNodeId {
223        let node_id = self.nodes.insert(node);
224        if let Some(varname) = varname_opt {
225            self.node_varnames.insert(node_id, Varname(varname));
226        }
227        if let Some(loop_id) = loop_opt {
228            self.node_loops.insert(node_id, loop_id);
229            self.loop_nodes[loop_id].push(node_id);
230        }
231        node_id
232    }
233
234    /// Insert an operator instance for the given node. Panics if already set.
235    pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
236        assert!(matches!(
237            self.nodes.get(node_id),
238            Some(GraphNode::Operator(_))
239        ));
240        let old_inst = self.operator_instances.insert(node_id, op_inst);
241        assert!(old_inst.is_none());
242    }
243
244    /// Assign all operator instances if not set. Write diagnostic messages/errors into `diagnostics`.
245    pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Vec<Diagnostic>) {
246        let mut op_insts = Vec::new();
247        for (node_id, node) in self.nodes() {
248            let GraphNode::Operator(operator) = node else {
249                continue;
250            };
251            if self.node_op_inst(node_id).is_some() {
252                continue;
253            };
254
255            // Op constraints.
256            let Some(op_constraints) = find_op_op_constraints(operator) else {
257                diagnostics.push(Diagnostic::spanned(
258                    operator.path.span(),
259                    Level::Error,
260                    format!("Unknown operator `{}`", operator.name_string()),
261                ));
262                continue;
263            };
264
265            // Input and output ports.
266            let (input_ports, output_ports) = {
267                let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
268                    .node_predecessors(node_id)
269                    .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
270                    .collect();
271                // Ensure sorted by port index.
272                input_edges.sort();
273                let input_ports: Vec<PortIndexValue> = input_edges
274                    .into_iter()
275                    .map(|(port, _pred)| port)
276                    .cloned()
277                    .collect();
278
279                // Collect output arguments (successors).
280                let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
281                    .node_successors(node_id)
282                    .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
283                    .collect();
284                // Ensure sorted by port index.
285                output_edges.sort();
286                let output_ports: Vec<PortIndexValue> = output_edges
287                    .into_iter()
288                    .map(|(port, _succ)| port)
289                    .cloned()
290                    .collect();
291
292                (input_ports, output_ports)
293            };
294
295            // Generic arguments.
296            let generics = get_operator_generics(diagnostics, operator);
297            // Generic argument errors.
298            {
299                // Span of `generic_args` (if it exists), otherwise span of the operator name.
300                let generics_span = generics
301                    .generic_args
302                    .as_ref()
303                    .map(Spanned::span)
304                    .unwrap_or_else(|| operator.path.span());
305
306                if !op_constraints
307                    .persistence_args
308                    .contains(&generics.persistence_args.len())
309                {
310                    diagnostics.push(Diagnostic::spanned(
311                        generics_span,
312                        Level::Error,
313                        format!(
314                            "`{}` should have {} persistence lifetime arguments, actually has {}.",
315                            op_constraints.name,
316                            op_constraints.persistence_args.human_string(),
317                            generics.persistence_args.len()
318                        ),
319                    ));
320                }
321                if !op_constraints.type_args.contains(&generics.type_args.len()) {
322                    diagnostics.push(Diagnostic::spanned(
323                        generics_span,
324                        Level::Error,
325                        format!(
326                            "`{}` should have {} generic type arguments, actually has {}.",
327                            op_constraints.name,
328                            op_constraints.type_args.human_string(),
329                            generics.type_args.len()
330                        ),
331                    ));
332                }
333            }
334
335            op_insts.push((
336                node_id,
337                OperatorInstance {
338                    op_constraints,
339                    input_ports,
340                    output_ports,
341                    singletons_referenced: operator.singletons_referenced.clone(),
342                    generics,
343                    arguments_pre: operator.args.clone(),
344                    arguments_raw: operator.args_raw.clone(),
345                },
346            ));
347        }
348
349        for (node_id, op_inst) in op_insts {
350            self.insert_node_op_inst(node_id, op_inst);
351        }
352    }
353
354    /// Inserts a node between two existing nodes connected by the given `edge_id`.
355    ///
356    /// `edge`: (src, dst, dst_idx)
357    ///
358    /// Before: A (src) ------------> B (dst)
359    /// After:  A (src) -> X (new) -> B (dst)
360    ///
361    /// Returns the ID of X & ID of edge OUT of X.
362    ///
363    /// Note that both the edges will be new and `edge_id` will be removed. Both new edges will
364    /// get the edge type of the original edge.
365    pub fn insert_intermediate_node(
366        &mut self,
367        edge_id: GraphEdgeId,
368        new_node: GraphNode,
369    ) -> (GraphNodeId, GraphEdgeId) {
370        let span = Some(new_node.span());
371
372        // Make corresponding operator instance (if `node` is an operator).
373        let op_inst_opt = 'oc: {
374            let GraphNode::Operator(operator) = &new_node else {
375                break 'oc None;
376            };
377            let Some(op_constraints) = find_op_op_constraints(operator) else {
378                break 'oc None;
379            };
380            let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
381            let generics = get_operator_generics(
382                &mut Vec::new(), // TODO(mingwei) diagnostics
383                operator,
384            );
385            Some(OperatorInstance {
386                op_constraints,
387                input_ports: vec![input_port],
388                output_ports: vec![output_port],
389                singletons_referenced: operator.singletons_referenced.clone(),
390                generics,
391                arguments_pre: operator.args.clone(),
392                arguments_raw: operator.args_raw.clone(),
393            })
394        };
395
396        // Insert new `node`.
397        let node_id = self.nodes.insert(new_node);
398        // Insert corresponding `OperatorInstance` if applicable.
399        if let Some(op_inst) = op_inst_opt {
400            self.operator_instances.insert(node_id, op_inst);
401        }
402        // Update edges to insert node within `edge_id`.
403        let (e0, e1) = self
404            .graph
405            .insert_intermediate_vertex(node_id, edge_id)
406            .unwrap();
407
408        // Update corresponding ports.
409        let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
410        self.ports
411            .insert(e0, (src_idx, PortIndexValue::Elided(span)));
412        self.ports
413            .insert(e1, (PortIndexValue::Elided(span), dst_idx));
414
415        (node_id, e1)
416    }
417
418    /// Remove the node `node_id` but preserves and connects the single predecessor and single successor.
419    /// Panics if the node does not have exactly one predecessor and one successor, or is not in the graph.
420    pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
421        assert_eq!(
422            1,
423            self.node_degree_in(node_id),
424            "Removed intermediate node must have one predecessor"
425        );
426        assert_eq!(
427            1,
428            self.node_degree_out(node_id),
429            "Removed intermediate node must have one successor"
430        );
431        assert!(
432            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
433            "Should not remove intermediate node after subgraph partitioning"
434        );
435
436        assert!(self.nodes.remove(node_id).is_some());
437        let (new_edge_id, (pred_edge_id, succ_edge_id)) =
438            self.graph.remove_intermediate_vertex(node_id).unwrap();
439        self.operator_instances.remove(node_id);
440        self.node_varnames.remove(node_id);
441
442        let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
443        let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
444        self.ports.insert(new_edge_id, (src_port, dst_port));
445    }
446
447    /// Helper method: determine the "color" (pull vs push) of a node based on its in and out degree,
448    /// excluding reference edges. If linear (1 in, 1 out), color is `None`, indicating it can be
449    /// either push or pull.
450    ///
451    /// Note that this does NOT consider `DelayType` barriers (which generally implies `Pull`).
452    pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
453        if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
454            return Some(Color::Hoff);
455        }
456
457        // TODO(shadaj): this is a horrible hack
458        if let GraphNode::Operator(op) = self.node(node_id)
459            && (op.name_string() == "resolve_futures_blocking"
460                || op.name_string() == "resolve_futures_blocking_ordered")
461        {
462            return Some(Color::Push);
463        }
464
465        // In-degree, excluding ref-edges.
466        let inn_degree = self.node_predecessor_nodes(node_id).count();
467        // Out-degree excluding ref-edges.
468        let out_degree = self.node_successor_nodes(node_id).count();
469
470        match (inn_degree, out_degree) {
471            (0, 0) => None, // Generally should not happen, "Degenerate subgraph detected".
472            (0, 1) => Some(Color::Pull),
473            (1, 0) => Some(Color::Push),
474            (1, 1) => None, // Linear, can be either push or pull.
475            (_many, 0 | 1) => Some(Color::Pull),
476            (0 | 1, _many) => Some(Color::Push),
477            (_many, _to_many) => Some(Color::Comp),
478        }
479    }
480
481    /// Set the operator tag (for debugging/tracing).
482    pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
483        self.operator_tag.insert(node_id, tag.to_owned());
484    }
485}
486
487/// Singleton references.
488impl DfirGraph {
489    /// Set the singletons referenced for the `node_id` operator. Each reference corresponds to the
490    /// same index in the [`crate::parse::Operator::singletons_referenced`] vec.
491    pub fn set_node_singleton_references(
492        &mut self,
493        node_id: GraphNodeId,
494        singletons_referenced: Vec<Option<GraphNodeId>>,
495    ) -> Option<Vec<Option<GraphNodeId>>> {
496        self.node_singleton_references
497            .insert(node_id, singletons_referenced)
498    }
499
500    /// Gets the singletons referenced by a node. Returns an empty iterator for non-operators and
501    /// operators that do not reference singletons.
502    pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
503        self.node_singleton_references
504            .get(node_id)
505            .map(std::ops::Deref::deref)
506            .unwrap_or_default()
507    }
508}
509
510/// Module methods.
511impl DfirGraph {
512    /// When modules are imported into a flat graph, they come with an input and output ModuleBoundary node.
513    /// The partitioner doesn't understand these nodes and will panic if it encounters them.
514    /// merge_modules removes them from the graph, stitching the input and ouput sides of the ModuleBondaries based on their ports
515    /// For example:
516    ///     source_iter([]) -> \[myport\]ModuleBoundary(input)\[my_port\] -> map(|x| x) -> ModuleBoundary(output) -> null();
517    /// in the above eaxmple, the \[myport\] port will be used to connect the source_iter with the map that is inside of the module.
518    /// The output module boundary has elided ports, this is also used to match up the input/output across the module boundary.
519    pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
520        let mod_bound_nodes = self
521            .nodes()
522            .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
523            .map(|(nid, _node)| nid)
524            .collect::<Vec<_>>();
525
526        for mod_bound_node in mod_bound_nodes {
527            self.remove_module_boundary(mod_bound_node)?;
528        }
529
530        Ok(())
531    }
532
533    /// see `merge_modules`
534    /// This function removes a singular module boundary from the graph and performs the necessary stitching to fix the graph afterward.
535    /// `merge_modules` calls this function for each module boundary in the graph.
536    fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
537        assert!(
538            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
539            "Should not remove intermediate node after subgraph partitioning"
540        );
541
542        let mut mod_pred_ports = BTreeMap::new();
543        let mut mod_succ_ports = BTreeMap::new();
544
545        for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
546            let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
547            mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
548        }
549
550        for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
551            let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
552            mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
553        }
554
555        if mod_pred_ports.keys().collect::<BTreeSet<_>>()
556            != mod_succ_ports.keys().collect::<BTreeSet<_>>()
557        {
558            // get module boundary node
559            let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
560                panic!();
561            };
562
563            if *input {
564                return Err(Diagnostic {
565                    span: *import_expr,
566                    level: Level::Error,
567                    message: format!(
568                        "The ports into the module did not match. input: {:?}, expected: {:?}",
569                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
570                        mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
571                    ),
572                });
573            } else {
574                return Err(Diagnostic {
575                    span: *import_expr,
576                    level: Level::Error,
577                    message: format!(
578                        "The ports out of the module did not match. output: {:?}, expected: {:?}",
579                        mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
580                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
581                    ),
582                });
583            }
584        }
585
586        for (port, (pred_edge, pred_port)) in mod_pred_ports {
587            let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
588
589            let (src, _) = self.edge(pred_edge);
590            let (_, dst) = self.edge(succ_edge);
591            self.remove_edge(pred_edge);
592            self.remove_edge(succ_edge);
593
594            let new_edge_id = self.graph.insert_edge(src, dst);
595            self.ports.insert(new_edge_id, (pred_port, succ_port));
596        }
597
598        self.graph.remove_vertex(mod_bound_node);
599        self.nodes.remove(mod_bound_node);
600
601        Ok(())
602    }
603}
604
605/// Edge methods.
606impl DfirGraph {
607    /// Get the `src` and `dst` for an edge: `(src GraphNodeId, dst GraphNodeId)`.
608    pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
609        let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
610        (src, dst)
611    }
612
613    /// Get the source and destination ports for an edge: `(src &PortIndexValue, dst &PortIndexValue)`.
614    pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
615        let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
616        (src_port, dst_port)
617    }
618
619    /// Iterator of all edge IDs `GraphEdgeId`.
620    pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
621        self.graph.edge_ids()
622    }
623
624    /// Iterator over all edges: `(GraphEdgeId, (src GraphNodeId, dst GraphNodeId))`.
625    pub fn edges(
626        &self,
627    ) -> impl '_
628    + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
629    + FusedIterator
630    + Clone
631    + Debug {
632        self.graph.edges()
633    }
634
635    /// Insert an edge between nodes thru the given ports.
636    pub fn insert_edge(
637        &mut self,
638        src: GraphNodeId,
639        src_port: PortIndexValue,
640        dst: GraphNodeId,
641        dst_port: PortIndexValue,
642    ) -> GraphEdgeId {
643        let edge_id = self.graph.insert_edge(src, dst);
644        self.ports.insert(edge_id, (src_port, dst_port));
645        edge_id
646    }
647
648    /// Removes an edge and its corresponding ports and edge type info.
649    pub fn remove_edge(&mut self, edge: GraphEdgeId) {
650        let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
651        let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
652    }
653}
654
655/// Subgraph methods.
656impl DfirGraph {
657    /// Nodes belonging to the given subgraph.
658    pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
659        self.subgraph_nodes
660            .get(subgraph_id)
661            .expect("Subgraph not found.")
662    }
663
664    /// Iterator over all subgraph IDs.
665    pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
666        self.subgraph_nodes.keys()
667    }
668
669    /// Iterator over all subgraphs, ID and members: `(GraphSubgraphId, Vec<GraphNodeId>)`.
670    pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
671        self.subgraph_nodes.iter()
672    }
673
674    /// Create a subgraph consisting of `node_ids`. Returns an error if any of the nodes are already in a subgraph.
675    pub fn insert_subgraph(
676        &mut self,
677        node_ids: Vec<GraphNodeId>,
678    ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
679        // Check none are already in subgraphs
680        for &node_id in node_ids.iter() {
681            if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
682                return Err((node_id, old_sg_id));
683            }
684        }
685        let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
686            for &node_id in node_ids.iter() {
687                self.node_subgraph.insert(node_id, sg_id);
688            }
689            node_ids
690        });
691
692        Ok(subgraph_id)
693    }
694
695    /// Removes a node from its subgraph. Returns true if the node was in a subgraph.
696    pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
697        if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
698            self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
699            true
700        } else {
701            false
702        }
703    }
704
705    /// Gets the stratum number of the subgraph.
706    pub fn subgraph_stratum(&self, sg_id: GraphSubgraphId) -> Option<usize> {
707        self.subgraph_stratum.get(sg_id).copied()
708    }
709
710    /// Set subgraph's stratum number, returning the old value if exists.
711    pub fn set_subgraph_stratum(
712        &mut self,
713        sg_id: GraphSubgraphId,
714        stratum: usize,
715    ) -> Option<usize> {
716        self.subgraph_stratum.insert(sg_id, stratum)
717    }
718
719    /// Gets whether the subgraph is lazy or not
720    fn subgraph_laziness(&self, sg_id: GraphSubgraphId) -> bool {
721        self.subgraph_laziness.get(sg_id).copied().unwrap_or(false)
722    }
723
724    /// Set subgraph's laziness, returning the old value.
725    pub fn set_subgraph_laziness(&mut self, sg_id: GraphSubgraphId, lazy: bool) -> bool {
726        self.subgraph_laziness.insert(sg_id, lazy).unwrap_or(false)
727    }
728
729    /// Returns the the stratum number of the largest (latest) stratum (inclusive).
730    pub fn max_stratum(&self) -> Option<usize> {
731        self.subgraph_stratum.values().copied().max()
732    }
733
734    /// Helper: finds the first index in `subgraph_nodes` where it transitions from pull to push.
735    fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
736        subgraph_nodes
737            .iter()
738            .position(|&node_id| {
739                self.node_color(node_id)
740                    .is_some_and(|color| Color::Pull != color)
741            })
742            .unwrap_or(subgraph_nodes.len())
743    }
744}
745
746/// Display/output methods.
747impl DfirGraph {
748    /// Helper to generate a deterministic `Ident` for the given node.
749    fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
750        let name = match &self.nodes[node_id] {
751            GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
752            GraphNode::Handoff { .. } => format!(
753                "hoff_{:?}_{}",
754                node_id.data(),
755                if is_pred { "recv" } else { "send" }
756            ),
757            GraphNode::ModuleBoundary { .. } => panic!(),
758        };
759        let span = match (is_pred, &self.nodes[node_id]) {
760            (_, GraphNode::Operator(operator)) => operator.span(),
761            (true, &GraphNode::Handoff { src_span, .. }) => src_span,
762            (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
763            (_, GraphNode::ModuleBoundary { .. }) => panic!(),
764        };
765        Ident::new(&name, span)
766    }
767
768    /// For per-node singleton references. Helper to generate a deterministic `Ident` for the given node.
769    fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
770        Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
771    }
772
773    /// Resolve the singletons via [`Self::node_singleton_references`] for the given `node_id`.
774    fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<Ident> {
775        self.node_singleton_references(node_id)
776            .iter()
777            .map(|singleton_node_id| {
778                // TODO(mingwei): this `expect` should be caught in error checking
779                self.node_as_singleton_ident(
780                    singleton_node_id
781                        .expect("Expected singleton to be resolved but was not, this is a bug."),
782                    span,
783                )
784            })
785            .collect::<Vec<_>>()
786    }
787
788    /// Returns each subgraph's receive and send handoffs.
789    /// `Map<GraphSubgraphId, (recv handoffs, send handoffs)>`
790    fn helper_collect_subgraph_handoffs(
791        &self,
792    ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
793        // Get data on handoff src and dst subgraphs.
794        let mut subgraph_handoffs: SecondaryMap<
795            GraphSubgraphId,
796            (Vec<GraphNodeId>, Vec<GraphNodeId>),
797        > = self
798            .subgraph_nodes
799            .keys()
800            .map(|k| (k, Default::default()))
801            .collect();
802
803        // For each handoff node, add it to the `send`/`recv` lists for the corresponding subgraphs.
804        for (hoff_id, node) in self.nodes() {
805            if !matches!(node, GraphNode::Handoff { .. }) {
806                continue;
807            }
808            // Receivers from the handoff. (Should really only be one).
809            for (_edge, succ_id) in self.node_successors(hoff_id) {
810                let succ_sg = self.node_subgraph(succ_id).unwrap();
811                subgraph_handoffs[succ_sg].0.push(hoff_id);
812            }
813            // Senders into the handoff. (Should really only be one).
814            for (_edge, pred_id) in self.node_predecessors(hoff_id) {
815                let pred_sg = self.node_subgraph(pred_id).unwrap();
816                subgraph_handoffs[pred_sg].1.push(hoff_id);
817            }
818        }
819
820        subgraph_handoffs
821    }
822
823    /// Code for adding all nested loops.
824    fn codegen_nested_loops(&self, df: &Ident) -> TokenStream {
825        // Breadth-first iteration from outermost (root) loops to deepest nested loops.
826        let mut out = TokenStream::new();
827        let mut queue = VecDeque::from_iter(self.root_loops.iter().copied());
828        while let Some(loop_id) = queue.pop_front() {
829            let parent_opt = self
830                .loop_parent(loop_id)
831                .map(|loop_id| loop_id.as_ident(Span::call_site()))
832                .map(|ident| quote! { Some(#ident) })
833                .unwrap_or_else(|| quote! { None });
834            let loop_name = loop_id.as_ident(Span::call_site());
835            out.append_all(quote! {
836                let #loop_name = #df.add_loop(#parent_opt);
837            });
838            queue.extend(self.loop_children.get(loop_id).into_iter().flatten());
839        }
840        out
841    }
842
843    /// Emit this graph as runnable Rust source code tokens.
844    pub fn as_code(
845        &self,
846        root: &TokenStream,
847        include_type_guards: bool,
848        prefix: TokenStream,
849        diagnostics: &mut Vec<Diagnostic>,
850    ) -> TokenStream {
851        let df = Ident::new(GRAPH, Span::call_site());
852        let context = Ident::new(CONTEXT, Span::call_site());
853
854        // Code for adding handoffs.
855        let handoff_code = self
856            .nodes
857            .iter()
858            .filter_map(|(node_id, node)| match node {
859                GraphNode::Operator(_) => None,
860                &GraphNode::Handoff { src_span, dst_span } => Some((node_id, (src_span, dst_span))),
861                GraphNode::ModuleBoundary { .. } => panic!(),
862            })
863            .map(|(node_id, (src_span, dst_span))| {
864                let ident_send = Ident::new(&format!("hoff_{:?}_send", node_id.data()), dst_span);
865                let ident_recv = Ident::new(&format!("hoff_{:?}_recv", node_id.data()), src_span);
866                let span = src_span.join(dst_span).unwrap_or(src_span);
867                let mut hoff_name = Literal::string(&format!("handoff {:?}", node_id));
868                hoff_name.set_span(span);
869                let hoff_type = quote_spanned! (span=> #root::scheduled::handoff::VecHandoff<_>);
870                quote_spanned! {span=>
871                    let (#ident_send, #ident_recv) =
872                        #df.make_edge::<_, #hoff_type>(#hoff_name);
873                }
874            });
875
876        let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
877
878        // we first generate the subgraphs that have no inputs to guide type inference
879        let (subgraphs_without_preds, subgraphs_with_preds) = self
880            .subgraph_nodes
881            .iter()
882            .partition::<Vec<_>, _>(|(_, nodes)| {
883                nodes
884                    .iter()
885                    .any(|&node_id| self.node_degree_in(node_id) == 0)
886            });
887
888        let mut op_prologue_code = Vec::new();
889        let mut op_prologue_after_code = Vec::new();
890        let mut subgraphs = Vec::new();
891        {
892            for &(subgraph_id, subgraph_nodes) in subgraphs_without_preds
893                .iter()
894                .chain(subgraphs_with_preds.iter())
895            {
896                let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
897                let recv_ports: Vec<Ident> = recv_hoffs
898                    .iter()
899                    .map(|&hoff_id| self.node_as_ident(hoff_id, true))
900                    .collect();
901                let send_ports: Vec<Ident> = send_hoffs
902                    .iter()
903                    .map(|&hoff_id| self.node_as_ident(hoff_id, false))
904                    .collect();
905
906                let recv_port_code = recv_ports.iter().map(|ident| {
907                    quote_spanned! {ident.span()=>
908                        let mut #ident = #ident.borrow_mut_swap();
909                        let #ident = #root::futures::stream::iter(#ident.drain(..));
910                    }
911                });
912                let send_port_code = send_ports.iter().map(|ident| {
913                    quote_spanned! {ident.span()=>
914                        let #ident = #root::sinktools::for_each(|v| {
915                            #ident.give(Some(v));
916                        });
917                    }
918                });
919
920                let loop_id = self
921                    // All nodes in a subgraph should be in the same loop.
922                    .node_loop(subgraph_nodes[0]);
923
924                let mut subgraph_op_iter_code = Vec::new();
925                let mut subgraph_op_iter_after_code = Vec::new();
926                {
927                    let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
928
929                    let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
930                    let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
931
932                    for (idx, &node_id) in nodes_iter.enumerate() {
933                        let node = &self.nodes[node_id];
934                        assert!(
935                            matches!(node, GraphNode::Operator(_)),
936                            "Handoffs are not part of subgraphs."
937                        );
938                        let op_inst = &self.operator_instances[node_id];
939
940                        let op_span = node.span();
941                        let op_name = op_inst.op_constraints.name;
942                        // Use op's span for root. #root is expected to be correct, any errors should span back to the op gen.
943                        let root = change_spans(root.clone(), op_span);
944                        // TODO(mingwei): Just use `op_inst.op_constraints`?
945                        let op_constraints = OPERATORS
946                            .iter()
947                            .find(|op| op_name == op.name)
948                            .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
949
950                        let ident = self.node_as_ident(node_id, false);
951
952                        {
953                            // TODO clean this up.
954                            // Collect input arguments (predecessors).
955                            let mut input_edges = self
956                                .graph
957                                .predecessor_edges(node_id)
958                                .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
959                                .collect::<Vec<_>>();
960                            // Ensure sorted by port index.
961                            input_edges.sort();
962
963                            let inputs = input_edges
964                                .iter()
965                                .map(|&(_port, edge_id)| {
966                                    let (pred, _) = self.edge(edge_id);
967                                    self.node_as_ident(pred, true)
968                                })
969                                .collect::<Vec<_>>();
970
971                            // Collect output arguments (successors).
972                            let mut output_edges = self
973                                .graph
974                                .successor_edges(node_id)
975                                .map(|edge_id| (&self.ports[edge_id].0, edge_id))
976                                .collect::<Vec<_>>();
977                            // Ensure sorted by port index.
978                            output_edges.sort();
979
980                            let outputs = output_edges
981                                .iter()
982                                .map(|&(_port, edge_id)| {
983                                    let (_, succ) = self.edge(edge_id);
984                                    self.node_as_ident(succ, false)
985                                })
986                                .collect::<Vec<_>>();
987
988                            let is_pull = idx < pull_to_push_idx;
989
990                            let singleton_output_ident = &if op_constraints.has_singleton_output {
991                                self.node_as_singleton_ident(node_id, op_span)
992                            } else {
993                                // This ident *should* go unused.
994                                Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
995                            };
996
997                            // There's a bit of dark magic hidden in `Span`s... you'd think it's just a `file:line:column`,
998                            // but it has one extra bit of info for _name resolution_, used for `Ident`s. `Span::call_site()`
999                            // has the (unhygienic) resolution we want, an ident is just solely determined by its string name,
1000                            // which is what you'd expect out of unhygienic proc macros like this. Meanwhile, declarative macros
1001                            // use `Span::mixed_site()` which is weird and I don't understand it. It turns out that if you call
1002                            // the dfir syntax proc macro from _within_ a declarative macro then `op_span` will have the
1003                            // bad `Span::mixed_site()` name resolution and cause "Cannot find value `df/context`" errors. So
1004                            // we call `.resolved_at()` to fix resolution back to `Span::call_site()`. -Mingwei
1005                            let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1006                            let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1007
1008                            let singletons_resolved =
1009                                self.helper_resolve_singletons(node_id, op_span);
1010                            let arguments = &process_singletons::postprocess_singletons(
1011                                op_inst.arguments_raw.clone(),
1012                                singletons_resolved.clone(),
1013                                context,
1014                            );
1015                            let arguments_handles =
1016                                &process_singletons::postprocess_singletons_handles(
1017                                    op_inst.arguments_raw.clone(),
1018                                    singletons_resolved.clone(),
1019                                );
1020
1021                            let source_tag = 'a: {
1022                                if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1023                                    break 'a tag;
1024                                }
1025
1026                                #[cfg(nightly)]
1027                                if proc_macro::is_available() {
1028                                    let op_span = op_span.unwrap();
1029                                    break 'a format!(
1030                                        "loc_{}_{}_{}_{}_{}",
1031                                        crate::pretty_span::make_source_path_relative(
1032                                            &op_span.file()
1033                                        )
1034                                        .display()
1035                                        .to_string()
1036                                        .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1037                                        op_span.start().line(),
1038                                        op_span.start().column(),
1039                                        op_span.end().line(),
1040                                        op_span.end().column(),
1041                                    );
1042                                }
1043
1044                                format!(
1045                                    "loc_nopath_{}_{}_{}_{}",
1046                                    op_span.start().line,
1047                                    op_span.start().column,
1048                                    op_span.end().line,
1049                                    op_span.end().column
1050                                )
1051                            };
1052
1053                            let work_fn = format_ident!(
1054                                "{}__{}__{}",
1055                                ident,
1056                                op_name,
1057                                source_tag,
1058                                span = op_span
1059                            );
1060                            let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1061
1062                            let context_args = WriteContextArgs {
1063                                root: &root,
1064                                df_ident: df_local,
1065                                context,
1066                                subgraph_id,
1067                                node_id,
1068                                loop_id,
1069                                op_span,
1070                                op_tag: self.operator_tag.get(node_id).cloned(),
1071                                work_fn: &work_fn,
1072                                work_fn_async: &work_fn_async,
1073                                ident: &ident,
1074                                is_pull,
1075                                inputs: &inputs,
1076                                outputs: &outputs,
1077                                singleton_output_ident,
1078                                op_name,
1079                                op_inst,
1080                                arguments,
1081                                arguments_handles,
1082                            };
1083
1084                            let write_result =
1085                                (op_constraints.write_fn)(&context_args, diagnostics);
1086                            let OperatorWriteOutput {
1087                                write_prologue,
1088                                write_prologue_after,
1089                                write_iterator,
1090                                write_iterator_after,
1091                            } = write_result.unwrap_or_else(|()| {
1092                                assert!(
1093                                    diagnostics.iter().any(Diagnostic::is_error),
1094                                    "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1095                                    op_name,
1096                                );
1097                                OperatorWriteOutput { write_iterator: null_write_iterator_fn(&context_args), ..Default::default() }
1098                            });
1099
1100                            op_prologue_code.push(syn::parse_quote! {
1101                                #[allow(non_snake_case)]
1102                                #[inline(always)]
1103                                fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1104                                    thunk()
1105                                }
1106
1107                                #[allow(non_snake_case)]
1108                                #[inline(always)]
1109                                async fn #work_fn_async<T>(thunk: impl ::std::future::Future<Output = T>) -> T {
1110                                    thunk.await
1111                                }
1112                            });
1113                            op_prologue_code.push(write_prologue);
1114                            op_prologue_after_code.push(write_prologue_after);
1115                            subgraph_op_iter_code.push(write_iterator);
1116
1117                            if include_type_guards {
1118                                let type_guard = if is_pull {
1119                                    quote_spanned! {op_span=>
1120                                        let #ident = {
1121                                            #[allow(non_snake_case)]
1122                                            #[inline(always)]
1123                                            pub fn #work_fn<Item, Input: #root::futures::stream::Stream<Item = Item>>(input: Input) -> impl #root::futures::stream::Stream<Item = Item> {
1124                                                #root::pin_project_lite::pin_project! {
1125                                                    #[repr(transparent)]
1126                                                    struct Pull<Item, Input: #root::futures::stream::Stream<Item = Item>> {
1127                                                        #[pin]
1128                                                        inner: Input
1129                                                    }
1130                                                }
1131
1132                                                impl<Item, Input> #root::futures::stream::Stream for Pull<Item, Input>
1133                                                where
1134                                                    Input: #root::futures::stream::Stream<Item = Item>,
1135                                                {
1136                                                    type Item = Item;
1137
1138                                                    #[inline(always)]
1139                                                    fn poll_next(
1140                                                        self: ::std::pin::Pin<&mut Self>,
1141                                                        cx: &mut ::std::task::Context<'_>,
1142                                                    ) -> ::std::task::Poll<::std::option::Option<Self::Item>> {
1143                                                        #root::futures::stream::Stream::poll_next(self.project().inner, cx)
1144                                                    }
1145
1146                                                    #[inline(always)]
1147                                                    fn size_hint(&self) -> (usize, Option<usize>) {
1148                                                        #root::futures::stream::Stream::size_hint(&self.inner)
1149                                                    }
1150                                                }
1151
1152                                                Pull {
1153                                                    inner: input
1154                                                }
1155                                            }
1156                                            #work_fn( #ident )
1157                                        };
1158                                    }
1159                                } else {
1160                                    quote_spanned! {op_span=>
1161                                        let #ident = {
1162                                            #[allow(non_snake_case)]
1163                                            #[inline(always)]
1164                                            pub fn #work_fn<Item, Si>(si: Si) -> impl #root::futures::sink::Sink<Item, Error = #root::Never>
1165                                            where
1166                                                Si: #root::futures::sink::Sink<Item, Error = #root::Never>
1167                                            {
1168                                                #root::pin_project_lite::pin_project! {
1169                                                    #[repr(transparent)]
1170                                                    struct Push<Si> {
1171                                                        #[pin]
1172                                                        si: Si,
1173                                                    }
1174                                                }
1175
1176                                                impl<Item, Si> #root::futures::sink::Sink<Item> for Push<Si>
1177                                                where
1178                                                    Si: #root::futures::sink::Sink<Item>,
1179                                                {
1180                                                    type Error = Si::Error;
1181
1182                                                    fn poll_ready(
1183                                                        self: ::std::pin::Pin<&mut Self>,
1184                                                        cx: &mut ::std::task::Context<'_>,
1185                                                    ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1186                                                        self.project().si.poll_ready(cx)
1187                                                    }
1188
1189                                                    fn start_send(
1190                                                        self: ::std::pin::Pin<&mut Self>,
1191                                                        item: Item,
1192                                                    ) -> ::std::result::Result<(), Self::Error> {
1193                                                        self.project().si.start_send(item)
1194                                                    }
1195
1196                                                    fn poll_flush(
1197                                                        self: ::std::pin::Pin<&mut Self>,
1198                                                        cx: &mut ::std::task::Context<'_>,
1199                                                    ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1200                                                        self.project().si.poll_flush(cx)
1201                                                    }
1202
1203                                                    fn poll_close(
1204                                                        self: ::std::pin::Pin<&mut Self>,
1205                                                        cx: &mut ::std::task::Context<'_>,
1206                                                    ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1207                                                        self.project().si.poll_close(cx)
1208                                                    }
1209                                                }
1210
1211                                                Push {
1212                                                    si
1213                                                }
1214                                            }
1215                                            #work_fn( #ident )
1216                                        };
1217                                    }
1218                                };
1219                                subgraph_op_iter_code.push(type_guard);
1220                            }
1221                            subgraph_op_iter_after_code.push(write_iterator_after);
1222                        }
1223                    }
1224
1225                    {
1226                        // Determine pull and push halves of the `Pivot`.
1227                        let pull_ident = if 0 < pull_to_push_idx {
1228                            self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1229                        } else {
1230                            // Entire subgraph is push (with a single recv/pull handoff input).
1231                            recv_ports[0].clone()
1232                        };
1233
1234                        #[rustfmt::skip]
1235                        let push_ident = if let Some(&node_id) =
1236                            subgraph_nodes.get(pull_to_push_idx)
1237                        {
1238                            self.node_as_ident(node_id, false)
1239                        } else if 1 == send_ports.len() {
1240                            // Entire subgraph is pull (with a single send/push handoff output).
1241                            send_ports[0].clone()
1242                        } else {
1243                            diagnostics.push(Diagnostic::spanned(
1244                                pull_ident.span(),
1245                                Level::Error,
1246                                "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1247                            ));
1248                            continue;
1249                        };
1250
1251                        // Pivot span is combination of pull and push spans (or if not possible, just take the push).
1252                        let pivot_span = pull_ident
1253                            .span()
1254                            .join(push_ident.span())
1255                            .unwrap_or_else(|| push_ident.span());
1256                        let pivot_fn_ident =
1257                            Ident::new(&format!("pivot_run_sg_{:?}", subgraph_id.0), pivot_span);
1258                        let root = change_spans(root.clone(), pivot_span);
1259                        subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1260                            #[inline(always)]
1261                            fn #pivot_fn_ident<Pull, Push, Item>(pull: Pull, push: Push)
1262                                -> impl ::std::future::Future<Output = ::std::result::Result<(), #root::Never>>
1263                            where
1264                                Pull: #root::futures::stream::Stream<Item = Item>,
1265                                Push: #root::futures::sink::Sink<Item, Error = #root::Never>,
1266                            {
1267                                #root::sinktools::send_stream(pull, push)
1268                            }
1269                            (#pivot_fn_ident)(#pull_ident, #push_ident).await.unwrap(/* Never */);
1270                        });
1271                    }
1272                };
1273
1274                let subgraph_name = Literal::string(&format!("Subgraph {:?}", subgraph_id));
1275                let stratum = Literal::usize_unsuffixed(
1276                    self.subgraph_stratum.get(subgraph_id).cloned().unwrap_or(0),
1277                );
1278                let laziness = self.subgraph_laziness(subgraph_id);
1279
1280                // Codegen: the loop that this subgraph is in `Some(<loop_id>)`, or `None` if not in a loop.
1281                let loop_id_opt = loop_id
1282                    .map(|loop_id| loop_id.as_ident(Span::call_site()))
1283                    .map(|ident| quote! { Some(#ident) })
1284                    .unwrap_or_else(|| quote! { None });
1285
1286                let sg_ident = subgraph_id.as_ident(Span::call_site());
1287
1288                subgraphs.push(quote! {
1289                    let #sg_ident = #df.add_subgraph_full(
1290                        #subgraph_name,
1291                        #stratum,
1292                        var_expr!( #( #recv_ports ),* ),
1293                        var_expr!( #( #send_ports ),* ),
1294                        #laziness,
1295                        #loop_id_opt,
1296                        async move |#context, var_args!( #( #recv_ports ),* ), var_args!( #( #send_ports ),* )| {
1297                            #( #recv_port_code )*
1298                            #( #send_port_code )*
1299                            #( #subgraph_op_iter_code )*
1300                            #( #subgraph_op_iter_after_code )*
1301                        },
1302                    );
1303                });
1304            }
1305        }
1306
1307        let loop_code = self.codegen_nested_loops(&df);
1308
1309        // These two are quoted separately here because iterators are lazily evaluated, so this
1310        // forces them to do their work. This work includes populating some data, namely
1311        // `diagonstics`, which we need to determine if it compilation was actually successful.
1312        // -Mingwei
1313        let code = quote! {
1314            #( #handoff_code )*
1315            #loop_code
1316            #( #op_prologue_code )*
1317            #( #subgraphs )*
1318            #( #op_prologue_after_code )*
1319        };
1320
1321        let meta_graph_json = serde_json::to_string(&self).unwrap();
1322        let meta_graph_json = Literal::string(&meta_graph_json);
1323
1324        let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1325        let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1326        let diagnostics_json = Literal::string(&diagnostics_json);
1327
1328        quote! {
1329            {
1330                #[allow(unused_qualifications, clippy::await_holding_refcell_ref)]
1331                {
1332                    #prefix
1333
1334                    use #root::{var_expr, var_args};
1335
1336                    let mut #df = #root::scheduled::graph::Dfir::new();
1337                    #df.__assign_meta_graph(#meta_graph_json);
1338                    #df.__assign_diagnostics(#diagnostics_json);
1339
1340                    #code
1341
1342                    #df
1343                }
1344            }
1345        }
1346    }
1347
1348    /// Color mode (pull vs. push, handoff vs. comp) for nodes. Some nodes can be push *OR* pull;
1349    /// those nodes will not be set in the returned map.
1350    pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1351        let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1352            .node_ids()
1353            .filter_map(|node_id| {
1354                let op_color = self.node_color(node_id)?;
1355                Some((node_id, op_color))
1356            })
1357            .collect();
1358
1359        // Fill in rest via subgraphs.
1360        for sg_nodes in self.subgraph_nodes.values() {
1361            let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1362
1363            for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1364                let is_pull = idx < pull_to_push_idx;
1365                node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1366            }
1367        }
1368
1369        node_color_map
1370    }
1371
1372    /// Writes this graph as mermaid into a string.
1373    pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1374        let mut output = String::new();
1375        self.write_mermaid(&mut output, write_config).unwrap();
1376        output
1377    }
1378
1379    /// Writes this graph as mermaid into the given `Write`.
1380    pub fn write_mermaid(
1381        &self,
1382        output: impl std::fmt::Write,
1383        write_config: &WriteConfig,
1384    ) -> std::fmt::Result {
1385        let mut graph_write = Mermaid::new(output);
1386        self.write_graph(&mut graph_write, write_config)
1387    }
1388
1389    /// Writes this graph as DOT (graphviz) into a string.
1390    pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1391        let mut output = String::new();
1392        let mut graph_write = Dot::new(&mut output);
1393        self.write_graph(&mut graph_write, write_config).unwrap();
1394        output
1395    }
1396
1397    /// Writes this graph as DOT (graphviz) into the given `Write`.
1398    pub fn write_dot(
1399        &self,
1400        output: impl std::fmt::Write,
1401        write_config: &WriteConfig,
1402    ) -> std::fmt::Result {
1403        let mut graph_write = Dot::new(output);
1404        self.write_graph(&mut graph_write, write_config)
1405    }
1406
1407    /// Write out this graph using the given `GraphWrite`. E.g. `Mermaid` or `Dot.
1408    pub(crate) fn write_graph<W>(
1409        &self,
1410        mut graph_write: W,
1411        write_config: &WriteConfig,
1412    ) -> Result<(), W::Err>
1413    where
1414        W: GraphWrite,
1415    {
1416        fn helper_edge_label(
1417            src_port: &PortIndexValue,
1418            dst_port: &PortIndexValue,
1419        ) -> Option<String> {
1420            let src_label = match src_port {
1421                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1422                PortIndexValue::Int(index) => Some(index.value.to_string()),
1423                _ => None,
1424            };
1425            let dst_label = match dst_port {
1426                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1427                PortIndexValue::Int(index) => Some(index.value.to_string()),
1428                _ => None,
1429            };
1430            let label = match (src_label, dst_label) {
1431                (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1432                (Some(l1), None) => Some(l1),
1433                (None, Some(l2)) => Some(l2),
1434                (None, None) => None,
1435            };
1436            label
1437        }
1438
1439        // Make node color map one time.
1440        let node_color_map = self.node_color_map();
1441
1442        // Write prologue.
1443        graph_write.write_prologue()?;
1444
1445        // Define nodes.
1446        let mut skipped_handoffs = BTreeSet::new();
1447        let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1448        for (node_id, node) in self.nodes() {
1449            if matches!(node, GraphNode::Handoff { .. }) {
1450                if write_config.no_handoffs {
1451                    skipped_handoffs.insert(node_id);
1452                    continue;
1453                } else {
1454                    let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1455                    let pred_sg = self.node_subgraph(pred_node);
1456                    let succ_node = self.node_successor_nodes(node_id).next().unwrap();
1457                    let succ_sg = self.node_subgraph(succ_node);
1458                    if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1459                        && pred_sg == succ_sg
1460                    {
1461                        subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1462                    }
1463                }
1464            }
1465            graph_write.write_node_definition(
1466                node_id,
1467                &if write_config.op_short_text {
1468                    node.to_name_string()
1469                } else if write_config.op_text_no_imports {
1470                    // Remove any lines that start with "use" (imports)
1471                    let full_text = node.to_pretty_string();
1472                    let mut output = String::new();
1473                    for sentence in full_text.split('\n') {
1474                        if sentence.trim().starts_with("use") {
1475                            continue;
1476                        }
1477                        output.push('\n');
1478                        output.push_str(sentence);
1479                    }
1480                    output.into()
1481                } else {
1482                    node.to_pretty_string()
1483                },
1484                if write_config.no_pull_push {
1485                    None
1486                } else {
1487                    node_color_map.get(node_id).copied()
1488                },
1489            )?;
1490        }
1491
1492        // Write edges.
1493        for (edge_id, (src_id, mut dst_id)) in self.edges() {
1494            // Handling for if `write_config.no_handoffs` true.
1495            if skipped_handoffs.contains(&src_id) {
1496                continue;
1497            }
1498
1499            let (src_port, mut dst_port) = self.edge_ports(edge_id);
1500            if skipped_handoffs.contains(&dst_id) {
1501                let mut handoff_succs = self.node_successors(dst_id);
1502                assert_eq!(1, handoff_succs.len());
1503                let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1504                dst_id = succ_node;
1505                dst_port = self.edge_ports(succ_edge).1;
1506            }
1507
1508            let label = helper_edge_label(src_port, dst_port);
1509            let delay_type = self
1510                .node_op_inst(dst_id)
1511                .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1512            graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1513        }
1514
1515        // Write reference edges.
1516        if !write_config.no_references {
1517            for dst_id in self.node_ids() {
1518                for src_ref_id in self
1519                    .node_singleton_references(dst_id)
1520                    .iter()
1521                    .copied()
1522                    .flatten()
1523                {
1524                    let delay_type = Some(DelayType::Stratum);
1525                    let label = None;
1526                    graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1527                }
1528            }
1529        }
1530
1531        // The following code is a little bit tricky. Generally, the graph has the hierarchy:
1532        // `loop -> subgraph -> varname -> node`. However, each of these can be disabled via the `write_config`. To
1533        // handle both the enabled and disabled case, this code is structured as a series of nested loops. If the layer
1534        // is disabled, then the HashMap<Option<KEY>, Vec<VALUE>> will only have a single key (`None`) with a
1535        // corresponding `Vec` value containing everything. This way no special handling is needed for the next layer.
1536        //
1537        // (Note: `stratum` could also be included in this hierarchy, but it is being phased-out/deprecated in favor of
1538        // Flo loops).
1539
1540        // Loop -> Subgraphs
1541        let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1542            let loop_id = if write_config.no_loops {
1543                None
1544            } else {
1545                self.subgraph_loop(sg_id)
1546            };
1547            (loop_id, sg_id)
1548        });
1549        let loop_subgraphs = into_group_map(loop_subgraphs);
1550        for (loop_id, subgraph_ids) in loop_subgraphs {
1551            if let Some(loop_id) = loop_id {
1552                graph_write.write_loop_start(loop_id)?;
1553            }
1554
1555            // Subgraph -> Varnames.
1556            let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
1557                self.subgraph(sg_id).iter().copied().map(move |node_id| {
1558                    let opt_sg_id = if write_config.no_subgraphs {
1559                        None
1560                    } else {
1561                        Some(sg_id)
1562                    };
1563                    (opt_sg_id, (self.node_varname(node_id), node_id))
1564                })
1565            });
1566            let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
1567            for (sg_id, varnames) in subgraph_varnames_nodes {
1568                if let Some(sg_id) = sg_id {
1569                    let stratum = self.subgraph_stratum(sg_id).unwrap();
1570                    graph_write.write_subgraph_start(sg_id, stratum)?;
1571                }
1572
1573                // Varnames -> Nodes.
1574                let varname_nodes = varnames.into_iter().map(|(varname, node)| {
1575                    let varname = if write_config.no_varnames {
1576                        None
1577                    } else {
1578                        varname
1579                    };
1580                    (varname, node)
1581                });
1582                let varname_nodes = into_group_map(varname_nodes);
1583                for (varname, node_ids) in varname_nodes {
1584                    if let Some(varname) = varname {
1585                        graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
1586                    }
1587
1588                    // Write all nodes.
1589                    for node_id in node_ids {
1590                        graph_write.write_node(node_id)?;
1591                    }
1592
1593                    if varname.is_some() {
1594                        graph_write.write_varname_end()?;
1595                    }
1596                }
1597
1598                if sg_id.is_some() {
1599                    graph_write.write_subgraph_end()?;
1600                }
1601            }
1602
1603            if loop_id.is_some() {
1604                graph_write.write_loop_end()?;
1605            }
1606        }
1607
1608        // Write epilogue.
1609        graph_write.write_epilogue()?;
1610
1611        Ok(())
1612    }
1613
1614    /// Convert back into surface syntax.
1615    pub fn surface_syntax_string(&self) -> String {
1616        let mut string = String::new();
1617        self.write_surface_syntax(&mut string).unwrap();
1618        string
1619    }
1620
1621    /// Convert back into surface syntax.
1622    pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1623        for (key, node) in self.nodes.iter() {
1624            match node {
1625                GraphNode::Operator(op) => {
1626                    writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
1627                }
1628                GraphNode::Handoff { .. } => {
1629                    writeln!(write, "// {:?} = <handoff>;", key.data())?;
1630                }
1631                GraphNode::ModuleBoundary { .. } => panic!(),
1632            }
1633        }
1634        writeln!(write)?;
1635        for (_e, (src_key, dst_key)) in self.graph.edges() {
1636            writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
1637        }
1638        Ok(())
1639    }
1640
1641    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
1642    pub fn mermaid_string_flat(&self) -> String {
1643        let mut string = String::new();
1644        self.write_mermaid_flat(&mut string).unwrap();
1645        string
1646    }
1647
1648    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
1649    pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1650        writeln!(write, "flowchart TB")?;
1651        for (key, node) in self.nodes.iter() {
1652            match node {
1653                GraphNode::Operator(operator) => writeln!(
1654                    write,
1655                    "    %% {span}\n    {id:?}[\"{row_col} <tt>{code}</tt>\"]",
1656                    span = PrettySpan(node.span()),
1657                    id = key.data(),
1658                    row_col = PrettyRowCol(node.span()),
1659                    code = operator
1660                        .to_token_stream()
1661                        .to_string()
1662                        .replace('&', "&amp;")
1663                        .replace('<', "&lt;")
1664                        .replace('>', "&gt;")
1665                        .replace('"', "&quot;")
1666                        .replace('\n', "<br>"),
1667                ),
1668                GraphNode::Handoff { .. } => {
1669                    writeln!(write, r#"    {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
1670                }
1671                GraphNode::ModuleBoundary { .. } => {
1672                    writeln!(
1673                        write,
1674                        r#"    {:?}{{"{}"}}"#,
1675                        key.data(),
1676                        MODULE_BOUNDARY_NODE_STR
1677                    )
1678                }
1679            }?;
1680        }
1681        writeln!(write)?;
1682        for (_e, (src_key, dst_key)) in self.graph.edges() {
1683            writeln!(write, "    {:?}-->{:?}", src_key.data(), dst_key.data())?;
1684        }
1685        Ok(())
1686    }
1687}
1688
1689/// Loops
1690impl DfirGraph {
1691    /// Iterator over all loop IDs.
1692    pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
1693        self.loop_nodes.keys()
1694    }
1695
1696    /// Iterator over all loops, ID and members: `(GraphLoopId, Vec<GraphNodeId>)`.
1697    pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
1698        self.loop_nodes.iter()
1699    }
1700
1701    /// Create a new loop context, with the given parent loop (or `None`).
1702    pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
1703        let loop_id = self.loop_nodes.insert(Vec::new());
1704        self.loop_children.insert(loop_id, Vec::new());
1705        if let Some(parent_loop) = parent_loop {
1706            self.loop_parent.insert(loop_id, parent_loop);
1707            self.loop_children
1708                .get_mut(parent_loop)
1709                .unwrap()
1710                .push(loop_id);
1711        } else {
1712            self.root_loops.push(loop_id);
1713        }
1714        loop_id
1715    }
1716
1717    /// Get a node's loop context (or `None` for root).
1718    pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
1719        self.node_loops.get(node_id).copied()
1720    }
1721
1722    /// Get a subgraph's loop context (or `None` for root).
1723    pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
1724        let &node_id = self.subgraph(subgraph_id).first().unwrap();
1725        let out = self.node_loop(node_id);
1726        debug_assert!(
1727            self.subgraph(subgraph_id)
1728                .iter()
1729                .all(|&node_id| self.node_loop(node_id) == out),
1730            "Subgraph nodes should all have the same loop context."
1731        );
1732        out
1733    }
1734
1735    /// Get a loop context's parent loop context (or `None` for root).
1736    pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
1737        self.loop_parent.get(loop_id).copied()
1738    }
1739
1740    /// Get a loop context's child loops.
1741    pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
1742        self.loop_children.get(loop_id).unwrap()
1743    }
1744}
1745
1746/// Configuration for writing graphs.
1747#[derive(Clone, Debug, Default)]
1748#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
1749pub struct WriteConfig {
1750    /// Subgraphs will not be rendered if set.
1751    #[cfg_attr(feature = "clap-derive", arg(long))]
1752    pub no_subgraphs: bool,
1753    /// Variable names will not be rendered if set.
1754    #[cfg_attr(feature = "clap-derive", arg(long))]
1755    pub no_varnames: bool,
1756    /// Will not render pull/push shapes if set.
1757    #[cfg_attr(feature = "clap-derive", arg(long))]
1758    pub no_pull_push: bool,
1759    /// Will not render handoffs if set.
1760    #[cfg_attr(feature = "clap-derive", arg(long))]
1761    pub no_handoffs: bool,
1762    /// Will not render singleton references if set.
1763    #[cfg_attr(feature = "clap-derive", arg(long))]
1764    pub no_references: bool,
1765    /// Will not render loops if set.
1766    #[cfg_attr(feature = "clap-derive", arg(long))]
1767    pub no_loops: bool,
1768
1769    /// Op text will only be their name instead of the whole source.
1770    #[cfg_attr(feature = "clap-derive", arg(long))]
1771    pub op_short_text: bool,
1772    /// Op text will exclude any line that starts with "use".
1773    #[cfg_attr(feature = "clap-derive", arg(long))]
1774    pub op_text_no_imports: bool,
1775}
1776
1777/// Enum for choosing between mermaid and dot graph writing.
1778#[derive(Copy, Clone, Debug)]
1779#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
1780pub enum WriteGraphType {
1781    /// Mermaid graphs.
1782    Mermaid,
1783    /// Dot (Graphviz) graphs.
1784    Dot,
1785}
1786
1787/// [`itertools::Itertools::into_group_map`], but for `BTreeMap`.
1788fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
1789where
1790    K: Ord,
1791{
1792    let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
1793    for (k, v) in iter {
1794        out.entry(k).or_default().push(v);
1795    }
1796    out
1797}