dfir_lang/graph/
mod.rs

1//! Graph representation stages for DFIR graphs.
2
3use std::borrow::Cow;
4use std::hash::Hash;
5
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::ToTokens;
8use serde::{Deserialize, Serialize};
9use slotmap::new_key_type;
10use syn::punctuated::Punctuated;
11use syn::spanned::Spanned;
12use syn::{Expr, ExprPath, GenericArgument, Token, Type};
13
14use self::ops::{OperatorConstraints, Persistence};
15use crate::diagnostic::{Diagnostic, Level};
16use crate::parse::{DfirCode, IndexInt, Operator, PortIndex, Ported};
17use crate::pretty_span::PrettySpan;
18
19mod di_mul_graph;
20mod eliminate_extra_unions_tees;
21mod flat_graph_builder;
22mod flat_to_partitioned;
23mod graph_write;
24mod meta_graph;
25mod meta_graph_debugging;
26
27use std::fmt::Display;
28
29pub use di_mul_graph::DiMulGraph;
30pub use eliminate_extra_unions_tees::eliminate_extra_unions_tees;
31pub use flat_graph_builder::FlatGraphBuilder;
32pub use flat_to_partitioned::partition_graph;
33pub use meta_graph::{DfirGraph, WriteConfig, WriteGraphType};
34
35pub mod graph_algorithms;
36pub mod ops;
37
38new_key_type! {
39    /// ID to identify a node (operator or handoff) in [`DfirGraph`].
40    pub struct GraphNodeId;
41
42    /// ID to identify an edge.
43    pub struct GraphEdgeId;
44
45    /// ID to identify a subgraph in [`DfirGraph`].
46    pub struct GraphSubgraphId;
47
48    /// ID to identify a loop block in [`DfirGraph`].
49    pub struct GraphLoopId;
50}
51
52/// Context identifier as a string.
53const CONTEXT: &str = "context";
54/// Runnable DFIR graph object identifier as a string.
55const GRAPH: &str = "df";
56
57const HANDOFF_NODE_STR: &str = "handoff";
58const MODULE_BOUNDARY_NODE_STR: &str = "module_boundary";
59
60mod serde_syn {
61    use serde::{Deserialize, Deserializer, Serializer};
62
63    pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
64    where
65        S: Serializer,
66        T: quote::ToTokens,
67    {
68        serializer.serialize_str(&value.to_token_stream().to_string())
69    }
70
71    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
72    where
73        D: Deserializer<'de>,
74        T: syn::parse::Parse,
75    {
76        let s = String::deserialize(deserializer)?;
77        syn::parse_str(&s).map_err(<D::Error as serde::de::Error>::custom)
78    }
79}
80
81#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd, Ord, PartialEq, Eq)]
82struct Varname(#[serde(with = "serde_syn")] pub Ident);
83
84/// A node, corresponding to an operator or a handoff.
85#[derive(Clone, Serialize, Deserialize)]
86pub enum GraphNode {
87    /// An operator.
88    Operator(#[serde(with = "serde_syn")] Operator),
89    /// A handoff point, used between subgraphs (or within a subgraph to break a cycle).
90    Handoff {
91        /// The span of the input into the handoff.
92        #[serde(skip, default = "Span::call_site")]
93        src_span: Span,
94        /// The span of the output out of the handoff.
95        #[serde(skip, default = "Span::call_site")]
96        dst_span: Span,
97    },
98
99    /// Module Boundary, used for importing modules. Only exists prior to partitioning.
100    ModuleBoundary {
101        /// If this module is an input or output boundary.
102        input: bool,
103
104        /// The span of the import!() expression that imported this module.
105        /// The value of this span when the ModuleBoundary node is still inside the module is Span::call_site()
106        /// TODO: This could one day reference into the module file itself?
107        #[serde(skip, default = "Span::call_site")]
108        import_expr: Span,
109    },
110}
111impl GraphNode {
112    /// Return the node as a human-readable string.
113    pub fn to_pretty_string(&self) -> Cow<'static, str> {
114        match self {
115            GraphNode::Operator(op) => op.to_pretty_string().into(),
116            GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
117            GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
118        }
119    }
120
121    /// Return the name of the node as a string, excluding parenthesis and op source code.
122    pub fn to_name_string(&self) -> Cow<'static, str> {
123        match self {
124            GraphNode::Operator(op) => op.name_string().into(),
125            GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
126            GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
127        }
128    }
129
130    /// Return the source code span of the node (for operators) or input/otput spans for handoffs.
131    pub fn span(&self) -> Span {
132        match self {
133            Self::Operator(op) => op.span(),
134            &Self::Handoff {
135                src_span, dst_span, ..
136            } => src_span.join(dst_span).unwrap_or(src_span),
137            Self::ModuleBoundary { import_expr, .. } => *import_expr,
138        }
139    }
140}
141impl std::fmt::Debug for GraphNode {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        match self {
144            Self::Operator(operator) => {
145                write!(f, "Node::Operator({} span)", PrettySpan(operator.span()))
146            }
147            Self::Handoff { .. } => write!(f, "Node::Handoff"),
148            Self::ModuleBoundary { input, .. } => {
149                write!(f, "Node::ModuleBoundary{{input: {}}}", input)
150            }
151        }
152    }
153}
154
155/// Meta-data relating to operators which may be useful throughout the compilation process.
156///
157/// This data can be generated from the graph, but it is useful to have it readily available
158/// pre-computed as many algorithms use the same info. Stuff like port names, arguments, and the
159/// [`OperatorConstraints`] for the operator.
160///
161/// Because it is derived from the graph itself, there can be "cache invalidation"-esque issues
162/// if this data is not kept in sync with the graph.
163#[derive(Clone, Debug)]
164pub struct OperatorInstance {
165    /// Name of the operator (will match [`OperatorConstraints::name`]).
166    pub op_constraints: &'static OperatorConstraints,
167    /// Port values used as this operator's input.
168    pub input_ports: Vec<PortIndexValue>,
169    /// Port values used as this operator's output.
170    pub output_ports: Vec<PortIndexValue>,
171    /// Singleton references within the operator arguments.
172    pub singletons_referenced: Vec<Ident>,
173
174    /// Generic arguments.
175    pub generics: OpInstGenerics,
176    /// Arguments provided by the user into the operator as arguments.
177    /// I.e. the `a, b, c` in `-> my_op(a, b, c) -> `.
178    ///
179    /// These arguments do not include singleton postprocessing codegen. Instead use
180    /// [`ops::WriteContextArgs::arguments`].
181    pub arguments_pre: Punctuated<Expr, Token![,]>,
182    /// Unparsed arguments, for singleton parsing.
183    pub arguments_raw: TokenStream,
184}
185
186/// Operator generic arguments, split into specific categories.
187#[derive(Clone, Debug)]
188pub struct OpInstGenerics {
189    /// Operator generic (type or lifetime) arguments.
190    pub generic_args: Option<Punctuated<GenericArgument, Token![,]>>,
191    /// Lifetime persistence arguments. Corresponds to a prefix of [`Self::generic_args`].
192    pub persistence_args: Vec<Persistence>,
193    /// Type persistence arguments. Corersponds to a (suffix) of [`Self::generic_args`].
194    pub type_args: Vec<Type>,
195}
196
197/// Gets the generic arguments for the operator.
198///
199/// This helper method is useful due to the special handling of persistence lifetimes (`'static`,
200/// `'tick`, `'mutable`) which must come before other generic parameters.
201pub fn get_operator_generics(
202    diagnostics: &mut Vec<Diagnostic>,
203    operator: &Operator,
204) -> OpInstGenerics {
205    // Generic arguments.
206    let generic_args = operator.type_arguments().cloned();
207    let persistence_args = generic_args.iter().flatten().map_while(|generic_arg| match generic_arg {
208            GenericArgument::Lifetime(lifetime) => {
209                match &*lifetime.ident.to_string() {
210                    "none" => Some(Persistence::None),
211                    "tick" => Some(Persistence::Tick),
212                    "static" => Some(Persistence::Static),
213                    "mutable" => Some(Persistence::Mutable),
214                    _ => {
215                        diagnostics.push(Diagnostic::spanned(
216                            generic_arg.span(),
217                            Level::Error,
218                            format!("Unknown lifetime generic argument `'{}`, expected `'tick`, `'static`, or `'mutable`.", lifetime.ident),
219                        ));
220                        // TODO(mingwei): should really keep going and not short circuit?
221                        None
222                    }
223                }
224            },
225            _ => None,
226        }).collect::<Vec<_>>();
227    let type_args = generic_args
228        .iter()
229        .flatten()
230        .skip(persistence_args.len())
231        .map_while(|generic_arg| match generic_arg {
232            GenericArgument::Type(typ) => Some(typ),
233            _ => None,
234        })
235        .cloned()
236        .collect::<Vec<_>>();
237
238    OpInstGenerics {
239        generic_args,
240        persistence_args,
241        type_args,
242    }
243}
244
245/// Push, Pull, Comp, or Hoff polarity.
246#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
247pub enum Color {
248    /// Pull (green)
249    Pull,
250    /// Push (blue)
251    Push,
252    /// Computation (yellow)
253    Comp,
254    /// Handoff (grey) -- not a color for operators, inserted between subgraphs.
255    Hoff,
256}
257
258/// Helper struct for [`PortIndex`] which keeps span information for elided ports.
259#[derive(Clone, Debug, Serialize, Deserialize)]
260pub enum PortIndexValue {
261    /// An integer value: `[0]`, `[1]`, etc. Can be negative although we don't use that (2023-08-16).
262    Int(#[serde(with = "serde_syn")] IndexInt),
263    /// A name or path. `[pos]`, `[neg]`, etc. Can use `::` separators but we don't use that (2023-08-16).
264    Path(#[serde(with = "serde_syn")] ExprPath),
265    /// Elided, unspecified port. We have this variant, rather than wrapping in `Option`, in order
266    /// to preserve the `Span` information.
267    Elided(#[serde(skip)] Option<Span>),
268}
269impl PortIndexValue {
270    /// For a [`Ported`] value like `[port_in]name[port_out]`, get the `port_in` and `port_out` as
271    /// [`PortIndexValue`]s.
272    pub fn from_ported<Inner>(ported: Ported<Inner>) -> (Self, Inner, Self)
273    where
274        Inner: Spanned,
275    {
276        let ported_span = Some(ported.inner.span());
277        let port_inn = ported
278            .inn
279            .map(|idx| idx.index.into())
280            .unwrap_or_else(|| Self::Elided(ported_span));
281        let inner = ported.inner;
282        let port_out = ported
283            .out
284            .map(|idx| idx.index.into())
285            .unwrap_or_else(|| Self::Elided(ported_span));
286        (port_inn, inner, port_out)
287    }
288
289    /// Returns `true` if `self` is not [`PortIndexValue::Elided`].
290    pub fn is_specified(&self) -> bool {
291        !matches!(self, Self::Elided(_))
292    }
293
294    /// Returns whichever of the two ports are specified.
295    /// If both are [`Self::Elided`], returns [`Self::Elided`].
296    /// If both are specified, returns `Err(self)`.
297    #[allow(clippy::allow_attributes, reason = "Only triggered on nightly.")]
298    #[allow(
299        clippy::result_large_err,
300        reason = "variants are same size, error isn't to be propagated."
301    )]
302    pub fn combine(self, other: Self) -> Result<Self, Self> {
303        match (self.is_specified(), other.is_specified()) {
304            (false, _other) => Ok(other),
305            (true, false) => Ok(self),
306            (true, true) => Err(self),
307        }
308    }
309
310    /// Formats self as a human-readable string for error messages.
311    pub fn as_error_message_string(&self) -> String {
312        match self {
313            PortIndexValue::Int(n) => format!("`{}`", n.value),
314            PortIndexValue::Path(path) => format!("`{}`", path.to_token_stream()),
315            PortIndexValue::Elided(_) => "<elided>".to_owned(),
316        }
317    }
318
319    /// Returns the span of this port value.
320    pub fn span(&self) -> Span {
321        match self {
322            PortIndexValue::Int(x) => x.span(),
323            PortIndexValue::Path(x) => x.span(),
324            PortIndexValue::Elided(span) => span.unwrap_or_else(Span::call_site),
325        }
326    }
327}
328impl From<PortIndex> for PortIndexValue {
329    fn from(value: PortIndex) -> Self {
330        match value {
331            PortIndex::Int(x) => Self::Int(x),
332            PortIndex::Path(x) => Self::Path(x),
333        }
334    }
335}
336impl PartialEq for PortIndexValue {
337    fn eq(&self, other: &Self) -> bool {
338        match (self, other) {
339            (Self::Int(l0), Self::Int(r0)) => l0 == r0,
340            (Self::Path(l0), Self::Path(r0)) => l0 == r0,
341            (Self::Elided(_), Self::Elided(_)) => true,
342            _else => false,
343        }
344    }
345}
346impl Eq for PortIndexValue {}
347impl PartialOrd for PortIndexValue {
348    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
349        Some(self.cmp(other))
350    }
351}
352impl Ord for PortIndexValue {
353    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
354        match (self, other) {
355            (Self::Int(s), Self::Int(o)) => s.cmp(o),
356            (Self::Path(s), Self::Path(o)) => s
357                .to_token_stream()
358                .to_string()
359                .cmp(&o.to_token_stream().to_string()),
360            (Self::Elided(_), Self::Elided(_)) => std::cmp::Ordering::Equal,
361            (Self::Int(_), Self::Path(_)) => std::cmp::Ordering::Less,
362            (Self::Path(_), Self::Int(_)) => std::cmp::Ordering::Greater,
363            (_, Self::Elided(_)) => std::cmp::Ordering::Less,
364            (Self::Elided(_), _) => std::cmp::Ordering::Greater,
365        }
366    }
367}
368
369impl Display for PortIndexValue {
370    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371        match self {
372            PortIndexValue::Int(x) => write!(f, "{}", x.to_token_stream()),
373            PortIndexValue::Path(x) => write!(f, "{}", x.to_token_stream()),
374            PortIndexValue::Elided(_) => write!(f, "[]"),
375        }
376    }
377}
378
379/// The main function of this module. Compiles a [`DfirCode`] AST into a [`DfirGraph`] and
380/// source code, or [`Diagnostic`] errors.
381pub fn build_hfcode(
382    hf_code: DfirCode,
383    root: &TokenStream,
384) -> (Option<(DfirGraph, TokenStream)>, Vec<Diagnostic>) {
385    let flat_graph_builder = FlatGraphBuilder::from_dfir(hf_code);
386    let (mut flat_graph, uses, mut diagnostics) = flat_graph_builder.build();
387    if !diagnostics.iter().any(Diagnostic::is_error) {
388        if let Err(diagnostic) = flat_graph.merge_modules() {
389            diagnostics.push(diagnostic);
390            return (None, diagnostics);
391        }
392
393        eliminate_extra_unions_tees(&mut flat_graph);
394        match partition_graph(flat_graph) {
395            Ok(partitioned_graph) => {
396                let code = partitioned_graph.as_code(
397                    root,
398                    true,
399                    quote::quote! { #( #uses )* },
400                    &mut diagnostics,
401                );
402                if !diagnostics.iter().any(Diagnostic::is_error) {
403                    // Success.
404                    return (Some((partitioned_graph, code)), diagnostics);
405                }
406            }
407            Err(diagnostic) => diagnostics.push(diagnostic),
408        }
409    }
410    (None, diagnostics)
411}
412
413/// Changes all of token's spans to `span`, recursing into groups.
414fn change_spans(tokens: TokenStream, span: Span) -> TokenStream {
415    use proc_macro2::{Group, TokenTree};
416    tokens
417        .into_iter()
418        .map(|token| match token {
419            TokenTree::Group(mut group) => {
420                group.set_span(span);
421                TokenTree::Group(Group::new(
422                    group.delimiter(),
423                    change_spans(group.stream(), span),
424                ))
425            }
426            TokenTree::Ident(mut ident) => {
427                ident.set_span(span.resolved_at(ident.span()));
428                TokenTree::Ident(ident)
429            }
430            TokenTree::Punct(mut punct) => {
431                punct.set_span(span);
432                TokenTree::Punct(punct)
433            }
434            TokenTree::Literal(mut literal) => {
435                literal.set_span(span);
436                TokenTree::Literal(literal)
437            }
438        })
439        .collect()
440}