dfir_rs/util/unsync/
mpsc.rs
1use std::cell::RefCell;
4use std::collections::VecDeque;
5use std::num::NonZeroUsize;
6use std::pin::Pin;
7use std::rc::{Rc, Weak};
8use std::task::{Context, Poll, Waker};
9
10use futures::{Sink, Stream, ready};
11use smallvec::SmallVec;
12#[doc(inline)]
13pub use tokio::sync::mpsc::error::{SendError, TrySendError};
14
15pub struct Sender<T> {
17 weak: Weak<RefCell<Shared<T>>>,
18}
19impl<T> Sender<T> {
20 pub async fn send(&self, item: T) -> Result<(), SendError<T>> {
22 let mut item = Some(item);
23 std::future::poll_fn(move |ctx| {
24 if let Some(strong) = Weak::upgrade(&self.weak) {
25 let mut shared = strong.borrow_mut();
26 if shared
27 .capacity
28 .is_some_and(|cap| cap.get() <= shared.buffer.len())
29 {
30 shared.send_wakers.push(ctx.waker().clone());
32 Poll::Pending
33 } else {
34 shared.buffer.push_back(item.take().unwrap());
35 shared.wake_receiver();
36 Poll::Ready(Ok(()))
37 }
38 } else {
39 Poll::Ready(Err(SendError(item.take().unwrap())))
41 }
42 })
43 .await
44 }
45
46 pub fn try_send(&self, item: T) -> Result<(), TrySendError<T>> {
52 if let Some(strong) = Weak::upgrade(&self.weak) {
53 let mut shared = strong.borrow_mut();
54 if shared
55 .capacity
56 .is_some_and(|cap| cap.get() <= shared.buffer.len())
57 {
58 Err(TrySendError::Full(item))
59 } else {
60 shared.buffer.push_back(item);
61 shared.wake_receiver();
62 Ok(())
63 }
64 } else {
65 Err(TrySendError::Closed(item))
66 }
67 }
68
69 pub fn close_this_sender(&mut self) {
74 self.weak = Weak::new();
75 }
76
77 pub fn is_closed(&self) -> bool {
79 0 == self.weak.strong_count()
80 }
81}
82impl<T> Clone for Sender<T> {
83 fn clone(&self) -> Self {
84 Self {
85 weak: self.weak.clone(),
86 }
87 }
88}
89impl<T> Drop for Sender<T> {
90 fn drop(&mut self) {
91 if let Some(strong) = self.weak.upgrade() {
94 strong.borrow_mut().wake_receiver();
95 }
96 }
97}
98
99impl<T> Sink<T> for Sender<T> {
100 type Error = TrySendError<Option<T>>;
101
102 fn poll_ready(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
103 if let Some(strong) = Weak::upgrade(&self.weak) {
104 let mut shared = strong.borrow_mut();
105 if shared
106 .capacity
107 .is_some_and(|cap| cap.get() <= shared.buffer.len())
108 {
109 shared.send_wakers.push(ctx.waker().clone());
111 Poll::Pending
112 } else {
113 Poll::Ready(Ok(()))
115 }
116 } else {
117 Poll::Ready(Err(TrySendError::Closed(None)))
119 }
120 }
121
122 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
123 self.try_send(item).map_err(|e| match e {
124 TrySendError::Full(item) => TrySendError::Full(Some(item)),
125 TrySendError::Closed(item) => TrySendError::Closed(Some(item)),
126 })
127 }
128
129 fn poll_flush(self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
130 Poll::Ready(Ok(()))
131 }
132
133 fn poll_close(
134 mut self: Pin<&mut Self>,
135 ctx: &mut Context<'_>,
136 ) -> Poll<Result<(), Self::Error>> {
137 ready!(self.as_mut().poll_flush(ctx))?;
138 Pin::into_inner(self).close_this_sender();
139 Poll::Ready(Ok(()))
140 }
141}
142
143pub struct Receiver<T> {
145 strong: Rc<RefCell<Shared<T>>>,
146}
147impl<T> Receiver<T> {
148 pub async fn recv(&mut self) -> Option<T> {
150 std::future::poll_fn(|ctx| self.poll_recv(ctx)).await
151 }
152
153 pub fn poll_recv(&mut self, ctx: &Context<'_>) -> Poll<Option<T>> {
156 let mut shared = self.strong.borrow_mut();
157 if let Some(value) = shared.buffer.pop_front() {
158 shared.wake_sender();
159 Poll::Ready(Some(value))
160 } else if 0 == Rc::weak_count(&self.strong) {
161 Poll::Ready(None) } else {
163 shared.recv_waker = Some(ctx.waker().clone());
164 Poll::Pending
165 }
166 }
167
168 pub fn close(&mut self) {
170 assert_eq!(
171 1,
172 Rc::strong_count(&self.strong),
173 "BUG: receiver has non-exclusive Rc."
174 );
175
176 let new_shared = {
177 let mut shared = self.strong.borrow_mut();
178 shared.wake_all_senders();
179
180 Shared {
181 buffer: std::mem::take(&mut shared.buffer),
182 ..Default::default()
183 }
184 };
185 self.strong = Rc::new(RefCell::new(new_shared));
186 }
188}
189impl<T> Drop for Receiver<T> {
190 fn drop(&mut self) {
191 self.close()
192 }
193}
194impl<T> Stream for Receiver<T> {
195 type Item = T;
196
197 fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198 self.poll_recv(ctx)
199 }
200}
201
202struct Shared<T> {
204 buffer: VecDeque<T>,
205 capacity: Option<NonZeroUsize>,
206 send_wakers: SmallVec<[Waker; 1]>,
207 recv_waker: Option<Waker>,
208}
209impl<T> Shared<T> {
210 pub fn wake_sender(&mut self) {
212 if let Some(waker) = self.send_wakers.pop() {
213 waker.wake();
214 }
215 }
216 pub fn wake_all_senders(&mut self) {
218 self.send_wakers.drain(..).for_each(Waker::wake);
219 }
220 pub fn wake_receiver(&mut self) {
222 if let Some(waker) = self.recv_waker.take() {
223 waker.wake();
224 }
225 }
226}
227impl<T> Default for Shared<T> {
228 fn default() -> Self {
229 let (buffer, capacity, send_wakers, recv_waker) = Default::default();
230 Self {
231 buffer,
232 capacity,
233 send_wakers,
234 recv_waker,
235 }
236 }
237}
238
239pub fn channel<T>(capacity: Option<NonZeroUsize>) -> (Sender<T>, Receiver<T>) {
241 let (buffer, send_wakers, recv_waker) = Default::default();
242 let shared = Rc::new(RefCell::new(Shared {
243 buffer,
244 capacity,
245 send_wakers,
246 recv_waker,
247 }));
248 let sender = Sender {
249 weak: Rc::downgrade(&shared),
250 };
251 let receiver = Receiver { strong: shared };
252 (sender, receiver)
253}
254
255pub fn bounded<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
257 let capacity = NonZeroUsize::new(capacity);
258 assert!(capacity.is_some(), "Capacity cannot be zero.");
259 channel(capacity)
260}
261
262pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
264 channel(None)
265}
266
267#[cfg(test)]
268mod test {
269 use std::time::Duration;
270
271 use futures::StreamExt;
272 use rand::Rng;
273 use tokio::task::LocalSet;
274
275 use super::*;
276
277 async fn delay(n: u64) {
278 let millis = rand::thread_rng().gen_range(0..n);
279 tokio::time::sleep(Duration::from_millis(millis)).await;
280 }
281
282 #[crate::test]
283 async fn test_send_multiple_outstanding() {
284 let (send, recv) = bounded::<u64>(10);
285
286 let a_fut = send.send(123);
287 let b_fut = send.send(234);
288
289 futures::future::try_join(a_fut, b_fut).await.unwrap();
290 drop(send);
291
292 let mut out: Vec<_> = recv.collect().await;
293 out.sort_unstable();
294 assert_eq!([123, 234], &*out);
295 }
296
297 #[crate::test]
298 async fn test_spsc_random() {
299 let runs = (0..1_000).map(|_| async {
300 let (send, recv) = bounded::<u64>(10);
301
302 let local = LocalSet::new();
303
304 local.spawn_local(async move {
305 for x in 0..100 {
306 send.send(x).await.unwrap();
307 delay(4).await;
308 }
309 });
310 local.spawn_local(async move {
311 delay(5).await; let mut recv = recv;
314 let mut i = 0;
315 while let Some(x) = recv.recv().await {
316 assert_eq!(i, x);
317 i += 1;
318 delay(5).await;
319 }
320 assert_eq!(100, i);
321 });
322 local.await;
323 });
324 futures::future::join_all(runs).await;
325 }
326
327 #[crate::test]
328 async fn test_mpsc_random() {
329 let runs = (0..1_000).map(|_| async {
330 let (send, recv) = bounded::<u64>(30);
331 let send_a = send.clone();
332 let send_b = send.clone();
333 let send_c = send;
334
335 let local = LocalSet::new();
336
337 local.spawn_local(async move {
338 for x in 0..100 {
339 send_a.send(x).await.unwrap();
340 delay(5).await;
341 }
342 });
343 local.spawn_local(async move {
344 for x in 100..200 {
345 send_b.send(x).await.unwrap();
346 delay(5).await;
347 }
348 });
349 local.spawn_local(async move {
350 for x in 200..300 {
351 send_c.send(x).await.unwrap();
352 delay(5).await;
353 }
354 });
355 local.spawn_local(async move {
356 delay(1).await; let mut recv = recv;
359 let mut vec = Vec::new();
360 while let Some(x) = recv.next().await {
361 vec.push(x);
362 delay(1).await;
363 }
364 assert_eq!(300, vec.len());
365 vec.sort_unstable();
366 for (i, &x) in vec.iter().enumerate() {
367 assert_eq!(i as u64, x);
368 }
369 });
370 local.await;
371 });
372 futures::future::join_all(runs).await;
373 }
374
375 #[crate::test]
376 async fn test_stream_sink_loop() {
377 use futures::{SinkExt, StreamExt};
378
379 const N: usize = 100;
380
381 let (mut send, mut recv) = unbounded::<usize>();
382 send.send(0).await.unwrap();
383 let mut recv_ref = recv.by_ref().map(|x| x + 1).map(Ok).take(N);
385 send.send_all(&mut recv_ref).await.unwrap();
386 assert_eq!(Some(N), recv.recv().await);
387 }
388}