1use std::sync::{Arc, Mutex, Weak};
2use std::time::Duration;
3
4use futures::{Future, Stream, StreamExt};
5use tokio::sync::{mpsc, oneshot};
6
7pub async fn async_retry<T, E, F: Future<Output = Result<T, E>>>(
8 mut thunk: impl FnMut() -> F,
9 count: usize,
10 delay: Duration,
11) -> Result<T, E> {
12 for _ in 1..count {
13 let result = thunk().await;
14 if result.is_ok() {
15 return result;
16 } else {
17 tokio::time::sleep(delay).await;
18 }
19 }
20
21 thunk().await
22}
23
24#[derive(Clone)]
25pub struct PriorityBroadcast(Weak<Mutex<PriorityBroadcastInternal>>);
26
27struct PriorityBroadcastInternal {
28 priority_sender: Option<oneshot::Sender<String>>,
29 senders: Vec<(Option<String>, mpsc::UnboundedSender<String>)>,
30}
31
32impl PriorityBroadcast {
33 pub fn receive_priority(&self) -> oneshot::Receiver<String> {
34 let (sender, receiver) = oneshot::channel::<String>();
35
36 if let Some(internal) = self.0.upgrade() {
37 let mut internal = internal.lock().unwrap();
38 let prev_sender = internal.priority_sender.replace(sender);
39 if prev_sender.is_some() {
40 panic!("Only one deploy stdout receiver is allowed at a time");
41 }
42 }
43
44 receiver
45 }
46
47 pub fn receive(&self, prefix: Option<String>) -> mpsc::UnboundedReceiver<String> {
48 let (sender, receiver) = mpsc::unbounded_channel::<String>();
49
50 if let Some(internal) = self.0.upgrade() {
51 let mut internal = internal.lock().unwrap();
52 internal.senders.push((prefix, sender));
53 }
54
55 receiver
56 }
57}
58
59pub fn prioritized_broadcast<T: Stream<Item = std::io::Result<String>> + Send + Unpin + 'static>(
60 mut lines: T,
61 fallback_receiver: impl Fn(String) + Send + 'static,
62) -> PriorityBroadcast {
63 let internal = Arc::new(Mutex::new(PriorityBroadcastInternal {
64 priority_sender: None,
65 senders: Vec::new(),
66 }));
67
68 let weak_internal = Arc::downgrade(&internal);
69
70 tokio::spawn(async move {
72 while let Some(Ok(line)) = lines.next().await {
73 let mut internal = internal.lock().unwrap();
74
75 if let Some(priority_sender) = internal.priority_sender.take()
77 && priority_sender.send(line.clone()).is_ok()
78 {
79 continue; }
81
82 internal.senders.retain(|receiver| !receiver.1.is_closed());
84
85 let mut successful_send = false;
86 for (prefix_filter, sender) in internal.senders.iter() {
87 if prefix_filter
89 .as_ref()
90 .is_none_or(|prefix| line.starts_with(prefix))
91 {
92 successful_send |= sender.send(line.clone()).is_ok();
93 }
94 }
95
96 if !successful_send {
98 (fallback_receiver)(line);
99 }
100 }
101 });
103
104 PriorityBroadcast(weak_internal)
105}
106
107#[cfg(test)]
108mod test {
109 use tokio::sync::mpsc;
110 use tokio_stream::wrappers::UnboundedReceiverStream;
111
112 use super::*;
113
114 #[tokio::test]
115 async fn broadcast_listeners_close_when_source_does() {
116 let (tx, rx) = mpsc::unbounded_channel();
117 let priority_broadcast = prioritized_broadcast(UnboundedReceiverStream::new(rx), |_| {});
118
119 let mut rx2 = priority_broadcast.receive(None);
120
121 tx.send(Ok("hello".to_string())).unwrap();
122 assert_eq!(rx2.recv().await, Some("hello".to_string()));
123
124 let wait_again = tokio::spawn(async move { rx2.recv().await });
125
126 drop(tx);
127
128 assert_eq!(wait_again.await.unwrap(), None);
129 }
130}