1use std::borrow::Cow;
4use std::hash::Hash;
5
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::ToTokens;
8use serde::{Deserialize, Serialize};
9use slotmap::new_key_type;
10use syn::punctuated::Punctuated;
11use syn::spanned::Spanned;
12use syn::{Expr, ExprPath, GenericArgument, Token, Type};
13
14use self::ops::{OperatorConstraints, Persistence};
15use crate::diagnostic::{Diagnostic, Level};
16use crate::parse::{DfirCode, IndexInt, Operator, PortIndex, Ported};
17use crate::pretty_span::PrettySpan;
18
19mod di_mul_graph;
20mod eliminate_extra_unions_tees;
21mod flat_graph_builder;
22mod flat_to_partitioned;
23mod graph_write;
24mod meta_graph;
25mod meta_graph_debugging;
26
27use std::fmt::Display;
28
29pub use di_mul_graph::DiMulGraph;
30pub use eliminate_extra_unions_tees::eliminate_extra_unions_tees;
31pub use flat_graph_builder::FlatGraphBuilder;
32pub use flat_to_partitioned::partition_graph;
33pub use meta_graph::{DfirGraph, WriteConfig, WriteGraphType};
34
35pub mod graph_algorithms;
36pub mod ops;
37
38new_key_type! {
39 pub struct GraphNodeId;
41
42 pub struct GraphEdgeId;
44
45 pub struct GraphSubgraphId;
47
48 pub struct GraphLoopId;
50}
51
52const CONTEXT: &str = "context";
54const GRAPH: &str = "df";
56
57const HANDOFF_NODE_STR: &str = "handoff";
58const MODULE_BOUNDARY_NODE_STR: &str = "module_boundary";
59
60mod serde_syn {
61 use serde::{Deserialize, Deserializer, Serializer};
62
63 pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
64 where
65 S: Serializer,
66 T: quote::ToTokens,
67 {
68 serializer.serialize_str(&value.to_token_stream().to_string())
69 }
70
71 pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
72 where
73 D: Deserializer<'de>,
74 T: syn::parse::Parse,
75 {
76 let s = String::deserialize(deserializer)?;
77 syn::parse_str(&s).map_err(<D::Error as serde::de::Error>::custom)
78 }
79}
80
81#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd, Ord, PartialEq, Eq)]
82struct Varname(#[serde(with = "serde_syn")] pub Ident);
83
84#[derive(Clone, Serialize, Deserialize)]
86pub enum GraphNode {
87 Operator(#[serde(with = "serde_syn")] Operator),
89 Handoff {
91 #[serde(skip, default = "Span::call_site")]
93 src_span: Span,
94 #[serde(skip, default = "Span::call_site")]
96 dst_span: Span,
97 },
98
99 ModuleBoundary {
101 input: bool,
103
104 #[serde(skip, default = "Span::call_site")]
108 import_expr: Span,
109 },
110}
111impl GraphNode {
112 pub fn to_pretty_string(&self) -> Cow<'static, str> {
114 match self {
115 GraphNode::Operator(op) => op.to_pretty_string().into(),
116 GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
117 GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
118 }
119 }
120
121 pub fn to_name_string(&self) -> Cow<'static, str> {
123 match self {
124 GraphNode::Operator(op) => op.name_string().into(),
125 GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
126 GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
127 }
128 }
129
130 pub fn span(&self) -> Span {
132 match self {
133 Self::Operator(op) => op.span(),
134 &Self::Handoff {
135 src_span, dst_span, ..
136 } => src_span.join(dst_span).unwrap_or(src_span),
137 Self::ModuleBoundary { import_expr, .. } => *import_expr,
138 }
139 }
140}
141impl std::fmt::Debug for GraphNode {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 match self {
144 Self::Operator(operator) => {
145 write!(f, "Node::Operator({} span)", PrettySpan(operator.span()))
146 }
147 Self::Handoff { .. } => write!(f, "Node::Handoff"),
148 Self::ModuleBoundary { input, .. } => {
149 write!(f, "Node::ModuleBoundary{{input: {}}}", input)
150 }
151 }
152 }
153}
154
155#[derive(Clone, Debug)]
164pub struct OperatorInstance {
165 pub op_constraints: &'static OperatorConstraints,
167 pub input_ports: Vec<PortIndexValue>,
169 pub output_ports: Vec<PortIndexValue>,
171 pub singletons_referenced: Vec<Ident>,
173
174 pub generics: OpInstGenerics,
176 pub arguments_pre: Punctuated<Expr, Token![,]>,
182 pub arguments_raw: TokenStream,
184}
185
186#[derive(Clone, Debug)]
188pub struct OpInstGenerics {
189 pub generic_args: Option<Punctuated<GenericArgument, Token![,]>>,
191 pub persistence_args: Vec<Persistence>,
193 pub type_args: Vec<Type>,
195}
196
197pub fn get_operator_generics(
202 diagnostics: &mut Vec<Diagnostic>,
203 operator: &Operator,
204) -> OpInstGenerics {
205 let generic_args = operator.type_arguments().cloned();
207 let persistence_args = generic_args.iter().flatten().map_while(|generic_arg| match generic_arg {
208 GenericArgument::Lifetime(lifetime) => {
209 match &*lifetime.ident.to_string() {
210 "none" => Some(Persistence::None),
211 "tick" => Some(Persistence::Tick),
212 "static" => Some(Persistence::Static),
213 "mutable" => Some(Persistence::Mutable),
214 _ => {
215 diagnostics.push(Diagnostic::spanned(
216 generic_arg.span(),
217 Level::Error,
218 format!("Unknown lifetime generic argument `'{}`, expected `'tick`, `'static`, or `'mutable`.", lifetime.ident),
219 ));
220 None
222 }
223 }
224 },
225 _ => None,
226 }).collect::<Vec<_>>();
227 let type_args = generic_args
228 .iter()
229 .flatten()
230 .skip(persistence_args.len())
231 .map_while(|generic_arg| match generic_arg {
232 GenericArgument::Type(typ) => Some(typ),
233 _ => None,
234 })
235 .cloned()
236 .collect::<Vec<_>>();
237
238 OpInstGenerics {
239 generic_args,
240 persistence_args,
241 type_args,
242 }
243}
244
245#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
247pub enum Color {
248 Pull,
250 Push,
252 Comp,
254 Hoff,
256}
257
258#[derive(Clone, Debug, Serialize, Deserialize)]
260pub enum PortIndexValue {
261 Int(#[serde(with = "serde_syn")] IndexInt),
263 Path(#[serde(with = "serde_syn")] ExprPath),
265 Elided(#[serde(skip)] Option<Span>),
268}
269impl PortIndexValue {
270 pub fn from_ported<Inner>(ported: Ported<Inner>) -> (Self, Inner, Self)
273 where
274 Inner: Spanned,
275 {
276 let ported_span = Some(ported.inner.span());
277 let port_inn = ported
278 .inn
279 .map(|idx| idx.index.into())
280 .unwrap_or_else(|| Self::Elided(ported_span));
281 let inner = ported.inner;
282 let port_out = ported
283 .out
284 .map(|idx| idx.index.into())
285 .unwrap_or_else(|| Self::Elided(ported_span));
286 (port_inn, inner, port_out)
287 }
288
289 pub fn is_specified(&self) -> bool {
291 !matches!(self, Self::Elided(_))
292 }
293
294 #[allow(clippy::allow_attributes, reason = "Only triggered on nightly.")]
298 #[allow(
299 clippy::result_large_err,
300 reason = "variants are same size, error isn't to be propagated."
301 )]
302 pub fn combine(self, other: Self) -> Result<Self, Self> {
303 match (self.is_specified(), other.is_specified()) {
304 (false, _other) => Ok(other),
305 (true, false) => Ok(self),
306 (true, true) => Err(self),
307 }
308 }
309
310 pub fn as_error_message_string(&self) -> String {
312 match self {
313 PortIndexValue::Int(n) => format!("`{}`", n.value),
314 PortIndexValue::Path(path) => format!("`{}`", path.to_token_stream()),
315 PortIndexValue::Elided(_) => "<elided>".to_owned(),
316 }
317 }
318
319 pub fn span(&self) -> Span {
321 match self {
322 PortIndexValue::Int(x) => x.span(),
323 PortIndexValue::Path(x) => x.span(),
324 PortIndexValue::Elided(span) => span.unwrap_or_else(Span::call_site),
325 }
326 }
327}
328impl From<PortIndex> for PortIndexValue {
329 fn from(value: PortIndex) -> Self {
330 match value {
331 PortIndex::Int(x) => Self::Int(x),
332 PortIndex::Path(x) => Self::Path(x),
333 }
334 }
335}
336impl PartialEq for PortIndexValue {
337 fn eq(&self, other: &Self) -> bool {
338 match (self, other) {
339 (Self::Int(l0), Self::Int(r0)) => l0 == r0,
340 (Self::Path(l0), Self::Path(r0)) => l0 == r0,
341 (Self::Elided(_), Self::Elided(_)) => true,
342 _else => false,
343 }
344 }
345}
346impl Eq for PortIndexValue {}
347impl PartialOrd for PortIndexValue {
348 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
349 Some(self.cmp(other))
350 }
351}
352impl Ord for PortIndexValue {
353 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
354 match (self, other) {
355 (Self::Int(s), Self::Int(o)) => s.cmp(o),
356 (Self::Path(s), Self::Path(o)) => s
357 .to_token_stream()
358 .to_string()
359 .cmp(&o.to_token_stream().to_string()),
360 (Self::Elided(_), Self::Elided(_)) => std::cmp::Ordering::Equal,
361 (Self::Int(_), Self::Path(_)) => std::cmp::Ordering::Less,
362 (Self::Path(_), Self::Int(_)) => std::cmp::Ordering::Greater,
363 (_, Self::Elided(_)) => std::cmp::Ordering::Less,
364 (Self::Elided(_), _) => std::cmp::Ordering::Greater,
365 }
366 }
367}
368
369impl Display for PortIndexValue {
370 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371 match self {
372 PortIndexValue::Int(x) => write!(f, "{}", x.to_token_stream()),
373 PortIndexValue::Path(x) => write!(f, "{}", x.to_token_stream()),
374 PortIndexValue::Elided(_) => write!(f, "[]"),
375 }
376 }
377}
378
379pub fn build_hfcode(
382 hf_code: DfirCode,
383 root: &TokenStream,
384) -> (Option<(DfirGraph, TokenStream)>, Vec<Diagnostic>) {
385 let flat_graph_builder = FlatGraphBuilder::from_dfir(hf_code);
386 let (mut flat_graph, uses, mut diagnostics) = flat_graph_builder.build();
387 if !diagnostics.iter().any(Diagnostic::is_error) {
388 if let Err(diagnostic) = flat_graph.merge_modules() {
389 diagnostics.push(diagnostic);
390 return (None, diagnostics);
391 }
392
393 eliminate_extra_unions_tees(&mut flat_graph);
394 match partition_graph(flat_graph) {
395 Ok(partitioned_graph) => {
396 let code = partitioned_graph.as_code(
397 root,
398 true,
399 quote::quote! { #( #uses )* },
400 &mut diagnostics,
401 );
402 if !diagnostics.iter().any(Diagnostic::is_error) {
403 return (Some((partitioned_graph, code)), diagnostics);
405 }
406 }
407 Err(diagnostic) => diagnostics.push(diagnostic),
408 }
409 }
410 (None, diagnostics)
411}
412
413fn change_spans(tokens: TokenStream, span: Span) -> TokenStream {
415 use proc_macro2::{Group, TokenTree};
416 tokens
417 .into_iter()
418 .map(|token| match token {
419 TokenTree::Group(mut group) => {
420 group.set_span(span);
421 TokenTree::Group(Group::new(
422 group.delimiter(),
423 change_spans(group.stream(), span),
424 ))
425 }
426 TokenTree::Ident(mut ident) => {
427 ident.set_span(span.resolved_at(ident.span()));
428 TokenTree::Ident(ident)
429 }
430 TokenTree::Punct(mut punct) => {
431 punct.set_span(span);
432 TokenTree::Punct(punct)
433 }
434 TokenTree::Literal(mut literal) => {
435 literal.set_span(span);
436 TokenTree::Literal(literal)
437 }
438 })
439 .collect()
440}