Skip to main content

sinktools/
lazy_sink_source.rs

1//! [`LazySinkSource`], and related items.
2#![cfg(feature = "std")]
3
4use std::cell::RefCell;
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::rc::Rc;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll, Wake, Waker};
10
11use futures_util::{Sink, Stream, ready};
12
13struct MultiWaker {
14    wakers: Mutex<Vec<Waker>>,
15}
16
17impl MultiWaker {
18    fn new(waker: &Waker) -> Self {
19        MultiWaker {
20            wakers: Mutex::new(vec![waker.clone()]),
21        }
22    }
23
24    fn push(&self, waker: &Waker) {
25        let mut guard = self.wakers.lock().unwrap();
26        guard.push(waker.clone());
27    }
28}
29
30impl Wake for MultiWaker {
31    fn wake(self: Arc<Self>) {
32        let mut wakers = Vec::new();
33
34        {
35            let mut guard = self.wakers.lock().unwrap();
36            std::mem::swap(&mut wakers, &mut *guard);
37        }
38
39        for waker in wakers {
40            waker.wake();
41        }
42    }
43}
44
45enum SharedState<Fut, St, Si, Item> {
46    Uninit {
47        future: Pin<Box<Fut>>,
48    },
49    Thunkulating {
50        future: Pin<Box<Fut>>,
51        item: Option<Item>,
52        multi_waker: Option<Arc<MultiWaker>>,
53    },
54    Done {
55        stream: Pin<Box<St>>,
56        sink: Pin<Box<Si>>,
57        buf: Option<Item>,
58    },
59    Taken,
60}
61
62/// A lazy sink-source that can be split into a sink and a source. The internal state is initialized when the first item is attempted to be pulled from the source half, or when the first item is sent to the sink half.
63pub struct LazySinkSource<Fut, St, Si, Item, Error> {
64    state: Rc<RefCell<SharedState<Fut, St, Si, Item>>>,
65    _phantom: PhantomData<Error>,
66}
67
68impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error> {
69    /// Creates a new `LazySinkSource` with the given initialization future.
70    pub fn new(future: Fut) -> Self {
71        Self {
72            state: Rc::new(RefCell::new(SharedState::Uninit {
73                future: Box::pin(future),
74            })),
75            _phantom: PhantomData,
76        }
77    }
78
79    /// Splits into a sink and stream that share the same underlying connection.
80    pub fn split(
81        self,
82    ) -> (
83        LazySinkHalf<Fut, St, Si, Item, Error>,
84        LazySourceHalf<Fut, St, Si, Item, Error>,
85    ) {
86        let sink = LazySinkHalf {
87            state: Rc::clone(&self.state),
88            _phantom: PhantomData,
89        };
90        let stream = LazySourceHalf {
91            state: self.state,
92            _phantom: PhantomData,
93        };
94        (sink, stream)
95    }
96}
97
98/// Sink half of the SinkSource
99pub struct LazySinkHalf<Fut, St, Si, Item, Error> {
100    state: Rc<RefCell<SharedState<Fut, St, Si, Item>>>,
101    _phantom: PhantomData<Error>,
102}
103
104/// Stream half of the SinkSource
105pub struct LazySourceHalf<Fut, St, Si, Item, Error> {
106    state: Rc<RefCell<SharedState<Fut, St, Si, Item>>>,
107    _phantom: PhantomData<Error>,
108}
109
110impl<Fut, St, Si, Item, Error> Sink<Item> for LazySinkHalf<Fut, St, Si, Item, Error>
111where
112    Fut: Future<Output = Result<(St, Si), Error>>,
113    St: Stream,
114    Si: Sink<Item>,
115    Error: From<Si::Error>,
116{
117    type Error = Error;
118
119    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120        let mut state = self.state.borrow_mut();
121
122        if let SharedState::Uninit { .. } = &*state {
123            return Poll::Ready(Ok(()));
124        }
125
126        if let SharedState::Thunkulating {
127            future,
128            item,
129            multi_waker,
130        } = &mut *state
131        {
132            let waker = if let Some(waker) = multi_waker {
133                waker.push(cx.waker());
134                Waker::from(waker.clone())
135            } else {
136                let waker = Arc::new(MultiWaker::new(cx.waker()));
137                *multi_waker = Some(waker.clone());
138                Waker::from(waker)
139            };
140
141            let mut new_context = Context::from_waker(&waker);
142
143            match future.as_mut().poll(&mut new_context) {
144                Poll::Ready(Ok((stream, sink))) => {
145                    let buf = item.take();
146                    *state = SharedState::Done {
147                        stream: Box::pin(stream),
148                        sink: Box::pin(sink),
149                        buf,
150                    };
151                }
152                Poll::Ready(Err(e)) => {
153                    return Poll::Ready(Err(e));
154                }
155                Poll::Pending => {
156                    return Poll::Pending;
157                }
158            }
159        }
160
161        if let SharedState::Done { sink, buf, .. } = &mut *state {
162            if buf.is_some() {
163                ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
164                sink.as_mut().start_send(buf.take().unwrap())?;
165            }
166            let result = sink.as_mut().poll_ready(cx).map_err(From::from);
167            return result;
168        }
169
170        panic!("LazySinkHalf in invalid state.");
171    }
172
173    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
174        let mut state = self.state.borrow_mut();
175
176        if let SharedState::Uninit { .. } = &*state {
177            let old_state = std::mem::replace(&mut *state, SharedState::Taken);
178            if let SharedState::Uninit { future } = old_state {
179                *state = SharedState::Thunkulating {
180                    future,
181                    item: Some(item),
182                    multi_waker: None,
183                };
184
185                return Ok(());
186            }
187        }
188
189        if let SharedState::Thunkulating { .. } = &mut *state {
190            panic!("LazySinkHalf not ready.");
191        }
192
193        if let SharedState::Done { sink, buf, .. } = &mut *state {
194            debug_assert!(buf.is_none());
195            let result = sink.as_mut().start_send(item).map_err(From::from);
196            return result;
197        }
198
199        panic!("LazySinkHalf not ready.");
200    }
201
202    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
203        let mut state = self.state.borrow_mut();
204
205        if let SharedState::Uninit { .. } = &*state {
206            return Poll::Ready(Ok(()));
207        }
208
209        if let SharedState::Thunkulating {
210            future,
211            item,
212            multi_waker,
213        } = &mut *state
214        {
215            let waker = if let Some(waker) = multi_waker {
216                waker.push(cx.waker());
217                Waker::from(waker.clone())
218            } else {
219                let waker = Arc::new(MultiWaker::new(cx.waker()));
220                *multi_waker = Some(waker.clone());
221                Waker::from(waker)
222            };
223
224            let mut new_context = Context::from_waker(&waker);
225
226            match future.as_mut().poll(&mut new_context) {
227                Poll::Ready(Ok((stream, sink))) => {
228                    let buf = item.take();
229                    *state = SharedState::Done {
230                        stream: Box::pin(stream),
231                        sink: Box::pin(sink),
232                        buf,
233                    };
234                }
235                Poll::Ready(Err(e)) => {
236                    return Poll::Ready(Err(e));
237                }
238                Poll::Pending => {
239                    return Poll::Pending;
240                }
241            }
242        }
243
244        if let SharedState::Done { sink, buf, .. } = &mut *state {
245            if buf.is_some() {
246                ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
247                sink.as_mut().start_send(buf.take().unwrap())?;
248            }
249            let result = sink.as_mut().poll_flush(cx).map_err(From::from);
250            return result;
251        }
252
253        panic!("LazySinkHalf in invalid state.");
254    }
255
256    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
257        let mut state = self.state.borrow_mut();
258
259        if let SharedState::Uninit { .. } = &*state {
260            return Poll::Ready(Ok(()));
261        }
262
263        if let SharedState::Thunkulating {
264            future,
265            item,
266            multi_waker,
267        } = &mut *state
268        {
269            let waker = if let Some(waker) = multi_waker {
270                waker.push(cx.waker());
271                Waker::from(waker.clone())
272            } else {
273                let waker = Arc::new(MultiWaker::new(cx.waker()));
274                *multi_waker = Some(waker.clone());
275                Waker::from(waker)
276            };
277
278            let mut new_context = Context::from_waker(&waker);
279
280            match future.as_mut().poll(&mut new_context) {
281                Poll::Ready(Ok((stream, sink))) => {
282                    let buf = item.take();
283                    *state = SharedState::Done {
284                        stream: Box::pin(stream),
285                        sink: Box::pin(sink),
286                        buf,
287                    };
288                }
289                Poll::Ready(Err(e)) => {
290                    return Poll::Ready(Err(e));
291                }
292                Poll::Pending => {
293                    return Poll::Pending;
294                }
295            }
296        }
297
298        if let SharedState::Done { sink, buf, .. } = &mut *state {
299            if buf.is_some() {
300                ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
301                sink.as_mut().start_send(buf.take().unwrap())?;
302            }
303            let result = sink.as_mut().poll_close(cx).map_err(From::from);
304            return result;
305        }
306
307        panic!("LazySinkHalf in invalid state.");
308    }
309}
310
311impl<Fut, St, Si, Item, Error> Stream for LazySourceHalf<Fut, St, Si, Item, Error>
312where
313    Fut: Future<Output = Result<(St, Si), Error>>,
314    St: Stream,
315    Si: Sink<Item>,
316{
317    type Item = St::Item;
318
319    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
320        let mut state = self.state.borrow_mut();
321
322        if let SharedState::Uninit { .. } = &*state {
323            let old_state = std::mem::replace(&mut *state, SharedState::Taken);
324            if let SharedState::Uninit { future } = old_state {
325                *state = SharedState::Thunkulating {
326                    future,
327                    item: None,
328                    multi_waker: None,
329                };
330            } else {
331                unreachable!();
332            }
333        }
334
335        if let SharedState::Thunkulating {
336            future,
337            item,
338            multi_waker,
339        } = &mut *state
340        {
341            let waker = if let Some(waker) = multi_waker {
342                waker.push(cx.waker());
343                Waker::from(waker.clone())
344            } else {
345                let waker = Arc::new(MultiWaker::new(cx.waker()));
346                *multi_waker = Some(waker.clone());
347                Waker::from(waker)
348            };
349
350            let mut new_context = Context::from_waker(&waker);
351
352            match future.as_mut().poll(&mut new_context) {
353                Poll::Ready(Ok((stream, sink))) => {
354                    let buf = item.take();
355                    *state = SharedState::Done {
356                        stream: Box::pin(stream),
357                        sink: Box::pin(sink),
358                        buf,
359                    };
360                }
361
362                Poll::Ready(Err(_)) => {
363                    return Poll::Ready(None);
364                }
365
366                Poll::Pending => {
367                    return Poll::Pending;
368                }
369            }
370        }
371
372        if let SharedState::Done { stream, .. } = &mut *state {
373            let result = stream.as_mut().poll_next(cx);
374            match &result {
375                Poll::Ready(Some(_)) => {}
376                Poll::Ready(None) => {}
377                Poll::Pending => {}
378            }
379            return result;
380        }
381
382        panic!("LazySourceHalf in invalid state.");
383    }
384}
385
386#[cfg(test)]
387mod test {
388    use futures_util::{SinkExt, StreamExt};
389
390    use super::*;
391
392    #[tokio::test(flavor = "current_thread")]
393    async fn tcp_stream_drives_initialization() {
394        use tokio::net::{TcpListener, TcpStream};
395        use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
396
397        let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
398
399        let local = tokio::task::LocalSet::new();
400        local
401            .run_until(async {
402                let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
403                let addr = listener.local_addr().unwrap();
404
405                let sink_source = LazySinkSource::new(async move {
406                    // initialization is at least partially started now.
407                    initialization_tx.send(()).unwrap();
408
409                    let (stream, _) = listener.accept().await.unwrap();
410                    let (rx, tx) = stream.into_split();
411                    let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
412                    let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
413                    Ok::<_, std::io::Error>((fr, fw))
414                });
415
416                let (mut sink, mut stream) = sink_source.split();
417
418                let stream_task = tokio::task::spawn_local(async move { stream.next().await });
419
420                initialization_rx.await.unwrap(); // ensure that the runtime starts driving initialization via the stream.next() call.
421
422                let sink_task = tokio::task::spawn_local(async move {
423                    SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
424                        .await
425                        .unwrap();
426                });
427
428                // try to be really sure that the above sink_task is waiting on the same future to be resolved.
429                for _ in 0..20 {
430                    tokio::task::yield_now().await
431                }
432
433                // trigger further initialization of the future.
434                let mut socket = TcpStream::connect(addr).await.unwrap();
435                let (client_rx, client_tx) = socket.split();
436                let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
437                let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
438
439                // try to be really sure that the effects of the above initialization completing are propagated.
440                for _ in 0..20 {
441                    tokio::task::yield_now().await
442                }
443
444                assert!(!stream_task.is_finished()); // We haven't sent anything yet, so the stream should definitely not be resolved now.
445
446                // Now actually send an item so that the stream will wake up and have an item ready to pull from it.
447                SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
448                    .await
449                    .unwrap();
450
451                assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
452                sink_task.await.unwrap();
453
454                assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
455            })
456            .await;
457    }
458
459    #[tokio::test(flavor = "current_thread")]
460    async fn tcp_sink_drives_initialization() {
461        use tokio::net::{TcpListener, TcpStream};
462        use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
463
464        let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
465
466        let local = tokio::task::LocalSet::new();
467        local
468            .run_until(async {
469                let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
470                let addr = listener.local_addr().unwrap();
471
472                let sink_source = LazySinkSource::new(async move {
473                    // initialization is at least partially started now.
474                    initialization_tx.send(()).unwrap();
475
476                    let (stream, _) = listener.accept().await.unwrap();
477                    let (rx, tx) = stream.into_split();
478                    let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
479                    let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
480                    Ok::<_, std::io::Error>((fr, fw))
481                });
482
483                let (mut sink, mut stream) = sink_source.split();
484
485                let sink_task = tokio::task::spawn_local(async move {
486                    SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
487                        .await
488                        .unwrap();
489                });
490
491                initialization_rx.await.unwrap(); // ensure that the runtime starts driving initialization via the stream.next() call.
492
493                let stream_task = tokio::task::spawn_local(async move { stream.next().await });
494
495                // try to be really sure that the above sink_task is waiting on the same future to be resolved.
496                for _ in 0..20 {
497                    tokio::task::yield_now().await
498                }
499
500                assert!(!sink_task.is_finished()); // We haven't sent anything yet, so the stream should definitely not be resolved now.
501
502                // trigger further initialization of the future.
503                let mut socket = TcpStream::connect(addr).await.unwrap();
504                let (client_rx, client_tx) = socket.split();
505                let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
506                let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
507
508                // try to be really sure that the effects of the above initialization completing are propagated.
509                for _ in 0..20 {
510                    tokio::task::yield_now().await
511                }
512
513                assert!(sink_task.is_finished()); // We haven't sent anything yet, so the stream should definitely not be resolved now.
514
515                assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
516
517                // Now actually send an item so that the stream will wake up and have an item ready to pull from it.
518                SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
519                    .await
520                    .unwrap();
521
522                assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
523                sink_task.await.unwrap();
524            })
525            .await;
526    }
527}