1use std::borrow::Cow;
2use std::pin::Pin;
3use std::task::{Context, Poll, ready};
4
5use futures::future::Either as FutEither;
6use futures::stream::{FusedStream, Stream, StreamExt};
7use itertools::Either as IterEither;
8use pin_project_lite::pin_project;
9
10use super::{ForEach, HalfJoinState};
11
12pin_project! {
13 #[must_use = "streams do nothing unless polled"]
15 pub struct SymmetricHashJoin<'a, Lhs, Rhs, LhsState, RhsState>
16 {
17 #[pin]
18 lhs: Lhs,
19 #[pin]
20 rhs: Rhs,
21
22 lhs_state: &'a mut LhsState,
23 rhs_state: &'a mut RhsState,
24 }
25}
26
27impl<'a, Key, Lhs, V1, Rhs, V2, LhsState, RhsState> Stream
28 for SymmetricHashJoin<'a, Lhs, Rhs, LhsState, RhsState>
29where
30 Key: Eq + std::hash::Hash + Clone,
31 V1: Clone,
32 V2: Clone,
33 Lhs: FusedStream<Item = (Key, V1)>,
34 Rhs: FusedStream<Item = (Key, V2)>,
35 LhsState: HalfJoinState<Key, V1, V2>,
36 RhsState: HalfJoinState<Key, V2, V1>,
37{
38 type Item = (Key, (V1, V2));
39
40 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
41 let mut this = self.project();
42
43 loop {
44 if let Some((k, v2, v1)) = this.lhs_state.pop_match() {
45 return Poll::Ready(Some((k, (v1, v2))));
46 }
47 if let Some((k, v1, v2)) = this.rhs_state.pop_match() {
48 return Poll::Ready(Some((k, (v1, v2))));
49 }
50
51 let lhs_poll = this.lhs.as_mut().poll_next(cx);
52 if let Poll::Ready(Some((k, v1))) = lhs_poll {
53 if this.lhs_state.build(k.clone(), Cow::Borrowed(&v1))
54 && let Some((k, v1, v2)) = this.rhs_state.probe(&k, &v1)
55 {
56 return Poll::Ready(Some((k, (v1, v2))));
57 }
58 continue;
59 }
60
61 let rhs_poll = this.rhs.as_mut().poll_next(cx);
62 if let Poll::Ready(Some((k, v2))) = rhs_poll {
63 if this.rhs_state.build(k.clone(), Cow::Borrowed(&v2))
64 && let Some((k, v2, v1)) = this.lhs_state.probe(&k, &v2)
65 {
66 return Poll::Ready(Some((k, (v1, v2))));
67 }
68 continue;
69 }
70
71 let _none = ready!(lhs_poll);
72 let _none = ready!(rhs_poll);
73 return Poll::Ready(None);
74 }
75 }
76}
77
78pub async fn symmetric_hash_join_into_stream<'a, Key, Lhs, V1, Rhs, V2, LhsState, RhsState>(
80 lhs: Lhs,
81 rhs: Rhs,
82 lhs_state: &'a mut LhsState,
83 rhs_state: &'a mut RhsState,
84 is_new_tick: bool,
85) -> impl 'a + Stream<Item = (Key, (V1, V2))>
86where
87 Key: 'a + Eq + std::hash::Hash + Clone,
88 V1: 'a + Clone,
89 V2: 'a + Clone,
90 Lhs: 'a + Stream<Item = (Key, V1)>,
91 Rhs: 'a + Stream<Item = (Key, V2)>,
92 LhsState: HalfJoinState<Key, V1, V2>,
93 RhsState: HalfJoinState<Key, V2, V1>,
94{
95 if is_new_tick {
96 ForEach::new(lhs, |(k, v1)| {
97 lhs_state.build(k.clone(), Cow::Owned(v1));
98 })
99 .await;
100
101 ForEach::new(rhs, |(k, v2)| {
102 rhs_state.build(k.clone(), Cow::Owned(v2));
103 })
104 .await;
105
106 let iter = if lhs_state.len() < rhs_state.len() {
107 IterEither::Left(lhs_state.iter().flat_map(|(k, sv)| {
108 sv.iter().flat_map(|v1| {
109 rhs_state
110 .full_probe(k)
111 .map(|v2| (k.clone(), (v1.clone(), v2.clone())))
112 })
113 }))
114 } else {
115 IterEither::Right(rhs_state.iter().flat_map(|(k, sv)| {
116 sv.iter().flat_map(|v2| {
117 lhs_state
118 .full_probe(k)
119 .map(|v1| (k.clone(), (v1.clone(), v2.clone())))
120 })
121 }))
122 };
123 FutEither::Left(futures::stream::iter(iter))
124 } else {
125 FutEither::Right(SymmetricHashJoin {
126 lhs: lhs.fuse(),
127 rhs: rhs.fuse(),
128 lhs_state,
129 rhs_state,
130 })
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use std::collections::HashSet;
137
138 use super::super::HalfSetJoinState;
139 use super::*;
140
141 #[crate::test]
142 async fn hash_join() {
143 let lhs = futures::stream::iter((0..10).map(|x| (x, format!("left {}", x))));
144 let rhs = futures::stream::iter((6..15).map(|x| (x / 2, format!("right {} / 2", x))));
145
146 let (mut lhs_state, mut rhs_state) =
147 (HalfSetJoinState::default(), HalfSetJoinState::default());
148 let join =
149 symmetric_hash_join_into_stream(lhs, rhs, &mut lhs_state, &mut rhs_state, true).await;
150
151 let joined = join.collect::<HashSet<_>>().await;
152
153 assert!(joined.contains(&(3, ("left 3".into(), "right 6 / 2".into()))));
154 assert!(joined.contains(&(3, ("left 3".into(), "right 7 / 2".into()))));
155 assert!(joined.contains(&(4, ("left 4".into(), "right 8 / 2".into()))));
156 assert!(joined.contains(&(4, ("left 4".into(), "right 9 / 2".into()))));
157 assert!(joined.contains(&(5, ("left 5".into(), "right 10 / 2".into()))));
158 assert!(joined.contains(&(5, ("left 5".into(), "right 11 / 2".into()))));
159 assert!(joined.contains(&(6, ("left 6".into(), "right 12 / 2".into()))));
160 assert!(joined.contains(&(6, ("left 6".into(), "right 13 / 2".into()))));
161 assert!(joined.contains(&(7, ("left 7".into(), "right 14 / 2".into()))));
162 assert_eq!(9, joined.len());
163 }
164
165 #[crate::test]
166 async fn hash_join_subsequent_ticks_do_produce_even_if_nothing_is_changed() {
167 let (lhs_tx, lhs_rx) = tokio::sync::mpsc::unbounded_channel::<(usize, usize)>();
168 let (rhs_tx, rhs_rx) = tokio::sync::mpsc::unbounded_channel::<(usize, usize)>();
169 let lhs_rx = tokio_stream::wrappers::UnboundedReceiverStream::new(lhs_rx);
170 let rhs_rx = tokio_stream::wrappers::UnboundedReceiverStream::new(rhs_rx);
171
172 lhs_tx.send((7, 3)).unwrap();
173 rhs_tx.send((7, 3)).unwrap();
174
175 let (mut lhs_state, mut rhs_state) =
176 (HalfSetJoinState::default(), HalfSetJoinState::default());
177 let mut join =
178 symmetric_hash_join_into_stream(lhs_rx, rhs_rx, &mut lhs_state, &mut rhs_state, false)
179 .await;
180
181 assert_eq!(join.next().await, Some((7, (3, 3))));
182
183 lhs_tx.send((7, 3)).unwrap();
184 rhs_tx.send((7, 3)).unwrap();
185 drop((lhs_tx, rhs_tx));
186
187 assert_eq!(join.next().await, None);
188 }
189}