dfir_rs/compiled/pull/
flatten.rs

1use std::pin::Pin;
2use std::task::{Context, Poll, ready};
3
4use futures::stream::Stream;
5use pin_project_lite::pin_project;
6
7pin_project! {
8    /// Same as [`Iterator::flatten`] but as a [`Stream`].
9    ///
10    /// Flattens a stream of iterables into a stream of their items.
11    #[must_use = "streams do nothing unless polled"]
12    pub struct Flatten<St, IntoIter> {
13        #[pin]
14        stream: St,
15        // Current iterator being consumed
16        current_iter: Option<IntoIter>,
17    }
18}
19
20impl<St, IntoIter> Flatten<St, IntoIter::IntoIter>
21where
22    St: Stream<Item = IntoIter>,
23    IntoIter: IntoIterator,
24{
25    /// Create with source `stream`.
26    pub fn new(stream: St) -> Self {
27        Self {
28            stream,
29            current_iter: None,
30        }
31    }
32}
33
34impl<St, IntoIter> Stream for Flatten<St, IntoIter::IntoIter>
35where
36    St: Stream<Item = IntoIter>,
37    IntoIter: IntoIterator,
38{
39    type Item = IntoIter::Item;
40
41    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
42        let mut this = self.project();
43
44        loop {
45            // First, try to get the next item from the current iterator
46            if let Some(iter) = this.current_iter.as_mut() {
47                if let Some(item) = iter.next() {
48                    return Poll::Ready(Some(item));
49                }
50                // Current iterator is exhausted, clear it
51                *this.current_iter = None;
52            }
53
54            // Get the next item from the stream and create a new iterator
55            if let Some(stream_item) = ready!(this.stream.as_mut().poll_next(cx)) {
56                *this.current_iter = Some(stream_item.into_iter());
57                // Loop back to try getting an item from the new iterator
58            } else {
59                // Stream is exhausted
60                return Poll::Ready(None);
61            }
62        }
63    }
64
65    fn size_hint(&self) -> (usize, Option<usize>) {
66        let lower = self
67            .current_iter
68            .as_ref()
69            .map_or(0, |iter| iter.size_hint().0);
70        (lower, None)
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use futures::stream::{self, StreamExt};
77
78    use super::*;
79
80    #[tokio::test]
81    async fn test_flatten_basic() {
82        let stream = stream::iter(vec![vec![1, 2], vec![3, 4, 5], vec![]]);
83        let flattened = Flatten::new(stream);
84        let result: Vec<i32> = flattened.collect().await;
85        assert_eq!(result, vec![1, 2, 3, 4, 5]);
86    }
87
88    #[tokio::test]
89    async fn test_flatten_empty() {
90        let stream = stream::iter(Vec::<Vec<i32>>::new());
91        let flattened = Flatten::new(stream);
92        let result: Vec<i32> = flattened.collect().await;
93        assert_eq!(result, Vec::<i32>::new());
94    }
95
96    #[tokio::test]
97    async fn test_flatten_strings() {
98        let stream = stream::iter(vec![
99            "hello".chars().collect::<Vec<_>>(),
100            "world".chars().collect::<Vec<_>>(),
101        ]);
102        let flattened = Flatten::new(stream);
103        let result: Vec<char> = flattened.collect().await;
104        assert_eq!(
105            result,
106            vec!['h', 'e', 'l', 'l', 'o', 'w', 'o', 'r', 'l', 'd']
107        );
108    }
109
110    #[tokio::test]
111    async fn test_flatten_options() {
112        let stream = stream::iter(vec![Some(1), None, Some(2), Some(3)]);
113        let flattened = Flatten::new(stream);
114        let result: Vec<i32> = flattened.collect().await;
115        assert_eq!(result, vec![1, 2, 3]);
116    }
117}