1use itertools::Either;
2
3use super::HalfJoinState;
4
5pub struct SymmetricHashJoin<'a, Key, I1, V1, I2, V2, LhsState, RhsState>
6where
7 Key: Eq + std::hash::Hash + Clone,
8 V1: Clone,
9 V2: Clone,
10 I1: Iterator<Item = (Key, V1)>,
11 I2: Iterator<Item = (Key, V2)>,
12 LhsState: HalfJoinState<Key, V1, V2>,
13 RhsState: HalfJoinState<Key, V2, V1>,
14{
15 lhs: I1,
16 rhs: I2,
17 lhs_state: &'a mut LhsState,
18 rhs_state: &'a mut RhsState,
19}
20
21impl<Key, I1, V1, I2, V2, LhsState, RhsState> Iterator
22 for SymmetricHashJoin<'_, Key, I1, V1, I2, V2, LhsState, RhsState>
23where
24 Key: Eq + std::hash::Hash + Clone,
25 V1: Clone,
26 V2: Clone,
27 I1: Iterator<Item = (Key, V1)>,
28 I2: Iterator<Item = (Key, V2)>,
29 LhsState: HalfJoinState<Key, V1, V2>,
30 RhsState: HalfJoinState<Key, V2, V1>,
31{
32 type Item = (Key, (V1, V2));
33
34 fn next(&mut self) -> Option<Self::Item> {
35 loop {
36 if let Some((k, v2, v1)) = self.lhs_state.pop_match() {
37 return Some((k, (v1, v2)));
38 }
39 if let Some((k, v1, v2)) = self.rhs_state.pop_match() {
40 return Some((k, (v1, v2)));
41 }
42
43 if let Some((k, v1)) = self.lhs.next() {
44 if self.lhs_state.build(k.clone(), &v1) {
45 if let Some((k, v1, v2)) = self.rhs_state.probe(&k, &v1) {
46 return Some((k, (v1, v2)));
47 }
48 }
49 continue;
50 }
51 if let Some((k, v2)) = self.rhs.next() {
52 if self.rhs_state.build(k.clone(), &v2) {
53 if let Some((k, v2, v1)) = self.lhs_state.probe(&k, &v2) {
54 return Some((k, (v1, v2)));
55 }
56 }
57 continue;
58 }
59
60 return None;
61 }
62 }
63}
64
65pub fn symmetric_hash_join_into_iter<'a, Key, I1, V1, I2, V2, LhsState, RhsState>(
66 mut lhs: I1,
67 mut rhs: I2,
68 lhs_state: &'a mut LhsState,
69 rhs_state: &'a mut RhsState,
70 is_new_tick: bool,
71) -> impl 'a + Iterator<Item = (Key, (V1, V2))>
72where
73 Key: 'a + Eq + std::hash::Hash + Clone,
74 V1: 'a + Clone,
75 V2: 'a + Clone,
76 I1: 'a + Iterator<Item = (Key, V1)>,
77 I2: 'a + Iterator<Item = (Key, V2)>,
78 LhsState: HalfJoinState<Key, V1, V2>,
79 RhsState: HalfJoinState<Key, V2, V1>,
80{
81 if is_new_tick {
82 for (k, v1) in lhs.by_ref() {
83 lhs_state.build(k.clone(), &v1);
84 }
85
86 for (k, v2) in rhs.by_ref() {
87 rhs_state.build(k.clone(), &v2);
88 }
89
90 Either::Left(if lhs_state.len() < rhs_state.len() {
91 Either::Left(lhs_state.iter().flat_map(|(k, sv)| {
92 sv.iter().flat_map(|v1| {
93 rhs_state
94 .full_probe(k)
95 .map(|v2| (k.clone(), (v1.clone(), v2.clone())))
96 })
97 }))
98 } else {
99 Either::Right(rhs_state.iter().flat_map(|(k, sv)| {
100 sv.iter().flat_map(|v2| {
101 lhs_state
102 .full_probe(k)
103 .map(|v1| (k.clone(), (v1.clone(), v2.clone())))
104 })
105 }))
106 })
107 } else {
108 Either::Right(SymmetricHashJoin {
109 lhs,
110 rhs,
111 lhs_state,
112 rhs_state,
113 })
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use std::collections::HashSet;
120
121 use crate::compiled::pull::{HalfSetJoinState, symmetric_hash_join_into_iter};
122
123 #[test]
124 fn hash_join() {
125 let lhs = (0..10).map(|x| (x, format!("left {}", x)));
126 let rhs = (6..15).map(|x| (x / 2, format!("right {} / 2", x)));
127
128 let (mut lhs_state, mut rhs_state) =
129 (HalfSetJoinState::default(), HalfSetJoinState::default());
130 let join = symmetric_hash_join_into_iter(lhs, rhs, &mut lhs_state, &mut rhs_state, true);
131
132 let joined = join.collect::<HashSet<_>>();
133
134 assert!(joined.contains(&(3, ("left 3".into(), "right 6 / 2".into()))));
135 assert!(joined.contains(&(3, ("left 3".into(), "right 7 / 2".into()))));
136 assert!(joined.contains(&(4, ("left 4".into(), "right 8 / 2".into()))));
137 assert!(joined.contains(&(4, ("left 4".into(), "right 9 / 2".into()))));
138 assert!(joined.contains(&(5, ("left 5".into(), "right 10 / 2".into()))));
139 assert!(joined.contains(&(5, ("left 5".into(), "right 11 / 2".into()))));
140 assert!(joined.contains(&(6, ("left 6".into(), "right 12 / 2".into()))));
141 assert!(joined.contains(&(7, ("left 7".into(), "right 14 / 2".into()))));
142 }
143
144 #[test]
145 fn hash_join_subsequent_ticks_do_produce_even_if_nothing_is_changed() {
146 let (lhs_tx, lhs_rx) = std::sync::mpsc::channel::<(usize, usize)>();
147 let (rhs_tx, rhs_rx) = std::sync::mpsc::channel::<(usize, usize)>();
148
149 lhs_tx.send((7, 3)).unwrap();
150 rhs_tx.send((7, 3)).unwrap();
151
152 let (mut lhs_state, mut rhs_state) =
153 (HalfSetJoinState::default(), HalfSetJoinState::default());
154 let mut join = symmetric_hash_join_into_iter(
155 lhs_rx.try_iter(),
156 rhs_rx.try_iter(),
157 &mut lhs_state,
158 &mut rhs_state,
159 true,
160 );
161
162 assert_eq!(join.next(), Some((7, (3, 3))));
163 assert_eq!(join.next(), None);
164
165 lhs_tx.send((7, 3)).unwrap();
166 rhs_tx.send((7, 3)).unwrap();
167
168 assert_eq!(join.next(), None);
169 }
170}