1use core::marker::PhantomData;
4use core::pin::Pin;
5use core::task::{Context, Poll, Waker};
6use std::cell::RefCell;
7use std::rc::Rc;
8use std::sync::{Arc, Mutex};
9use std::task::Wake;
10
11use futures_util::{Sink, Stream, ready};
12
13struct MultiWaker {
14 wakers: Mutex<Vec<Waker>>,
15}
16
17impl MultiWaker {
18 fn new(waker: &Waker) -> Self {
19 MultiWaker {
20 wakers: Mutex::new(vec![waker.clone()]),
21 }
22 }
23
24 fn push(&self, waker: &Waker) {
25 let mut guard = self.wakers.lock().unwrap();
26 guard.push(waker.clone());
27 }
28}
29
30impl Wake for MultiWaker {
31 fn wake(self: Arc<Self>) {
32 let mut wakers = Vec::new();
33
34 {
35 let mut guard = self.wakers.lock().unwrap();
36 std::mem::swap(&mut wakers, &mut *guard);
37 }
38
39 for waker in wakers {
40 waker.wake();
41 }
42 }
43}
44
45enum SharedState<Fut, St, Si, Item> {
46 Uninit {
47 future: Pin<Box<Fut>>,
48 },
49 Thunkulating {
50 future: Pin<Box<Fut>>,
51 item: Option<Item>,
52 multi_waker: Option<Arc<MultiWaker>>,
53 },
54 Done {
55 stream: Pin<Box<St>>,
56 sink: Pin<Box<Si>>,
57 buf: Option<Item>,
58 },
59 Taken,
60}
61
62pub struct LazySinkSource<Fut, St, Si, Item, Error> {
64 state: Rc<RefCell<SharedState<Fut, St, Si, Item>>>,
65 _phantom: PhantomData<Error>,
66}
67
68impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error> {
69 pub fn new(future: Fut) -> Self {
71 Self {
72 state: Rc::new(RefCell::new(SharedState::Uninit {
73 future: Box::pin(future),
74 })),
75 _phantom: PhantomData,
76 }
77 }
78
79 #[expect(
80 clippy::type_complexity,
81 reason = "this type is actually fine and not too complex."
82 )]
83 pub fn split(
85 self,
86 ) -> (
87 LazySinkHalf<Fut, St, Si, Item, Error>,
88 LazySourceHalf<Fut, St, Si, Item, Error>,
89 ) {
90 let sink = LazySinkHalf {
91 state: Rc::clone(&self.state),
92 _phantom: PhantomData,
93 };
94 let stream = LazySourceHalf {
95 state: self.state,
96 _phantom: PhantomData,
97 };
98 (sink, stream)
99 }
100}
101
102pub struct LazySinkHalf<Fut, St, Si, Item, Error> {
104 state: Rc<RefCell<SharedState<Fut, St, Si, Item>>>,
105 _phantom: PhantomData<Error>,
106}
107
108pub struct LazySourceHalf<Fut, St, Si, Item, Error> {
110 state: Rc<RefCell<SharedState<Fut, St, Si, Item>>>,
111 _phantom: PhantomData<Error>,
112}
113
114impl<Fut, St, Si, Item, Error> Sink<Item> for LazySinkHalf<Fut, St, Si, Item, Error>
115where
116 Fut: Future<Output = Result<(St, Si), Error>>,
117 St: Stream,
118 Si: Sink<Item>,
119 Error: From<Si::Error>,
120{
121 type Error = Error;
122
123 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
124 let mut state = self.state.borrow_mut();
125
126 if let SharedState::Uninit { .. } = &*state {
127 return Poll::Ready(Ok(()));
128 }
129
130 if let SharedState::Thunkulating {
131 future,
132 item,
133 multi_waker,
134 } = &mut *state
135 {
136 let waker = if let Some(waker) = multi_waker {
137 waker.push(cx.waker());
138 Waker::from(waker.clone())
139 } else {
140 let waker = Arc::new(MultiWaker::new(cx.waker()));
141 *multi_waker = Some(waker.clone());
142 Waker::from(waker)
143 };
144
145 let mut new_context = Context::from_waker(&waker);
146
147 match future.as_mut().poll(&mut new_context) {
148 Poll::Ready(Ok((stream, sink))) => {
149 let buf = item.take();
150 *state = SharedState::Done {
151 stream: Box::pin(stream),
152 sink: Box::pin(sink),
153 buf,
154 };
155 }
156 Poll::Ready(Err(e)) => {
157 return Poll::Ready(Err(e));
158 }
159 Poll::Pending => {
160 return Poll::Pending;
161 }
162 }
163 }
164
165 if let SharedState::Done { sink, buf, .. } = &mut *state {
166 if buf.is_some() {
167 ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
168 sink.as_mut().start_send(buf.take().unwrap())?;
169 }
170 let result = sink.as_mut().poll_ready(cx).map_err(From::from);
171 return result;
172 }
173
174 panic!("LazySinkHalf in invalid state.");
175 }
176
177 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
178 let mut state = self.state.borrow_mut();
179
180 if let SharedState::Uninit { .. } = &*state {
181 let old_state = std::mem::replace(&mut *state, SharedState::Taken);
182 if let SharedState::Uninit { future } = old_state {
183 *state = SharedState::Thunkulating {
184 future,
185 item: Some(item),
186 multi_waker: None,
187 };
188
189 return Ok(());
190 }
191 }
192
193 if let SharedState::Thunkulating { .. } = &mut *state {
194 panic!("LazySinkHalf not ready.");
195 }
196
197 if let SharedState::Done { sink, buf, .. } = &mut *state {
198 debug_assert!(buf.is_none());
199 let result = sink.as_mut().start_send(item).map_err(From::from);
200 return result;
201 }
202
203 panic!("LazySinkHalf not ready.");
204 }
205
206 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
207 let mut state = self.state.borrow_mut();
208
209 if let SharedState::Uninit { .. } = &*state {
210 return Poll::Ready(Ok(()));
211 }
212
213 if let SharedState::Thunkulating {
214 future,
215 item,
216 multi_waker,
217 } = &mut *state
218 {
219 let waker = if let Some(waker) = multi_waker {
220 waker.push(cx.waker());
221 Waker::from(waker.clone())
222 } else {
223 let waker = Arc::new(MultiWaker::new(cx.waker()));
224 *multi_waker = Some(waker.clone());
225 Waker::from(waker)
226 };
227
228 let mut new_context = Context::from_waker(&waker);
229
230 match future.as_mut().poll(&mut new_context) {
231 Poll::Ready(Ok((stream, sink))) => {
232 let buf = item.take();
233 *state = SharedState::Done {
234 stream: Box::pin(stream),
235 sink: Box::pin(sink),
236 buf,
237 };
238 }
239 Poll::Ready(Err(e)) => {
240 return Poll::Ready(Err(e));
241 }
242 Poll::Pending => {
243 return Poll::Pending;
244 }
245 }
246 }
247
248 if let SharedState::Done { sink, buf, .. } = &mut *state {
249 if buf.is_some() {
250 ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
251 sink.as_mut().start_send(buf.take().unwrap())?;
252 }
253 let result = sink.as_mut().poll_flush(cx).map_err(From::from);
254 return result;
255 }
256
257 panic!("LazySinkHalf in invalid state.");
258 }
259
260 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
261 let mut state = self.state.borrow_mut();
262
263 if let SharedState::Uninit { .. } = &*state {
264 return Poll::Ready(Ok(()));
265 }
266
267 if let SharedState::Thunkulating {
268 future,
269 item,
270 multi_waker,
271 } = &mut *state
272 {
273 let waker = if let Some(waker) = multi_waker {
274 waker.push(cx.waker());
275 Waker::from(waker.clone())
276 } else {
277 let waker = Arc::new(MultiWaker::new(cx.waker()));
278 *multi_waker = Some(waker.clone());
279 Waker::from(waker)
280 };
281
282 let mut new_context = Context::from_waker(&waker);
283
284 match future.as_mut().poll(&mut new_context) {
285 Poll::Ready(Ok((stream, sink))) => {
286 let buf = item.take();
287 *state = SharedState::Done {
288 stream: Box::pin(stream),
289 sink: Box::pin(sink),
290 buf,
291 };
292 }
293 Poll::Ready(Err(e)) => {
294 return Poll::Ready(Err(e));
295 }
296 Poll::Pending => {
297 return Poll::Pending;
298 }
299 }
300 }
301
302 if let SharedState::Done { sink, buf, .. } = &mut *state {
303 if buf.is_some() {
304 ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
305 sink.as_mut().start_send(buf.take().unwrap())?;
306 }
307 let result = sink.as_mut().poll_close(cx).map_err(From::from);
308 return result;
309 }
310
311 panic!("LazySinkHalf in invalid state.");
312 }
313}
314
315impl<Fut, St, Si, Item, Error> Stream for LazySourceHalf<Fut, St, Si, Item, Error>
316where
317 Fut: Future<Output = Result<(St, Si), Error>>,
318 St: Stream,
319 Si: Sink<Item>,
320{
321 type Item = St::Item;
322
323 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
324 let mut state = self.state.borrow_mut();
325
326 if let SharedState::Uninit { .. } = &*state {
327 let old_state = std::mem::replace(&mut *state, SharedState::Taken);
328 if let SharedState::Uninit { future } = old_state {
329 *state = SharedState::Thunkulating {
330 future,
331 item: None,
332 multi_waker: None,
333 };
334 } else {
335 unreachable!();
336 }
337 }
338
339 if let SharedState::Thunkulating {
340 future,
341 item,
342 multi_waker,
343 } = &mut *state
344 {
345 let waker = if let Some(waker) = multi_waker {
346 waker.push(cx.waker());
347 Waker::from(waker.clone())
348 } else {
349 let waker = Arc::new(MultiWaker::new(cx.waker()));
350 *multi_waker = Some(waker.clone());
351 Waker::from(waker)
352 };
353
354 let mut new_context = Context::from_waker(&waker);
355
356 match future.as_mut().poll(&mut new_context) {
357 Poll::Ready(Ok((stream, sink))) => {
358 let buf = item.take();
359 *state = SharedState::Done {
360 stream: Box::pin(stream),
361 sink: Box::pin(sink),
362 buf,
363 };
364 }
365
366 Poll::Ready(Err(_)) => {
367 return Poll::Ready(None);
368 }
369
370 Poll::Pending => {
371 return Poll::Pending;
372 }
373 }
374 }
375
376 if let SharedState::Done { stream, .. } = &mut *state {
377 let result = stream.as_mut().poll_next(cx);
378 match &result {
379 Poll::Ready(Some(_)) => {}
380 Poll::Ready(None) => {}
381 Poll::Pending => {}
382 }
383 return result;
384 }
385
386 panic!("LazySourceHalf in invalid state.");
387 }
388}
389
390#[cfg(test)]
391mod test {
392 use futures_util::{SinkExt, StreamExt};
393
394 use super::*;
395
396 #[tokio::test(flavor = "current_thread")]
397 async fn tcp_stream_drives_initialization() {
398 use tokio::net::{TcpListener, TcpStream};
399 use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
400
401 let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
402
403 let local = tokio::task::LocalSet::new();
404 local
405 .run_until(async {
406 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
407 let addr = listener.local_addr().unwrap();
408
409 let sink_source = LazySinkSource::new(async move {
410 initialization_tx.send(()).unwrap();
412
413 let (stream, _) = listener.accept().await.unwrap();
414 let (rx, tx) = stream.into_split();
415 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
416 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
417 Ok::<_, std::io::Error>((fr, fw))
418 });
419
420 let (mut sink, mut stream) = sink_source.split();
421
422 let stream_task = tokio::task::spawn_local(async move { stream.next().await });
423
424 initialization_rx.await.unwrap(); let sink_task = tokio::task::spawn_local(async move {
427 SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
428 .await
429 .unwrap();
430 });
431
432 for _ in 0..20 {
434 tokio::task::yield_now().await
435 }
436
437 let mut socket = TcpStream::connect(addr).await.unwrap();
439 let (client_rx, client_tx) = socket.split();
440 let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
441 let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
442
443 for _ in 0..20 {
445 tokio::task::yield_now().await
446 }
447
448 assert!(!stream_task.is_finished()); SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
452 .await
453 .unwrap();
454
455 assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
456 sink_task.await.unwrap();
457
458 assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
459 })
460 .await;
461 }
462
463 #[tokio::test(flavor = "current_thread")]
464 async fn tcp_sink_drives_initialization() {
465 use tokio::net::{TcpListener, TcpStream};
466 use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
467
468 let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
469
470 let local = tokio::task::LocalSet::new();
471 local
472 .run_until(async {
473 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
474 let addr = listener.local_addr().unwrap();
475
476 let sink_source = LazySinkSource::new(async move {
477 initialization_tx.send(()).unwrap();
479
480 let (stream, _) = listener.accept().await.unwrap();
481 let (rx, tx) = stream.into_split();
482 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
483 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
484 Ok::<_, std::io::Error>((fr, fw))
485 });
486
487 let (mut sink, mut stream) = sink_source.split();
488
489 let sink_task = tokio::task::spawn_local(async move {
490 SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
491 .await
492 .unwrap();
493 });
494
495 initialization_rx.await.unwrap(); let stream_task = tokio::task::spawn_local(async move { stream.next().await });
498
499 for _ in 0..20 {
501 tokio::task::yield_now().await
502 }
503
504 assert!(!sink_task.is_finished()); let mut socket = TcpStream::connect(addr).await.unwrap();
508 let (client_rx, client_tx) = socket.split();
509 let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
510 let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
511
512 for _ in 0..20 {
514 tokio::task::yield_now().await
515 }
516
517 assert!(sink_task.is_finished()); assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
520
521 SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
523 .await
524 .unwrap();
525
526 assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
527 sink_task.await.unwrap();
528 })
529 .await;
530 }
531}