sinktools/
lazy_sink_source.rs

1//! [`LazySinkSource`], and related items.
2
3use core::marker::PhantomData;
4use core::pin::Pin;
5use core::task::{Context, Poll, Waker};
6use std::cell::RefCell;
7use std::rc::Rc;
8use std::sync::{Arc, Mutex};
9use std::task::Wake;
10
11use futures_util::{Sink, Stream, ready};
12
13struct MultiWaker {
14    wakers: Mutex<Vec<Waker>>,
15}
16
17impl MultiWaker {
18    fn new(waker: &Waker) -> Self {
19        MultiWaker {
20            wakers: Mutex::new(vec![waker.clone()]),
21        }
22    }
23
24    fn push(&self, waker: &Waker) {
25        let mut guard = self.wakers.lock().unwrap();
26        guard.push(waker.clone());
27    }
28}
29
30impl Wake for MultiWaker {
31    fn wake(self: Arc<Self>) {
32        let mut wakers = Vec::new();
33
34        {
35            let mut guard = self.wakers.lock().unwrap();
36            std::mem::swap(&mut wakers, &mut *guard);
37        }
38
39        for waker in wakers {
40            waker.wake();
41        }
42    }
43}
44
45enum SharedState<Fut, St, Si, Item> {
46    Uninit {
47        future: Pin<Box<Fut>>,
48    },
49    Thunkulating {
50        future: Pin<Box<Fut>>,
51        item: Option<Item>,
52        multi_waker: Option<Arc<MultiWaker>>,
53    },
54    Done {
55        stream: Pin<Box<St>>,
56        sink: Pin<Box<Si>>,
57        buf: Option<Item>,
58    },
59    Taken,
60}
61
62/// A lazy sink-source that can be split into a sink and a source. The internal state is initialized when the first item is attempted to be pulled from the source half, or when the first item is sent to the sink half.
63pub struct LazySinkSource<Fut, St, Si, Item, Error> {
64    state: Rc<RefCell<SharedState<Fut, St, Si, Item>>>,
65    _phantom: PhantomData<Error>,
66}
67
68impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error> {
69    /// Creates a new `LazySinkSource` with the given initialization future.
70    pub fn new(future: Fut) -> Self {
71        Self {
72            state: Rc::new(RefCell::new(SharedState::Uninit {
73                future: Box::pin(future),
74            })),
75            _phantom: PhantomData,
76        }
77    }
78
79    #[expect(
80        clippy::type_complexity,
81        reason = "this type is actually fine and not too complex."
82    )]
83    /// Splits into a sink and stream that share the same underlying connection.
84    pub fn split(
85        self,
86    ) -> (
87        LazySinkHalf<Fut, St, Si, Item, Error>,
88        LazySourceHalf<Fut, St, Si, Item, Error>,
89    ) {
90        let sink = LazySinkHalf {
91            state: Rc::clone(&self.state),
92            _phantom: PhantomData,
93        };
94        let stream = LazySourceHalf {
95            state: self.state,
96            _phantom: PhantomData,
97        };
98        (sink, stream)
99    }
100}
101
102/// Sink half of the SinkSource
103pub struct LazySinkHalf<Fut, St, Si, Item, Error> {
104    state: Rc<RefCell<SharedState<Fut, St, Si, Item>>>,
105    _phantom: PhantomData<Error>,
106}
107
108/// Stream half of the SinkSource
109pub struct LazySourceHalf<Fut, St, Si, Item, Error> {
110    state: Rc<RefCell<SharedState<Fut, St, Si, Item>>>,
111    _phantom: PhantomData<Error>,
112}
113
114impl<Fut, St, Si, Item, Error> Sink<Item> for LazySinkHalf<Fut, St, Si, Item, Error>
115where
116    Fut: Future<Output = Result<(St, Si), Error>>,
117    St: Stream,
118    Si: Sink<Item>,
119    Error: From<Si::Error>,
120{
121    type Error = Error;
122
123    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
124        let mut state = self.state.borrow_mut();
125
126        if let SharedState::Uninit { .. } = &*state {
127            return Poll::Ready(Ok(()));
128        }
129
130        if let SharedState::Thunkulating {
131            future,
132            item,
133            multi_waker,
134        } = &mut *state
135        {
136            let waker = if let Some(waker) = multi_waker {
137                waker.push(cx.waker());
138                Waker::from(waker.clone())
139            } else {
140                let waker = Arc::new(MultiWaker::new(cx.waker()));
141                *multi_waker = Some(waker.clone());
142                Waker::from(waker)
143            };
144
145            let mut new_context = Context::from_waker(&waker);
146
147            match future.as_mut().poll(&mut new_context) {
148                Poll::Ready(Ok((stream, sink))) => {
149                    let buf = item.take();
150                    *state = SharedState::Done {
151                        stream: Box::pin(stream),
152                        sink: Box::pin(sink),
153                        buf,
154                    };
155                }
156                Poll::Ready(Err(e)) => {
157                    return Poll::Ready(Err(e));
158                }
159                Poll::Pending => {
160                    return Poll::Pending;
161                }
162            }
163        }
164
165        if let SharedState::Done { sink, buf, .. } = &mut *state {
166            if buf.is_some() {
167                ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
168                sink.as_mut().start_send(buf.take().unwrap())?;
169            }
170            let result = sink.as_mut().poll_ready(cx).map_err(From::from);
171            return result;
172        }
173
174        panic!("LazySinkHalf in invalid state.");
175    }
176
177    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
178        let mut state = self.state.borrow_mut();
179
180        if let SharedState::Uninit { .. } = &*state {
181            let old_state = std::mem::replace(&mut *state, SharedState::Taken);
182            if let SharedState::Uninit { future } = old_state {
183                *state = SharedState::Thunkulating {
184                    future,
185                    item: Some(item),
186                    multi_waker: None,
187                };
188
189                return Ok(());
190            }
191        }
192
193        if let SharedState::Thunkulating { .. } = &mut *state {
194            panic!("LazySinkHalf not ready.");
195        }
196
197        if let SharedState::Done { sink, buf, .. } = &mut *state {
198            debug_assert!(buf.is_none());
199            let result = sink.as_mut().start_send(item).map_err(From::from);
200            return result;
201        }
202
203        panic!("LazySinkHalf not ready.");
204    }
205
206    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
207        let mut state = self.state.borrow_mut();
208
209        if let SharedState::Uninit { .. } = &*state {
210            return Poll::Ready(Ok(()));
211        }
212
213        if let SharedState::Thunkulating {
214            future,
215            item,
216            multi_waker,
217        } = &mut *state
218        {
219            let waker = if let Some(waker) = multi_waker {
220                waker.push(cx.waker());
221                Waker::from(waker.clone())
222            } else {
223                let waker = Arc::new(MultiWaker::new(cx.waker()));
224                *multi_waker = Some(waker.clone());
225                Waker::from(waker)
226            };
227
228            let mut new_context = Context::from_waker(&waker);
229
230            match future.as_mut().poll(&mut new_context) {
231                Poll::Ready(Ok((stream, sink))) => {
232                    let buf = item.take();
233                    *state = SharedState::Done {
234                        stream: Box::pin(stream),
235                        sink: Box::pin(sink),
236                        buf,
237                    };
238                }
239                Poll::Ready(Err(e)) => {
240                    return Poll::Ready(Err(e));
241                }
242                Poll::Pending => {
243                    return Poll::Pending;
244                }
245            }
246        }
247
248        if let SharedState::Done { sink, buf, .. } = &mut *state {
249            if buf.is_some() {
250                ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
251                sink.as_mut().start_send(buf.take().unwrap())?;
252            }
253            let result = sink.as_mut().poll_flush(cx).map_err(From::from);
254            return result;
255        }
256
257        panic!("LazySinkHalf in invalid state.");
258    }
259
260    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
261        let mut state = self.state.borrow_mut();
262
263        if let SharedState::Uninit { .. } = &*state {
264            return Poll::Ready(Ok(()));
265        }
266
267        if let SharedState::Thunkulating {
268            future,
269            item,
270            multi_waker,
271        } = &mut *state
272        {
273            let waker = if let Some(waker) = multi_waker {
274                waker.push(cx.waker());
275                Waker::from(waker.clone())
276            } else {
277                let waker = Arc::new(MultiWaker::new(cx.waker()));
278                *multi_waker = Some(waker.clone());
279                Waker::from(waker)
280            };
281
282            let mut new_context = Context::from_waker(&waker);
283
284            match future.as_mut().poll(&mut new_context) {
285                Poll::Ready(Ok((stream, sink))) => {
286                    let buf = item.take();
287                    *state = SharedState::Done {
288                        stream: Box::pin(stream),
289                        sink: Box::pin(sink),
290                        buf,
291                    };
292                }
293                Poll::Ready(Err(e)) => {
294                    return Poll::Ready(Err(e));
295                }
296                Poll::Pending => {
297                    return Poll::Pending;
298                }
299            }
300        }
301
302        if let SharedState::Done { sink, buf, .. } = &mut *state {
303            if buf.is_some() {
304                ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
305                sink.as_mut().start_send(buf.take().unwrap())?;
306            }
307            let result = sink.as_mut().poll_close(cx).map_err(From::from);
308            return result;
309        }
310
311        panic!("LazySinkHalf in invalid state.");
312    }
313}
314
315impl<Fut, St, Si, Item, Error> Stream for LazySourceHalf<Fut, St, Si, Item, Error>
316where
317    Fut: Future<Output = Result<(St, Si), Error>>,
318    St: Stream,
319    Si: Sink<Item>,
320{
321    type Item = St::Item;
322
323    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
324        let mut state = self.state.borrow_mut();
325
326        if let SharedState::Uninit { .. } = &*state {
327            let old_state = std::mem::replace(&mut *state, SharedState::Taken);
328            if let SharedState::Uninit { future } = old_state {
329                *state = SharedState::Thunkulating {
330                    future,
331                    item: None,
332                    multi_waker: None,
333                };
334            } else {
335                unreachable!();
336            }
337        }
338
339        if let SharedState::Thunkulating {
340            future,
341            item,
342            multi_waker,
343        } = &mut *state
344        {
345            let waker = if let Some(waker) = multi_waker {
346                waker.push(cx.waker());
347                Waker::from(waker.clone())
348            } else {
349                let waker = Arc::new(MultiWaker::new(cx.waker()));
350                *multi_waker = Some(waker.clone());
351                Waker::from(waker)
352            };
353
354            let mut new_context = Context::from_waker(&waker);
355
356            match future.as_mut().poll(&mut new_context) {
357                Poll::Ready(Ok((stream, sink))) => {
358                    let buf = item.take();
359                    *state = SharedState::Done {
360                        stream: Box::pin(stream),
361                        sink: Box::pin(sink),
362                        buf,
363                    };
364                }
365
366                Poll::Ready(Err(_)) => {
367                    return Poll::Ready(None);
368                }
369
370                Poll::Pending => {
371                    return Poll::Pending;
372                }
373            }
374        }
375
376        if let SharedState::Done { stream, .. } = &mut *state {
377            let result = stream.as_mut().poll_next(cx);
378            match &result {
379                Poll::Ready(Some(_)) => {}
380                Poll::Ready(None) => {}
381                Poll::Pending => {}
382            }
383            return result;
384        }
385
386        panic!("LazySourceHalf in invalid state.");
387    }
388}
389
390#[cfg(test)]
391mod test {
392    use futures_util::{SinkExt, StreamExt};
393
394    use super::*;
395
396    #[tokio::test(flavor = "current_thread")]
397    async fn tcp_stream_drives_initialization() {
398        use tokio::net::{TcpListener, TcpStream};
399        use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
400
401        let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
402
403        let local = tokio::task::LocalSet::new();
404        local
405            .run_until(async {
406                let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
407                let addr = listener.local_addr().unwrap();
408
409                let sink_source = LazySinkSource::new(async move {
410                    // initialization is at least partially started now.
411                    initialization_tx.send(()).unwrap();
412
413                    let (stream, _) = listener.accept().await.unwrap();
414                    let (rx, tx) = stream.into_split();
415                    let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
416                    let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
417                    Ok::<_, std::io::Error>((fr, fw))
418                });
419
420                let (mut sink, mut stream) = sink_source.split();
421
422                let stream_task = tokio::task::spawn_local(async move { stream.next().await });
423
424                initialization_rx.await.unwrap(); // ensure that the runtime starts driving initialization via the stream.next() call.
425
426                let sink_task = tokio::task::spawn_local(async move {
427                    SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
428                        .await
429                        .unwrap();
430                });
431
432                // try to be really sure that the above sink_task is waiting on the same future to be resolved.
433                for _ in 0..20 {
434                    tokio::task::yield_now().await
435                }
436
437                // trigger further initialization of the future.
438                let mut socket = TcpStream::connect(addr).await.unwrap();
439                let (client_rx, client_tx) = socket.split();
440                let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
441                let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
442
443                // try to be really sure that the effects of the above initialization completing are propagated.
444                for _ in 0..20 {
445                    tokio::task::yield_now().await
446                }
447
448                assert!(!stream_task.is_finished()); // We haven't sent anything yet, so the stream should definitely not be resolved now.
449
450                // Now actually send an item so that the stream will wake up and have an item ready to pull from it.
451                SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
452                    .await
453                    .unwrap();
454
455                assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
456                sink_task.await.unwrap();
457
458                assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
459            })
460            .await;
461    }
462
463    #[tokio::test(flavor = "current_thread")]
464    async fn tcp_sink_drives_initialization() {
465        use tokio::net::{TcpListener, TcpStream};
466        use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
467
468        let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
469
470        let local = tokio::task::LocalSet::new();
471        local
472            .run_until(async {
473                let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
474                let addr = listener.local_addr().unwrap();
475
476                let sink_source = LazySinkSource::new(async move {
477                    // initialization is at least partially started now.
478                    initialization_tx.send(()).unwrap();
479
480                    let (stream, _) = listener.accept().await.unwrap();
481                    let (rx, tx) = stream.into_split();
482                    let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
483                    let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
484                    Ok::<_, std::io::Error>((fr, fw))
485                });
486
487                let (mut sink, mut stream) = sink_source.split();
488
489                let sink_task = tokio::task::spawn_local(async move {
490                    SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
491                        .await
492                        .unwrap();
493                });
494
495                initialization_rx.await.unwrap(); // ensure that the runtime starts driving initialization via the stream.next() call.
496
497                let stream_task = tokio::task::spawn_local(async move { stream.next().await });
498
499                // try to be really sure that the above sink_task is waiting on the same future to be resolved.
500                for _ in 0..20 {
501                    tokio::task::yield_now().await
502                }
503
504                assert!(!sink_task.is_finished()); // We haven't sent anything yet, so the stream should definitely not be resolved now.
505
506                // trigger further initialization of the future.
507                let mut socket = TcpStream::connect(addr).await.unwrap();
508                let (client_rx, client_tx) = socket.split();
509                let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
510                let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
511
512                // try to be really sure that the effects of the above initialization completing are propagated.
513                for _ in 0..20 {
514                    tokio::task::yield_now().await
515                }
516
517                assert!(sink_task.is_finished()); // We haven't sent anything yet, so the stream should definitely not be resolved now.
518
519                assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
520
521                // Now actually send an item so that the stream will wake up and have an item ready to pull from it.
522                SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
523                    .await
524                    .unwrap();
525
526                assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
527                sink_task.await.unwrap();
528            })
529            .await;
530    }
531}