dfir_rs/compiled/pull/
symmetric_hash_join.rs

1use itertools::Either;
2
3use super::HalfJoinState;
4
5pub struct SymmetricHashJoin<'a, Key, I1, V1, I2, V2, LhsState, RhsState>
6where
7    Key: Eq + std::hash::Hash + Clone,
8    V1: Clone,
9    V2: Clone,
10    I1: Iterator<Item = (Key, V1)>,
11    I2: Iterator<Item = (Key, V2)>,
12    LhsState: HalfJoinState<Key, V1, V2>,
13    RhsState: HalfJoinState<Key, V2, V1>,
14{
15    lhs: I1,
16    rhs: I2,
17    lhs_state: &'a mut LhsState,
18    rhs_state: &'a mut RhsState,
19}
20
21impl<Key, I1, V1, I2, V2, LhsState, RhsState> Iterator
22    for SymmetricHashJoin<'_, Key, I1, V1, I2, V2, LhsState, RhsState>
23where
24    Key: Eq + std::hash::Hash + Clone,
25    V1: Clone,
26    V2: Clone,
27    I1: Iterator<Item = (Key, V1)>,
28    I2: Iterator<Item = (Key, V2)>,
29    LhsState: HalfJoinState<Key, V1, V2>,
30    RhsState: HalfJoinState<Key, V2, V1>,
31{
32    type Item = (Key, (V1, V2));
33
34    fn next(&mut self) -> Option<Self::Item> {
35        loop {
36            if let Some((k, v2, v1)) = self.lhs_state.pop_match() {
37                return Some((k, (v1, v2)));
38            }
39            if let Some((k, v1, v2)) = self.rhs_state.pop_match() {
40                return Some((k, (v1, v2)));
41            }
42
43            if let Some((k, v1)) = self.lhs.next() {
44                if self.lhs_state.build(k.clone(), &v1) {
45                    if let Some((k, v1, v2)) = self.rhs_state.probe(&k, &v1) {
46                        return Some((k, (v1, v2)));
47                    }
48                }
49                continue;
50            }
51            if let Some((k, v2)) = self.rhs.next() {
52                if self.rhs_state.build(k.clone(), &v2) {
53                    if let Some((k, v2, v1)) = self.lhs_state.probe(&k, &v2) {
54                        return Some((k, (v1, v2)));
55                    }
56                }
57                continue;
58            }
59
60            return None;
61        }
62    }
63}
64
65pub fn symmetric_hash_join_into_iter<'a, Key, I1, V1, I2, V2, LhsState, RhsState>(
66    mut lhs: I1,
67    mut rhs: I2,
68    lhs_state: &'a mut LhsState,
69    rhs_state: &'a mut RhsState,
70    is_new_tick: bool,
71) -> impl 'a + Iterator<Item = (Key, (V1, V2))>
72where
73    Key: 'a + Eq + std::hash::Hash + Clone,
74    V1: 'a + Clone,
75    V2: 'a + Clone,
76    I1: 'a + Iterator<Item = (Key, V1)>,
77    I2: 'a + Iterator<Item = (Key, V2)>,
78    LhsState: HalfJoinState<Key, V1, V2>,
79    RhsState: HalfJoinState<Key, V2, V1>,
80{
81    if is_new_tick {
82        for (k, v1) in lhs.by_ref() {
83            lhs_state.build(k.clone(), &v1);
84        }
85
86        for (k, v2) in rhs.by_ref() {
87            rhs_state.build(k.clone(), &v2);
88        }
89
90        Either::Left(if lhs_state.len() < rhs_state.len() {
91            Either::Left(lhs_state.iter().flat_map(|(k, sv)| {
92                sv.iter().flat_map(|v1| {
93                    rhs_state
94                        .full_probe(k)
95                        .map(|v2| (k.clone(), (v1.clone(), v2.clone())))
96                })
97            }))
98        } else {
99            Either::Right(rhs_state.iter().flat_map(|(k, sv)| {
100                sv.iter().flat_map(|v2| {
101                    lhs_state
102                        .full_probe(k)
103                        .map(|v1| (k.clone(), (v1.clone(), v2.clone())))
104                })
105            }))
106        })
107    } else {
108        Either::Right(SymmetricHashJoin {
109            lhs,
110            rhs,
111            lhs_state,
112            rhs_state,
113        })
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use std::collections::HashSet;
120
121    use crate::compiled::pull::{HalfSetJoinState, symmetric_hash_join_into_iter};
122
123    #[test]
124    fn hash_join() {
125        let lhs = (0..10).map(|x| (x, format!("left {}", x)));
126        let rhs = (6..15).map(|x| (x / 2, format!("right {} / 2", x)));
127
128        let (mut lhs_state, mut rhs_state) =
129            (HalfSetJoinState::default(), HalfSetJoinState::default());
130        let join = symmetric_hash_join_into_iter(lhs, rhs, &mut lhs_state, &mut rhs_state, true);
131
132        let joined = join.collect::<HashSet<_>>();
133
134        assert!(joined.contains(&(3, ("left 3".into(), "right 6 / 2".into()))));
135        assert!(joined.contains(&(3, ("left 3".into(), "right 7 / 2".into()))));
136        assert!(joined.contains(&(4, ("left 4".into(), "right 8 / 2".into()))));
137        assert!(joined.contains(&(4, ("left 4".into(), "right 9 / 2".into()))));
138        assert!(joined.contains(&(5, ("left 5".into(), "right 10 / 2".into()))));
139        assert!(joined.contains(&(5, ("left 5".into(), "right 11 / 2".into()))));
140        assert!(joined.contains(&(6, ("left 6".into(), "right 12 / 2".into()))));
141        assert!(joined.contains(&(7, ("left 7".into(), "right 14 / 2".into()))));
142    }
143
144    #[test]
145    fn hash_join_subsequent_ticks_do_produce_even_if_nothing_is_changed() {
146        let (lhs_tx, lhs_rx) = std::sync::mpsc::channel::<(usize, usize)>();
147        let (rhs_tx, rhs_rx) = std::sync::mpsc::channel::<(usize, usize)>();
148
149        lhs_tx.send((7, 3)).unwrap();
150        rhs_tx.send((7, 3)).unwrap();
151
152        let (mut lhs_state, mut rhs_state) =
153            (HalfSetJoinState::default(), HalfSetJoinState::default());
154        let mut join = symmetric_hash_join_into_iter(
155            lhs_rx.try_iter(),
156            rhs_rx.try_iter(),
157            &mut lhs_state,
158            &mut rhs_state,
159            true,
160        );
161
162        assert_eq!(join.next(), Some((7, (3, 3))));
163        assert_eq!(join.next(), None);
164
165        lhs_tx.send((7, 3)).unwrap();
166        rhs_tx.send((7, 3)).unwrap();
167
168        assert_eq!(join.next(), None);
169    }
170}