use std::collections::HashSet;
use stageleft::*;
use crate::ir::{HydroLeaf, HydroNode, SeenTees};
#[derive(Default)]
pub struct PropertyDatabase {
commutative: HashSet<syn::Expr>,
}
#[allow(clippy::allow_attributes, dead_code, reason = "staged programming")]
fn convert_hf_to_binary<I, A: Default, F: Fn(&mut A, I)>(f: F) -> impl Fn(I, I) -> A {
move |a, b| {
let mut acc = Default::default();
f(&mut acc, a);
f(&mut acc, b);
acc
}
}
impl PropertyDatabase {
pub fn add_commutative_tag<
'a,
I,
A,
F: Fn(&mut A, I),
Ctx,
Q: QuotedWithContext<'a, F, Ctx> + Clone,
>(
&mut self,
expr: Q,
ctx: &Ctx,
) -> Q {
let expr_clone = expr.clone();
self.commutative.insert(expr_clone.splice_untyped_ctx(ctx));
expr
}
pub fn is_tagged_commutative(&self, expr: &syn::Expr) -> bool {
self.commutative.contains(expr)
}
}
fn properties_optimize_node(node: &mut HydroNode, db: &PropertyDatabase, seen_tees: &mut SeenTees) {
node.transform_children(
|node, seen_tees| properties_optimize_node(node, db, seen_tees),
seen_tees,
);
match node {
HydroNode::ReduceKeyed { f, .. } if db.is_tagged_commutative(&f.0) => {
dbg!("IDENTIFIED COMMUTATIVE OPTIMIZATION for {:?}", &f);
}
_ => {}
}
}
pub fn properties_optimize(ir: Vec<HydroLeaf>, db: &PropertyDatabase) -> Vec<HydroLeaf> {
let mut seen_tees = Default::default();
ir.into_iter()
.map(|l| {
l.transform_children(
|node, seen_tees| properties_optimize_node(node, db, seen_tees),
&mut seen_tees,
)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::deploy::SingleProcessGraph;
use crate::location::Location;
use crate::FlowBuilder;
#[test]
fn test_property_database() {
let mut db = PropertyDatabase::default();
assert!(
!db.is_tagged_commutative(&(q!(|a: &mut i32, b: i32| *a += b).splice_untyped_ctx(&())))
);
let _ = db.add_commutative_tag(q!(|a: &mut i32, b: i32| *a += b), &());
assert!(
db.is_tagged_commutative(&(q!(|a: &mut i32, b: i32| *a += b).splice_untyped_ctx(&())))
);
}
#[test]
fn test_property_optimized() {
let flow = FlowBuilder::new();
let mut database = PropertyDatabase::default();
let process = flow.process::<()>();
let tick = process.tick();
let counter_func = q!(|count: &mut i32, _| *count += 1);
let _ = database.add_commutative_tag(counter_func, &tick);
unsafe {
process
.source_iter(q!(vec![]))
.map(q!(|string: String| (string, ())))
.timestamped(&tick)
.tick_batch()
}
.fold_keyed(q!(|| 0), counter_func)
.all_ticks()
.for_each(q!(|(string, count)| println!("{}: {}", string, count)));
let built = flow
.optimize_with(|ir| properties_optimize(ir, &database))
.with_default_optimize::<SingleProcessGraph>();
insta::assert_debug_snapshot!(built.ir());
let _ = built.compile_no_network();
}
}