1use core::pin::Pin;
3use core::task::{Context, Poll, ready};
4
5use pin_project_lite::pin_project;
6
7use crate::{Sink, SinkBuild};
8
9pin_project! {
10 #[must_use = "sinks do nothing unless polled"]
14 pub struct Flatten<Si, IntoIter>
15 where
16 IntoIter: IntoIterator,
17 {
18 #[pin]
19 sink: Si,
20 iter_next: Option<(IntoIter::IntoIter, IntoIter::Item)>,
22 }
23}
24
25impl<Si, IntoIter> Flatten<Si, IntoIter>
26where
27 IntoIter: IntoIterator,
28{
29 pub fn new(sink: Si) -> Self
31 where
32 Self: Sink<IntoIter>,
33 {
34 Self {
35 sink,
36 iter_next: None,
37 }
38 }
39
40 fn poll_ready_impl(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Si::Error>>
41 where
42 Si: Sink<IntoIter::Item>,
43 {
44 let mut this = self.project();
45
46 while this.iter_next.is_some() {
47 ready!(this.sink.as_mut().poll_ready(cx))?;
49
50 let (mut iter, next) = this.iter_next.take().unwrap();
52 this.sink.as_mut().start_send(next)?;
53
54 *this.iter_next = iter.next().map(|next| (iter, next));
56 }
57
58 Poll::Ready(Ok(()))
59 }
60}
61
62impl<Si, IntoIter> Sink<IntoIter> for Flatten<Si, IntoIter>
63where
64 Si: Sink<IntoIter::Item>,
65 IntoIter: IntoIterator,
66{
67 type Error = Si::Error;
68
69 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
70 self.poll_ready_impl(cx)
71 }
72
73 fn start_send(self: Pin<&mut Self>, item: IntoIter) -> Result<(), Self::Error> {
74 let this = self.project();
75
76 assert!(
77 this.iter_next.is_none(),
78 "Sink not ready: `poll_ready` must be called and return `Ready` before `start_send` is called."
79 );
80 let mut iter = item.into_iter();
81 *this.iter_next = iter.next().map(|next| (iter, next));
82 Ok(())
83 }
84
85 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86 ready!(self.as_mut().poll_ready_impl(cx)?);
87 self.project().sink.poll_flush(cx)
88 }
89
90 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
91 ready!(self.as_mut().poll_ready_impl(cx)?);
92 self.project().sink.poll_close(cx)
93 }
94}
95
96pub struct FlattenBuilder<Prev> {
98 pub(crate) prev: Prev,
99}
100impl<Prev> SinkBuild for FlattenBuilder<Prev>
101where
102 Prev: SinkBuild,
103 Prev::Item: IntoIterator,
104{
105 type Item = <Prev::Item as IntoIterator>::Item;
106
107 type Output<Next: Sink<Self::Item>> = Prev::Output<Flatten<Next, Prev::Item>>;
108
109 fn send_to<Next>(self, next: Next) -> Self::Output<Next>
110 where
111 Next: Sink<Self::Item>,
112 {
113 self.prev.send_to(Flatten::new(next))
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use futures_util::stream::StreamExt;
120 use tokio::sync::mpsc::channel;
121 use tokio_stream::wrappers::ReceiverStream;
122 use tokio_util::sync::PollSender;
123
124 use super::*;
125 use crate::sink::SinkExt;
126
127 #[tokio::test]
128 async fn test_flatten() {
129 let (out_send, out_recv) = channel(2);
130 let out_send = PollSender::new(out_send);
131 let mut out_recv = ReceiverStream::new(out_recv);
132
133 let mut sink = Flatten::new(out_send);
134
135 let a = tokio::task::spawn(async move {
136 sink.send(vec![0, 1, 2]).await.unwrap();
137 println!("{}", line!());
138 sink.send(vec![3, 4, 5]).await.unwrap();
139 println!("{}", line!());
140 sink.send(vec![6, 7, 8]).await.unwrap();
141 println!("{}", line!());
142 sink.send(vec![9]).await.unwrap();
143 });
144 println!("{}", line!());
145 assert_eq!(
146 &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
147 &*out_recv.by_ref().collect::<Vec<_>>().await
148 );
149 println!("{}", line!());
150 a.await.unwrap();
151 }
152}