use std::collections::HashMap;
use std::fmt::{Debug, Display};
use std::ops::{Bound, RangeBounds};
use std::sync::OnceLock;
use proc_macro2::{Ident, Literal, Span, TokenStream};
use quote::quote_spanned;
use serde::{Deserialize, Serialize};
use slotmap::Key;
use syn::punctuated::Punctuated;
use syn::{parse_quote_spanned, Expr, Token};
use super::{
GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance, PortIndexValue,
};
use crate::diagnostic::Diagnostic;
use crate::parse::{Operator, PortIndex};
#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
pub enum DelayType {
Stratum,
MonotoneAccum,
Tick,
TickLazy,
}
pub enum PortListSpec {
Variadic,
Fixed(Punctuated<PortIndex, Token![,]>),
}
pub struct OperatorConstraints {
pub name: &'static str,
pub categories: &'static [OperatorCategory],
pub hard_range_inn: &'static dyn RangeTrait<usize>,
pub soft_range_inn: &'static dyn RangeTrait<usize>,
pub hard_range_out: &'static dyn RangeTrait<usize>,
pub soft_range_out: &'static dyn RangeTrait<usize>,
pub num_args: usize,
pub persistence_args: &'static dyn RangeTrait<usize>,
pub type_args: &'static dyn RangeTrait<usize>,
pub is_external_input: bool,
pub has_singleton_output: bool,
pub flo_type: Option<FloType>,
pub ports_inn: Option<fn() -> PortListSpec>,
pub ports_out: Option<fn() -> PortListSpec>,
pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
pub write_fn: WriteFn,
}
pub type WriteFn =
fn(&WriteContextArgs<'_>, &mut Vec<Diagnostic>) -> Result<OperatorWriteOutput, ()>;
impl Debug for OperatorConstraints {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OperatorConstraints")
.field("name", &self.name)
.field("hard_range_inn", &self.hard_range_inn)
.field("soft_range_inn", &self.soft_range_inn)
.field("hard_range_out", &self.hard_range_out)
.field("soft_range_out", &self.soft_range_out)
.field("num_args", &self.num_args)
.field("persistence_args", &self.persistence_args)
.field("type_args", &self.type_args)
.field("is_external_input", &self.is_external_input)
.field("ports_inn", &self.ports_inn)
.field("ports_out", &self.ports_out)
.finish()
}
}
#[derive(Default)]
#[non_exhaustive]
pub struct OperatorWriteOutput {
pub write_prologue: TokenStream,
pub write_iterator: TokenStream,
pub write_iterator_after: TokenStream,
}
pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
pub fn identity_write_iterator_fn(
&WriteContextArgs {
root,
op_span,
ident,
inputs,
outputs,
is_pull,
op_inst:
OperatorInstance {
generics: OpInstGenerics { type_args, .. },
..
},
..
}: &WriteContextArgs,
) -> TokenStream {
let generic_type = type_args
.first()
.map(quote::ToTokens::to_token_stream)
.unwrap_or(quote_spanned!(op_span=> _));
if is_pull {
let input = &inputs[0];
quote_spanned! {op_span=>
let #ident = {
fn check_input<Iter: ::std::iter::Iterator<Item = Item>, Item>(iter: Iter) -> impl ::std::iter::Iterator<Item = Item> { iter }
check_input::<_, #generic_type>(#input)
};
}
} else {
let output = &outputs[0];
quote_spanned! {op_span=>
let #ident = {
fn check_output<Push: #root::pusherator::Pusherator<Item = Item>, Item>(push: Push) -> impl #root::pusherator::Pusherator<Item = Item> { push }
check_output::<_, #generic_type>(#output)
};
}
}
}
pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
let write_iterator = identity_write_iterator_fn(write_context_args);
Ok(OperatorWriteOutput {
write_iterator,
..Default::default()
})
};
pub fn null_write_iterator_fn(
&WriteContextArgs {
root,
op_span,
ident,
inputs,
outputs,
is_pull,
op_inst:
OperatorInstance {
generics: OpInstGenerics { type_args, .. },
..
},
..
}: &WriteContextArgs,
) -> TokenStream {
let default_type = parse_quote_spanned! {op_span=> _};
let iter_type = type_args.first().unwrap_or(&default_type);
if is_pull {
quote_spanned! {op_span=>
#(
#inputs.for_each(std::mem::drop);
)*
let #ident = std::iter::empty::<#iter_type>();
}
} else {
quote_spanned! {op_span=>
#[allow(clippy::let_unit_value)]
let _ = (#(#outputs),*);
let #ident = #root::pusherator::null::Null::<#iter_type>::new();
}
}
}
pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
let write_iterator = null_write_iterator_fn(write_context_args);
Ok(OperatorWriteOutput {
write_iterator,
..Default::default()
})
};
macro_rules! declare_ops {
( $( $mod:ident :: $op:ident, )* ) => {
$( pub(crate) mod $mod; )*
pub const OPERATORS: &[OperatorConstraints] = &[
$( $mod :: $op, )*
];
};
}
declare_ops![
all_once::ALL_ONCE,
anti_join::ANTI_JOIN,
anti_join_multiset::ANTI_JOIN_MULTISET,
assert::ASSERT,
assert_eq::ASSERT_EQ,
batch::BATCH,
chain::CHAIN,
cross_join::CROSS_JOIN,
cross_join_multiset::CROSS_JOIN_MULTISET,
cross_singleton::CROSS_SINGLETON,
demux::DEMUX,
demux_enum::DEMUX_ENUM,
dest_file::DEST_FILE,
dest_sink::DEST_SINK,
dest_sink_serde::DEST_SINK_SERDE,
difference::DIFFERENCE,
difference_multiset::DIFFERENCE_MULTISET,
enumerate::ENUMERATE,
filter::FILTER,
filter_map::FILTER_MAP,
flat_map::FLAT_MAP,
flatten::FLATTEN,
fold::FOLD,
for_each::FOR_EACH,
identity::IDENTITY,
initialize::INITIALIZE,
inspect::INSPECT,
join::JOIN,
join_fused::JOIN_FUSED,
join_fused_lhs::JOIN_FUSED_LHS,
join_fused_rhs::JOIN_FUSED_RHS,
join_multiset::JOIN_MULTISET,
fold_keyed::FOLD_KEYED,
reduce_keyed::REDUCE_KEYED,
lattice_bimorphism::LATTICE_BIMORPHISM,
_lattice_fold_batch::_LATTICE_FOLD_BATCH,
lattice_fold::LATTICE_FOLD,
_lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
lattice_reduce::LATTICE_REDUCE,
map::MAP,
union::UNION,
multiset_delta::MULTISET_DELTA,
next_stratum::NEXT_STRATUM,
defer_signal::DEFER_SIGNAL,
defer_tick::DEFER_TICK,
defer_tick_lazy::DEFER_TICK_LAZY,
null::NULL,
partition::PARTITION,
persist::PERSIST,
persist_mut::PERSIST_MUT,
persist_mut_keyed::PERSIST_MUT_KEYED,
py_udf::PY_UDF,
reduce::REDUCE,
spin::SPIN,
sort::SORT,
sort_by_key::SORT_BY_KEY,
source_file::SOURCE_FILE,
source_interval::SOURCE_INTERVAL,
source_iter::SOURCE_ITER,
source_json::SOURCE_JSON,
source_stdin::SOURCE_STDIN,
source_stream::SOURCE_STREAM,
source_stream_serde::SOURCE_STREAM_SERDE,
state::STATE,
state_by::STATE_BY,
tee::TEE,
unique::UNIQUE,
unzip::UNZIP,
zip::ZIP,
zip_longest::ZIP_LONGEST,
];
pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
OnceLock::new();
OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
}
pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
if let GraphNode::Operator(operator) = node {
find_op_op_constraints(operator)
} else {
None
}
}
pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
let name = &*operator.name_string();
operator_lookup().get(name).copied()
}
#[derive(Clone)]
pub struct WriteContextArgs<'a> {
pub root: &'a TokenStream,
pub context: &'a Ident,
pub hydroflow: &'a Ident,
pub subgraph_id: GraphSubgraphId,
pub node_id: GraphNodeId,
pub op_span: Span,
pub ident: &'a Ident,
pub is_pull: bool,
pub inputs: &'a [Ident],
pub outputs: &'a [Ident],
pub singleton_output_ident: &'a Ident,
pub op_name: &'static str,
pub op_inst: &'a OperatorInstance,
pub arguments: &'a Punctuated<Expr, Token![,]>,
pub arguments_handles: &'a Punctuated<Expr, Token![,]>,
}
impl WriteContextArgs<'_> {
pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
Ident::new(
&format!(
"sg_{:?}_node_{:?}_{}",
self.subgraph_id.data(),
self.node_id.data(),
suffix.as_ref(),
),
self.op_span,
)
}
}
pub trait RangeTrait<T>: Send + Sync + Debug
where
T: ?Sized,
{
fn start_bound(&self) -> Bound<&T>;
fn end_bound(&self) -> Bound<&T>;
fn contains(&self, item: &T) -> bool
where
T: PartialOrd<T>;
fn human_string(&self) -> String
where
T: Display + PartialEq,
{
match (self.start_bound(), self.end_bound()) {
(Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
(Bound::Included(n), Bound::Included(x)) if n == x => {
format!("exactly {}", n)
}
(Bound::Included(n), Bound::Included(x)) => {
format!("at least {} and at most {}", n, x)
}
(Bound::Included(n), Bound::Excluded(x)) => {
format!("at least {} and less than {}", n, x)
}
(Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
(Bound::Excluded(n), Bound::Included(x)) => {
format!("more than {} and at most {}", n, x)
}
(Bound::Excluded(n), Bound::Excluded(x)) => {
format!("more than {} and less than {}", n, x)
}
(Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
(Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
(Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
}
}
}
impl<R, T> RangeTrait<T> for R
where
R: RangeBounds<T> + Send + Sync + Debug,
{
fn start_bound(&self) -> Bound<&T> {
self.start_bound()
}
fn end_bound(&self) -> Bound<&T> {
self.end_bound()
}
fn contains(&self, item: &T) -> bool
where
T: PartialOrd<T>,
{
self.contains(item)
}
}
#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
pub enum Persistence {
Tick,
Static,
Mutable,
}
fn make_missing_runtime_msg(op_name: &str) -> Literal {
Literal::string(&format!("`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.", op_name))
}
#[allow(
clippy::allow_attributes,
missing_docs,
reason = "see `Self::description`"
)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum OperatorCategory {
Map,
Filter,
Flatten,
Fold,
KeyedFold,
LatticeFold,
Persistence,
MultiIn,
MultiOut,
Source,
Sink,
Control,
CompilerFusionOperator,
Windowing,
Unwindowing,
}
impl OperatorCategory {
pub fn name(self) -> &'static str {
match self {
OperatorCategory::Map => "Maps",
OperatorCategory::Filter => "Filters",
OperatorCategory::Flatten => "Flattens",
OperatorCategory::Fold => "Folds",
OperatorCategory::KeyedFold => "Keyed Folds",
OperatorCategory::LatticeFold => "Lattice Folds",
OperatorCategory::Persistence => "Persistent Operators",
OperatorCategory::MultiIn => "Multi-Input Operators",
OperatorCategory::MultiOut => "Multi-Output Operators",
OperatorCategory::Source => "Sources",
OperatorCategory::Sink => "Sinks",
OperatorCategory::Control => "Control Flow Operators",
OperatorCategory::CompilerFusionOperator => "Compiler Fusion Operators",
OperatorCategory::Windowing => "Windowing Operator",
OperatorCategory::Unwindowing => "Un-Windowing Operator",
}
}
pub fn description(self) -> &'static str {
match self {
OperatorCategory::Map => "Simple one-in-one-out operators.",
OperatorCategory::Filter => "One-in zero-or-one-out operators.",
OperatorCategory::Flatten => "One-in multiple-out operators.",
OperatorCategory::Fold => "Operators which accumulate elements together.",
OperatorCategory::KeyedFold => "Keyed fold operators.",
OperatorCategory::LatticeFold => "Folds based on lattice-merge.",
OperatorCategory::Persistence => "Persistent (stateful) operators.",
OperatorCategory::MultiIn => "Operators with multiple inputs.",
OperatorCategory::MultiOut => "Operators with multiple outputs.",
OperatorCategory::Source => {
"Operators which produce output elements (and consume no inputs)."
}
OperatorCategory::Sink => {
"Operators which consume input elements (and produce no outputs)."
}
OperatorCategory::Control => "Operators which affect control flow/scheduling.",
OperatorCategory::CompilerFusionOperator => {
"Operators which are necessary to implement certain optimizations and rewrite rules"
}
OperatorCategory::Windowing => "Operators for windowing `loop` inputs.",
OperatorCategory::Unwindowing => "Operators for collecting `loop` outputs.",
}
}
}
#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
pub enum FloType {
Source,
Windowing,
Unwindowing,
}