Skip to main content

dfir_lang/graph/
flat_graph_builder.rs

1//! Build a flat graph from [`HfStatement`]s.
2
3use std::borrow::Cow;
4use std::collections::btree_map::Entry;
5use std::collections::{BTreeMap, BTreeSet};
6
7use itertools::Itertools;
8use proc_macro2::Span;
9use quote::ToTokens;
10use syn::spanned::Spanned;
11use syn::{Error, Ident, ItemUse};
12
13use super::ops::next_iteration::NEXT_ITERATION;
14use super::ops::{FloType, Persistence};
15use super::{DfirGraph, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId, PortIndexValue};
16use crate::diagnostic::{Diagnostic, Diagnostics, Level};
17use crate::graph::graph_algorithms;
18use crate::graph::ops::{PortListSpec, RangeTrait};
19use crate::parse::{DfirCode, DfirStatement, Operator, Pipeline};
20use crate::pretty_span::PrettySpan;
21
22#[derive(Clone, Debug)]
23struct Ends {
24    inn: Option<(PortIndexValue, GraphDet)>,
25    out: Option<(PortIndexValue, GraphDet)>,
26}
27
28#[derive(Clone, Debug)]
29enum GraphDet {
30    Determined(GraphNodeId),
31    Undetermined(Ident),
32}
33
34/// Variable name info for each ident, see [`FlatGraphBuilder::varname_ends`].
35#[derive(Debug)]
36struct VarnameInfo {
37    /// What the variable name resolves to.
38    pub ends: Ends,
39    /// Set to true if the varname reference creates an illegal self-referential cycle.
40    pub illegal_cycle: bool,
41    /// Set to true once the in port is used. Used to track unused ports.
42    pub inn_used: bool,
43    /// Set to true once the out port is used. Used to track unused ports.
44    pub out_used: bool,
45}
46impl VarnameInfo {
47    pub fn new(ends: Ends) -> Self {
48        Self {
49            ends,
50            illegal_cycle: false,
51            inn_used: false,
52            out_used: false,
53        }
54    }
55}
56
57/// Wraper around [`DfirGraph`] to build a flat graph from AST code.
58#[derive(Debug, Default)]
59pub struct FlatGraphBuilder {
60    /// Spanned error/warning/etc diagnostics to emit.
61    diagnostics: Diagnostics,
62
63    /// [`DfirGraph`] being built.
64    flat_graph: DfirGraph,
65    /// Variable names, used as [`HfStatement::Named`] are added.
66    varname_ends: BTreeMap<Ident, VarnameInfo>,
67    /// Each (out -> inn) link inputted.
68    links: Vec<Ends>,
69
70    /// Use statements.
71    uses: Vec<ItemUse>,
72
73    /// If the flat graph is being loaded as a module, then two initial ModuleBoundary nodes are inserted into the graph. One
74    /// for the input into the module and one for the output out of the module.
75    module_boundary_nodes: Option<(GraphNodeId, GraphNodeId)>,
76}
77
78/// Output of [`FlatGraphBuilder::build`].
79pub struct FlatGraphBuilderOutput {
80    /// The flat DFIR graph.
81    pub flat_graph: DfirGraph,
82    /// Any `use` statements.
83    pub uses: Vec<ItemUse>,
84    /// Any emitted diagnostics (warnings, errors).
85    pub diagnostics: Diagnostics,
86}
87
88impl FlatGraphBuilder {
89    /// Create a new empty graph builder.
90    pub fn new() -> Self {
91        Default::default()
92    }
93
94    /// Convert the DFIR code AST into a graph builder.
95    pub fn from_dfir(input: DfirCode) -> Self {
96        let mut builder = Self::default();
97        builder.add_dfir(input, None, None);
98        builder
99    }
100
101    /// Build into an unpartitioned [`DfirGraph`], returning a struct containing the flat graph, any diagnostics, and
102    /// other outputs.
103    ///
104    /// If any diagnostics are errors, `Err` is returned and the underlying graph is lost.
105    pub fn build(mut self) -> Result<FlatGraphBuilderOutput, Diagnostics> {
106        self.finalize_connect_operator_links();
107        self.process_operator_errors();
108
109        if self.diagnostics.has_error() {
110            Err(self.diagnostics)
111        } else {
112            Ok(FlatGraphBuilderOutput {
113                flat_graph: self.flat_graph,
114                uses: self.uses,
115                diagnostics: self.diagnostics,
116            })
117        }
118    }
119
120    /// Adds all [`DfirStatement`]s within the [`DfirCode`] to this [`DfirGraph`].
121    ///
122    /// Optional configuration:
123    /// * In the given loop context `current_loop`.
124    /// * With the given operator tag `operator_tag`.
125    pub fn add_dfir(
126        &mut self,
127        dfir: DfirCode,
128        current_loop: Option<GraphLoopId>,
129        operator_tag: Option<&str>,
130    ) {
131        for stmt in dfir.statements {
132            self.add_statement_internal(stmt, current_loop, operator_tag);
133        }
134    }
135
136    /// Add a single [`DfirStatement`] line to this [`DfirGraph`] in the root context.
137    pub fn add_statement(&mut self, stmt: DfirStatement) {
138        self.add_statement_internal(stmt, None, None);
139    }
140
141    /// Add a single [`DfirStatement`] line to this [`DfirGraph`] with given configuration.
142    ///
143    /// Optional configuration:
144    /// * In the given loop context `current_loop`.
145    /// * With the given operator tag `operator_tag`.
146    fn add_statement_internal(
147        &mut self,
148        stmt: DfirStatement,
149        current_loop: Option<GraphLoopId>,
150        operator_tag: Option<&str>,
151    ) {
152        match stmt {
153            DfirStatement::Use(yuse) => {
154                self.uses.push(yuse);
155            }
156            DfirStatement::Named(named) => {
157                let stmt_span = named.span();
158                let ends = self.add_pipeline(
159                    named.pipeline,
160                    Some(&named.name),
161                    current_loop,
162                    operator_tag,
163                );
164                self.assign_varname_checked(named.name, stmt_span, ends);
165            }
166            DfirStatement::Pipeline(pipeline_stmt) => {
167                let ends =
168                    self.add_pipeline(pipeline_stmt.pipeline, None, current_loop, operator_tag);
169                Self::helper_check_unused_port(&mut self.diagnostics, &ends, true);
170                Self::helper_check_unused_port(&mut self.diagnostics, &ends, false);
171            }
172            DfirStatement::Loop(loop_statement) => {
173                let inner_loop = self.flat_graph.insert_loop(current_loop);
174                for stmt in loop_statement.statements {
175                    self.add_statement_internal(stmt, Some(inner_loop), operator_tag);
176                }
177            }
178        }
179    }
180
181    /// Programatically add an pipeline, optionally adding `pred_name` as a single predecessor and
182    /// assigning it all to `asgn_name`.
183    ///
184    /// In DFIR syntax, equivalent to [`Self::add_statement`] of (if all names are supplied):
185    /// ```text
186    /// #asgn_name = #pred_name -> #pipeline;
187    /// ```
188    ///
189    /// But with, optionally:
190    /// * A `current_loop` to put the operator in.
191    /// * An `operator_tag` to tag the operator with, for debugging/tracing.
192    pub fn append_assign_pipeline(
193        &mut self,
194        asgn_name: Option<&Ident>,
195        pred_name: Option<&Ident>,
196        pipeline: Pipeline,
197        current_loop: Option<GraphLoopId>,
198        operator_tag: Option<&str>,
199    ) {
200        let span = pipeline.span();
201        let mut ends = self.add_pipeline(pipeline, asgn_name, current_loop, operator_tag);
202
203        // Connect `pred_name` if supplied.
204        if let Some(pred_name) = pred_name {
205            if let Some(pred_varname_info) = self.varname_ends.get(pred_name) {
206                // Update ends for `asgn_name`.
207                ends = self.connect_ends(pred_varname_info.ends.clone(), ends);
208            } else {
209                self.diagnostics.push(Diagnostic::spanned(
210                    pred_name.span(),
211                    Level::Error,
212                    format!(
213                        "Cannot find referenced name `{}`; name was never assigned.",
214                        pred_name
215                    ),
216                ));
217            }
218        }
219
220        // Assign `asgn_name` if supplied.
221        if let Some(asgn_name) = asgn_name {
222            self.assign_varname_checked(asgn_name.clone(), span, ends);
223        }
224    }
225}
226
227/// Internal methods.
228impl FlatGraphBuilder {
229    /// Assign a variable name to a pipeline, checking for conflicts.
230    fn assign_varname_checked(&mut self, name: Ident, stmt_span: Span, ends: Ends) {
231        match self.varname_ends.entry(name) {
232            Entry::Vacant(vacant_entry) => {
233                vacant_entry.insert(VarnameInfo::new(ends));
234            }
235            Entry::Occupied(occupied_entry) => {
236                let prev_conflict = occupied_entry.key();
237                self.diagnostics.push(Diagnostic::spanned(
238                    prev_conflict.span(),
239                    Level::Error,
240                    format!(
241                        "Existing assignment to `{}` conflicts with later assignment: {} (1/2)",
242                        prev_conflict,
243                        PrettySpan(stmt_span),
244                    ),
245                ));
246                self.diagnostics.push(Diagnostic::spanned(
247                    stmt_span,
248                    Level::Error,
249                    format!(
250                        "Name assignment to `{}` conflicts with existing assignment: {} (2/2)",
251                        prev_conflict,
252                        PrettySpan(prev_conflict.span())
253                    ),
254                ));
255            }
256        }
257    }
258
259    /// Helper: Add a pipeline, i.e. `a -> b -> c`. Return the input and output [`Ends`] for it.
260    fn add_pipeline(
261        &mut self,
262        pipeline: Pipeline,
263        current_varname: Option<&Ident>,
264        current_loop: Option<GraphLoopId>,
265        operator_tag: Option<&str>,
266    ) -> Ends {
267        match pipeline {
268            Pipeline::Paren(ported_pipeline_paren) => {
269                let (inn_port, pipeline_paren, out_port) =
270                    PortIndexValue::from_ported(ported_pipeline_paren);
271                let og_ends = self.add_pipeline(
272                    *pipeline_paren.pipeline,
273                    current_varname,
274                    current_loop,
275                    operator_tag,
276                );
277                Self::helper_combine_ends(&mut self.diagnostics, og_ends, inn_port, out_port)
278            }
279            Pipeline::Name(pipeline_name) => {
280                let (inn_port, ident, out_port) = PortIndexValue::from_ported(pipeline_name);
281
282                // Mingwei: We could lookup non-forward references immediately, but easier to just
283                // have one consistent code path: `GraphDet::Undetermined`.
284                Ends {
285                    inn: Some((inn_port, GraphDet::Undetermined(ident.clone()))),
286                    out: Some((out_port, GraphDet::Undetermined(ident))),
287                }
288            }
289            Pipeline::ModuleBoundary(pipeline_name) => {
290                let Some((input_node, output_node)) = self.module_boundary_nodes else {
291                    self.diagnostics.push(
292                        Error::new(
293                            pipeline_name.span(),
294                            "`mod` is only usable inside of a module.",
295                        )
296                        .into(),
297                    );
298
299                    return Ends {
300                        inn: None,
301                        out: None,
302                    };
303                };
304
305                let (inn_port, _, out_port) = PortIndexValue::from_ported(pipeline_name);
306
307                Ends {
308                    inn: Some((inn_port, GraphDet::Determined(output_node))),
309                    out: Some((out_port, GraphDet::Determined(input_node))),
310                }
311            }
312            Pipeline::Link(pipeline_link) => {
313                // Add the nested LHS and RHS of this link.
314                let lhs_ends = self.add_pipeline(
315                    *pipeline_link.lhs,
316                    current_varname,
317                    current_loop,
318                    operator_tag,
319                );
320                let rhs_ends = self.add_pipeline(
321                    *pipeline_link.rhs,
322                    current_varname,
323                    current_loop,
324                    operator_tag,
325                );
326
327                self.connect_ends(lhs_ends, rhs_ends)
328            }
329            Pipeline::Operator(operator) => {
330                let op_span = Some(operator.span());
331                let (node_id, ends) =
332                    self.add_operator(current_varname, current_loop, operator, op_span);
333                if let Some(operator_tag) = operator_tag {
334                    self.flat_graph
335                        .set_operator_tag(node_id, operator_tag.to_owned());
336                }
337                ends
338            }
339        }
340    }
341
342    /// Connects two [`Ends`] together. Returns the outer [`Ends`] for the connection.
343    ///
344    /// Links the inner ends together by adding it to `self.links`.
345    fn connect_ends(&mut self, lhs_ends: Ends, rhs_ends: Ends) -> Ends {
346        // Outer (first and last) ends.
347        let outer_ends = Ends {
348            inn: lhs_ends.inn,
349            out: rhs_ends.out,
350        };
351        // Inner (link) ends.
352        let link_ends = Ends {
353            out: lhs_ends.out,
354            inn: rhs_ends.inn,
355        };
356        self.links.push(link_ends);
357        outer_ends
358    }
359
360    /// Adds an operator to the graph, returning its [`GraphNodeId`] the input and output [`Ends`] for it.
361    fn add_operator(
362        &mut self,
363        current_varname: Option<&Ident>,
364        current_loop: Option<GraphLoopId>,
365        operator: Operator,
366        op_span: Option<Span>,
367    ) -> (GraphNodeId, Ends) {
368        let node_id = self.flat_graph.insert_node(
369            GraphNode::Operator(operator),
370            current_varname.cloned(),
371            current_loop,
372        );
373        let ends = Ends {
374            inn: Some((
375                PortIndexValue::Elided(op_span),
376                GraphDet::Determined(node_id),
377            )),
378            out: Some((
379                PortIndexValue::Elided(op_span),
380                GraphDet::Determined(node_id),
381            )),
382        };
383        (node_id, ends)
384    }
385
386    /// Connects operator links as a final building step. Processes all the links stored in
387    /// `self.links` and actually puts them into the graph.
388    fn finalize_connect_operator_links(&mut self) {
389        // `->` edges
390        for Ends { out, inn } in std::mem::take(&mut self.links) {
391            let out_opt = self.helper_resolve_name(out, false);
392            let inn_opt = self.helper_resolve_name(inn, true);
393            // `None` already have errors in `self.diagnostics`.
394            if let (Some((out_port, out_node)), Some((inn_port, inn_node))) = (out_opt, inn_opt) {
395                let _ = self.finalize_connect_operators(out_port, out_node, inn_port, inn_node);
396            }
397        }
398
399        // Resolve the singleton references for each node.
400        for node_id in self.flat_graph.node_ids().collect::<Vec<_>>() {
401            if let GraphNode::Operator(operator) = self.flat_graph.node(node_id) {
402                let singletons_referenced = operator
403                    .singletons_referenced
404                    .clone()
405                    .into_iter()
406                    .map(|singleton_ref| {
407                        let port_det = self
408                            .varname_ends
409                            .get(&singleton_ref)
410                            .filter(|varname_info| !varname_info.illegal_cycle)
411                            .map(|varname_info| &varname_info.ends)
412                            .and_then(|ends| ends.out.as_ref())
413                            .cloned();
414                        if let Some((_port, node_id)) = self.helper_resolve_name(port_det, false) {
415                            Some(node_id)
416                        } else {
417                            self.diagnostics.push(Diagnostic::spanned(
418                                singleton_ref.span(),
419                                Level::Error,
420                                format!(
421                                    "Cannot find referenced name `{}`; name was never assigned.",
422                                    singleton_ref
423                                ),
424                            ));
425                            None
426                        }
427                    })
428                    .collect();
429
430                self.flat_graph
431                    .set_node_singleton_references(node_id, singletons_referenced);
432            }
433        }
434    }
435
436    /// Recursively resolve a variable name. For handling forward (and backward) name references
437    /// after all names have been assigned.
438    /// Returns `None` if the name is not resolvable, either because it was never assigned or
439    /// because it contains a self-referential cycle.
440    ///
441    /// `is_in` set to `true` means the _input_ side will be returned. `false` means the _output_ side will be returned.
442    fn helper_resolve_name(
443        &mut self,
444        mut port_det: Option<(PortIndexValue, GraphDet)>,
445        is_in: bool,
446    ) -> Option<(PortIndexValue, GraphNodeId)> {
447        const BACKUP_RECURSION_LIMIT: usize = 1024;
448
449        let mut names = Vec::new();
450        for _ in 0..BACKUP_RECURSION_LIMIT {
451            match port_det? {
452                (port, GraphDet::Determined(node_id)) => {
453                    return Some((port, node_id));
454                }
455                (port, GraphDet::Undetermined(ident)) => {
456                    let Some(varname_info) = self.varname_ends.get_mut(&ident) else {
457                        self.diagnostics.push(Diagnostic::spanned(
458                            ident.span(),
459                            Level::Error,
460                            format!("Cannot find name `{}`; name was never assigned.", ident),
461                        ));
462                        return None;
463                    };
464                    // Check for a self-referential cycle.
465                    let cycle_found = names.contains(&ident);
466                    if !cycle_found {
467                        names.push(ident);
468                    };
469                    if cycle_found || varname_info.illegal_cycle {
470                        let len = names.len();
471                        for (i, name) in names.into_iter().enumerate() {
472                            self.diagnostics.push(Diagnostic::spanned(
473                                name.span(),
474                                Level::Error,
475                                format!(
476                                    "Name `{}` forms or references an illegal self-referential cycle ({}/{}).",
477                                    name,
478                                    i + 1,
479                                    len
480                                ),
481                            ));
482                            // Set value as `Err(())` to trigger `name_ends_result.is_err()`
483                            // diagnostics above if the name is referenced in the future.
484                            self.varname_ends.get_mut(&name).unwrap().illegal_cycle = true;
485                        }
486                        return None;
487                    }
488
489                    // No self-cycle.
490                    let prev = if is_in {
491                        varname_info.inn_used = true;
492                        &varname_info.ends.inn
493                    } else {
494                        varname_info.out_used = true;
495                        &varname_info.ends.out
496                    };
497                    port_det = Self::helper_combine_end(
498                        &mut self.diagnostics,
499                        prev.clone(),
500                        port,
501                        if is_in { "input" } else { "output" },
502                    );
503                }
504            }
505        }
506        self.diagnostics.push(Diagnostic::spanned(
507            Span::call_site(),
508            Level::Error,
509            format!(
510                "Reached the recursion limit {} while resolving names. This is either a dfir bug or you have an absurdly long chain of names: `{}`.",
511                BACKUP_RECURSION_LIMIT,
512                names.iter().map(ToString::to_string).collect::<Vec<_>>().join("` -> `"),
513            )
514        ));
515        None
516    }
517
518    /// Connect two operators on the given port indexes.
519    fn finalize_connect_operators(
520        &mut self,
521        src_port: PortIndexValue,
522        src: GraphNodeId,
523        dst_port: PortIndexValue,
524        dst: GraphNodeId,
525    ) -> GraphEdgeId {
526        {
527            /// Helper to emit conflicts when a port is used twice.
528            fn emit_conflict(
529                inout: &str,
530                old: &PortIndexValue,
531                new: &PortIndexValue,
532                diagnostics: &mut Diagnostics,
533            ) {
534                // TODO(mingwei): Use `MultiSpan` once `proc_macro2` supports it.
535                diagnostics.push(Diagnostic::spanned(
536                    old.span(),
537                    Level::Error,
538                    format!(
539                        "{} connection conflicts with below ({}) (1/2)",
540                        inout,
541                        PrettySpan(new.span()),
542                    ),
543                ));
544                diagnostics.push(Diagnostic::spanned(
545                    new.span(),
546                    Level::Error,
547                    format!(
548                        "{} connection conflicts with above ({}) (2/2)",
549                        inout,
550                        PrettySpan(old.span()),
551                    ),
552                ));
553            }
554
555            // Handle src's successor port conflicts:
556            if src_port.is_specified() {
557                for conflicting_port in self
558                    .flat_graph
559                    .node_successor_edges(src)
560                    .map(|edge_id| self.flat_graph.edge_ports(edge_id).0)
561                    .filter(|&port| port == &src_port)
562                {
563                    emit_conflict("Output", conflicting_port, &src_port, &mut self.diagnostics);
564                }
565            }
566
567            // Handle dst's predecessor port conflicts:
568            if dst_port.is_specified() {
569                for conflicting_port in self
570                    .flat_graph
571                    .node_predecessor_edges(dst)
572                    .map(|edge_id| self.flat_graph.edge_ports(edge_id).1)
573                    .filter(|&port| port == &dst_port)
574                {
575                    emit_conflict("Input", conflicting_port, &dst_port, &mut self.diagnostics);
576                }
577            }
578        }
579        self.flat_graph.insert_edge(src, src_port, dst, dst_port)
580    }
581
582    /// Process operators and emit operator errors.
583    fn process_operator_errors(&mut self) {
584        self.make_operator_instances();
585        self.check_operator_errors();
586        self.warn_unused_port_indexing();
587        self.check_loop_errors();
588    }
589
590    /// Make `OperatorInstance`s for each operator node.
591    fn make_operator_instances(&mut self) {
592        self.flat_graph
593            .insert_node_op_insts_all(&mut self.diagnostics);
594    }
595
596    /// Validates that operators have valid number of inputs, outputs, & arguments.
597    /// Adds errors (and warnings) to `self.diagnostics`.
598    fn check_operator_errors(&mut self) {
599        for (node_id, node) in self.flat_graph.nodes() {
600            match node {
601                GraphNode::Operator(operator) => {
602                    let Some(op_inst) = self.flat_graph.node_op_inst(node_id) else {
603                        // Error already emitted by `insert_node_op_insts_all`.
604                        continue;
605                    };
606                    let op_constraints = op_inst.op_constraints;
607                    let op_name = operator.name_string();
608
609                    // Check number of args
610                    if op_constraints.num_args != operator.args.len() {
611                        self.diagnostics.push(Diagnostic::spanned(
612                            operator.span(),
613                            Level::Error,
614                            format!(
615                                "`{}` expects {} argument(s), received {}.",
616                                op_name,
617                                op_constraints.num_args,
618                                operator.args.len()
619                            ),
620                        ));
621                    }
622
623                    // Check input/output (port) arity
624                    /// Returns true if an error was found.
625                    fn emit_arity_error(
626                        op_span: Span,
627                        op_name: &str,
628                        is_in: bool,
629                        is_hard: bool,
630                        degree: usize,
631                        range: &dyn RangeTrait<usize>,
632                        diagnostics: &mut Diagnostics,
633                    ) -> bool {
634                        let message = format!(
635                            "`{}` {} have {} {}, actually has {}.",
636                            op_name,
637                            if is_hard { "must" } else { "should" },
638                            range.human_string(),
639                            if is_in { "input(s)" } else { "output(s)" },
640                            degree,
641                        );
642                        let out_of_range = !range.contains(&degree);
643                        if out_of_range {
644                            diagnostics.push(Diagnostic::spanned(
645                                op_span,
646                                if is_hard {
647                                    Level::Error
648                                } else {
649                                    Level::Warning
650                                },
651                                message,
652                            ));
653                        }
654                        out_of_range
655                    }
656
657                    let inn_degree = self.flat_graph.node_degree_in(node_id);
658                    let _ = emit_arity_error(
659                        operator.span(),
660                        &op_name,
661                        true,
662                        true,
663                        inn_degree,
664                        op_constraints.hard_range_inn,
665                        &mut self.diagnostics,
666                    ) || emit_arity_error(
667                        operator.span(),
668                        &op_name,
669                        true,
670                        false,
671                        inn_degree,
672                        op_constraints.soft_range_inn,
673                        &mut self.diagnostics,
674                    );
675
676                    let out_degree = self.flat_graph.node_degree_out(node_id);
677                    let _ = emit_arity_error(
678                        operator.span(),
679                        &op_name,
680                        false,
681                        true,
682                        out_degree,
683                        op_constraints.hard_range_out,
684                        &mut self.diagnostics,
685                    ) || emit_arity_error(
686                        operator.span(),
687                        &op_name,
688                        false,
689                        false,
690                        out_degree,
691                        op_constraints.soft_range_out,
692                        &mut self.diagnostics,
693                    );
694
695                    fn emit_port_error<'a>(
696                        op_span: Span,
697                        op_name: &str,
698                        expected_ports_fn: Option<fn() -> PortListSpec>,
699                        actual_ports_iter: impl Iterator<Item = &'a PortIndexValue>,
700                        input_output: &'static str,
701                        diagnostics: &mut Diagnostics,
702                    ) {
703                        let Some(expected_ports_fn) = expected_ports_fn else {
704                            return;
705                        };
706                        let PortListSpec::Fixed(expected_ports) = (expected_ports_fn)() else {
707                            // Separate check inside of `demux` special case.
708                            return;
709                        };
710                        let expected_ports: Vec<_> = expected_ports.into_iter().collect();
711
712                        // Reject unexpected ports.
713                        let ports: BTreeSet<_> = actual_ports_iter
714                            // Use `inspect` before collecting into `BTreeSet` to ensure we get
715                            // both error messages on duplicated port names.
716                            .inspect(|actual_port_iv| {
717                                // For each actually used port `port_index_value`, check if it is expected.
718                                let is_expected = expected_ports.iter().any(|port_index| {
719                                    actual_port_iv == &&port_index.clone().into()
720                                });
721                                // If it is not expected, emit a diagnostic error.
722                                if !is_expected {
723                                    diagnostics.push(Diagnostic::spanned(
724                                        actual_port_iv.span(),
725                                        Level::Error,
726                                        format!(
727                                            "`{}` received unexpected {} port: {}. Expected one of: `{}`",
728                                            op_name,
729                                            input_output,
730                                            actual_port_iv.as_error_message_string(),
731                                            Itertools::intersperse(
732                                                expected_ports
733                                                    .iter()
734                                                    .map(|port| port.to_token_stream().to_string())
735                                                    .map(Cow::Owned),
736                                                Cow::Borrowed("`, `"),
737                                            ).collect::<String>()
738                                        ),
739                                    ))
740                                }
741                            })
742                            .collect();
743
744                        // List missing expected ports.
745                        let missing: Vec<_> = expected_ports
746                            .into_iter()
747                            .filter_map(|expected_port| {
748                                let tokens = expected_port.to_token_stream();
749                                if !ports.contains(&&expected_port.into()) {
750                                    Some(tokens)
751                                } else {
752                                    None
753                                }
754                            })
755                            .collect();
756                        if !missing.is_empty() {
757                            diagnostics.push(Diagnostic::spanned(
758                                op_span,
759                                Level::Error,
760                                format!(
761                                    "`{}` missing expected {} port(s): `{}`.",
762                                    op_name,
763                                    input_output,
764                                    Itertools::intersperse(
765                                        missing.into_iter().map(|port| Cow::Owned(
766                                            port.to_token_stream().to_string()
767                                        )),
768                                        Cow::Borrowed("`, `")
769                                    )
770                                    .collect::<String>()
771                                ),
772                            ));
773                        }
774                    }
775
776                    emit_port_error(
777                        operator.span(),
778                        &op_name,
779                        op_constraints.ports_inn,
780                        self.flat_graph
781                            .node_predecessor_edges(node_id)
782                            .map(|edge_id| self.flat_graph.edge_ports(edge_id).1),
783                        "input",
784                        &mut self.diagnostics,
785                    );
786                    emit_port_error(
787                        operator.span(),
788                        &op_name,
789                        op_constraints.ports_out,
790                        self.flat_graph
791                            .node_successor_edges(node_id)
792                            .map(|edge_id| self.flat_graph.edge_ports(edge_id).0),
793                        "output",
794                        &mut self.diagnostics,
795                    );
796
797                    // Check that singleton references actually reference *stateful* operators.
798                    {
799                        let singletons_resolved =
800                            self.flat_graph.node_singleton_references(node_id);
801                        for (singleton_node_id, singleton_ident) in singletons_resolved
802                            .iter()
803                            .zip_eq(&*operator.singletons_referenced)
804                        {
805                            let &Some(singleton_node_id) = singleton_node_id else {
806                                // Error already emitted by `connect_operator_links`, "Cannot find referenced name...".
807                                continue;
808                            };
809                            let Some(ref_op_inst) = self.flat_graph.node_op_inst(singleton_node_id)
810                            else {
811                                // Error already emitted by `insert_node_op_insts_all`.
812                                continue;
813                            };
814                            let ref_op_constraints = ref_op_inst.op_constraints;
815                            if !ref_op_constraints.has_singleton_output {
816                                self.diagnostics.push(Diagnostic::spanned(
817                                    singleton_ident.span(),
818                                    Level::Error,
819                                    format!(
820                                        "Cannot reference operator `{}`. Only operators with singleton state can be referenced.",
821                                        ref_op_constraints.name,
822                                    ),
823                                ));
824                            }
825                        }
826                    }
827                }
828                GraphNode::Handoff { .. } => todo!("Node::Handoff"),
829                GraphNode::ModuleBoundary { .. } => {
830                    // Module boundaries don't require any checking.
831                }
832            }
833        }
834    }
835
836    /// Warns about unused port indexing referenced in [`Self::varname_ends`].
837    /// https://github.com/hydro-project/hydro/issues/1108
838    fn warn_unused_port_indexing(&mut self) {
839        for (_ident, varname_info) in self.varname_ends.iter() {
840            if !varname_info.inn_used {
841                Self::helper_check_unused_port(&mut self.diagnostics, &varname_info.ends, true);
842            }
843            if !varname_info.out_used {
844                Self::helper_check_unused_port(&mut self.diagnostics, &varname_info.ends, false);
845            }
846        }
847    }
848
849    /// Emit a warning to `diagnostics` for an unused port (i.e. if the port is specified for
850    /// reason).
851    fn helper_check_unused_port(diagnostics: &mut Diagnostics, ends: &Ends, is_in: bool) {
852        let port = if is_in { &ends.inn } else { &ends.out };
853        if let Some((port, _)) = port
854            && port.is_specified()
855        {
856            diagnostics.push(Diagnostic::spanned(
857                port.span(),
858                Level::Error,
859                format!(
860                    "{} port index is unused. (Is the port on the correct side?)",
861                    if is_in { "Input" } else { "Output" },
862                ),
863            ));
864        }
865    }
866
867    /// Helper function.
868    /// Combine the port indexing information for indexing wrapped around a name.
869    /// Because the name may already have indexing, this may introduce double indexing (i.e. `[0][0]my_var[0][0]`)
870    /// which would be an error.
871    fn helper_combine_ends(
872        diagnostics: &mut Diagnostics,
873        og_ends: Ends,
874        inn_port: PortIndexValue,
875        out_port: PortIndexValue,
876    ) -> Ends {
877        Ends {
878            inn: Self::helper_combine_end(diagnostics, og_ends.inn, inn_port, "input"),
879            out: Self::helper_combine_end(diagnostics, og_ends.out, out_port, "output"),
880        }
881    }
882
883    /// Helper function.
884    /// Combine the port indexing info for one input or output.
885    fn helper_combine_end(
886        diagnostics: &mut Diagnostics,
887        og: Option<(PortIndexValue, GraphDet)>,
888        other: PortIndexValue,
889        input_output: &'static str,
890    ) -> Option<(PortIndexValue, GraphDet)> {
891        // TODO(mingwei): minification pass over this code?
892
893        let other_span = other.span();
894
895        let (og_port, og_node) = og?;
896        match og_port.combine(other) {
897            Ok(combined_port) => Some((combined_port, og_node)),
898            Err(og_port) => {
899                // TODO(mingwei): Use `MultiSpan` once `proc_macro2` supports it.
900                diagnostics.push(Diagnostic::spanned(
901                    og_port.span(),
902                    Level::Error,
903                    format!(
904                        "Indexing on {} is overwritten below ({}) (1/2).",
905                        input_output,
906                        PrettySpan(other_span),
907                    ),
908                ));
909                diagnostics.push(Diagnostic::spanned(
910                    other_span,
911                    Level::Error,
912                    format!(
913                        "Cannot index on already-indexed {}, previously indexed above ({}) (2/2).",
914                        input_output,
915                        PrettySpan(og_port.span()),
916                    ),
917                ));
918                // When errored, just use original and ignore OTHER port to minimize
919                // noisy/extra diagnostics.
920                Some((og_port, og_node))
921            }
922        }
923    }
924
925    /// Check for loop context-related errors.
926    fn check_loop_errors(&mut self) {
927        for (node_id, node) in self.flat_graph.nodes() {
928            let Some(op_inst) = self.flat_graph.node_op_inst(node_id) else {
929                continue;
930            };
931            let loop_opt = self.flat_graph.node_loop(node_id);
932
933            // Ensure no `'tick` or `'static` persistences are used WITHIN a loop context.
934            // Ensure no `'loop` persistences are used OUTSIDE a loop context.
935            for persistence in &op_inst.generics.persistence_args {
936                let span = op_inst.generics.generic_args.span();
937                match (loop_opt, persistence) {
938                    (Some(_loop_id), p @ (Persistence::Tick | Persistence::Static)) => {
939                        self.diagnostics.push(Diagnostic::spanned(
940                            span,
941                            Level::Error,
942                            format!(
943                                "Operator uses `'{}` persistence, which is not allowed within a `loop {{ ... }}` context.",
944                                p.to_str_lowercase(),
945                            ),
946                        ));
947                    }
948                    (None, p @ (Persistence::None | Persistence::Loop)) => {
949                        self.diagnostics.push(Diagnostic::spanned(
950                            span,
951                            Level::Error,
952                            format!(
953                                "Operator uses `'{}` persistence, but is not within a `loop {{ ... }}` context.",
954                                p.to_str_lowercase(),
955                            ),
956                        ));
957                    }
958                    _ => {}
959                }
960            }
961
962            // All inputs must be declared in the root block.
963            if let (Some(_loop_id), Some(FloType::Source)) =
964                (loop_opt, op_inst.op_constraints.flo_type)
965            {
966                self.diagnostics.push(Diagnostic::spanned(
967                    node.span(),
968                    Level::Error,
969                    format!(
970                        "Source operator `{}(...)` must be at the root level, not within any `loop {{ ... }}` contexts.",
971                        op_inst.op_constraints.name
972                    )
973                ));
974            }
975        }
976
977        // Check windowing and un-windowing operators, for loop inputs and outputs respectively.
978        for (_edge_id, (pred_id, node_id)) in self.flat_graph.edges() {
979            let Some(op_inst) = self.flat_graph.node_op_inst(node_id) else {
980                continue;
981            };
982            let flo_type = &op_inst.op_constraints.flo_type;
983
984            let pred_loop_id = self.flat_graph.node_loop(pred_id);
985            let loop_id = self.flat_graph.node_loop(node_id);
986
987            let span = self.flat_graph.node(node_id).span();
988
989            let (is_input, is_output) = {
990                let parent_pred_loop_id =
991                    pred_loop_id.and_then(|lid| self.flat_graph.loop_parent(lid));
992                let parent_loop_id = loop_id.and_then(|lid| self.flat_graph.loop_parent(lid));
993                let is_same = pred_loop_id == loop_id;
994                let is_input = !is_same && parent_loop_id == pred_loop_id;
995                let is_output = !is_same && parent_pred_loop_id == loop_id;
996                if !(is_input || is_output || is_same) {
997                    self.diagnostics.push(Diagnostic::spanned(
998                        span,
999                        Level::Error,
1000                        "Operator input edge may not cross multiple loop contexts.",
1001                    ));
1002                    continue;
1003                }
1004                (is_input, is_output)
1005            };
1006
1007            match flo_type {
1008                None => {
1009                    if is_input {
1010                        self.diagnostics.push(Diagnostic::spanned(
1011                            span,
1012                            Level::Error,
1013                            format!(
1014                                "Operator `{}(...)` entering a loop context must be a windowing operator, but is not.",
1015                                op_inst.op_constraints.name
1016                            )
1017                        ));
1018                    }
1019                    if is_output {
1020                        self.diagnostics.push(Diagnostic::spanned(
1021                            span,
1022                            Level::Error,
1023                            format!(
1024                                "Operator `{}(...)` exiting a loop context must be an un-windowing operator, but is not.",
1025                                op_inst.op_constraints.name
1026                            )
1027                        ));
1028                    }
1029                }
1030                Some(FloType::Windowing) => {
1031                    if !is_input {
1032                        self.diagnostics.push(Diagnostic::spanned(
1033                            span,
1034                            Level::Error,
1035                            format!(
1036                                "Windowing operator `{}(...)` must be the first input operator into a `loop {{ ... }} context.",
1037                                op_inst.op_constraints.name
1038                            )
1039                        ));
1040                    }
1041                }
1042                Some(FloType::Unwindowing) => {
1043                    if !is_output {
1044                        self.diagnostics.push(Diagnostic::spanned(
1045                            span,
1046                            Level::Error,
1047                            format!(
1048                                "Un-windowing operator `{}(...)` must be the first output operator after exiting a `loop {{ ... }} context.",
1049                                op_inst.op_constraints.name
1050                            )
1051                        ));
1052                    }
1053                }
1054                Some(FloType::NextIteration) => {
1055                    // Must be in a loop context.
1056                    if loop_id.is_none() {
1057                        self.diagnostics.push(Diagnostic::spanned(
1058                            span,
1059                            Level::Error,
1060                            format!(
1061                                "Operator `{}(...)` must be within a `loop {{ ... }}` context.",
1062                                op_inst.op_constraints.name
1063                            ),
1064                        ));
1065                    }
1066                }
1067                Some(FloType::Source) => {
1068                    // Handled above.
1069                }
1070            }
1071        }
1072
1073        // Must be a DAG (excluding `next_iteration()` operators).
1074        // TODO(mingwei): Nested loop blocks should count as a single node.
1075        // But this doesn't cause any correctness issues because the nested loops are also DAGs.
1076        for (loop_id, loop_nodes) in self.flat_graph.loops() {
1077            // Filter out `next_iteration()` operators.
1078            let filter_next_iteration = |&node_id: &GraphNodeId| {
1079                self.flat_graph
1080                    .node_op_inst(node_id)
1081                    .map(|op_inst| Some(FloType::NextIteration) != op_inst.op_constraints.flo_type)
1082                    .unwrap_or(true)
1083            };
1084
1085            let topo_sort_result = graph_algorithms::topo_sort(
1086                loop_nodes.iter().copied().filter(filter_next_iteration),
1087                |dst| {
1088                    self.flat_graph
1089                        .node_predecessor_nodes(dst)
1090                        .filter(|&src| Some(loop_id) == self.flat_graph.node_loop(src))
1091                        .filter(filter_next_iteration)
1092                },
1093            );
1094            if let Err(cycle) = topo_sort_result {
1095                let len = cycle.len();
1096                for (i, node_id) in cycle.into_iter().enumerate() {
1097                    let span = self.flat_graph.node(node_id).span();
1098                    self.diagnostics.push(Diagnostic::spanned(
1099                        span,
1100                        Level::Error,
1101                        format!(
1102                            "Operator forms an illegal cycle within a `loop {{ ... }}` block. Use `{}()` to pass data across loop iterations. ({}/{})",
1103                            NEXT_ITERATION.name,
1104                            i + 1,
1105                            len,
1106                        ),
1107                    ));
1108                }
1109            }
1110        }
1111    }
1112}