dfir_rs/util/unsync/
mpsc.rs

1//! Unsync single-producer single-consumer channel (i.e. a single-threaded queue with async hooks).
2
3use std::cell::RefCell;
4use std::collections::VecDeque;
5use std::num::NonZeroUsize;
6use std::pin::Pin;
7use std::rc::{Rc, Weak};
8use std::task::{Context, Poll, Waker};
9
10use futures::{Sink, Stream, ready};
11use smallvec::SmallVec;
12#[doc(inline)]
13pub use tokio::sync::mpsc::error::{SendError, TrySendError};
14
15/// Send half of am unsync MPSC.
16pub struct Sender<T> {
17    weak: Weak<RefCell<Shared<T>>>,
18}
19impl<T> Sender<T> {
20    /// Asynchronously sends value to the receiver.
21    pub async fn send(&self, item: T) -> Result<(), SendError<T>> {
22        let mut item = Some(item);
23        std::future::poll_fn(move |ctx| {
24            if let Some(strong) = Weak::upgrade(&self.weak) {
25                let mut shared = strong.borrow_mut();
26                if shared
27                    .capacity
28                    .is_some_and(|cap| cap.get() <= shared.buffer.len())
29                {
30                    // Full.
31                    shared.send_wakers.push(ctx.waker().clone());
32                    Poll::Pending
33                } else {
34                    shared.buffer.push_back(item.take().unwrap());
35                    shared.wake_receiver();
36                    Poll::Ready(Ok(()))
37                }
38            } else {
39                // Closed.
40                Poll::Ready(Err(SendError(item.take().unwrap())))
41            }
42        })
43        .await
44    }
45
46    /// Tries to send the value to the receiver without blocking.
47    ///
48    /// Returns an error if the destination is closed or if the buffer is at capacity.
49    ///
50    /// [`TrySendError::Full`] will never be returned if this is an unbounded channel.
51    pub fn try_send(&self, item: T) -> Result<(), TrySendError<T>> {
52        if let Some(strong) = Weak::upgrade(&self.weak) {
53            let mut shared = strong.borrow_mut();
54            if shared
55                .capacity
56                .is_some_and(|cap| cap.get() <= shared.buffer.len())
57            {
58                Err(TrySendError::Full(item))
59            } else {
60                shared.buffer.push_back(item);
61                shared.wake_receiver();
62                Ok(())
63            }
64        } else {
65            Err(TrySendError::Closed(item))
66        }
67    }
68
69    /// Close this sender. No more messages can be sent from this sender.
70    ///
71    /// Note that this only closes the channel from the view-point of this sender. The channel
72    /// remains open until all senders have gone away, or until the [`Receiver`] closes the channel.
73    pub fn close_this_sender(&mut self) {
74        self.weak = Weak::new();
75    }
76
77    /// If this sender or the corresponding [`Receiver`] is closed.
78    pub fn is_closed(&self) -> bool {
79        0 == self.weak.strong_count()
80    }
81}
82impl<T> Clone for Sender<T> {
83    fn clone(&self) -> Self {
84        Self {
85            weak: self.weak.clone(),
86        }
87    }
88}
89impl<T> Drop for Sender<T> {
90    fn drop(&mut self) {
91        // Really we should only do this if we're the very last sender,
92        // But `1 == self.weak.weak_count()` seems unreliable.
93        if let Some(strong) = self.weak.upgrade() {
94            strong.borrow_mut().wake_receiver();
95        }
96    }
97}
98
99impl<T> Sink<T> for Sender<T> {
100    type Error = TrySendError<Option<T>>;
101
102    fn poll_ready(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
103        if let Some(strong) = Weak::upgrade(&self.weak) {
104            let mut shared = strong.borrow_mut();
105            if shared
106                .capacity
107                .is_some_and(|cap| cap.get() <= shared.buffer.len())
108            {
109                // Full.
110                shared.send_wakers.push(ctx.waker().clone());
111                Poll::Pending
112            } else {
113                // Has room.
114                Poll::Ready(Ok(()))
115            }
116        } else {
117            // Closed
118            Poll::Ready(Err(TrySendError::Closed(None)))
119        }
120    }
121
122    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
123        self.try_send(item).map_err(|e| match e {
124            TrySendError::Full(item) => TrySendError::Full(Some(item)),
125            TrySendError::Closed(item) => TrySendError::Closed(Some(item)),
126        })
127    }
128
129    fn poll_flush(self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
130        Poll::Ready(Ok(()))
131    }
132
133    fn poll_close(
134        mut self: Pin<&mut Self>,
135        ctx: &mut Context<'_>,
136    ) -> Poll<Result<(), Self::Error>> {
137        ready!(self.as_mut().poll_flush(ctx))?;
138        Pin::into_inner(self).close_this_sender();
139        Poll::Ready(Ok(()))
140    }
141}
142
143/// Receiving half of an unsync MPSC.
144pub struct Receiver<T> {
145    strong: Rc<RefCell<Shared<T>>>,
146}
147impl<T> Receiver<T> {
148    /// Receive a value asynchronously.
149    pub async fn recv(&mut self) -> Option<T> {
150        std::future::poll_fn(|ctx| self.poll_recv(ctx)).await
151    }
152
153    /// Poll for a value.
154    /// NOTE: takes `&mut self` to prevent multiple concurrent receives.
155    pub fn poll_recv(&mut self, ctx: &Context<'_>) -> Poll<Option<T>> {
156        let mut shared = self.strong.borrow_mut();
157        if let Some(value) = shared.buffer.pop_front() {
158            shared.wake_sender();
159            Poll::Ready(Some(value))
160        } else if 0 == Rc::weak_count(&self.strong) {
161            Poll::Ready(None) // Empty and dropped.
162        } else {
163            shared.recv_waker = Some(ctx.waker().clone());
164            Poll::Pending
165        }
166    }
167
168    /// Closes this receiving end, not allowing more values to be sent while still allowing already-sent values to be consumed.
169    pub fn close(&mut self) {
170        assert_eq!(
171            1,
172            Rc::strong_count(&self.strong),
173            "BUG: receiver has non-exclusive Rc."
174        );
175
176        let new_shared = {
177            let mut shared = self.strong.borrow_mut();
178            shared.wake_all_senders();
179
180            Shared {
181                buffer: std::mem::take(&mut shared.buffer),
182                ..Default::default()
183            }
184        };
185        self.strong = Rc::new(RefCell::new(new_shared));
186        // Drop old `Rc`, invalidating all `Weak`s.
187    }
188}
189impl<T> Drop for Receiver<T> {
190    fn drop(&mut self) {
191        self.close()
192    }
193}
194impl<T> Stream for Receiver<T> {
195    type Item = T;
196
197    fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198        self.poll_recv(ctx)
199    }
200}
201
202/// Struct shared between sender and receiver.
203struct Shared<T> {
204    buffer: VecDeque<T>,
205    capacity: Option<NonZeroUsize>,
206    send_wakers: SmallVec<[Waker; 1]>,
207    recv_waker: Option<Waker>,
208}
209impl<T> Shared<T> {
210    /// Wakes one sender (if there are any wakers), and removes the waker.
211    pub fn wake_sender(&mut self) {
212        if let Some(waker) = self.send_wakers.pop() {
213            waker.wake();
214        }
215    }
216    /// Wakes all senders and removes their wakers.
217    pub fn wake_all_senders(&mut self) {
218        self.send_wakers.drain(..).for_each(Waker::wake);
219    }
220    /// Wakes the receiver (if the waker is set) and removes it.
221    pub fn wake_receiver(&mut self) {
222        if let Some(waker) = self.recv_waker.take() {
223            waker.wake();
224        }
225    }
226}
227impl<T> Default for Shared<T> {
228    fn default() -> Self {
229        let (buffer, capacity, send_wakers, recv_waker) = Default::default();
230        Self {
231            buffer,
232            capacity,
233            send_wakers,
234            recv_waker,
235        }
236    }
237}
238
239/// Create an unsync MPSC channel, either bounded (if `capacity` is `Some`) or unbounded (if `capacity` is `None`).
240pub fn channel<T>(capacity: Option<NonZeroUsize>) -> (Sender<T>, Receiver<T>) {
241    let (buffer, send_wakers, recv_waker) = Default::default();
242    let shared = Rc::new(RefCell::new(Shared {
243        buffer,
244        capacity,
245        send_wakers,
246        recv_waker,
247    }));
248    let sender = Sender {
249        weak: Rc::downgrade(&shared),
250    };
251    let receiver = Receiver { strong: shared };
252    (sender, receiver)
253}
254
255/// Create a bounded unsync MPSC channel. Panics if capacity is zero.
256pub fn bounded<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
257    let capacity = NonZeroUsize::new(capacity);
258    assert!(capacity.is_some(), "Capacity cannot be zero.");
259    channel(capacity)
260}
261
262/// Create an unbounded unsync MPSC channel.
263pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
264    channel(None)
265}
266
267#[cfg(test)]
268mod test {
269    use std::time::Duration;
270
271    use futures::StreamExt;
272    use rand::Rng;
273    use tokio::task::LocalSet;
274
275    use super::*;
276
277    async fn delay(n: u64) {
278        let millis = rand::thread_rng().gen_range(0..n);
279        tokio::time::sleep(Duration::from_millis(millis)).await;
280    }
281
282    #[crate::test]
283    async fn test_send_multiple_outstanding() {
284        let (send, recv) = bounded::<u64>(10);
285
286        let a_fut = send.send(123);
287        let b_fut = send.send(234);
288
289        futures::future::try_join(a_fut, b_fut).await.unwrap();
290        drop(send);
291
292        let mut out: Vec<_> = recv.collect().await;
293        out.sort_unstable();
294        assert_eq!([123, 234], &*out);
295    }
296
297    #[crate::test]
298    async fn test_spsc_random() {
299        let runs = (0..1_000).map(|_| async {
300            let (send, recv) = bounded::<u64>(10);
301
302            let local = LocalSet::new();
303
304            local.spawn_local(async move {
305                for x in 0..100 {
306                    send.send(x).await.unwrap();
307                    delay(4).await;
308                }
309            });
310            local.spawn_local(async move {
311                delay(5).await; // Delay once first.
312
313                let mut recv = recv;
314                let mut i = 0;
315                while let Some(x) = recv.recv().await {
316                    assert_eq!(i, x);
317                    i += 1;
318                    delay(5).await;
319                }
320                assert_eq!(100, i);
321            });
322            local.await;
323        });
324        futures::future::join_all(runs).await;
325    }
326
327    #[crate::test]
328    async fn test_mpsc_random() {
329        let runs = (0..1_000).map(|_| async {
330            let (send, recv) = bounded::<u64>(30);
331            let send_a = send.clone();
332            let send_b = send.clone();
333            let send_c = send;
334
335            let local = LocalSet::new();
336
337            local.spawn_local(async move {
338                for x in 0..100 {
339                    send_a.send(x).await.unwrap();
340                    delay(5).await;
341                }
342            });
343            local.spawn_local(async move {
344                for x in 100..200 {
345                    send_b.send(x).await.unwrap();
346                    delay(5).await;
347                }
348            });
349            local.spawn_local(async move {
350                for x in 200..300 {
351                    send_c.send(x).await.unwrap();
352                    delay(5).await;
353                }
354            });
355            local.spawn_local(async move {
356                delay(1).await; // Delay once first.
357
358                let mut recv = recv;
359                let mut vec = Vec::new();
360                while let Some(x) = recv.next().await {
361                    vec.push(x);
362                    delay(1).await;
363                }
364                assert_eq!(300, vec.len());
365                vec.sort_unstable();
366                for (i, &x) in vec.iter().enumerate() {
367                    assert_eq!(i as u64, x);
368                }
369            });
370            local.await;
371        });
372        futures::future::join_all(runs).await;
373    }
374
375    #[crate::test]
376    async fn test_stream_sink_loop() {
377        use futures::{SinkExt, StreamExt};
378
379        const N: usize = 100;
380
381        let (mut send, mut recv) = unbounded::<usize>();
382        send.send(0).await.unwrap();
383        // Connect it to itself
384        let mut recv_ref = recv.by_ref().map(|x| x + 1).map(Ok).take(N);
385        send.send_all(&mut recv_ref).await.unwrap();
386        assert_eq!(Some(N), recv.recv().await);
387    }
388}