dfir_rs/util/unsync/
mpsc.rs1use std::cell::RefCell;
4use std::collections::VecDeque;
5use std::num::NonZeroUsize;
6use std::pin::Pin;
7use std::rc::{Rc, Weak};
8use std::task::{Context, Poll, Waker};
9
10use futures::{Sink, Stream, ready};
11use smallvec::SmallVec;
12#[doc(inline)]
13pub use tokio::sync::mpsc::error::{SendError, TrySendError};
14
15pub struct Sender<T> {
17 weak: Weak<RefCell<Shared<T>>>,
18}
19impl<T> Sender<T> {
20 pub async fn send(&self, item: T) -> Result<(), SendError<T>> {
22 let mut item = Some(item);
23 std::future::poll_fn(move |ctx| {
24 if let Some(strong) = Weak::upgrade(&self.weak) {
25 let mut shared = strong.borrow_mut();
26 if shared
27 .capacity
28 .is_some_and(|cap| cap.get() <= shared.buffer.len())
29 {
30 shared.send_wakers.push(ctx.waker().clone());
32 Poll::Pending
33 } else {
34 shared.buffer.push_back(item.take().unwrap());
35 shared.wake_receiver();
36 Poll::Ready(Ok(()))
37 }
38 } else {
39 Poll::Ready(Err(SendError(item.take().unwrap())))
41 }
42 })
43 .await
44 }
45
46 pub fn try_send(&self, item: T) -> Result<(), TrySendError<T>> {
52 if let Some(strong) = Weak::upgrade(&self.weak) {
53 let mut shared = strong.borrow_mut();
54 if shared
55 .capacity
56 .is_some_and(|cap| cap.get() <= shared.buffer.len())
57 {
58 Err(TrySendError::Full(item))
59 } else {
60 shared.buffer.push_back(item);
61 shared.wake_receiver();
62 Ok(())
63 }
64 } else {
65 Err(TrySendError::Closed(item))
66 }
67 }
68
69 pub fn close_this_sender(&mut self) {
74 self.weak = Weak::new();
75 }
76
77 pub fn is_closed(&self) -> bool {
79 0 == self.weak.strong_count()
80 }
81}
82impl<T> Clone for Sender<T> {
83 fn clone(&self) -> Self {
84 Self {
85 weak: self.weak.clone(),
86 }
87 }
88}
89impl<T> Drop for Sender<T> {
90 fn drop(&mut self) {
91 if let Some(strong) = self.weak.upgrade() {
94 strong.borrow_mut().wake_receiver();
95 }
96 }
97}
98
99impl<T> Sink<T> for Sender<T> {
100 type Error = TrySendError<Option<T>>;
101
102 fn poll_ready(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
103 if let Some(strong) = Weak::upgrade(&self.weak) {
104 let mut shared = strong.borrow_mut();
105 if shared
106 .capacity
107 .is_some_and(|cap| cap.get() <= shared.buffer.len())
108 {
109 shared.send_wakers.push(ctx.waker().clone());
111 Poll::Pending
112 } else {
113 Poll::Ready(Ok(()))
115 }
116 } else {
117 Poll::Ready(Err(TrySendError::Closed(None)))
119 }
120 }
121
122 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
123 self.try_send(item).map_err(|e| match e {
124 TrySendError::Full(item) => TrySendError::Full(Some(item)),
125 TrySendError::Closed(item) => TrySendError::Closed(Some(item)),
126 })
127 }
128
129 fn poll_flush(self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
130 Poll::Ready(Ok(()))
131 }
132
133 fn poll_close(
134 mut self: Pin<&mut Self>,
135 ctx: &mut Context<'_>,
136 ) -> Poll<Result<(), Self::Error>> {
137 ready!(self.as_mut().poll_flush(ctx))?;
138 Pin::into_inner(self).close_this_sender();
139 Poll::Ready(Ok(()))
140 }
141}
142
143pub struct Receiver<T> {
145 strong: Rc<RefCell<Shared<T>>>,
146}
147impl<T> Receiver<T> {
148 pub async fn recv(&mut self) -> Option<T> {
150 std::future::poll_fn(|ctx| self.poll_recv(ctx)).await
151 }
152
153 pub fn poll_recv(&mut self, ctx: &Context<'_>) -> Poll<Option<T>> {
156 let mut shared = self.strong.borrow_mut();
157 if let Some(value) = shared.buffer.pop_front() {
158 shared.wake_sender();
159 Poll::Ready(Some(value))
160 } else if 0 == Rc::weak_count(&self.strong) {
161 Poll::Ready(None) } else {
163 shared.recv_waker = Some(ctx.waker().clone());
164 Poll::Pending
165 }
166 }
167
168 pub fn close(&mut self) {
170 assert_eq!(
171 1,
172 Rc::strong_count(&self.strong),
173 "BUG: receiver has non-exclusive Rc."
174 );
175
176 let new_shared = {
177 let mut shared = self.strong.borrow_mut();
178 shared.wake_all_senders();
179
180 Shared {
181 buffer: std::mem::take(&mut shared.buffer),
182 ..Default::default()
183 }
184 };
185 self.strong = Rc::new(RefCell::new(new_shared));
186 }
188}
189impl<T> Drop for Receiver<T> {
190 fn drop(&mut self) {
191 self.close()
192 }
193}
194impl<T> Stream for Receiver<T> {
195 type Item = T;
196
197 fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198 self.poll_recv(ctx)
199 }
200}
201
202struct Shared<T> {
204 buffer: VecDeque<T>,
205 capacity: Option<NonZeroUsize>,
206 send_wakers: SmallVec<[Waker; 1]>,
207 recv_waker: Option<Waker>,
208}
209impl<T> Shared<T> {
210 pub fn wake_sender(&mut self) {
212 if let Some(waker) = self.send_wakers.pop() {
213 waker.wake();
214 }
215 }
216 pub fn wake_all_senders(&mut self) {
218 self.send_wakers.drain(..).for_each(Waker::wake);
219 }
220 pub fn wake_receiver(&mut self) {
222 if let Some(waker) = self.recv_waker.take() {
223 waker.wake();
224 }
225 }
226}
227impl<T> Default for Shared<T> {
228 fn default() -> Self {
229 let (buffer, capacity, send_wakers, recv_waker) = Default::default();
230 Self {
231 buffer,
232 capacity,
233 send_wakers,
234 recv_waker,
235 }
236 }
237}
238
239pub fn channel<T>(capacity: Option<NonZeroUsize>) -> (Sender<T>, Receiver<T>) {
241 let (buffer, send_wakers, recv_waker) = Default::default();
242 let shared = Rc::new(RefCell::new(Shared {
243 buffer,
244 capacity,
245 send_wakers,
246 recv_waker,
247 }));
248 let sender = Sender {
249 weak: Rc::downgrade(&shared),
250 };
251 let receiver = Receiver { strong: shared };
252 (sender, receiver)
253}
254
255pub fn bounded<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
257 let capacity = NonZeroUsize::new(capacity);
258 assert!(capacity.is_some(), "Capacity cannot be zero.");
259 channel(capacity)
260}
261
262pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
264 channel(None)
265}
266
267#[cfg(test)]
268mod test {
269 use futures::StreamExt;
270 use rand::Rng;
271 use tokio::task::LocalSet;
272 use web_time::Duration;
273
274 use super::*;
275
276 async fn delay(n: u64) {
277 let millis = rand::thread_rng().gen_range(0..n);
278 tokio::time::sleep(Duration::from_millis(millis)).await;
279 }
280
281 #[crate::test]
282 async fn test_send_multiple_outstanding() {
283 let (send, recv) = bounded::<u64>(10);
284
285 let a_fut = send.send(123);
286 let b_fut = send.send(234);
287
288 futures::future::try_join(a_fut, b_fut).await.unwrap();
289 drop(send);
290
291 let mut out: Vec<_> = recv.collect().await;
292 out.sort_unstable();
293 assert_eq!([123, 234], &*out);
294 }
295
296 #[crate::test]
297 async fn test_spsc_random() {
298 let runs = (0..1_000).map(|_| async {
299 let (send, recv) = bounded::<u64>(10);
300
301 let local = LocalSet::new();
302
303 local.spawn_local(async move {
304 for x in 0..100 {
305 send.send(x).await.unwrap();
306 delay(4).await;
307 }
308 });
309 local.spawn_local(async move {
310 delay(5).await; let mut recv = recv;
313 let mut i = 0;
314 while let Some(x) = recv.recv().await {
315 assert_eq!(i, x);
316 i += 1;
317 delay(5).await;
318 }
319 assert_eq!(100, i);
320 });
321 local.await;
322 });
323 futures::future::join_all(runs).await;
324 }
325
326 #[crate::test]
327 async fn test_mpsc_random() {
328 let runs = (0..1_000).map(|_| async {
329 let (send, recv) = bounded::<u64>(30);
330 let send_a = send.clone();
331 let send_b = send.clone();
332 let send_c = send;
333
334 let local = LocalSet::new();
335
336 local.spawn_local(async move {
337 for x in 0..100 {
338 send_a.send(x).await.unwrap();
339 delay(5).await;
340 }
341 });
342 local.spawn_local(async move {
343 for x in 100..200 {
344 send_b.send(x).await.unwrap();
345 delay(5).await;
346 }
347 });
348 local.spawn_local(async move {
349 for x in 200..300 {
350 send_c.send(x).await.unwrap();
351 delay(5).await;
352 }
353 });
354 local.spawn_local(async move {
355 delay(1).await; let mut recv = recv;
358 let mut vec = Vec::new();
359 while let Some(x) = recv.next().await {
360 vec.push(x);
361 delay(1).await;
362 }
363 assert_eq!(300, vec.len());
364 vec.sort_unstable();
365 for (i, &x) in vec.iter().enumerate() {
366 assert_eq!(i as u64, x);
367 }
368 });
369 local.await;
370 });
371 futures::future::join_all(runs).await;
372 }
373
374 #[crate::test]
375 async fn test_stream_sink_loop() {
376 use futures::{SinkExt, StreamExt};
377
378 const N: usize = 100;
379
380 let (mut send, mut recv) = unbounded::<usize>();
381 send.send(0).await.unwrap();
382 let mut recv_ref = recv.by_ref().map(|x| x + 1).map(Ok).take(N);
384 send.send_all(&mut recv_ref).await.unwrap();
385 assert_eq!(Some(N), recv.recv().await);
386 }
387}