1use core::marker::PhantomData;
4use core::pin::Pin;
5use core::task::{Context, Poll};
6
7use pin_project_lite::pin_project;
8use sealed::sealed;
9use variadics::Variadic;
10
11use crate::{Sink, forward_sink, ready_both};
12
13#[sealed]
17pub trait SinkVariadic<Item, Error>: Variadic {
18 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>>;
20
21 fn start_send(self: Pin<&mut Self>, idx: usize, item: Item) -> Result<(), Error>;
23
24 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>>;
26
27 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>>;
29}
30
31#[sealed]
32impl<Si, Item, Rest> SinkVariadic<Item, Si::Error> for (Si, Rest)
33where
34 Si: Sink<Item>,
35 Rest: SinkVariadic<Item, Si::Error>,
36{
37 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Si::Error>> {
38 let (sink, rest) = pin_project_pair(self);
39 ready_both!(sink.poll_ready(cx)?, rest.poll_ready(cx)?);
40 Poll::Ready(Ok(()))
41 }
42
43 fn start_send(self: Pin<&mut Self>, idx: usize, item: Item) -> Result<(), Si::Error> {
44 let (sink, rest) = pin_project_pair(self);
45 if idx == 0 {
46 sink.start_send(item)
47 } else {
48 rest.start_send(idx - 1, item)
49 }
50 }
51
52 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Si::Error>> {
53 let (sink, rest) = pin_project_pair(self);
54 ready_both!(sink.poll_flush(cx)?, rest.poll_flush(cx)?);
55 Poll::Ready(Ok(()))
56 }
57
58 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Si::Error>> {
59 let (sink, rest) = pin_project_pair(self);
60 ready_both!(sink.poll_close(cx)?, rest.poll_close(cx)?);
61 Poll::Ready(Ok(()))
62 }
63}
64
65#[sealed]
66impl<Item, Error> SinkVariadic<Item, Error> for () {
67 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
68 Poll::Ready(Ok(()))
69 }
70
71 fn start_send(self: Pin<&mut Self>, idx: usize, _item: Item) -> Result<(), Error> {
72 panic!("index out of bounds (len + {idx})");
73 }
74
75 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
76 Poll::Ready(Ok(()))
77 }
78
79 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
80 Poll::Ready(Ok(()))
81 }
82}
83
84fn pin_project_pair<A, B>(pair: Pin<&mut (A, B)>) -> (Pin<&mut A>, Pin<&mut B>) {
85 unsafe {
87 let (a, b) = pair.get_unchecked_mut();
88 (Pin::new_unchecked(a), Pin::new_unchecked(b))
89 }
90}
91
92pin_project! {
93 #[must_use = "sinks do nothing unless polled"]
95 pub struct DemuxVar<Sinks, Error> {
96 #[pin]
97 sink: Sinks,
98 _marker: PhantomData<fn() -> Error>,
100 }
101}
102
103impl<Sinks, Error> DemuxVar<Sinks, Error> {
104 pub fn new<Item>(sinks: Sinks) -> Self
106 where
107 Self: Sink<Item>,
108 {
109 Self {
110 sink: sinks,
111 _marker: PhantomData,
112 }
113 }
114}
115
116impl<Sinks, Item, Error> Sink<(usize, Item)> for DemuxVar<Sinks, Error>
117where
118 Sinks: SinkVariadic<Item, Error>,
119{
120 type Error = Error;
121
122 fn start_send(self: Pin<&mut Self>, (idx, item): (usize, Item)) -> Result<(), Self::Error> {
123 self.project().sink.start_send(idx, item)
124 }
125
126 forward_sink!(poll_ready, poll_flush, poll_close);
127}
128
129pub fn demux_var<Sinks, Item, Error>(sinks: Sinks) -> DemuxVar<Sinks, Error>
131where
132 Sinks: SinkVariadic<Item, Error>,
133{
134 DemuxVar::new(sinks)
135}