1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
use std::collections::HashSet;

use stageleft::*;

use crate::ir::{HydroLeaf, HydroNode, SeenTees};

/// Structure for tracking expressions known to have particular algebraic properties.
///
/// # Schema
///
/// Each field in this struct corresponds to an algebraic property, and contains the list of
/// expressions that satisfy the property. Currently only `commutative`.
///
/// # Interface
///
/// "Tag" an expression with a property and it will add it to that table. For example, [`Self::add_commutative_tag`].
/// Can also run a check to see if an expression satisfies a property.
#[derive(Default)]
pub struct PropertyDatabase {
    commutative: HashSet<syn::Expr>,
}

/// Allows us to convert the dfir datatype for folds to a binary operation for the algebra
/// property tests.
#[allow(clippy::allow_attributes, dead_code, reason = "staged programming")]
fn convert_hf_to_binary<I, A: Default, F: Fn(&mut A, I)>(f: F) -> impl Fn(I, I) -> A {
    move |a, b| {
        let mut acc = Default::default();
        f(&mut acc, a);
        f(&mut acc, b);
        acc
    }
}

impl PropertyDatabase {
    /// Tags the expression as commutative.
    pub fn add_commutative_tag<
        'a,
        I,
        A,
        F: Fn(&mut A, I),
        Ctx,
        Q: QuotedWithContext<'a, F, Ctx> + Clone,
    >(
        &mut self,
        expr: Q,
        ctx: &Ctx,
    ) -> Q {
        let expr_clone = expr.clone();
        self.commutative.insert(expr_clone.splice_untyped_ctx(ctx));
        expr
    }

    pub fn is_tagged_commutative(&self, expr: &syn::Expr) -> bool {
        self.commutative.contains(expr)
    }
}

// Dataflow graph optimization rewrite rules based on algebraic property tags
// TODO add a test that verifies the space of possible graphs after rewrites is correct for each property

fn properties_optimize_node(node: &mut HydroNode, db: &PropertyDatabase, seen_tees: &mut SeenTees) {
    node.transform_children(
        |node, seen_tees| properties_optimize_node(node, db, seen_tees),
        seen_tees,
    );
    match node {
        HydroNode::ReduceKeyed { f, .. } if db.is_tagged_commutative(&f.0) => {
            dbg!("IDENTIFIED COMMUTATIVE OPTIMIZATION for {:?}", &f);
        }
        _ => {}
    }
}

pub fn properties_optimize(ir: Vec<HydroLeaf>, db: &PropertyDatabase) -> Vec<HydroLeaf> {
    let mut seen_tees = Default::default();
    ir.into_iter()
        .map(|l| {
            l.transform_children(
                |node, seen_tees| properties_optimize_node(node, db, seen_tees),
                &mut seen_tees,
            )
        })
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::deploy::SingleProcessGraph;
    use crate::location::Location;
    use crate::FlowBuilder;

    #[test]
    fn test_property_database() {
        let mut db = PropertyDatabase::default();

        assert!(
            !db.is_tagged_commutative(&(q!(|a: &mut i32, b: i32| *a += b).splice_untyped_ctx(&())))
        );

        let _ = db.add_commutative_tag(q!(|a: &mut i32, b: i32| *a += b), &());

        assert!(
            db.is_tagged_commutative(&(q!(|a: &mut i32, b: i32| *a += b).splice_untyped_ctx(&())))
        );
    }

    #[test]
    fn test_property_optimized() {
        let flow = FlowBuilder::new();
        let mut database = PropertyDatabase::default();

        let process = flow.process::<()>();
        let tick = process.tick();

        let counter_func = q!(|count: &mut i32, _| *count += 1);
        let _ = database.add_commutative_tag(counter_func, &tick);

        unsafe {
            process
                .source_iter(q!(vec![]))
                .map(q!(|string: String| (string, ())))
                .timestamped(&tick)
                .tick_batch()
        }
        .fold_keyed(q!(|| 0), counter_func)
        .all_ticks()
        .for_each(q!(|(string, count)| println!("{}: {}", string, count)));

        let built = flow
            .optimize_with(|ir| properties_optimize(ir, &database))
            .with_default_optimize::<SingleProcessGraph>();

        insta::assert_debug_snapshot!(built.ir());

        let _ = built.compile_no_network();
    }
}