dfir_rs/compiled/pull/
symmetric_hash_join.rs

1use std::borrow::Cow;
2use std::pin::Pin;
3use std::task::{Context, Poll, ready};
4
5use futures::future::Either as FutEither;
6use futures::stream::{FusedStream, Stream, StreamExt};
7use itertools::Either as IterEither;
8use pin_project_lite::pin_project;
9
10use super::{ForEach, HalfJoinState};
11
12pin_project! {
13    /// Stream combinator for symmetric hash join operations.
14    #[must_use = "streams do nothing unless polled"]
15    pub struct SymmetricHashJoin<'a, Lhs, Rhs, LhsState, RhsState>
16    {
17        #[pin]
18        lhs: Lhs,
19        #[pin]
20        rhs: Rhs,
21
22        lhs_state: &'a mut LhsState,
23        rhs_state: &'a mut RhsState,
24    }
25}
26
27impl<'a, Key, Lhs, V1, Rhs, V2, LhsState, RhsState> Stream
28    for SymmetricHashJoin<'a, Lhs, Rhs, LhsState, RhsState>
29where
30    Key: Eq + std::hash::Hash + Clone,
31    V1: Clone,
32    V2: Clone,
33    Lhs: FusedStream<Item = (Key, V1)>,
34    Rhs: FusedStream<Item = (Key, V2)>,
35    LhsState: HalfJoinState<Key, V1, V2>,
36    RhsState: HalfJoinState<Key, V2, V1>,
37{
38    type Item = (Key, (V1, V2));
39
40    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
41        let mut this = self.project();
42
43        loop {
44            if let Some((k, v2, v1)) = this.lhs_state.pop_match() {
45                return Poll::Ready(Some((k, (v1, v2))));
46            }
47            if let Some((k, v1, v2)) = this.rhs_state.pop_match() {
48                return Poll::Ready(Some((k, (v1, v2))));
49            }
50
51            let lhs_poll = this.lhs.as_mut().poll_next(cx);
52            if let Poll::Ready(Some((k, v1))) = lhs_poll {
53                if this.lhs_state.build(k.clone(), Cow::Borrowed(&v1))
54                    && let Some((k, v1, v2)) = this.rhs_state.probe(&k, &v1)
55                {
56                    return Poll::Ready(Some((k, (v1, v2))));
57                }
58                continue;
59            }
60
61            let rhs_poll = this.rhs.as_mut().poll_next(cx);
62            if let Poll::Ready(Some((k, v2))) = rhs_poll {
63                if this.rhs_state.build(k.clone(), Cow::Borrowed(&v2))
64                    && let Some((k, v2, v1)) = this.lhs_state.probe(&k, &v2)
65                {
66                    return Poll::Ready(Some((k, (v1, v2))));
67                }
68                continue;
69            }
70
71            let _none = ready!(lhs_poll);
72            let _none = ready!(rhs_poll);
73            return Poll::Ready(None);
74        }
75    }
76}
77
78/// Creates a symmetric hash join stream from two input streams and their join states.
79pub async fn symmetric_hash_join_into_stream<'a, Key, Lhs, V1, Rhs, V2, LhsState, RhsState>(
80    lhs: Lhs,
81    rhs: Rhs,
82    lhs_state: &'a mut LhsState,
83    rhs_state: &'a mut RhsState,
84    is_new_tick: bool,
85) -> impl 'a + Stream<Item = (Key, (V1, V2))>
86where
87    Key: 'a + Eq + std::hash::Hash + Clone,
88    V1: 'a + Clone,
89    V2: 'a + Clone,
90    Lhs: 'a + Stream<Item = (Key, V1)>,
91    Rhs: 'a + Stream<Item = (Key, V2)>,
92    LhsState: HalfJoinState<Key, V1, V2>,
93    RhsState: HalfJoinState<Key, V2, V1>,
94{
95    if is_new_tick {
96        ForEach::new(lhs, |(k, v1)| {
97            lhs_state.build(k.clone(), Cow::Owned(v1));
98        })
99        .await;
100
101        ForEach::new(rhs, |(k, v2)| {
102            rhs_state.build(k.clone(), Cow::Owned(v2));
103        })
104        .await;
105
106        let iter = if lhs_state.len() < rhs_state.len() {
107            IterEither::Left(lhs_state.iter().flat_map(|(k, sv)| {
108                sv.iter().flat_map(|v1| {
109                    rhs_state
110                        .full_probe(k)
111                        .map(|v2| (k.clone(), (v1.clone(), v2.clone())))
112                })
113            }))
114        } else {
115            IterEither::Right(rhs_state.iter().flat_map(|(k, sv)| {
116                sv.iter().flat_map(|v2| {
117                    lhs_state
118                        .full_probe(k)
119                        .map(|v1| (k.clone(), (v1.clone(), v2.clone())))
120                })
121            }))
122        };
123        FutEither::Left(futures::stream::iter(iter))
124    } else {
125        FutEither::Right(SymmetricHashJoin {
126            lhs: lhs.fuse(),
127            rhs: rhs.fuse(),
128            lhs_state,
129            rhs_state,
130        })
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use std::collections::HashSet;
137
138    use super::super::HalfSetJoinState;
139    use super::*;
140
141    #[crate::test]
142    async fn hash_join() {
143        let lhs = futures::stream::iter((0..10).map(|x| (x, format!("left {}", x))));
144        let rhs = futures::stream::iter((6..15).map(|x| (x / 2, format!("right {} / 2", x))));
145
146        let (mut lhs_state, mut rhs_state) =
147            (HalfSetJoinState::default(), HalfSetJoinState::default());
148        let join =
149            symmetric_hash_join_into_stream(lhs, rhs, &mut lhs_state, &mut rhs_state, true).await;
150
151        let joined = join.collect::<HashSet<_>>().await;
152
153        assert!(joined.contains(&(3, ("left 3".into(), "right 6 / 2".into()))));
154        assert!(joined.contains(&(3, ("left 3".into(), "right 7 / 2".into()))));
155        assert!(joined.contains(&(4, ("left 4".into(), "right 8 / 2".into()))));
156        assert!(joined.contains(&(4, ("left 4".into(), "right 9 / 2".into()))));
157        assert!(joined.contains(&(5, ("left 5".into(), "right 10 / 2".into()))));
158        assert!(joined.contains(&(5, ("left 5".into(), "right 11 / 2".into()))));
159        assert!(joined.contains(&(6, ("left 6".into(), "right 12 / 2".into()))));
160        assert!(joined.contains(&(6, ("left 6".into(), "right 13 / 2".into()))));
161        assert!(joined.contains(&(7, ("left 7".into(), "right 14 / 2".into()))));
162        assert_eq!(9, joined.len());
163    }
164
165    #[crate::test]
166    async fn hash_join_subsequent_ticks_do_produce_even_if_nothing_is_changed() {
167        let (lhs_tx, lhs_rx) = tokio::sync::mpsc::unbounded_channel::<(usize, usize)>();
168        let (rhs_tx, rhs_rx) = tokio::sync::mpsc::unbounded_channel::<(usize, usize)>();
169        let lhs_rx = tokio_stream::wrappers::UnboundedReceiverStream::new(lhs_rx);
170        let rhs_rx = tokio_stream::wrappers::UnboundedReceiverStream::new(rhs_rx);
171
172        lhs_tx.send((7, 3)).unwrap();
173        rhs_tx.send((7, 3)).unwrap();
174
175        let (mut lhs_state, mut rhs_state) =
176            (HalfSetJoinState::default(), HalfSetJoinState::default());
177        let mut join =
178            symmetric_hash_join_into_stream(lhs_rx, rhs_rx, &mut lhs_state, &mut rhs_state, false)
179                .await;
180
181        assert_eq!(join.next().await, Some((7, (3, 3))));
182
183        lhs_tx.send((7, 3)).unwrap();
184        rhs_tx.send((7, 3)).unwrap();
185        drop((lhs_tx, rhs_tx));
186
187        assert_eq!(join.next().await, None);
188    }
189}