hydro_lang/rewrites/
partitioner.rs

1use std::collections::HashMap;
2
3use syn::visit_mut::{self, VisitMut};
4
5use crate::ir::*;
6
7/// Fields that could be used for partitioning
8pub enum PartitionAttribute {
9    All(),
10    TupleIndex(usize),
11}
12
13pub struct Partitioner {
14    pub nodes_to_partition: HashMap<usize, PartitionAttribute>, /* ID of node right before a Network -> what to partition on */
15    pub num_partitions: usize,
16    pub partitioned_cluster_id: usize,
17}
18
19/// Replace CLUSTER_SELF_ID with the ID of the original node the partition is assigned to
20pub struct ClusterSelfIdReplace {
21    pub num_partitions: usize,
22    pub partitioned_cluster_id: usize,
23}
24
25impl VisitMut for ClusterSelfIdReplace {
26    fn visit_expr_mut(&mut self, expr: &mut syn::Expr) {
27        if let syn::Expr::Path(path_expr) = expr {
28            for segment in path_expr.path.segments.iter_mut() {
29                let ident = segment.ident.to_string();
30                let prefix = format!(
31                    "__hydro_lang_cluster_self_id_{}",
32                    self.partitioned_cluster_id
33                );
34                if ident.starts_with(&prefix) {
35                    let num_partitions = self.num_partitions;
36                    let expr_content = std::mem::replace(expr, syn::Expr::PLACEHOLDER);
37                    *expr = syn::parse_quote!({
38                        #expr_content / #num_partitions as u32
39                    });
40                    println!("Partitioning: Replaced CLUSTER_SELF_ID");
41                    return;
42                }
43            }
44        }
45        visit_mut::visit_expr_mut(self, expr);
46    }
47}
48
49/// Don't expose partition members to the cluster
50pub struct ClusterMembersReplace {
51    pub num_partitions: usize,
52    pub partitioned_cluster_id: usize,
53}
54
55impl VisitMut for ClusterMembersReplace {
56    fn visit_expr_mut(&mut self, expr: &mut syn::Expr) {
57        if let syn::Expr::Unsafe(unsafe_expr) = expr {
58            for stmt in &mut unsafe_expr.block.stmts {
59                if let syn::Stmt::Expr(syn::Expr::Call(call_expr), _) = stmt {
60                    for arg in call_expr.args.iter_mut() {
61                        if let syn::Expr::Path(path_expr) = arg {
62                            for segment in path_expr.path.segments.iter_mut() {
63                                let ident = segment.ident.to_string();
64                                let prefix = format!(
65                                    "__hydro_lang_cluster_ids_{}",
66                                    self.partitioned_cluster_id
67                                );
68                                if ident.starts_with(&prefix) {
69                                    let num_partitions = self.num_partitions;
70                                    let expr_content =
71                                        std::mem::replace(expr, syn::Expr::PLACEHOLDER);
72                                    *expr = syn::parse_quote!({
73                                        let all_ids = #expr_content;
74                                        &all_ids[0..all_ids.len() / #num_partitions]
75                                    });
76                                    println!("Partitioning: Replaced cluster members");
77                                    // Don't need to visit children
78                                    return;
79                                }
80                            }
81                        }
82                    }
83                }
84            }
85        }
86        visit_mut::visit_expr_mut(self, expr);
87    }
88}
89
90fn replace_membership_info(node: &mut HydroNode, partitioner: &Partitioner) {
91    let Partitioner {
92        num_partitions,
93        partitioned_cluster_id,
94        ..
95    } = *partitioner;
96
97    node.visit_debug_expr(|expr| {
98        let mut visitor = ClusterMembersReplace {
99            num_partitions,
100            partitioned_cluster_id,
101        };
102        visitor.visit_expr_mut(&mut expr.0);
103    });
104    node.visit_debug_expr(|expr| {
105        let mut visitor = ClusterSelfIdReplace {
106            num_partitions,
107            partitioned_cluster_id,
108        };
109        visitor.visit_expr_mut(&mut expr.0);
110    });
111}
112
113fn replace_sender_network(node: &mut HydroNode, partitioner: &Partitioner, next_stmt_id: usize) {
114    let Partitioner {
115        nodes_to_partition,
116        num_partitions,
117        ..
118    } = partitioner;
119
120    if let Some(partition_attr) = nodes_to_partition.get(&next_stmt_id) {
121        println!("Partitioning node {} {}", next_stmt_id, node.print_root());
122
123        let node_content = std::mem::replace(node, HydroNode::Placeholder);
124        let metadata = node_content.metadata().clone();
125
126        let f: syn::Expr = match partition_attr {
127            PartitionAttribute::All() => {
128                syn::parse_quote!(|(orig_dest, item)| {
129                    let orig_dest_id = orig_dest.raw_id;
130                    let new_dest_id = (orig_dest_id * #num_partitions as u32) + (item as usize % #num_partitions) as u32;
131                    (
132                        ClusterId::<()>::from_raw(new_dest_id),
133                        item
134                    )
135                })
136            }
137            PartitionAttribute::TupleIndex(tuple_index) => {
138                let tuple_index_ident = syn::Index::from(*tuple_index);
139                syn::parse_quote!(|(orig_dest, tuple)| {
140                    let orig_dest_id = orig_dest.raw_id;
141                    let new_dest_id = (orig_dest_id * #num_partitions as u32) + (tuple.#tuple_index_ident as usize % #num_partitions) as u32;
142                    (
143                        ClusterId::<()>::from_raw(new_dest_id),
144                        tuple
145                    )
146                })
147            }
148        };
149
150        let mapped_node = HydroNode::Map {
151            f: f.into(),
152            input: Box::new(node_content),
153            metadata,
154        };
155
156        *node = mapped_node;
157    }
158}
159
160fn replace_receiver_network(node: &mut HydroNode, partitioner: &Partitioner) {
161    let Partitioner {
162        num_partitions,
163        partitioned_cluster_id,
164        ..
165    } = partitioner;
166
167    if let HydroNode::Network {
168        input, metadata, ..
169    } = node
170    {
171        if input.metadata().location_kind.raw_id() == *partitioned_cluster_id {
172            println!("Rewriting network on receiver to remap location ids");
173
174            let metadata = metadata.clone();
175            let node_content = std::mem::replace(node, HydroNode::Placeholder);
176            let f: syn::Expr = syn::parse_quote!(|(sender_id, b)| (
177                ClusterId::<_>::from_raw(sender_id.raw_id / #num_partitions as u32),
178                b
179            ));
180
181            let mapped_node = HydroNode::Map {
182                f: f.into(),
183                input: Box::new(node_content),
184                metadata: metadata.clone(),
185            };
186
187            *node = mapped_node;
188        }
189    }
190}
191
192fn partition_node(node: &mut HydroNode, partitioner: &Partitioner, next_stmt_id: &mut usize) {
193    replace_membership_info(node, partitioner);
194    replace_sender_network(node, partitioner, *next_stmt_id);
195    replace_receiver_network(node, partitioner);
196}
197
198/// Limitations: Can only partition sends to clusters (not processes). Can only partition sends to 1 cluster at a time. Assumes that the partitioned attribute can be casted to usize.
199pub fn partition(ir: &mut [HydroLeaf], partitioner: &Partitioner) {
200    traverse_dfir(
201        ir,
202        |_, _| {},
203        |node, next_stmt_id| {
204            partition_node(node, partitioner, next_stmt_id);
205        },
206    );
207}