dfir_rs/scheduled/
graph_ext.rs

1//! Helper extensions for [`Dfir`].
2
3use core::task;
4use std::borrow::Cow;
5use std::sync::mpsc::SyncSender;
6use std::task::Poll;
7
8use futures::Stream;
9
10use super::SubgraphId;
11use super::context::Context;
12use super::graph::Dfir;
13use super::handoff::{CanReceive, Handoff};
14use super::input::Input;
15use super::port::{RecvCtx, RecvPort, SendCtx, SendPort};
16
17macro_rules! subgraph_ext {
18    (
19        $fn_name:ident,
20        ( $($recv_param:ident : $recv_generic:ident),* ),
21        ( $($send_param:ident : $send_generic:ident),* )
22    ) => {
23        /// Adds a subgraph with specific topology:
24        ///
25        #[doc = concat!("* Inputs: ", $( stringify!( $recv_generic ), ", ", )*)]
26        #[doc = concat!("* Outputs: ", $( stringify!( $send_generic ), ", ", )*)]
27        fn $fn_name <Name, F, $($recv_generic,)* $($send_generic),*> (
28            &mut self,
29            name: Name,
30            $($recv_param : RecvPort< $recv_generic >,)*
31            $($send_param : SendPort< $send_generic >,)* subgraph: F
32        ) -> SubgraphId
33        where
34            Name: Into<Cow<'static, str>>,
35            F: 'static + FnMut(&Context, $(&RecvCtx< $recv_generic >,)* $(&SendCtx< $send_generic >),*),
36            $($recv_generic : 'static + Handoff,)*
37            $($send_generic : 'static + Handoff,)*;
38    };
39    (
40        impl
41        $fn_name:ident,
42        ( $($recv_param:ident : $recv_generic:ident),* ),
43        ( $($send_param:ident : $send_generic:ident),* )
44    ) => {
45        fn $fn_name <Name, F, $($recv_generic,)* $($send_generic),*> (
46            &mut self,
47            name: Name,
48            $($recv_param : RecvPort< $recv_generic >,)*
49            $($send_param : SendPort< $send_generic >,)* subgraph: F
50        ) -> SubgraphId
51        where
52            Name: Into<Cow<'static, str>>,
53            F: 'static + FnMut(&Context, $(&RecvCtx< $recv_generic >,)* $(&SendCtx< $send_generic >),*),
54            $($recv_generic : 'static + Handoff,)*
55            $($send_generic : 'static + Handoff,)*
56        {
57            let mut subgraph = subgraph;
58            self.add_subgraph(
59                name,
60                crate::var_expr!($($recv_param),*),
61                crate::var_expr!($($send_param),*),
62                async move |ctx, crate::var_args!($($recv_param),*), crate::var_args!($($send_param),*)| {
63                    (subgraph)(ctx, $($recv_param,)* $($send_param),*);
64                }
65            )
66        }
67    };
68}
69
70/// Convenience extension methods for the [`Dfir`] struct.
71///
72/// Prefer DFIR syntax to avoid subgraph-per-op strucuring.
73#[sealed::sealed]
74pub trait GraphExt {
75    subgraph_ext!(add_subgraph_sink, (recv_port: R), ());
76    subgraph_ext!(add_subgraph_2sink, (recv_port_1: R1, recv_port_2: R2), ());
77
78    subgraph_ext!(add_subgraph_source, (), (send_port: W));
79
80    subgraph_ext!(add_subgraph_in_out, (recv_port: R), (send_port: W));
81    subgraph_ext!(
82        add_subgraph_in_2out,
83        (recv_port: R),
84        (send_port_1: W1, send_port_2: W2)
85    );
86
87    subgraph_ext!(
88        add_subgraph_2in_out,
89        (recv_port_1: R1, recv_port_2: R2),
90        (send_port: W)
91    );
92    subgraph_ext!(
93        add_subgraph_2in_2out,
94        (recv_port_1: R1, recv_port_2: R2),
95        (send_port_1: W1, send_port_2: W2)
96    );
97
98    /// Adds a channel input which sends to the `send_port`.
99    fn add_channel_input<Name, T, W>(
100        &mut self,
101        name: Name,
102        send_port: SendPort<W>,
103    ) -> Input<T, SyncSender<T>>
104    where
105        Name: Into<Cow<'static, str>>,
106        T: 'static,
107        W: 'static + Handoff + CanReceive<T>;
108
109    /// Adds an "input" operator, returning a handle to insert data into it.
110    /// TODO(justin): make this thing work better
111    fn add_input<Name, T, W>(
112        &mut self,
113        name: Name,
114        send_port: SendPort<W>,
115    ) -> Input<T, super::input::Buffer<T>>
116    where
117        Name: Into<Cow<'static, str>>,
118        T: 'static,
119        W: 'static + Handoff + CanReceive<T>;
120
121    /// Adds a subgraph which pulls from the async stream and sends to the `send_port`.
122    fn add_input_from_stream<Name, T, W, S>(
123        &mut self,
124        name: Name,
125        send_port: SendPort<W>,
126        stream: S,
127    ) where
128        Name: Into<Cow<'static, str>>,
129        S: 'static + Stream<Item = T>,
130        W: 'static + Handoff + CanReceive<T>;
131}
132
133#[sealed::sealed]
134impl GraphExt for Dfir<'_> {
135    subgraph_ext!(impl add_subgraph_sink, (recv_port: R), ());
136    subgraph_ext!(
137        impl add_subgraph_2sink,
138        (recv_port_1: R1, recv_port_2: R2),
139        ()
140    );
141
142    subgraph_ext!(impl add_subgraph_source, (), (send_port: W));
143
144    subgraph_ext!(impl add_subgraph_in_out, (recv_port: R), (send_port: W));
145    subgraph_ext!(
146        impl add_subgraph_in_2out,
147        (recv_port: R),
148        (send_port_1: W1, send_port_2: W2)
149    );
150
151    subgraph_ext!(
152        impl add_subgraph_2in_out,
153        (recv_port_1: R1, recv_port_2: R2),
154        (send_port: W)
155    );
156    subgraph_ext!(
157        impl add_subgraph_2in_2out,
158        (recv_port_1: R1, recv_port_2: R2),
159        (send_port_1: W1, send_port_2: W2)
160    );
161
162    fn add_channel_input<Name, T, W>(
163        &mut self,
164        name: Name,
165        send_port: SendPort<W>,
166    ) -> Input<T, SyncSender<T>>
167    where
168        Name: Into<Cow<'static, str>>,
169        T: 'static,
170        W: 'static + Handoff + CanReceive<T>,
171    {
172        use std::sync::mpsc;
173
174        let (sender, receiver) = mpsc::sync_channel(8000);
175        let sg_id = self.add_subgraph_source::<_, _, W>(name, send_port, move |_ctx, send| {
176            for x in receiver.try_iter() {
177                send.give(x);
178            }
179        });
180        Input::new(self.reactor(), sg_id, sender)
181    }
182
183    fn add_input<Name, T, W>(
184        &mut self,
185        name: Name,
186        send_port: SendPort<W>,
187    ) -> Input<T, super::input::Buffer<T>>
188    where
189        Name: Into<Cow<'static, str>>,
190        T: 'static,
191        W: 'static + Handoff + CanReceive<T>,
192    {
193        let input = super::input::Buffer::default();
194        let inner_input = input.clone();
195        let sg_id = self.add_subgraph_source::<_, _, W>(name, send_port, move |_ctx, send| {
196            for x in (*inner_input.0).borrow_mut().drain(..) {
197                send.give(x);
198            }
199        });
200        Input::new(self.reactor(), sg_id, input)
201    }
202
203    fn add_input_from_stream<Name, T, W, S>(
204        &mut self,
205        name: Name,
206        send_port: SendPort<W>,
207        stream: S,
208    ) where
209        Name: Into<Cow<'static, str>>,
210        S: 'static + Stream<Item = T>,
211        W: 'static + Handoff + CanReceive<T>,
212    {
213        let mut stream = Box::pin(stream);
214        self.add_subgraph_source::<_, _, W>(name, send_port, move |ctx, send| {
215            let waker = ctx.waker();
216            let mut cx = task::Context::from_waker(&waker);
217            while let Poll::Ready(Some(v)) = stream.as_mut().poll_next(&mut cx) {
218                send.give(v);
219            }
220        });
221    }
222}