sinktools/
flatten.rs

1//! [`Flatten`] and related items.
2use core::pin::Pin;
3use core::task::{Context, Poll, ready};
4
5use pin_project_lite::pin_project;
6
7use crate::{Sink, SinkBuild};
8
9pin_project! {
10    /// Same as [`core::iterator::Flatten`] but as a [`Sink`].
11    ///
12    /// Synchronously flattens items and sends the outputs to the following sink.
13    #[must_use = "sinks do nothing unless polled"]
14    pub struct Flatten<Si, IntoIter>
15    where
16        IntoIter: IntoIterator,
17    {
18        #[pin]
19        sink: Si,
20        // Current iterator and the next item.
21        iter_next: Option<(IntoIter::IntoIter, IntoIter::Item)>,
22    }
23}
24
25impl<Si, IntoIter> Flatten<Si, IntoIter>
26where
27    IntoIter: IntoIterator,
28{
29    /// Create with next `sink`.
30    pub fn new(sink: Si) -> Self
31    where
32        Self: Sink<IntoIter>,
33    {
34        Self {
35            sink,
36            iter_next: None,
37        }
38    }
39
40    fn poll_ready_impl(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Si::Error>>
41    where
42        Si: Sink<IntoIter::Item>,
43    {
44        let mut this = self.project();
45
46        while this.iter_next.is_some() {
47            // Ensure following sink is ready.
48            ready!(this.sink.as_mut().poll_ready(cx))?;
49
50            // Send the output the next item.
51            let (mut iter, next) = this.iter_next.take().unwrap();
52            this.sink.as_mut().start_send(next)?;
53
54            // Replace the iterator and next item (if any).
55            *this.iter_next = iter.next().map(|next| (iter, next));
56        }
57
58        Poll::Ready(Ok(()))
59    }
60}
61
62impl<Si, IntoIter> Sink<IntoIter> for Flatten<Si, IntoIter>
63where
64    Si: Sink<IntoIter::Item>,
65    IntoIter: IntoIterator,
66{
67    type Error = Si::Error;
68
69    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
70        self.poll_ready_impl(cx)
71    }
72
73    fn start_send(self: Pin<&mut Self>, item: IntoIter) -> Result<(), Self::Error> {
74        let this = self.project();
75
76        assert!(
77            this.iter_next.is_none(),
78            "Sink not ready: `poll_ready` must be called and return `Ready` before `start_send` is called."
79        );
80        let mut iter = item.into_iter();
81        *this.iter_next = iter.next().map(|next| (iter, next));
82        Ok(())
83    }
84
85    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86        ready!(self.as_mut().poll_ready_impl(cx)?);
87        self.project().sink.poll_flush(cx)
88    }
89
90    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
91        ready!(self.as_mut().poll_ready_impl(cx)?);
92        self.project().sink.poll_close(cx)
93    }
94}
95
96/// [`SinkBuild`] for [`Flatten`].
97pub struct FlattenBuilder<Prev> {
98    pub(crate) prev: Prev,
99}
100impl<Prev> SinkBuild for FlattenBuilder<Prev>
101where
102    Prev: SinkBuild,
103    Prev::Item: IntoIterator,
104{
105    type Item = <Prev::Item as IntoIterator>::Item;
106
107    type Output<Next: Sink<Self::Item>> = Prev::Output<Flatten<Next, Prev::Item>>;
108
109    fn send_to<Next>(self, next: Next) -> Self::Output<Next>
110    where
111        Next: Sink<Self::Item>,
112    {
113        self.prev.send_to(Flatten::new(next))
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use futures_util::stream::StreamExt;
120    use tokio::sync::mpsc::channel;
121    use tokio_stream::wrappers::ReceiverStream;
122    use tokio_util::sync::PollSender;
123
124    use super::*;
125    use crate::sink::SinkExt;
126
127    #[tokio::test]
128    async fn test_flatten() {
129        let (out_send, out_recv) = channel(2);
130        let out_send = PollSender::new(out_send);
131        let mut out_recv = ReceiverStream::new(out_recv);
132
133        let mut sink = Flatten::new(out_send);
134
135        let a = tokio::task::spawn(async move {
136            sink.send(vec![0, 1, 2]).await.unwrap();
137            println!("{}", line!());
138            sink.send(vec![3, 4, 5]).await.unwrap();
139            println!("{}", line!());
140            sink.send(vec![6, 7, 8]).await.unwrap();
141            println!("{}", line!());
142            sink.send(vec![9]).await.unwrap();
143        });
144        println!("{}", line!());
145        assert_eq!(
146            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
147            &*out_recv.by_ref().collect::<Vec<_>>().await
148        );
149        println!("{}", line!());
150        a.await.unwrap();
151    }
152}