hydro_lang/rewrites/
properties.rs

1use std::collections::HashSet;
2
3use stageleft::*;
4
5use crate::ir::{HydroNode, HydroRoot, transform_bottom_up};
6
7/// Structure for tracking expressions known to have particular algebraic properties.
8///
9/// # Schema
10///
11/// Each field in this struct corresponds to an algebraic property, and contains the list of
12/// expressions that satisfy the property. Currently only `commutative`.
13///
14/// # Interface
15///
16/// "Tag" an expression with a property and it will add it to that table. For example, [`Self::add_commutative_tag`].
17/// Can also run a check to see if an expression satisfies a property.
18#[derive(Default)]
19pub struct PropertyDatabase {
20    commutative: HashSet<syn::Expr>,
21}
22
23impl PropertyDatabase {
24    /// Tags the expression as commutative.
25    pub fn add_commutative_tag<
26        'a,
27        I,
28        A,
29        F: Fn(&mut A, I),
30        Ctx,
31        Q: QuotedWithContext<'a, F, Ctx> + Clone,
32    >(
33        &mut self,
34        expr: Q,
35        ctx: &Ctx,
36    ) -> Q {
37        let expr_clone = expr.clone();
38        self.commutative.insert(expr_clone.splice_untyped_ctx(ctx));
39        expr
40    }
41
42    pub fn is_tagged_commutative(&self, expr: &syn::Expr) -> bool {
43        self.commutative.contains(expr)
44    }
45}
46
47// Dataflow graph optimization rewrite rules based on algebraic property tags
48// TODO add a test that verifies the space of possible graphs after rewrites is correct for each property
49
50fn properties_optimize_node(node: &mut HydroNode, db: &mut PropertyDatabase) {
51    match node {
52        HydroNode::ReduceKeyed { f, .. } if db.is_tagged_commutative(&f.0) => {
53            dbg!("IDENTIFIED COMMUTATIVE OPTIMIZATION for {:?}", &f);
54        }
55        _ => {}
56    }
57}
58
59pub fn properties_optimize(ir: &mut [HydroRoot], db: &mut PropertyDatabase) {
60    transform_bottom_up(
61        ir,
62        &mut |_| (),
63        &mut |node| properties_optimize_node(node, db),
64        false,
65    );
66}
67
68#[cfg(stageleft_runtime)]
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use crate::deploy::HydroDeploy;
73    use crate::location::Location;
74    use crate::{FlowBuilder, nondet};
75
76    #[test]
77    fn test_property_database() {
78        let mut db = PropertyDatabase::default();
79
80        assert!(
81            !db.is_tagged_commutative(&(q!(|a: &mut i32, b: i32| *a += b).splice_untyped_ctx(&())))
82        );
83
84        let _ = db.add_commutative_tag(q!(|a: &mut i32, b: i32| *a += b), &());
85
86        assert!(
87            db.is_tagged_commutative(&(q!(|a: &mut i32, b: i32| *a += b).splice_untyped_ctx(&())))
88        );
89    }
90
91    #[test]
92    fn test_property_optimized() {
93        let flow = FlowBuilder::new();
94        let mut database = PropertyDatabase::default();
95
96        let process = flow.process::<()>();
97        let tick = process.tick();
98
99        let counter_func = q!(|count: &mut i32, _| *count += 1);
100        let _ = database.add_commutative_tag(counter_func, &tick);
101
102        process
103            .source_iter(q!(vec![]))
104            .map(q!(|string: String| (string, ())))
105            .batch(&tick, nondet!(/** test */))
106            .into_keyed()
107            .fold(q!(|| 0), counter_func)
108            .entries()
109            .all_ticks()
110            .for_each(q!(|(string, count)| println!("{}: {}", string, count)));
111
112        let built = flow
113            .optimize_with(|ir| properties_optimize(ir, &mut database))
114            .with_default_optimize::<HydroDeploy>();
115
116        hydro_build_utils::assert_debug_snapshot!(built.ir());
117
118        let _ = built.compile_no_network();
119    }
120}