hydro_deploy/
util.rs

1use std::io;
2use std::sync::{Arc, Mutex};
3use std::time::Duration;
4
5use anyhow::Result;
6use futures::{Future, Stream, StreamExt};
7use tokio::sync::oneshot;
8
9use crate::ssh::PrefixFilteredChannel;
10
11pub async fn async_retry<T, F: Future<Output = Result<T>>>(
12    mut thunk: impl FnMut() -> F,
13    count: usize,
14    delay: Duration,
15) -> Result<T> {
16    for _ in 1..count {
17        let result = thunk().await;
18        if result.is_ok() {
19            return result;
20        } else {
21            tokio::time::sleep(delay).await;
22        }
23    }
24
25    thunk().await
26}
27
28type PriorityBroadcacst = (
29    Arc<Mutex<Option<oneshot::Sender<String>>>>,
30    Arc<Mutex<Vec<PrefixFilteredChannel>>>,
31);
32
33pub fn prioritized_broadcast<T: Stream<Item = io::Result<String>> + Send + Unpin + 'static>(
34    mut lines: T,
35    default: impl Fn(String) + Send + 'static,
36) -> PriorityBroadcacst {
37    let priority_receivers = Arc::new(Mutex::new(None::<oneshot::Sender<String>>));
38    // Option<String> is the prefix to separate special stdout messages from regular ones
39    let receivers = Arc::new(Mutex::new(Vec::<PrefixFilteredChannel>::new()));
40
41    let weak_priority_receivers = Arc::downgrade(&priority_receivers);
42    let weak_receivers = Arc::downgrade(&receivers);
43
44    tokio::spawn(async move {
45        while let Some(Result::Ok(line)) = lines.next().await {
46            if let Some(deploy_receivers) = weak_priority_receivers.upgrade() {
47                let mut deploy_receivers = deploy_receivers.lock().unwrap();
48
49                let successful_send = if let Some(r) = deploy_receivers.take() {
50                    r.send(line.clone()).is_ok()
51                } else {
52                    false
53                };
54                drop(deploy_receivers);
55
56                if successful_send {
57                    continue;
58                }
59            }
60
61            if let Some(receivers) = weak_receivers.upgrade() {
62                let mut receivers = receivers.lock().unwrap();
63                receivers.retain(|receiver| !receiver.1.is_closed());
64
65                let mut successful_send = false;
66                // Send to specific receivers if the filter prefix matches
67                for (prefix_filter, receiver) in receivers.iter() {
68                    if prefix_filter
69                        .as_ref()
70                        .map(|prefix| line.starts_with(prefix))
71                        .unwrap_or(true)
72                    {
73                        successful_send |= receiver.send(line.clone()).is_ok();
74                    }
75                }
76                if !successful_send {
77                    (default)(line);
78                }
79            } else {
80                break;
81            }
82        }
83
84        if let Some(deploy_receivers) = weak_priority_receivers.upgrade() {
85            let mut deploy_receivers = deploy_receivers.lock().unwrap();
86            drop(deploy_receivers.take());
87        }
88
89        if let Some(receivers) = weak_receivers.upgrade() {
90            let mut receivers = receivers.lock().unwrap();
91            receivers.clear();
92        }
93    });
94
95    (priority_receivers, receivers)
96}
97
98#[cfg(test)]
99mod test {
100    use tokio::sync::mpsc;
101    use tokio_stream::wrappers::UnboundedReceiverStream;
102
103    use super::*;
104
105    #[tokio::test]
106    async fn broadcast_listeners_close_when_source_does() {
107        let (tx, rx) = mpsc::unbounded_channel();
108        let (_, receivers) = prioritized_broadcast(UnboundedReceiverStream::new(rx), |_| {});
109
110        let (tx2, mut rx2) = mpsc::unbounded_channel();
111
112        receivers.lock().unwrap().push((None, tx2));
113
114        tx.send(Ok("hello".to_string())).unwrap();
115        assert_eq!(rx2.recv().await, Some("hello".to_string()));
116
117        let wait_again = tokio::spawn(async move { rx2.recv().await });
118
119        drop(tx);
120
121        assert_eq!(wait_again.await.unwrap(), None);
122    }
123}