dfir_rs/scheduled/
graph_ext.rs

1//! Helper extensions for [`Dfir`].
2
3use core::task;
4use std::borrow::Cow;
5use std::sync::mpsc::SyncSender;
6use std::task::Poll;
7
8use futures::Stream;
9
10use super::SubgraphId;
11use super::context::Context;
12use super::graph::Dfir;
13use super::handoff::{CanReceive, Handoff};
14use super::input::Input;
15use super::port::{RecvCtx, RecvPort, SendCtx, SendPort};
16
17macro_rules! subgraph_ext {
18    (
19        $fn_name:ident,
20        ( $($recv_param:ident : $recv_generic:ident),* ),
21        ( $($send_param:ident : $send_generic:ident),* )
22    ) => {
23        /// Adds a subgraph with specific topology:
24        ///
25        #[doc = concat!("* Inputs: ", $( stringify!( $recv_generic ), ", ", )*)]
26        #[doc = concat!("* Outputs: ", $( stringify!( $send_generic ), ", ", )*)]
27        fn $fn_name <Name, F, $($recv_generic,)* $($send_generic),*> (
28            &mut self,
29            name: Name,
30            $($recv_param : RecvPort< $recv_generic >,)*
31            $($send_param : SendPort< $send_generic >,)* subgraph: F
32        ) -> SubgraphId
33        where
34            Name: Into<Cow<'static, str>>,
35            F: 'static + FnMut(&Context, $(&RecvCtx< $recv_generic >,)* $(&SendCtx< $send_generic >),*),
36            $($recv_generic : 'static + Handoff,)*
37            $($send_generic : 'static + Handoff,)*;
38    };
39    (
40        impl
41        $fn_name:ident,
42        ( $($recv_param:ident : $recv_generic:ident),* ),
43        ( $($send_param:ident : $send_generic:ident),* )
44    ) => {
45        fn $fn_name <Name, F, $($recv_generic,)* $($send_generic),*> (
46            &mut self,
47            name: Name,
48            $($recv_param : RecvPort< $recv_generic >,)*
49            $($send_param : SendPort< $send_generic >,)* subgraph: F
50        ) -> SubgraphId
51        where
52            Name: Into<Cow<'static, str>>,
53            F: 'static + FnMut(&Context, $(&RecvCtx< $recv_generic >,)* $(&SendCtx< $send_generic >),*),
54            $($recv_generic : 'static + Handoff,)*
55            $($send_generic : 'static + Handoff,)*
56        {
57            let mut subgraph = subgraph;
58            self.add_subgraph(
59                name,
60                crate::var_expr!($($recv_param),*),
61                crate::var_expr!($($send_param),*),
62                move |ctx, crate::var_args!($($recv_param),*), crate::var_args!($($send_param),*)| {
63                    (subgraph)(ctx, $($recv_param,)* $($send_param),*);
64                    ::std::future::ready(())
65                }
66            )
67        }
68    };
69}
70
71/// Convenience extension methods for the [`Dfir`] struct.
72///
73/// Prefer DFIR syntax to avoid subgraph-per-op strucuring.
74#[sealed::sealed]
75pub trait GraphExt {
76    subgraph_ext!(add_subgraph_sink, (recv_port: R), ());
77    subgraph_ext!(add_subgraph_2sink, (recv_port_1: R1, recv_port_2: R2), ());
78
79    subgraph_ext!(add_subgraph_source, (), (send_port: W));
80
81    subgraph_ext!(add_subgraph_in_out, (recv_port: R), (send_port: W));
82    subgraph_ext!(
83        add_subgraph_in_2out,
84        (recv_port: R),
85        (send_port_1: W1, send_port_2: W2)
86    );
87
88    subgraph_ext!(
89        add_subgraph_2in_out,
90        (recv_port_1: R1, recv_port_2: R2),
91        (send_port: W)
92    );
93    subgraph_ext!(
94        add_subgraph_2in_2out,
95        (recv_port_1: R1, recv_port_2: R2),
96        (send_port_1: W1, send_port_2: W2)
97    );
98
99    /// Adds a channel input which sends to the `send_port`.
100    fn add_channel_input<Name, T, W>(
101        &mut self,
102        name: Name,
103        send_port: SendPort<W>,
104    ) -> Input<T, SyncSender<T>>
105    where
106        Name: Into<Cow<'static, str>>,
107        T: 'static,
108        W: 'static + Handoff + CanReceive<T>;
109
110    /// Adds an "input" operator, returning a handle to insert data into it.
111    /// TODO(justin): make this thing work better
112    fn add_input<Name, T, W>(
113        &mut self,
114        name: Name,
115        send_port: SendPort<W>,
116    ) -> Input<T, super::input::Buffer<T>>
117    where
118        Name: Into<Cow<'static, str>>,
119        T: 'static,
120        W: 'static + Handoff + CanReceive<T>;
121
122    /// Adds a subgraph which pulls from the async stream and sends to the `send_port`.
123    fn add_input_from_stream<Name, T, W, S>(
124        &mut self,
125        name: Name,
126        send_port: SendPort<W>,
127        stream: S,
128    ) where
129        Name: Into<Cow<'static, str>>,
130        S: 'static + Stream<Item = T>,
131        W: 'static + Handoff + CanReceive<T>;
132}
133
134#[sealed::sealed]
135impl GraphExt for Dfir<'_> {
136    subgraph_ext!(impl add_subgraph_sink, (recv_port: R), ());
137    subgraph_ext!(
138        impl add_subgraph_2sink,
139        (recv_port_1: R1, recv_port_2: R2),
140        ()
141    );
142
143    subgraph_ext!(impl add_subgraph_source, (), (send_port: W));
144
145    subgraph_ext!(impl add_subgraph_in_out, (recv_port: R), (send_port: W));
146    subgraph_ext!(
147        impl add_subgraph_in_2out,
148        (recv_port: R),
149        (send_port_1: W1, send_port_2: W2)
150    );
151
152    subgraph_ext!(
153        impl add_subgraph_2in_out,
154        (recv_port_1: R1, recv_port_2: R2),
155        (send_port: W)
156    );
157    subgraph_ext!(
158        impl add_subgraph_2in_2out,
159        (recv_port_1: R1, recv_port_2: R2),
160        (send_port_1: W1, send_port_2: W2)
161    );
162
163    fn add_channel_input<Name, T, W>(
164        &mut self,
165        name: Name,
166        send_port: SendPort<W>,
167    ) -> Input<T, SyncSender<T>>
168    where
169        Name: Into<Cow<'static, str>>,
170        T: 'static,
171        W: 'static + Handoff + CanReceive<T>,
172    {
173        use std::sync::mpsc;
174
175        let (sender, receiver) = mpsc::sync_channel(8000);
176        let sg_id = self.add_subgraph_source::<_, _, W>(name, send_port, move |_ctx, send| {
177            for x in receiver.try_iter() {
178                send.give(x);
179            }
180        });
181        Input::new(self.reactor(), sg_id, sender)
182    }
183
184    fn add_input<Name, T, W>(
185        &mut self,
186        name: Name,
187        send_port: SendPort<W>,
188    ) -> Input<T, super::input::Buffer<T>>
189    where
190        Name: Into<Cow<'static, str>>,
191        T: 'static,
192        W: 'static + Handoff + CanReceive<T>,
193    {
194        let input = super::input::Buffer::default();
195        let inner_input = input.clone();
196        let sg_id = self.add_subgraph_source::<_, _, W>(name, send_port, move |_ctx, send| {
197            for x in (*inner_input.0).borrow_mut().drain(..) {
198                send.give(x);
199            }
200        });
201        Input::new(self.reactor(), sg_id, input)
202    }
203
204    fn add_input_from_stream<Name, T, W, S>(
205        &mut self,
206        name: Name,
207        send_port: SendPort<W>,
208        stream: S,
209    ) where
210        Name: Into<Cow<'static, str>>,
211        S: 'static + Stream<Item = T>,
212        W: 'static + Handoff + CanReceive<T>,
213    {
214        let mut stream = Box::pin(stream);
215        self.add_subgraph_source::<_, _, W>(name, send_port, move |ctx, send| {
216            let waker = ctx.waker();
217            let mut cx = task::Context::from_waker(&waker);
218            while let Poll::Ready(Some(v)) = stream.as_mut().poll_next(&mut cx) {
219                send.give(v);
220            }
221        });
222    }
223}