use std::borrow::Cow;
use std::hash::Hash;
use proc_macro2::{Ident, Span, TokenStream};
use quote::ToTokens;
use serde::{Deserialize, Serialize};
use slotmap::new_key_type;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{Expr, ExprPath, GenericArgument, Token, Type};
use self::ops::{OperatorConstraints, Persistence};
use crate::diagnostic::{Diagnostic, Level};
use crate::parse::{HfCode, IndexInt, Operator, PortIndex, Ported};
use crate::pretty_span::PrettySpan;
mod di_mul_graph;
mod eliminate_extra_unions_tees;
mod flat_graph_builder;
mod flat_to_partitioned;
mod graph_write;
mod hydroflow_graph;
mod hydroflow_graph_debugging;
use std::fmt::Display;
pub use di_mul_graph::DiMulGraph;
pub use eliminate_extra_unions_tees::eliminate_extra_unions_tees;
pub use flat_graph_builder::FlatGraphBuilder;
pub use flat_to_partitioned::partition_graph;
pub use hydroflow_graph::{DfirGraph, WriteConfig, WriteGraphType};
pub mod graph_algorithms;
pub mod ops;
new_key_type! {
pub struct GraphNodeId;
pub struct GraphEdgeId;
pub struct GraphSubgraphId;
pub struct GraphLoopId;
}
const CONTEXT: &str = "context";
const HYDROFLOW: &str = "df";
const HANDOFF_NODE_STR: &str = "handoff";
const MODULE_BOUNDARY_NODE_STR: &str = "module_boundary";
mod serde_syn {
use serde::{Deserialize, Deserializer, Serializer};
pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: quote::ToTokens,
{
serializer.serialize_str(&value.to_token_stream().to_string())
}
pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
T: syn::parse::Parse,
{
let s = String::deserialize(deserializer)?;
syn::parse_str(&s).map_err(<D::Error as serde::de::Error>::custom)
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd, Ord, PartialEq, Eq)]
struct Varname(#[serde(with = "serde_syn")] pub Ident);
#[derive(Clone, Serialize, Deserialize)]
pub enum GraphNode {
Operator(#[serde(with = "serde_syn")] Operator),
Handoff {
#[serde(skip, default = "Span::call_site")]
src_span: Span,
#[serde(skip, default = "Span::call_site")]
dst_span: Span,
},
ModuleBoundary {
input: bool,
#[serde(skip, default = "Span::call_site")]
import_expr: Span,
},
}
impl GraphNode {
pub fn to_pretty_string(&self) -> Cow<'static, str> {
match self {
GraphNode::Operator(op) => op.to_pretty_string().into(),
GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
}
}
pub fn to_name_string(&self) -> Cow<'static, str> {
match self {
GraphNode::Operator(op) => op.name_string().into(),
GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
}
}
pub fn span(&self) -> Span {
match self {
Self::Operator(op) => op.span(),
&Self::Handoff { src_span, dst_span } => src_span.join(dst_span).unwrap_or(src_span),
Self::ModuleBoundary { import_expr, .. } => *import_expr,
}
}
}
impl std::fmt::Debug for GraphNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Operator(operator) => {
write!(f, "Node::Operator({} span)", PrettySpan(operator.span()))
}
Self::Handoff { .. } => write!(f, "Node::Handoff"),
Self::ModuleBoundary { input, .. } => {
write!(f, "Node::ModuleBoundary{{input: {}}}", input)
}
}
}
}
#[derive(Clone, Debug)]
pub struct OperatorInstance {
pub op_constraints: &'static OperatorConstraints,
pub input_ports: Vec<PortIndexValue>,
pub output_ports: Vec<PortIndexValue>,
pub singletons_referenced: Vec<Ident>,
pub generics: OpInstGenerics,
pub arguments_pre: Punctuated<Expr, Token![,]>,
pub arguments_raw: TokenStream,
}
#[derive(Clone, Debug)]
pub struct OpInstGenerics {
pub generic_args: Option<Punctuated<GenericArgument, Token![,]>>,
pub persistence_args: Vec<Persistence>,
pub type_args: Vec<Type>,
}
pub fn get_operator_generics(
diagnostics: &mut Vec<Diagnostic>,
operator: &Operator,
) -> OpInstGenerics {
let generic_args = operator.type_arguments().cloned();
let persistence_args = generic_args.iter().flatten().map_while(|generic_arg| match generic_arg {
GenericArgument::Lifetime(lifetime) => {
match &*lifetime.ident.to_string() {
"static" => Some(Persistence::Static),
"tick" => Some(Persistence::Tick),
"mutable" => Some(Persistence::Mutable),
_ => {
diagnostics.push(Diagnostic::spanned(
generic_arg.span(),
Level::Error,
format!("Unknown lifetime generic argument `'{}`, expected `'tick`, `'static`, or `'mutable`.", lifetime.ident),
));
None
}
}
},
_ => None,
}).collect::<Vec<_>>();
let type_args = generic_args
.iter()
.flatten()
.skip(persistence_args.len())
.map_while(|generic_arg| match generic_arg {
GenericArgument::Type(typ) => Some(typ),
_ => None,
})
.cloned()
.collect::<Vec<_>>();
OpInstGenerics {
generic_args,
persistence_args,
type_args,
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum Color {
Pull,
Push,
Comp,
Hoff,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum PortIndexValue {
Int(#[serde(with = "serde_syn")] IndexInt),
Path(#[serde(with = "serde_syn")] ExprPath),
Elided(#[serde(skip)] Option<Span>),
}
impl PortIndexValue {
pub fn from_ported<Inner>(ported: Ported<Inner>) -> (Self, Inner, Self)
where
Inner: Spanned,
{
let ported_span = Some(ported.inner.span());
let port_inn = ported
.inn
.map(|idx| idx.index.into())
.unwrap_or_else(|| Self::Elided(ported_span));
let inner = ported.inner;
let port_out = ported
.out
.map(|idx| idx.index.into())
.unwrap_or_else(|| Self::Elided(ported_span));
(port_inn, inner, port_out)
}
pub fn is_specified(&self) -> bool {
!matches!(self, Self::Elided(_))
}
pub fn combine(self, other: Self) -> Result<Self, Self> {
if self.is_specified() {
if other.is_specified() {
Err(self)
} else {
Ok(self)
}
} else {
Ok(other)
}
}
pub fn as_error_message_string(&self) -> String {
match self {
PortIndexValue::Int(n) => format!("`{}`", n.value),
PortIndexValue::Path(path) => format!("`{}`", path.to_token_stream()),
PortIndexValue::Elided(_) => "<elided>".to_owned(),
}
}
pub fn span(&self) -> Span {
match self {
PortIndexValue::Int(x) => x.span(),
PortIndexValue::Path(x) => x.span(),
PortIndexValue::Elided(span) => span.unwrap_or_else(Span::call_site),
}
}
}
impl From<PortIndex> for PortIndexValue {
fn from(value: PortIndex) -> Self {
match value {
PortIndex::Int(x) => Self::Int(x),
PortIndex::Path(x) => Self::Path(x),
}
}
}
impl PartialEq for PortIndexValue {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Int(l0), Self::Int(r0)) => l0 == r0,
(Self::Path(l0), Self::Path(r0)) => l0 == r0,
(Self::Elided(_), Self::Elided(_)) => true,
_else => false,
}
}
}
impl Eq for PortIndexValue {}
impl PartialOrd for PortIndexValue {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PortIndexValue {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match (self, other) {
(Self::Int(s), Self::Int(o)) => s.cmp(o),
(Self::Path(s), Self::Path(o)) => s
.to_token_stream()
.to_string()
.cmp(&o.to_token_stream().to_string()),
(Self::Elided(_), Self::Elided(_)) => std::cmp::Ordering::Equal,
(Self::Int(_), Self::Path(_)) => std::cmp::Ordering::Less,
(Self::Path(_), Self::Int(_)) => std::cmp::Ordering::Greater,
(_, Self::Elided(_)) => std::cmp::Ordering::Less,
(Self::Elided(_), _) => std::cmp::Ordering::Greater,
}
}
}
impl Display for PortIndexValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PortIndexValue::Int(x) => write!(f, "{}", x.to_token_stream()),
PortIndexValue::Path(x) => write!(f, "{}", x.to_token_stream()),
PortIndexValue::Elided(_) => write!(f, "[]"),
}
}
}
pub fn build_hfcode(
hf_code: HfCode,
root: &TokenStream,
) -> (Option<(DfirGraph, TokenStream)>, Vec<Diagnostic>) {
let flat_graph_builder = FlatGraphBuilder::from_hfcode(hf_code);
let (mut flat_graph, uses, mut diagnostics) = flat_graph_builder.build();
if !diagnostics.iter().any(Diagnostic::is_error) {
if let Err(diagnostic) = flat_graph.merge_modules() {
diagnostics.push(diagnostic);
return (None, diagnostics);
}
eliminate_extra_unions_tees(&mut flat_graph);
match partition_graph(flat_graph) {
Ok(partitioned_graph) => {
let code = partitioned_graph.as_code(
root,
true,
quote::quote! { #( #uses )* },
&mut diagnostics,
);
if !diagnostics.iter().any(Diagnostic::is_error) {
return (Some((partitioned_graph, code)), diagnostics);
}
}
Err(diagnostic) => diagnostics.push(diagnostic),
}
}
(None, diagnostics)
}
fn change_spans(tokens: TokenStream, span: Span) -> TokenStream {
use proc_macro2::{Group, TokenTree};
tokens
.into_iter()
.map(|token| match token {
TokenTree::Group(mut group) => {
group.set_span(span);
TokenTree::Group(Group::new(
group.delimiter(),
change_spans(group.stream(), span),
))
}
TokenTree::Ident(mut ident) => {
ident.set_span(span.resolved_at(ident.span()));
TokenTree::Ident(ident)
}
TokenTree::Punct(mut punct) => {
punct.set_span(span);
TokenTree::Punct(punct)
}
TokenTree::Literal(mut literal) => {
literal.set_span(span);
TokenTree::Literal(literal)
}
})
.collect()
}