dfir_rs/scheduled/
graph_ext.rs

1//! Helper extensions for [`Dfir`].
2
3use core::task;
4use std::borrow::Cow;
5use std::sync::mpsc::SyncSender;
6use std::task::Poll;
7
8use futures::Stream;
9
10use super::SubgraphId;
11use super::context::Context;
12use super::graph::Dfir;
13use super::handoff::{CanReceive, Handoff};
14use super::input::Input;
15use super::port::{RecvCtx, RecvPort, SendCtx, SendPort};
16
17macro_rules! subgraph_ext {
18    (
19        $fn_name:ident,
20        ( $($recv_param:ident : $recv_generic:ident),* ),
21        ( $($send_param:ident : $send_generic:ident),* )
22    ) => {
23        /// Adds a subgraph with specific topology:
24        ///
25        #[doc = concat!("* Inputs: ", $( stringify!( $recv_generic ), ", ", )*)]
26        #[doc = concat!("* Outputs: ", $( stringify!( $send_generic ), ", ", )*)]
27        fn $fn_name <Name, F, $($recv_generic,)* $($send_generic),*> (
28            &mut self,
29            name: Name,
30            $($recv_param : RecvPort< $recv_generic >,)*
31            $($send_param : SendPort< $send_generic >,)* subgraph: F
32        ) -> SubgraphId
33        where
34            Name: Into<Cow<'static, str>>,
35            F: 'static + FnMut(&Context, $(&RecvCtx< $recv_generic >,)* $(&SendCtx< $send_generic >),*),
36            $($recv_generic : 'static + Handoff,)*
37            $($send_generic : 'static + Handoff,)*;
38    };
39    (
40        impl
41        $fn_name:ident,
42        ( $($recv_param:ident : $recv_generic:ident),* ),
43        ( $($send_param:ident : $send_generic:ident),* )
44    ) => {
45        fn $fn_name <Name, F, $($recv_generic,)* $($send_generic),*> (
46            &mut self,
47            name: Name,
48            $($recv_param : RecvPort< $recv_generic >,)*
49            $($send_param : SendPort< $send_generic >,)* subgraph: F
50        ) -> SubgraphId
51        where
52            Name: Into<Cow<'static, str>>,
53            F: 'static + FnMut(&Context, $(&RecvCtx< $recv_generic >,)* $(&SendCtx< $send_generic >),*),
54            $($recv_generic : 'static + Handoff,)*
55            $($send_generic : 'static + Handoff,)*
56        {
57            let mut subgraph = subgraph;
58            self.add_subgraph(
59                name,
60                crate::var_expr!($($recv_param),*),
61                crate::var_expr!($($send_param),*),
62                move |ctx, crate::var_args!($($recv_param),*), crate::var_args!($($send_param),*)|
63                    (subgraph)(ctx, $($recv_param,)* $($send_param),*),
64            )
65        }
66    };
67}
68
69/// Convenience extension methods for the [`Dfir`] struct.
70pub trait GraphExt {
71    subgraph_ext!(add_subgraph_sink, (recv_port: R), ());
72    subgraph_ext!(add_subgraph_2sink, (recv_port_1: R1, recv_port_2: R2), ());
73
74    subgraph_ext!(add_subgraph_source, (), (send_port: W));
75
76    subgraph_ext!(add_subgraph_in_out, (recv_port: R), (send_port: W));
77    subgraph_ext!(
78        add_subgraph_in_2out,
79        (recv_port: R),
80        (send_port_1: W1, send_port_2: W2)
81    );
82
83    subgraph_ext!(
84        add_subgraph_2in_out,
85        (recv_port_1: R1, recv_port_2: R2),
86        (send_port: W)
87    );
88    subgraph_ext!(
89        add_subgraph_2in_2out,
90        (recv_port_1: R1, recv_port_2: R2),
91        (send_port_1: W1, send_port_2: W2)
92    );
93
94    /// Adds a channel input which sends to the `send_port`.
95    fn add_channel_input<Name, T, W>(
96        &mut self,
97        name: Name,
98        send_port: SendPort<W>,
99    ) -> Input<T, SyncSender<T>>
100    where
101        Name: Into<Cow<'static, str>>,
102        T: 'static,
103        W: 'static + Handoff + CanReceive<T>;
104
105    /// Adds an "input" operator, returning a handle to insert data into it.
106    /// TODO(justin): make this thing work better
107    fn add_input<Name, T, W>(
108        &mut self,
109        name: Name,
110        send_port: SendPort<W>,
111    ) -> Input<T, super::input::Buffer<T>>
112    where
113        Name: Into<Cow<'static, str>>,
114        T: 'static,
115        W: 'static + Handoff + CanReceive<T>;
116
117    /// Adds a subgraph which pulls from the async stream and sends to the `send_port`.
118    fn add_input_from_stream<Name, T, W, S>(
119        &mut self,
120        name: Name,
121        send_port: SendPort<W>,
122        stream: S,
123    ) where
124        Name: Into<Cow<'static, str>>,
125        S: 'static + Stream<Item = T>,
126        W: 'static + Handoff + CanReceive<T>;
127}
128
129impl GraphExt for Dfir<'_> {
130    subgraph_ext!(impl add_subgraph_sink, (recv_port: R), ());
131    subgraph_ext!(
132        impl add_subgraph_2sink,
133        (recv_port_1: R1, recv_port_2: R2),
134        ()
135    );
136
137    subgraph_ext!(impl add_subgraph_source, (), (send_port: W));
138
139    subgraph_ext!(impl add_subgraph_in_out, (recv_port: R), (send_port: W));
140    subgraph_ext!(
141        impl add_subgraph_in_2out,
142        (recv_port: R),
143        (send_port_1: W1, send_port_2: W2)
144    );
145
146    subgraph_ext!(
147        impl add_subgraph_2in_out,
148        (recv_port_1: R1, recv_port_2: R2),
149        (send_port: W)
150    );
151    subgraph_ext!(
152        impl add_subgraph_2in_2out,
153        (recv_port_1: R1, recv_port_2: R2),
154        (send_port_1: W1, send_port_2: W2)
155    );
156
157    fn add_channel_input<Name, T, W>(
158        &mut self,
159        name: Name,
160        send_port: SendPort<W>,
161    ) -> Input<T, SyncSender<T>>
162    where
163        Name: Into<Cow<'static, str>>,
164        T: 'static,
165        W: 'static + Handoff + CanReceive<T>,
166    {
167        use std::sync::mpsc;
168
169        let (sender, receiver) = mpsc::sync_channel(8000);
170        let sg_id = self.add_subgraph_source::<_, _, W>(name, send_port, move |_ctx, send| {
171            for x in receiver.try_iter() {
172                send.give(x);
173            }
174        });
175        Input::new(self.reactor(), sg_id, sender)
176    }
177
178    fn add_input<Name, T, W>(
179        &mut self,
180        name: Name,
181        send_port: SendPort<W>,
182    ) -> Input<T, super::input::Buffer<T>>
183    where
184        Name: Into<Cow<'static, str>>,
185        T: 'static,
186        W: 'static + Handoff + CanReceive<T>,
187    {
188        let input = super::input::Buffer::default();
189        let inner_input = input.clone();
190        let sg_id = self.add_subgraph_source::<_, _, W>(name, send_port, move |_ctx, send| {
191            for x in (*inner_input.0).borrow_mut().drain(..) {
192                send.give(x);
193            }
194        });
195        Input::new(self.reactor(), sg_id, input)
196    }
197
198    fn add_input_from_stream<Name, T, W, S>(
199        &mut self,
200        name: Name,
201        send_port: SendPort<W>,
202        stream: S,
203    ) where
204        Name: Into<Cow<'static, str>>,
205        S: 'static + Stream<Item = T>,
206        W: 'static + Handoff + CanReceive<T>,
207    {
208        let mut stream = Box::pin(stream);
209        self.add_subgraph_source::<_, _, W>(name, send_port, move |ctx, send| {
210            let waker = ctx.waker();
211            let mut cx = task::Context::from_waker(&waker);
212            while let Poll::Ready(Some(v)) = stream.as_mut().poll_next(&mut cx) {
213                send.give(v);
214            }
215        });
216    }
217}