Skip to main content

dfir_rs/compiled/pull/
zip_longest.rs

1use std::pin::Pin;
2use std::task::{Context, Poll, ready};
3
4use futures::stream::{FusedStream, Stream};
5use itertools::EitherOrBoth;
6use pin_project_lite::pin_project;
7
8pin_project! {
9    /// Special stream for the `zip_longest` operator.
10    #[must_use = "streams do nothing unless polled"]
11    pub struct ZipLongest<St1: Stream, St2: Stream> {
12        #[pin]
13        stream1: St1,
14        #[pin]
15        stream2: St2,
16        // Buffers an item from `stream1` so it is not lost if `stream2` returns `Poll::Pending`.
17        // `None` = no buffered item (stream1 not yet polled, or result already consumed);
18        // `Some(item)` = item waiting to be paired with stream2's next value.
19        item1: Option<St1::Item>,
20    }
21}
22
23impl<St1, St2> ZipLongest<St1, St2>
24where
25    St1: FusedStream,
26    St2: FusedStream,
27{
28    /// Create a new `ZipLongest` stream from two source streams.
29    pub fn new(stream1: St1, stream2: St2) -> Self {
30        Self {
31            stream1,
32            stream2,
33            item1: None,
34        }
35    }
36}
37
38impl<St1, St2> Stream for ZipLongest<St1, St2>
39where
40    St1: FusedStream,
41    St2: FusedStream,
42{
43    type Item = EitherOrBoth<St1::Item, St2::Item>;
44
45    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
46        let mut this = self.project();
47
48        // Store `item1` so it is not dropped if `stream2` returns `Poll::Pending`.
49        if this.item1.is_none() {
50            *this.item1 = ready!(this.stream1.as_mut().poll_next(cx));
51        }
52        let item2 = ready!(this.stream2.as_mut().poll_next(cx));
53
54        Poll::Ready(match (this.item1.take(), item2) {
55            (None, None) => None,
56            (Some(item1), None) => Some(EitherOrBoth::Left(item1)),
57            (None, Some(item2)) => Some(EitherOrBoth::Right(item2)),
58            (Some(item1), Some(item2)) => Some(EitherOrBoth::Both(item1, item2)),
59        })
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use std::pin::pin;
66    use std::task::{Context, Poll, Waker};
67
68    use futures::stream::{FusedStream, Stream};
69    use itertools::EitherOrBoth;
70
71    use super::ZipLongest;
72
73    /// A stream that returns `Poll::Pending` for the first `pending_count` polls, then yields
74    /// items from `items`, then returns `Poll::Ready(None)`.
75    struct PendingThenItems<T> {
76        pending_count: usize,
77        items: std::vec::IntoIter<T>,
78        done: bool,
79    }
80
81    impl<T> PendingThenItems<T> {
82        fn new(pending_count: usize, items: Vec<T>) -> Self {
83            Self {
84                pending_count,
85                items: items.into_iter(),
86                done: false,
87            }
88        }
89    }
90
91    impl<T: Unpin> Stream for PendingThenItems<T> {
92        type Item = T;
93
94        fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
95            if self.pending_count > 0 {
96                self.pending_count -= 1;
97                cx.waker().wake_by_ref();
98                return Poll::Pending;
99            }
100            let item = self.items.next();
101            if item.is_none() {
102                self.done = true;
103            }
104            Poll::Ready(item)
105        }
106    }
107
108    impl<T: Unpin> FusedStream for PendingThenItems<T> {
109        fn is_terminated(&self) -> bool {
110            self.done
111        }
112    }
113
114    /// Regression test: LHS item must not be dropped when RHS returns `Poll::Pending`.
115    #[test]
116    fn test_lhs_not_dropped_when_rhs_pending() {
117        // LHS: immediately yields 1, 2, 3
118        // RHS: returns Pending once, then yields 10, 20
119        let lhs = PendingThenItems::new(0, vec![1_i32, 2, 3]);
120        let rhs = PendingThenItems::new(1, vec![10_i32, 20]);
121
122        let mut zip = pin!(ZipLongest::new(lhs, rhs));
123        let mut cx = Context::from_waker(Waker::noop());
124
125        // First poll: LHS ready(1), RHS pending -> should return Pending (with LHS buffered)
126        assert_eq!(Poll::Pending, zip.as_mut().poll_next(&mut cx));
127
128        // Second poll: LHS buffered(1), RHS ready(10) -> Both(1, 10)
129        assert_eq!(
130            Poll::Ready(Some(EitherOrBoth::Both(1, 10))),
131            zip.as_mut().poll_next(&mut cx)
132        );
133
134        // Remaining: Both(2, 20), Left(3), None
135        assert_eq!(
136            Poll::Ready(Some(EitherOrBoth::Both(2, 20))),
137            zip.as_mut().poll_next(&mut cx)
138        );
139        assert_eq!(
140            Poll::Ready(Some(EitherOrBoth::Left(3))),
141            zip.as_mut().poll_next(&mut cx)
142        );
143        assert_eq!(Poll::Ready(None), zip.as_mut().poll_next(&mut cx));
144    }
145}