1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
use std::io;
use std::sync::{Arc, Mutex};
use std::time::Duration;

use anyhow::Result;
use futures::{Future, Stream, StreamExt};
use tokio::sync::{mpsc, oneshot};

pub async fn async_retry<T, F: Future<Output = Result<T>>>(
    mut thunk: impl FnMut() -> F,
    count: usize,
    delay: Duration,
) -> Result<T> {
    for _ in 1..count {
        let result = thunk().await;
        if result.is_ok() {
            return result;
        } else {
            tokio::time::sleep(delay).await;
        }
    }

    thunk().await
}

type PriorityBroadcacst = (
    Arc<Mutex<Option<oneshot::Sender<String>>>>,
    Arc<Mutex<Vec<mpsc::UnboundedSender<String>>>>,
);

pub fn prioritized_broadcast<T: Stream<Item = io::Result<String>> + Send + Unpin + 'static>(
    mut lines: T,
    default: impl Fn(String) + Send + 'static,
) -> PriorityBroadcacst {
    let priority_receivers = Arc::new(Mutex::new(None::<oneshot::Sender<String>>));
    let receivers = Arc::new(Mutex::new(Vec::<mpsc::UnboundedSender<String>>::new()));

    let weak_priority_receivers = Arc::downgrade(&priority_receivers);
    let weak_receivers = Arc::downgrade(&receivers);

    tokio::spawn(async move {
        while let Some(Result::Ok(line)) = lines.next().await {
            if let Some(deploy_receivers) = weak_priority_receivers.upgrade() {
                let mut deploy_receivers = deploy_receivers.lock().unwrap();

                let successful_send = if let Some(r) = deploy_receivers.take() {
                    r.send(line.clone()).is_ok()
                } else {
                    false
                };
                drop(deploy_receivers);

                if successful_send {
                    continue;
                }
            }

            if let Some(receivers) = weak_receivers.upgrade() {
                let mut receivers = receivers.lock().unwrap();
                receivers.retain(|receiver| !receiver.is_closed());

                let mut successful_send = false;
                for receiver in receivers.iter() {
                    successful_send |= receiver.send(line.clone()).is_ok();
                }
                if !successful_send {
                    (default)(line);
                }
            } else {
                break;
            }
        }

        if let Some(deploy_receivers) = weak_priority_receivers.upgrade() {
            let mut deploy_receivers = deploy_receivers.lock().unwrap();
            drop(deploy_receivers.take());
        }

        if let Some(receivers) = weak_receivers.upgrade() {
            let mut receivers = receivers.lock().unwrap();
            receivers.clear();
        }
    });

    (priority_receivers, receivers)
}

#[cfg(test)]
mod test {
    use tokio_stream::wrappers::UnboundedReceiverStream;

    use super::*;

    #[tokio::test]
    async fn broadcast_listeners_close_when_source_does() {
        let (tx, rx) = mpsc::unbounded_channel();
        let (_, receivers) = prioritized_broadcast(UnboundedReceiverStream::new(rx), |_| {});

        let (tx2, mut rx2) = mpsc::unbounded_channel();

        receivers.lock().unwrap().push(tx2);

        tx.send(Ok("hello".to_string())).unwrap();
        assert_eq!(rx2.recv().await, Some("hello".to_string()));

        let wait_again = tokio::spawn(async move { rx2.recv().await });

        drop(tx);

        assert_eq!(wait_again.await.unwrap(), None);
    }
}