hydro_lang/rewrites/
partitioner.rs
1use std::collections::HashMap;
2
3use syn::visit_mut::{self, VisitMut};
4
5use crate::ir::*;
6
7pub enum PartitionAttribute {
9 All(),
10 TupleIndex(usize),
11}
12
13pub struct Partitioner {
14 pub nodes_to_partition: HashMap<usize, PartitionAttribute>, pub num_partitions: usize,
16 pub partitioned_cluster_id: usize,
17}
18
19pub struct ClusterSelfIdReplace {
21 pub num_partitions: usize,
22 pub partitioned_cluster_id: usize,
23}
24
25impl VisitMut for ClusterSelfIdReplace {
26 fn visit_expr_mut(&mut self, expr: &mut syn::Expr) {
27 if let syn::Expr::Path(path_expr) = expr {
28 for segment in path_expr.path.segments.iter_mut() {
29 let ident = segment.ident.to_string();
30 let prefix = format!(
31 "__hydro_lang_cluster_self_id_{}",
32 self.partitioned_cluster_id
33 );
34 if ident.starts_with(&prefix) {
35 let num_partitions = self.num_partitions;
36 let expr_content = std::mem::replace(expr, syn::Expr::PLACEHOLDER);
37 *expr = syn::parse_quote!({
38 #expr_content / #num_partitions as u32
39 });
40 println!("Partitioning: Replaced CLUSTER_SELF_ID");
41 return;
42 }
43 }
44 }
45 visit_mut::visit_expr_mut(self, expr);
46 }
47}
48
49pub struct ClusterMembersReplace {
51 pub num_partitions: usize,
52 pub partitioned_cluster_id: usize,
53}
54
55impl VisitMut for ClusterMembersReplace {
56 fn visit_expr_mut(&mut self, expr: &mut syn::Expr) {
57 if let syn::Expr::Unsafe(unsafe_expr) = expr {
58 for stmt in &mut unsafe_expr.block.stmts {
59 if let syn::Stmt::Expr(syn::Expr::Call(call_expr), _) = stmt {
60 for arg in call_expr.args.iter_mut() {
61 if let syn::Expr::Path(path_expr) = arg {
62 for segment in path_expr.path.segments.iter_mut() {
63 let ident = segment.ident.to_string();
64 let prefix = format!(
65 "__hydro_lang_cluster_ids_{}",
66 self.partitioned_cluster_id
67 );
68 if ident.starts_with(&prefix) {
69 let num_partitions = self.num_partitions;
70 let expr_content =
71 std::mem::replace(expr, syn::Expr::PLACEHOLDER);
72 *expr = syn::parse_quote!({
73 let all_ids = #expr_content;
74 &all_ids[0..all_ids.len() / #num_partitions]
75 });
76 println!("Partitioning: Replaced cluster members");
77 return;
79 }
80 }
81 }
82 }
83 }
84 }
85 }
86 visit_mut::visit_expr_mut(self, expr);
87 }
88}
89
90fn replace_membership_info(node: &mut HydroNode, partitioner: &Partitioner) {
91 let Partitioner {
92 num_partitions,
93 partitioned_cluster_id,
94 ..
95 } = *partitioner;
96
97 node.visit_debug_expr(|expr| {
98 let mut visitor = ClusterMembersReplace {
99 num_partitions,
100 partitioned_cluster_id,
101 };
102 visitor.visit_expr_mut(&mut expr.0);
103 });
104 node.visit_debug_expr(|expr| {
105 let mut visitor = ClusterSelfIdReplace {
106 num_partitions,
107 partitioned_cluster_id,
108 };
109 visitor.visit_expr_mut(&mut expr.0);
110 });
111}
112
113fn replace_sender_network(node: &mut HydroNode, partitioner: &Partitioner, next_stmt_id: usize) {
114 let Partitioner {
115 nodes_to_partition,
116 num_partitions,
117 ..
118 } = partitioner;
119
120 if let Some(partition_attr) = nodes_to_partition.get(&next_stmt_id) {
121 println!("Partitioning node {} {}", next_stmt_id, node.print_root());
122
123 let node_content = std::mem::replace(node, HydroNode::Placeholder);
124 let metadata = node_content.metadata().clone();
125
126 let f: syn::Expr = match partition_attr {
127 PartitionAttribute::All() => {
128 syn::parse_quote!(|(orig_dest, item)| {
129 let orig_dest_id = orig_dest.raw_id;
130 let new_dest_id = (orig_dest_id * #num_partitions as u32) + (item as usize % #num_partitions) as u32;
131 (
132 ClusterId::<()>::from_raw(new_dest_id),
133 item
134 )
135 })
136 }
137 PartitionAttribute::TupleIndex(tuple_index) => {
138 let tuple_index_ident = syn::Index::from(*tuple_index);
139 syn::parse_quote!(|(orig_dest, tuple)| {
140 let orig_dest_id = orig_dest.raw_id;
141 let new_dest_id = (orig_dest_id * #num_partitions as u32) + (tuple.#tuple_index_ident as usize % #num_partitions) as u32;
142 (
143 ClusterId::<()>::from_raw(new_dest_id),
144 tuple
145 )
146 })
147 }
148 };
149
150 let mapped_node = HydroNode::Map {
151 f: f.into(),
152 input: Box::new(node_content),
153 metadata,
154 };
155
156 *node = mapped_node;
157 }
158}
159
160fn replace_receiver_network(node: &mut HydroNode, partitioner: &Partitioner) {
161 let Partitioner {
162 num_partitions,
163 partitioned_cluster_id,
164 ..
165 } = partitioner;
166
167 if let HydroNode::Network {
168 input, metadata, ..
169 } = node
170 {
171 if input.metadata().location_kind.raw_id() == *partitioned_cluster_id {
172 println!("Rewriting network on receiver to remap location ids");
173
174 let metadata = metadata.clone();
175 let node_content = std::mem::replace(node, HydroNode::Placeholder);
176 let f: syn::Expr = syn::parse_quote!(|(sender_id, b)| (
177 ClusterId::<_>::from_raw(sender_id.raw_id / #num_partitions as u32),
178 b
179 ));
180
181 let mapped_node = HydroNode::Map {
182 f: f.into(),
183 input: Box::new(node_content),
184 metadata: metadata.clone(),
185 };
186
187 *node = mapped_node;
188 }
189 }
190}
191
192fn partition_node(node: &mut HydroNode, partitioner: &Partitioner, next_stmt_id: &mut usize) {
193 replace_membership_info(node, partitioner);
194 replace_sender_network(node, partitioner, *next_stmt_id);
195 replace_receiver_network(node, partitioner);
196}
197
198pub fn partition(ir: &mut [HydroLeaf], partitioner: &Partitioner) {
200 traverse_dfir(
201 ir,
202 |_, _| {},
203 |node, next_stmt_id| {
204 partition_node(node, partitioner, next_stmt_id);
205 },
206 );
207}