1#![cfg(feature = "std")]
3
4use std::cell::RefCell;
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::rc::Rc;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll, Wake, Waker};
10
11use futures_util::{Sink, Stream, ready};
12
13struct MultiWaker {
14 wakers: Mutex<Vec<Waker>>,
15}
16
17impl MultiWaker {
18 fn new(waker: &Waker) -> Self {
19 MultiWaker {
20 wakers: Mutex::new(vec![waker.clone()]),
21 }
22 }
23
24 fn push(&self, waker: &Waker) {
25 let mut guard = self.wakers.lock().unwrap();
26 guard.push(waker.clone());
27 }
28}
29
30impl Wake for MultiWaker {
31 fn wake(self: Arc<Self>) {
32 let mut wakers = Vec::new();
33
34 {
35 let mut guard = self.wakers.lock().unwrap();
36 std::mem::swap(&mut wakers, &mut *guard);
37 }
38
39 for waker in wakers {
40 waker.wake();
41 }
42 }
43}
44
45enum SharedState<Fut, St, Si, Item> {
46 Uninit {
47 future: Pin<Box<Fut>>,
48 },
49 Thunkulating {
50 future: Pin<Box<Fut>>,
51 item: Option<Item>,
52 multi_waker: Option<Arc<MultiWaker>>,
53 },
54 Done {
55 stream: Pin<Box<St>>,
56 sink: Pin<Box<Si>>,
57 buf: Option<Item>,
58 },
59 Taken,
60}
61
62pub struct LazySinkSource<Fut, St, Si, Item, Error> {
64 state: Rc<RefCell<SharedState<Fut, St, Si, Item>>>,
65 _phantom: PhantomData<Error>,
66}
67
68impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error> {
69 pub fn new(future: Fut) -> Self {
71 Self {
72 state: Rc::new(RefCell::new(SharedState::Uninit {
73 future: Box::pin(future),
74 })),
75 _phantom: PhantomData,
76 }
77 }
78
79 pub fn split(
81 self,
82 ) -> (
83 LazySinkHalf<Fut, St, Si, Item, Error>,
84 LazySourceHalf<Fut, St, Si, Item, Error>,
85 ) {
86 let sink = LazySinkHalf {
87 state: Rc::clone(&self.state),
88 _phantom: PhantomData,
89 };
90 let stream = LazySourceHalf {
91 state: self.state,
92 _phantom: PhantomData,
93 };
94 (sink, stream)
95 }
96}
97
98pub struct LazySinkHalf<Fut, St, Si, Item, Error> {
100 state: Rc<RefCell<SharedState<Fut, St, Si, Item>>>,
101 _phantom: PhantomData<Error>,
102}
103
104pub struct LazySourceHalf<Fut, St, Si, Item, Error> {
106 state: Rc<RefCell<SharedState<Fut, St, Si, Item>>>,
107 _phantom: PhantomData<Error>,
108}
109
110impl<Fut, St, Si, Item, Error> Sink<Item> for LazySinkHalf<Fut, St, Si, Item, Error>
111where
112 Fut: Future<Output = Result<(St, Si), Error>>,
113 St: Stream,
114 Si: Sink<Item>,
115 Error: From<Si::Error>,
116{
117 type Error = Error;
118
119 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120 let mut state = self.state.borrow_mut();
121
122 if let SharedState::Uninit { .. } = &*state {
123 return Poll::Ready(Ok(()));
124 }
125
126 if let SharedState::Thunkulating {
127 future,
128 item,
129 multi_waker,
130 } = &mut *state
131 {
132 let waker = if let Some(waker) = multi_waker {
133 waker.push(cx.waker());
134 Waker::from(waker.clone())
135 } else {
136 let waker = Arc::new(MultiWaker::new(cx.waker()));
137 *multi_waker = Some(waker.clone());
138 Waker::from(waker)
139 };
140
141 let mut new_context = Context::from_waker(&waker);
142
143 match future.as_mut().poll(&mut new_context) {
144 Poll::Ready(Ok((stream, sink))) => {
145 let buf = item.take();
146 *state = SharedState::Done {
147 stream: Box::pin(stream),
148 sink: Box::pin(sink),
149 buf,
150 };
151 }
152 Poll::Ready(Err(e)) => {
153 return Poll::Ready(Err(e));
154 }
155 Poll::Pending => {
156 return Poll::Pending;
157 }
158 }
159 }
160
161 if let SharedState::Done { sink, buf, .. } = &mut *state {
162 if buf.is_some() {
163 ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
164 sink.as_mut().start_send(buf.take().unwrap())?;
165 }
166 let result = sink.as_mut().poll_ready(cx).map_err(From::from);
167 return result;
168 }
169
170 panic!("LazySinkHalf in invalid state.");
171 }
172
173 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
174 let mut state = self.state.borrow_mut();
175
176 if let SharedState::Uninit { .. } = &*state {
177 let old_state = std::mem::replace(&mut *state, SharedState::Taken);
178 if let SharedState::Uninit { future } = old_state {
179 *state = SharedState::Thunkulating {
180 future,
181 item: Some(item),
182 multi_waker: None,
183 };
184
185 return Ok(());
186 }
187 }
188
189 if let SharedState::Thunkulating { .. } = &mut *state {
190 panic!("LazySinkHalf not ready.");
191 }
192
193 if let SharedState::Done { sink, buf, .. } = &mut *state {
194 debug_assert!(buf.is_none());
195 let result = sink.as_mut().start_send(item).map_err(From::from);
196 return result;
197 }
198
199 panic!("LazySinkHalf not ready.");
200 }
201
202 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
203 let mut state = self.state.borrow_mut();
204
205 if let SharedState::Uninit { .. } = &*state {
206 return Poll::Ready(Ok(()));
207 }
208
209 if let SharedState::Thunkulating {
210 future,
211 item,
212 multi_waker,
213 } = &mut *state
214 {
215 let waker = if let Some(waker) = multi_waker {
216 waker.push(cx.waker());
217 Waker::from(waker.clone())
218 } else {
219 let waker = Arc::new(MultiWaker::new(cx.waker()));
220 *multi_waker = Some(waker.clone());
221 Waker::from(waker)
222 };
223
224 let mut new_context = Context::from_waker(&waker);
225
226 match future.as_mut().poll(&mut new_context) {
227 Poll::Ready(Ok((stream, sink))) => {
228 let buf = item.take();
229 *state = SharedState::Done {
230 stream: Box::pin(stream),
231 sink: Box::pin(sink),
232 buf,
233 };
234 }
235 Poll::Ready(Err(e)) => {
236 return Poll::Ready(Err(e));
237 }
238 Poll::Pending => {
239 return Poll::Pending;
240 }
241 }
242 }
243
244 if let SharedState::Done { sink, buf, .. } = &mut *state {
245 if buf.is_some() {
246 ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
247 sink.as_mut().start_send(buf.take().unwrap())?;
248 }
249 let result = sink.as_mut().poll_flush(cx).map_err(From::from);
250 return result;
251 }
252
253 panic!("LazySinkHalf in invalid state.");
254 }
255
256 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
257 let mut state = self.state.borrow_mut();
258
259 if let SharedState::Uninit { .. } = &*state {
260 return Poll::Ready(Ok(()));
261 }
262
263 if let SharedState::Thunkulating {
264 future,
265 item,
266 multi_waker,
267 } = &mut *state
268 {
269 let waker = if let Some(waker) = multi_waker {
270 waker.push(cx.waker());
271 Waker::from(waker.clone())
272 } else {
273 let waker = Arc::new(MultiWaker::new(cx.waker()));
274 *multi_waker = Some(waker.clone());
275 Waker::from(waker)
276 };
277
278 let mut new_context = Context::from_waker(&waker);
279
280 match future.as_mut().poll(&mut new_context) {
281 Poll::Ready(Ok((stream, sink))) => {
282 let buf = item.take();
283 *state = SharedState::Done {
284 stream: Box::pin(stream),
285 sink: Box::pin(sink),
286 buf,
287 };
288 }
289 Poll::Ready(Err(e)) => {
290 return Poll::Ready(Err(e));
291 }
292 Poll::Pending => {
293 return Poll::Pending;
294 }
295 }
296 }
297
298 if let SharedState::Done { sink, buf, .. } = &mut *state {
299 if buf.is_some() {
300 ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
301 sink.as_mut().start_send(buf.take().unwrap())?;
302 }
303 let result = sink.as_mut().poll_close(cx).map_err(From::from);
304 return result;
305 }
306
307 panic!("LazySinkHalf in invalid state.");
308 }
309}
310
311impl<Fut, St, Si, Item, Error> Stream for LazySourceHalf<Fut, St, Si, Item, Error>
312where
313 Fut: Future<Output = Result<(St, Si), Error>>,
314 St: Stream,
315 Si: Sink<Item>,
316{
317 type Item = St::Item;
318
319 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
320 let mut state = self.state.borrow_mut();
321
322 if let SharedState::Uninit { .. } = &*state {
323 let old_state = std::mem::replace(&mut *state, SharedState::Taken);
324 if let SharedState::Uninit { future } = old_state {
325 *state = SharedState::Thunkulating {
326 future,
327 item: None,
328 multi_waker: None,
329 };
330 } else {
331 unreachable!();
332 }
333 }
334
335 if let SharedState::Thunkulating {
336 future,
337 item,
338 multi_waker,
339 } = &mut *state
340 {
341 let waker = if let Some(waker) = multi_waker {
342 waker.push(cx.waker());
343 Waker::from(waker.clone())
344 } else {
345 let waker = Arc::new(MultiWaker::new(cx.waker()));
346 *multi_waker = Some(waker.clone());
347 Waker::from(waker)
348 };
349
350 let mut new_context = Context::from_waker(&waker);
351
352 match future.as_mut().poll(&mut new_context) {
353 Poll::Ready(Ok((stream, sink))) => {
354 let buf = item.take();
355 *state = SharedState::Done {
356 stream: Box::pin(stream),
357 sink: Box::pin(sink),
358 buf,
359 };
360 }
361
362 Poll::Ready(Err(_)) => {
363 return Poll::Ready(None);
364 }
365
366 Poll::Pending => {
367 return Poll::Pending;
368 }
369 }
370 }
371
372 if let SharedState::Done { stream, .. } = &mut *state {
373 let result = stream.as_mut().poll_next(cx);
374 match &result {
375 Poll::Ready(Some(_)) => {}
376 Poll::Ready(None) => {}
377 Poll::Pending => {}
378 }
379 return result;
380 }
381
382 panic!("LazySourceHalf in invalid state.");
383 }
384}
385
386#[cfg(test)]
387mod test {
388 use futures_util::{SinkExt, StreamExt};
389
390 use super::*;
391
392 #[tokio::test(flavor = "current_thread")]
393 async fn tcp_stream_drives_initialization() {
394 use tokio::net::{TcpListener, TcpStream};
395 use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
396
397 let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
398
399 let local = tokio::task::LocalSet::new();
400 local
401 .run_until(async {
402 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
403 let addr = listener.local_addr().unwrap();
404
405 let sink_source = LazySinkSource::new(async move {
406 initialization_tx.send(()).unwrap();
408
409 let (stream, _) = listener.accept().await.unwrap();
410 let (rx, tx) = stream.into_split();
411 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
412 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
413 Ok::<_, std::io::Error>((fr, fw))
414 });
415
416 let (mut sink, mut stream) = sink_source.split();
417
418 let stream_task = tokio::task::spawn_local(async move { stream.next().await });
419
420 initialization_rx.await.unwrap(); let sink_task = tokio::task::spawn_local(async move {
423 SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
424 .await
425 .unwrap();
426 });
427
428 for _ in 0..20 {
430 tokio::task::yield_now().await
431 }
432
433 let mut socket = TcpStream::connect(addr).await.unwrap();
435 let (client_rx, client_tx) = socket.split();
436 let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
437 let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
438
439 for _ in 0..20 {
441 tokio::task::yield_now().await
442 }
443
444 assert!(!stream_task.is_finished()); SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
448 .await
449 .unwrap();
450
451 assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
452 sink_task.await.unwrap();
453
454 assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
455 })
456 .await;
457 }
458
459 #[tokio::test(flavor = "current_thread")]
460 async fn tcp_sink_drives_initialization() {
461 use tokio::net::{TcpListener, TcpStream};
462 use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
463
464 let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
465
466 let local = tokio::task::LocalSet::new();
467 local
468 .run_until(async {
469 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
470 let addr = listener.local_addr().unwrap();
471
472 let sink_source = LazySinkSource::new(async move {
473 initialization_tx.send(()).unwrap();
475
476 let (stream, _) = listener.accept().await.unwrap();
477 let (rx, tx) = stream.into_split();
478 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
479 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
480 Ok::<_, std::io::Error>((fr, fw))
481 });
482
483 let (mut sink, mut stream) = sink_source.split();
484
485 let sink_task = tokio::task::spawn_local(async move {
486 SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
487 .await
488 .unwrap();
489 });
490
491 initialization_rx.await.unwrap(); let stream_task = tokio::task::spawn_local(async move { stream.next().await });
494
495 for _ in 0..20 {
497 tokio::task::yield_now().await
498 }
499
500 assert!(!sink_task.is_finished()); let mut socket = TcpStream::connect(addr).await.unwrap();
504 let (client_rx, client_tx) = socket.split();
505 let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
506 let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
507
508 for _ in 0..20 {
510 tokio::task::yield_now().await
511 }
512
513 assert!(sink_task.is_finished()); assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
516
517 SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
519 .await
520 .unwrap();
521
522 assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
523 sink_task.await.unwrap();
524 })
525 .await;
526 }
527}