1use std::borrow::Cow;
4use std::hash::Hash;
5
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::ToTokens;
8use serde::{Deserialize, Serialize};
9use slotmap::new_key_type;
10use syn::punctuated::Punctuated;
11use syn::spanned::Spanned;
12use syn::{Expr, ExprPath, GenericArgument, Token, Type};
13
14use self::ops::{OperatorConstraints, Persistence};
15use crate::diagnostic::{Diagnostic, Level};
16use crate::parse::{DfirCode, IndexInt, Operator, PortIndex, Ported};
17use crate::pretty_span::PrettySpan;
18
19mod di_mul_graph;
20mod eliminate_extra_unions_tees;
21mod flat_graph_builder;
22mod flat_to_partitioned;
23mod graph_write;
24mod meta_graph;
25mod meta_graph_debugging;
26
27use std::fmt::Display;
28
29pub use di_mul_graph::DiMulGraph;
30pub use eliminate_extra_unions_tees::eliminate_extra_unions_tees;
31pub use flat_graph_builder::FlatGraphBuilder;
32pub use flat_to_partitioned::partition_graph;
33pub use meta_graph::{DfirGraph, WriteConfig, WriteGraphType};
34
35pub mod graph_algorithms;
36pub mod ops;
37
38new_key_type! {
39 pub struct GraphNodeId;
41
42 pub struct GraphEdgeId;
44
45 pub struct GraphSubgraphId;
47
48 pub struct GraphLoopId;
50}
51
52impl GraphSubgraphId {
53 pub fn as_ident(self, span: Span) -> Ident {
55 use slotmap::Key;
56 Ident::new(&format!("sgid_{:?}", self.data()), span)
57 }
58}
59
60impl GraphLoopId {
61 pub fn as_ident(self, span: Span) -> Ident {
63 use slotmap::Key;
64 Ident::new(&format!("loop_{:?}", self.data()), span)
65 }
66}
67
68const CONTEXT: &str = "context";
70const GRAPH: &str = "df";
72
73const HANDOFF_NODE_STR: &str = "handoff";
74const MODULE_BOUNDARY_NODE_STR: &str = "module_boundary";
75
76mod serde_syn {
77 use serde::{Deserialize, Deserializer, Serializer};
78
79 pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
80 where
81 S: Serializer,
82 T: quote::ToTokens,
83 {
84 serializer.serialize_str(&value.to_token_stream().to_string())
85 }
86
87 pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
88 where
89 D: Deserializer<'de>,
90 T: syn::parse::Parse,
91 {
92 let s = String::deserialize(deserializer)?;
93 syn::parse_str(&s).map_err(<D::Error as serde::de::Error>::custom)
94 }
95}
96
97#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd, Ord, PartialEq, Eq, Hash)]
101pub struct Varname(#[serde(with = "serde_syn")] pub Ident);
102
103#[derive(Clone, Serialize, Deserialize)]
105pub enum GraphNode {
106 Operator(#[serde(with = "serde_syn")] Operator),
108 Handoff {
110 #[serde(skip, default = "Span::call_site")]
112 src_span: Span,
113 #[serde(skip, default = "Span::call_site")]
115 dst_span: Span,
116 },
117
118 ModuleBoundary {
120 input: bool,
122
123 #[serde(skip, default = "Span::call_site")]
127 import_expr: Span,
128 },
129}
130impl GraphNode {
131 pub fn to_pretty_string(&self) -> Cow<'static, str> {
133 match self {
134 GraphNode::Operator(op) => op.to_pretty_string().into(),
135 GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
136 GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
137 }
138 }
139
140 pub fn to_name_string(&self) -> Cow<'static, str> {
142 match self {
143 GraphNode::Operator(op) => op.name_string().into(),
144 GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
145 GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
146 }
147 }
148
149 pub fn span(&self) -> Span {
151 match self {
152 Self::Operator(op) => op.span(),
153 &Self::Handoff {
154 src_span, dst_span, ..
155 } => src_span.join(dst_span).unwrap_or(src_span),
156 Self::ModuleBoundary { import_expr, .. } => *import_expr,
157 }
158 }
159}
160impl std::fmt::Debug for GraphNode {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 match self {
163 Self::Operator(operator) => {
164 write!(f, "Node::Operator({} span)", PrettySpan(operator.span()))
165 }
166 Self::Handoff { .. } => write!(f, "Node::Handoff"),
167 Self::ModuleBoundary { input, .. } => {
168 write!(f, "Node::ModuleBoundary{{input: {}}}", input)
169 }
170 }
171 }
172}
173
174#[derive(Clone, Debug)]
183pub struct OperatorInstance {
184 pub op_constraints: &'static OperatorConstraints,
186 pub input_ports: Vec<PortIndexValue>,
188 pub output_ports: Vec<PortIndexValue>,
190 pub singletons_referenced: Vec<Ident>,
192
193 pub generics: OpInstGenerics,
195 pub arguments_pre: Punctuated<Expr, Token![,]>,
201 pub arguments_raw: TokenStream,
203}
204
205#[derive(Clone, Debug)]
207pub struct OpInstGenerics {
208 pub generic_args: Option<Punctuated<GenericArgument, Token![,]>>,
210 pub persistence_args: Vec<Persistence>,
212 pub type_args: Vec<Type>,
214}
215
216pub fn get_operator_generics(
221 diagnostics: &mut Vec<Diagnostic>,
222 operator: &Operator,
223) -> OpInstGenerics {
224 let generic_args = operator.type_arguments().cloned();
226 let persistence_args = generic_args.iter().flatten().map_while(|generic_arg| match generic_arg {
227 GenericArgument::Lifetime(lifetime) => {
228 match &*lifetime.ident.to_string() {
229 "none" => Some(Persistence::None),
230 "loop" => Some(Persistence::Loop),
231 "tick" => Some(Persistence::Tick),
232 "static" => Some(Persistence::Static),
233 "mutable" => Some(Persistence::Mutable),
234 _ => {
235 diagnostics.push(Diagnostic::spanned(
236 generic_arg.span(),
237 Level::Error,
238 format!("Unknown lifetime generic argument `'{}`, expected `'none`, `'loop`, `'tick`, `'static`, or `'mutable`.", lifetime.ident),
239 ));
240 None
242 }
243 }
244 },
245 _ => None,
246 }).collect::<Vec<_>>();
247 let type_args = generic_args
248 .iter()
249 .flatten()
250 .skip(persistence_args.len())
251 .map_while(|generic_arg| match generic_arg {
252 GenericArgument::Type(typ) => Some(typ),
253 _ => None,
254 })
255 .cloned()
256 .collect::<Vec<_>>();
257
258 OpInstGenerics {
259 generic_args,
260 persistence_args,
261 type_args,
262 }
263}
264
265#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
267pub enum Color {
268 Pull,
270 Push,
272 Comp,
274 Hoff,
276}
277
278#[derive(Clone, Debug, Serialize, Deserialize)]
280pub enum PortIndexValue {
281 Int(#[serde(with = "serde_syn")] IndexInt),
283 Path(#[serde(with = "serde_syn")] ExprPath),
285 Elided(#[serde(skip)] Option<Span>),
288}
289impl PortIndexValue {
290 pub fn from_ported<Inner>(ported: Ported<Inner>) -> (Self, Inner, Self)
293 where
294 Inner: Spanned,
295 {
296 let ported_span = Some(ported.inner.span());
297 let port_inn = ported
298 .inn
299 .map(|idx| idx.index.into())
300 .unwrap_or_else(|| Self::Elided(ported_span));
301 let inner = ported.inner;
302 let port_out = ported
303 .out
304 .map(|idx| idx.index.into())
305 .unwrap_or_else(|| Self::Elided(ported_span));
306 (port_inn, inner, port_out)
307 }
308
309 pub fn is_specified(&self) -> bool {
311 !matches!(self, Self::Elided(_))
312 }
313
314 #[allow(clippy::allow_attributes, reason = "Only triggered on nightly.")]
318 #[allow(
319 clippy::result_large_err,
320 reason = "variants are same size, error isn't to be propagated."
321 )]
322 pub fn combine(self, other: Self) -> Result<Self, Self> {
323 match (self.is_specified(), other.is_specified()) {
324 (false, _other) => Ok(other),
325 (true, false) => Ok(self),
326 (true, true) => Err(self),
327 }
328 }
329
330 pub fn as_error_message_string(&self) -> String {
332 match self {
333 PortIndexValue::Int(n) => format!("`{}`", n.value),
334 PortIndexValue::Path(path) => format!("`{}`", path.to_token_stream()),
335 PortIndexValue::Elided(_) => "<elided>".to_owned(),
336 }
337 }
338
339 pub fn span(&self) -> Span {
341 match self {
342 PortIndexValue::Int(x) => x.span(),
343 PortIndexValue::Path(x) => x.span(),
344 PortIndexValue::Elided(span) => span.unwrap_or_else(Span::call_site),
345 }
346 }
347}
348impl From<PortIndex> for PortIndexValue {
349 fn from(value: PortIndex) -> Self {
350 match value {
351 PortIndex::Int(x) => Self::Int(x),
352 PortIndex::Path(x) => Self::Path(x),
353 }
354 }
355}
356impl PartialEq for PortIndexValue {
357 fn eq(&self, other: &Self) -> bool {
358 match (self, other) {
359 (Self::Int(l0), Self::Int(r0)) => l0 == r0,
360 (Self::Path(l0), Self::Path(r0)) => l0 == r0,
361 (Self::Elided(_), Self::Elided(_)) => true,
362 _else => false,
363 }
364 }
365}
366impl Eq for PortIndexValue {}
367impl PartialOrd for PortIndexValue {
368 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
369 Some(self.cmp(other))
370 }
371}
372impl Ord for PortIndexValue {
373 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
374 match (self, other) {
375 (Self::Int(s), Self::Int(o)) => s.cmp(o),
376 (Self::Path(s), Self::Path(o)) => s
377 .to_token_stream()
378 .to_string()
379 .cmp(&o.to_token_stream().to_string()),
380 (Self::Elided(_), Self::Elided(_)) => std::cmp::Ordering::Equal,
381 (Self::Int(_), Self::Path(_)) => std::cmp::Ordering::Less,
382 (Self::Path(_), Self::Int(_)) => std::cmp::Ordering::Greater,
383 (_, Self::Elided(_)) => std::cmp::Ordering::Less,
384 (Self::Elided(_), _) => std::cmp::Ordering::Greater,
385 }
386 }
387}
388
389impl Display for PortIndexValue {
390 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391 match self {
392 PortIndexValue::Int(x) => write!(f, "{}", x.to_token_stream()),
393 PortIndexValue::Path(x) => write!(f, "{}", x.to_token_stream()),
394 PortIndexValue::Elided(_) => write!(f, "[]"),
395 }
396 }
397}
398
399pub fn build_hfcode(
402 hf_code: DfirCode,
403 root: &TokenStream,
404) -> (Option<(DfirGraph, TokenStream)>, Vec<Diagnostic>) {
405 let flat_graph_builder = FlatGraphBuilder::from_dfir(hf_code);
406 let (mut flat_graph, uses, mut diagnostics) = flat_graph_builder.build();
407 if !diagnostics.iter().any(Diagnostic::is_error) {
408 if let Err(diagnostic) = flat_graph.merge_modules() {
409 diagnostics.push(diagnostic);
410 return (None, diagnostics);
411 }
412
413 eliminate_extra_unions_tees(&mut flat_graph);
414 match partition_graph(flat_graph) {
415 Ok(partitioned_graph) => {
416 let code = partitioned_graph.as_code(
417 root,
418 true,
419 quote::quote! { #( #uses )* },
420 &mut diagnostics,
421 );
422 if !diagnostics.iter().any(Diagnostic::is_error) {
423 return (Some((partitioned_graph, code)), diagnostics);
425 }
426 }
427 Err(diagnostic) => diagnostics.push(diagnostic),
428 }
429 }
430 (None, diagnostics)
431}
432
433fn change_spans(tokens: TokenStream, span: Span) -> TokenStream {
435 use proc_macro2::{Group, TokenTree};
436 tokens
437 .into_iter()
438 .map(|token| match token {
439 TokenTree::Group(mut group) => {
440 group.set_span(span);
441 TokenTree::Group(Group::new(
442 group.delimiter(),
443 change_spans(group.stream(), span),
444 ))
445 }
446 TokenTree::Ident(mut ident) => {
447 ident.set_span(span.resolved_at(ident.span()));
448 TokenTree::Ident(ident)
449 }
450 TokenTree::Punct(mut punct) => {
451 punct.set_span(span);
452 TokenTree::Punct(punct)
453 }
454 TokenTree::Literal(mut literal) => {
455 literal.set_span(span);
456 TokenTree::Literal(literal)
457 }
458 })
459 .collect()
460}