hydro_deploy/
util.rs

1use std::sync::{Arc, Mutex, Weak};
2use std::time::Duration;
3
4use futures::{Future, Stream, StreamExt};
5use tokio::sync::{mpsc, oneshot};
6
7pub async fn async_retry<T, E, F: Future<Output = Result<T, E>>>(
8    mut thunk: impl FnMut() -> F,
9    count: usize,
10    delay: Duration,
11) -> Result<T, E> {
12    for _ in 1..count {
13        let result = thunk().await;
14        if result.is_ok() {
15            return result;
16        } else {
17            tokio::time::sleep(delay).await;
18        }
19    }
20
21    thunk().await
22}
23
24#[derive(Clone)]
25pub struct PriorityBroadcast(Weak<Mutex<PriorityBroadcastInternal>>);
26
27struct PriorityBroadcastInternal {
28    priority_sender: Option<oneshot::Sender<String>>,
29    senders: Vec<(Option<String>, mpsc::UnboundedSender<String>)>,
30}
31
32impl PriorityBroadcast {
33    pub fn receive_priority(&self) -> oneshot::Receiver<String> {
34        let (sender, receiver) = oneshot::channel::<String>();
35
36        if let Some(internal) = self.0.upgrade() {
37            let mut internal = internal.lock().unwrap();
38            let prev_sender = internal.priority_sender.replace(sender);
39            if prev_sender.is_some() {
40                panic!("Only one deploy stdout receiver is allowed at a time");
41            }
42        }
43
44        receiver
45    }
46
47    pub fn receive(&self, prefix: Option<String>) -> mpsc::UnboundedReceiver<String> {
48        let (sender, receiver) = mpsc::unbounded_channel::<String>();
49
50        if let Some(internal) = self.0.upgrade() {
51            let mut internal = internal.lock().unwrap();
52            internal.senders.push((prefix, sender));
53        }
54
55        receiver
56    }
57}
58
59pub fn prioritized_broadcast<T: Stream<Item = std::io::Result<String>> + Send + Unpin + 'static>(
60    mut lines: T,
61    fallback_receiver: impl Fn(String) + Send + 'static,
62) -> PriorityBroadcast {
63    let internal = Arc::new(Mutex::new(PriorityBroadcastInternal {
64        priority_sender: None,
65        senders: Vec::new(),
66    }));
67
68    let weak_internal = Arc::downgrade(&internal);
69
70    // TODO(mingwei): eliminate the need for a separate task.
71    tokio::spawn(async move {
72        while let Some(Ok(line)) = lines.next().await {
73            let mut internal = internal.lock().unwrap();
74
75            // Priority receiver
76            if let Some(priority_sender) = internal.priority_sender.take()
77                && priority_sender.send(line.clone()).is_ok()
78            {
79                continue; // Skip regular receivers if successfully sent to the priority receiver.
80            }
81
82            // Regular receivers
83            internal.senders.retain(|receiver| !receiver.1.is_closed());
84
85            let mut successful_send = false;
86            for (prefix_filter, sender) in internal.senders.iter() {
87                // Send to specific receivers if the filter prefix matches
88                if prefix_filter
89                    .as_ref()
90                    .is_none_or(|prefix| line.starts_with(prefix))
91                {
92                    successful_send |= sender.send(line.clone()).is_ok();
93                }
94            }
95
96            // If no receivers successfully received the line, use the fallback receiver.
97            if !successful_send {
98                (fallback_receiver)(line);
99            }
100        }
101        // Dropping `internal` will close all senders because it is the only strong `Arc` reference.
102    });
103
104    PriorityBroadcast(weak_internal)
105}
106
107#[cfg(test)]
108mod test {
109    use tokio::sync::mpsc;
110    use tokio_stream::wrappers::UnboundedReceiverStream;
111
112    use super::*;
113
114    #[tokio::test]
115    async fn broadcast_listeners_close_when_source_does() {
116        let (tx, rx) = mpsc::unbounded_channel();
117        let priority_broadcast = prioritized_broadcast(UnboundedReceiverStream::new(rx), |_| {});
118
119        let mut rx2 = priority_broadcast.receive(None);
120
121        tx.send(Ok("hello".to_string())).unwrap();
122        assert_eq!(rx2.recv().await, Some("hello".to_string()));
123
124        let wait_again = tokio::spawn(async move { rx2.recv().await });
125
126        drop(tx);
127
128        assert_eq!(wait_again.await.unwrap(), None);
129    }
130}