sinktools/
demux_map_lazy.rs1use core::fmt::Debug;
3use core::hash::Hash;
4use core::pin::Pin;
5use core::task::{Context, Poll};
6use std::collections::HashMap;
7
8use crate::{Sink, ready_both};
9
10pub struct LazyDemuxSink<Key, Si, Func> {
12 sinks: HashMap<Key, Si>,
13 func: Func,
14}
15
16impl<Key, Si, Func> LazyDemuxSink<Key, Si, Func> {
17 pub fn new<Item>(func: Func) -> Self
19 where
20 Self: Sink<(Key, Item)>,
21 {
22 Self {
23 sinks: HashMap::new(),
24 func,
25 }
26 }
27}
28
29impl<Key, Si, Item, Func> Sink<(Key, Item)> for LazyDemuxSink<Key, Si, Func>
30where
31 Key: Eq + Hash + Debug + Unpin,
32 Si: Sink<Item> + Unpin,
33 Func: FnMut(&Key) -> Si + Unpin,
34{
35 type Error = Si::Error;
36
37 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
38 self.get_mut()
39 .sinks
40 .values_mut()
41 .try_fold(Poll::Ready(()), |poll, sink| {
42 ready_both!(poll, Pin::new(sink).poll_ready(cx)?);
43 Poll::Ready(Ok(()))
44 })
45 }
46
47 fn start_send(self: Pin<&mut Self>, item: (Key, Item)) -> Result<(), Self::Error> {
48 let this = self.get_mut();
49 let sink = this
50 .sinks
51 .entry(item.0)
52 .or_insert_with_key(|k| (this.func)(k));
53 Pin::new(sink).start_send(item.1)
54 }
55
56 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
57 self.get_mut()
58 .sinks
59 .values_mut()
60 .try_fold(Poll::Ready(()), |poll, sink| {
61 ready_both!(poll, Pin::new(sink).poll_flush(cx)?);
62 Poll::Ready(Ok(()))
63 })
64 }
65
66 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
67 self.get_mut()
68 .sinks
69 .values_mut()
70 .try_fold(Poll::Ready(()), |poll, sink| {
71 ready_both!(poll, Pin::new(sink).poll_close(cx)?);
72 Poll::Ready(Ok(()))
73 })
74 }
75}
76
77pub fn demux_map_lazy<Key, Si, Item, Func>(func: Func) -> LazyDemuxSink<Key, Si, Func>
81where
82 Key: Eq + Hash + Debug + Unpin,
83 Si: Sink<Item> + Unpin,
84 Func: FnMut(&Key) -> Si + Unpin,
85{
86 LazyDemuxSink::new(func)
87}
88
89#[cfg(test)]
90mod test {
91 use core::cell::RefCell;
92 use core::pin::pin;
93 use std::collections::HashMap;
94 use std::rc::Rc;
95
96 use futures_util::SinkExt;
97
98 use super::*;
99 use crate::for_each::ForEach;
100
101 #[tokio::test]
102 async fn test_lazy_demux_sink() {
103 let outputs: Rc<RefCell<HashMap<String, Vec<u8>>>> = Rc::new(RefCell::new(HashMap::new()));
104 let outputs_clone = outputs.clone();
105
106 let mut sink = demux_map_lazy(move |key: &String| {
107 let key = key.clone();
108 let outputs = outputs_clone.clone();
109 ForEach::new(move |item: &[u8]| {
110 outputs
111 .borrow_mut()
112 .entry(key.clone())
113 .or_default()
114 .extend_from_slice(item);
115 })
116 });
117
118 sink.send(("a".to_owned(), b"test1".as_slice()))
119 .await
120 .unwrap();
121 sink.send(("b".to_owned(), b"test2".as_slice()))
122 .await
123 .unwrap();
124 sink.send(("a".to_owned(), b"test3".as_slice()))
125 .await
126 .unwrap();
127 sink.flush().await.unwrap();
128 sink.close().await.unwrap();
129
130 let outputs = outputs.borrow();
131 assert_eq!(outputs.get("a").unwrap().as_slice(), b"test1test3");
132 assert_eq!(outputs.get("b").unwrap().as_slice(), b"test2");
133 }
134
135 #[test]
136 fn test_lazy_demux_sink_good() {
137 use core::task::Context;
138
139 let outputs: Rc<RefCell<HashMap<String, Vec<u8>>>> = Rc::new(RefCell::new(HashMap::new()));
140 let outputs_clone = outputs.clone();
141
142 let mut sink = pin!(demux_map_lazy(move |key: &String| {
143 let outputs = outputs_clone.clone();
144 let key = key.clone();
145 ForEach::new(move |item: &[u8]| {
146 outputs
147 .borrow_mut()
148 .entry(key.clone())
149 .or_default()
150 .extend_from_slice(item);
151 })
152 }));
153
154 let cx = &mut Context::from_waker(futures_task::noop_waker_ref());
155
156 assert_eq!(Poll::Ready(Ok(())), sink.as_mut().poll_ready(cx));
157 assert_eq!(
158 Ok(()),
159 sink.as_mut()
160 .start_send(("a".to_owned(), b"test1".as_slice()))
161 );
162 assert_eq!(
163 Ok(()),
164 sink.as_mut()
165 .start_send(("b".to_owned(), b"test2".as_slice()))
166 );
167 assert_eq!(
168 Ok(()),
169 sink.as_mut()
170 .start_send(("a".to_owned(), b"test3".as_slice()))
171 );
172 assert_eq!(Poll::Ready(Ok(())), sink.as_mut().poll_flush(cx));
173 assert_eq!(Poll::Ready(Ok(())), sink.as_mut().poll_close(cx));
174
175 let outputs = outputs.borrow();
176 assert_eq!(outputs.get("a").unwrap().as_slice(), b"test1test3");
177 assert_eq!(outputs.get("b").unwrap().as_slice(), b"test2");
178 }
179}