dfir_rs/util/unsync/
mpsc.rs

1//! Unsync single-producer single-consumer channel (i.e. a single-threaded queue with async hooks).
2
3use std::cell::RefCell;
4use std::collections::VecDeque;
5use std::num::NonZeroUsize;
6use std::pin::Pin;
7use std::rc::{Rc, Weak};
8use std::task::{Context, Poll, Waker};
9
10use futures::{Sink, Stream, ready};
11use smallvec::SmallVec;
12#[doc(inline)]
13pub use tokio::sync::mpsc::error::{SendError, TrySendError};
14
15/// Send half of am unsync MPSC.
16pub struct Sender<T> {
17    weak: Weak<RefCell<Shared<T>>>,
18}
19impl<T> Sender<T> {
20    /// Asynchronously sends value to the receiver.
21    pub async fn send(&self, item: T) -> Result<(), SendError<T>> {
22        let mut item = Some(item);
23        std::future::poll_fn(move |ctx| {
24            if let Some(strong) = Weak::upgrade(&self.weak) {
25                let mut shared = strong.borrow_mut();
26                if shared
27                    .capacity
28                    .is_some_and(|cap| cap.get() <= shared.buffer.len())
29                {
30                    // Full.
31                    shared.send_wakers.push(ctx.waker().clone());
32                    Poll::Pending
33                } else {
34                    shared.buffer.push_back(item.take().unwrap());
35                    shared.wake_receiver();
36                    Poll::Ready(Ok(()))
37                }
38            } else {
39                // Closed.
40                Poll::Ready(Err(SendError(item.take().unwrap())))
41            }
42        })
43        .await
44    }
45
46    /// Tries to send the value to the receiver without blocking.
47    ///
48    /// Returns an error if the destination is closed or if the buffer is at capacity.
49    ///
50    /// [`TrySendError::Full`] will never be returned if this is an unbounded channel.
51    pub fn try_send(&self, item: T) -> Result<(), TrySendError<T>> {
52        if let Some(strong) = Weak::upgrade(&self.weak) {
53            let mut shared = strong.borrow_mut();
54            if shared
55                .capacity
56                .is_some_and(|cap| cap.get() <= shared.buffer.len())
57            {
58                Err(TrySendError::Full(item))
59            } else {
60                shared.buffer.push_back(item);
61                shared.wake_receiver();
62                Ok(())
63            }
64        } else {
65            Err(TrySendError::Closed(item))
66        }
67    }
68
69    /// Close this sender. No more messages can be sent from this sender.
70    ///
71    /// Note that this only closes the channel from the view-point of this sender. The channel
72    /// remains open until all senders have gone away, or until the [`Receiver`] closes the channel.
73    pub fn close_this_sender(&mut self) {
74        self.weak = Weak::new();
75    }
76
77    /// If this sender or the corresponding [`Receiver`] is closed.
78    pub fn is_closed(&self) -> bool {
79        0 == self.weak.strong_count()
80    }
81}
82impl<T> Clone for Sender<T> {
83    fn clone(&self) -> Self {
84        Self {
85            weak: self.weak.clone(),
86        }
87    }
88}
89impl<T> Drop for Sender<T> {
90    fn drop(&mut self) {
91        // Really we should only do this if we're the very last sender,
92        // But `1 == self.weak.weak_count()` seems unreliable.
93        if let Some(strong) = self.weak.upgrade() {
94            strong.borrow_mut().wake_receiver();
95        }
96    }
97}
98
99impl<T> Sink<T> for Sender<T> {
100    type Error = TrySendError<Option<T>>;
101
102    fn poll_ready(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
103        if let Some(strong) = Weak::upgrade(&self.weak) {
104            let mut shared = strong.borrow_mut();
105            if shared
106                .capacity
107                .is_some_and(|cap| cap.get() <= shared.buffer.len())
108            {
109                // Full.
110                shared.send_wakers.push(ctx.waker().clone());
111                Poll::Pending
112            } else {
113                // Has room.
114                Poll::Ready(Ok(()))
115            }
116        } else {
117            // Closed
118            Poll::Ready(Err(TrySendError::Closed(None)))
119        }
120    }
121
122    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
123        self.try_send(item).map_err(|e| match e {
124            TrySendError::Full(item) => TrySendError::Full(Some(item)),
125            TrySendError::Closed(item) => TrySendError::Closed(Some(item)),
126        })
127    }
128
129    fn poll_flush(self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
130        Poll::Ready(Ok(()))
131    }
132
133    fn poll_close(
134        mut self: Pin<&mut Self>,
135        ctx: &mut Context<'_>,
136    ) -> Poll<Result<(), Self::Error>> {
137        ready!(self.as_mut().poll_flush(ctx))?;
138        Pin::into_inner(self).close_this_sender();
139        Poll::Ready(Ok(()))
140    }
141}
142
143/// Receiving half of an unsync MPSC.
144pub struct Receiver<T> {
145    strong: Rc<RefCell<Shared<T>>>,
146}
147impl<T> Receiver<T> {
148    /// Receive a value asynchronously.
149    pub async fn recv(&mut self) -> Option<T> {
150        std::future::poll_fn(|ctx| self.poll_recv(ctx)).await
151    }
152
153    /// Poll for a value.
154    /// NOTE: takes `&mut self` to prevent multiple concurrent receives.
155    pub fn poll_recv(&mut self, ctx: &Context<'_>) -> Poll<Option<T>> {
156        let mut shared = self.strong.borrow_mut();
157        if let Some(value) = shared.buffer.pop_front() {
158            shared.wake_sender();
159            Poll::Ready(Some(value))
160        } else if 0 == Rc::weak_count(&self.strong) {
161            Poll::Ready(None) // Empty and dropped.
162        } else {
163            shared.recv_waker = Some(ctx.waker().clone());
164            Poll::Pending
165        }
166    }
167
168    /// Closes this receiving end, not allowing more values to be sent while still allowing already-sent values to be consumed.
169    pub fn close(&mut self) {
170        assert_eq!(
171            1,
172            Rc::strong_count(&self.strong),
173            "BUG: receiver has non-exclusive Rc."
174        );
175
176        let new_shared = {
177            let mut shared = self.strong.borrow_mut();
178            shared.wake_all_senders();
179
180            Shared {
181                buffer: std::mem::take(&mut shared.buffer),
182                ..Default::default()
183            }
184        };
185        self.strong = Rc::new(RefCell::new(new_shared));
186        // Drop old `Rc`, invalidating all `Weak`s.
187    }
188}
189impl<T> Drop for Receiver<T> {
190    fn drop(&mut self) {
191        self.close()
192    }
193}
194impl<T> Stream for Receiver<T> {
195    type Item = T;
196
197    fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198        self.poll_recv(ctx)
199    }
200}
201
202/// Struct shared between sender and receiver.
203struct Shared<T> {
204    buffer: VecDeque<T>,
205    capacity: Option<NonZeroUsize>,
206    send_wakers: SmallVec<[Waker; 1]>,
207    recv_waker: Option<Waker>,
208}
209impl<T> Shared<T> {
210    /// Wakes one sender (if there are any wakers), and removes the waker.
211    pub fn wake_sender(&mut self) {
212        if let Some(waker) = self.send_wakers.pop() {
213            waker.wake();
214        }
215    }
216    /// Wakes all senders and removes their wakers.
217    pub fn wake_all_senders(&mut self) {
218        self.send_wakers.drain(..).for_each(Waker::wake);
219    }
220    /// Wakes the receiver (if the waker is set) and removes it.
221    pub fn wake_receiver(&mut self) {
222        if let Some(waker) = self.recv_waker.take() {
223            waker.wake();
224        }
225    }
226}
227impl<T> Default for Shared<T> {
228    fn default() -> Self {
229        let (buffer, capacity, send_wakers, recv_waker) = Default::default();
230        Self {
231            buffer,
232            capacity,
233            send_wakers,
234            recv_waker,
235        }
236    }
237}
238
239/// Create an unsync MPSC channel, either bounded (if `capacity` is `Some`) or unbounded (if `capacity` is `None`).
240pub fn channel<T>(capacity: Option<NonZeroUsize>) -> (Sender<T>, Receiver<T>) {
241    let (buffer, send_wakers, recv_waker) = Default::default();
242    let shared = Rc::new(RefCell::new(Shared {
243        buffer,
244        capacity,
245        send_wakers,
246        recv_waker,
247    }));
248    let sender = Sender {
249        weak: Rc::downgrade(&shared),
250    };
251    let receiver = Receiver { strong: shared };
252    (sender, receiver)
253}
254
255/// Create a bounded unsync MPSC channel. Panics if capacity is zero.
256pub fn bounded<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
257    let capacity = NonZeroUsize::new(capacity);
258    assert!(capacity.is_some(), "Capacity cannot be zero.");
259    channel(capacity)
260}
261
262/// Create an unbounded unsync MPSC channel.
263pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
264    channel(None)
265}
266
267#[cfg(test)]
268mod test {
269    use futures::StreamExt;
270    use rand::Rng;
271    use tokio::task::LocalSet;
272    use web_time::Duration;
273
274    use super::*;
275
276    async fn delay(n: u64) {
277        let millis = rand::thread_rng().gen_range(0..n);
278        tokio::time::sleep(Duration::from_millis(millis)).await;
279    }
280
281    #[crate::test]
282    async fn test_send_multiple_outstanding() {
283        let (send, recv) = bounded::<u64>(10);
284
285        let a_fut = send.send(123);
286        let b_fut = send.send(234);
287
288        futures::future::try_join(a_fut, b_fut).await.unwrap();
289        drop(send);
290
291        let mut out: Vec<_> = recv.collect().await;
292        out.sort_unstable();
293        assert_eq!([123, 234], &*out);
294    }
295
296    #[crate::test]
297    async fn test_spsc_random() {
298        let runs = (0..1_000).map(|_| async {
299            let (send, recv) = bounded::<u64>(10);
300
301            let local = LocalSet::new();
302
303            local.spawn_local(async move {
304                for x in 0..100 {
305                    send.send(x).await.unwrap();
306                    delay(4).await;
307                }
308            });
309            local.spawn_local(async move {
310                delay(5).await; // Delay once first.
311
312                let mut recv = recv;
313                let mut i = 0;
314                while let Some(x) = recv.recv().await {
315                    assert_eq!(i, x);
316                    i += 1;
317                    delay(5).await;
318                }
319                assert_eq!(100, i);
320            });
321            local.await;
322        });
323        futures::future::join_all(runs).await;
324    }
325
326    #[crate::test]
327    async fn test_mpsc_random() {
328        let runs = (0..1_000).map(|_| async {
329            let (send, recv) = bounded::<u64>(30);
330            let send_a = send.clone();
331            let send_b = send.clone();
332            let send_c = send;
333
334            let local = LocalSet::new();
335
336            local.spawn_local(async move {
337                for x in 0..100 {
338                    send_a.send(x).await.unwrap();
339                    delay(5).await;
340                }
341            });
342            local.spawn_local(async move {
343                for x in 100..200 {
344                    send_b.send(x).await.unwrap();
345                    delay(5).await;
346                }
347            });
348            local.spawn_local(async move {
349                for x in 200..300 {
350                    send_c.send(x).await.unwrap();
351                    delay(5).await;
352                }
353            });
354            local.spawn_local(async move {
355                delay(1).await; // Delay once first.
356
357                let mut recv = recv;
358                let mut vec = Vec::new();
359                while let Some(x) = recv.next().await {
360                    vec.push(x);
361                    delay(1).await;
362                }
363                assert_eq!(300, vec.len());
364                vec.sort_unstable();
365                for (i, &x) in vec.iter().enumerate() {
366                    assert_eq!(i as u64, x);
367                }
368            });
369            local.await;
370        });
371        futures::future::join_all(runs).await;
372    }
373
374    #[crate::test]
375    async fn test_stream_sink_loop() {
376        use futures::{SinkExt, StreamExt};
377
378        const N: usize = 100;
379
380        let (mut send, mut recv) = unbounded::<usize>();
381        send.send(0).await.unwrap();
382        // Connect it to itself
383        let mut recv_ref = recv.by_ref().map(|x| x + 1).map(Ok).take(N);
384        send.send_all(&mut recv_ref).await.unwrap();
385        assert_eq!(Some(N), recv.recv().await);
386    }
387}