dfir_lang/graph/ops/
resolve_futures.rs

1use quote::quote_spanned;
2use syn::Ident;
3
4use super::{
5    OperatorCategory, OperatorConstraints, OperatorWriteOutput, RANGE_0, RANGE_1, WriteContextArgs,
6};
7
8/// Given an incoming stream of `F: Future`, sends those futures to the executor being used
9/// by the DFIR runtime and emits elements whenever a future is completed. The output order
10/// is based on when futures complete, and may be different than the input order.
11pub const RESOLVE_FUTURES: OperatorConstraints = OperatorConstraints {
12    name: "resolve_futures",
13    categories: &[OperatorCategory::Map],
14    hard_range_inn: RANGE_1,
15    soft_range_inn: RANGE_1,
16    hard_range_out: RANGE_1,
17    soft_range_out: RANGE_1,
18    num_args: 0,
19    persistence_args: RANGE_0,
20    type_args: RANGE_0,
21    is_external_input: false,
22    has_singleton_output: false,
23    flo_type: None,
24    ports_inn: None,
25    ports_out: None,
26    input_delaytype_fn: |_| None,
27    write_fn: move |wc, _| {
28        resolve_futures_writer(
29            Ident::new("FuturesUnordered", wc.op_span),
30            Ident::new("push", wc.op_span),
31            wc,
32        )
33    },
34};
35
36pub fn resolve_futures_writer(
37    future_type: Ident,
38    push_fn: Ident,
39    wc @ &WriteContextArgs {
40        root,
41        context,
42        op_span,
43        ident,
44        inputs,
45        outputs,
46        is_pull,
47        work_fn,
48        ..
49    }: &WriteContextArgs,
50) -> Result<OperatorWriteOutput, ()> {
51    let futures_ident = wc.make_ident("futures");
52
53    let write_prologue = quote_spanned! {op_span=>
54        let #futures_ident = df.add_state(
55            ::std::cell::RefCell::new(
56                #root::futures::stream::#future_type::new()
57            )
58        );
59    };
60
61    let write_iterator = if is_pull {
62        let input = &inputs[0];
63        quote_spanned! {op_span=>
64            let #ident = {
65                let mut out = ::std::vec::Vec::new();
66
67                let mut state = unsafe {
68                    // SAFETY: handle from `#df_ident.add_state(..)`.
69                    #context.state_ref_unchecked(#futures_ident)
70                        .borrow_mut()
71                };
72
73                #work_fn(|| {
74                    #input
75                        .for_each(|fut| {
76                            let mut fut = ::std::boxed::Box::pin(fut);
77                            if let #root::futures::task::Poll::Ready(val) = #root::futures::Future::poll(::std::pin::Pin::as_mut(&mut fut), &mut ::std::task::Context::from_waker(&#context.waker())) {
78                                out.push(val);
79                            } else {
80                                state.#push_fn(fut);
81                            }
82                        });
83
84                    while let #root::futures::task::Poll::Ready(Some(val)) =
85                        #root::futures::Stream::poll_next(::std::pin::Pin::new(&mut *state), &mut ::std::task::Context::from_waker(&#context.waker()))
86                    {
87                        out.push(val);
88                    }
89                });
90
91                ::std::iter::IntoIterator::into_iter(out)
92            };
93        }
94    } else {
95        let output = &outputs[0];
96        quote_spanned! {op_span=>
97            let #ident = {
98                let mut out = #output;
99                let mut state = unsafe {
100                    // SAFETY: handle from `#df_ident.add_state(..)`.
101                    #context.state_ref_unchecked(#futures_ident).borrow_mut()
102                };
103
104                #work_fn(|| {
105                    while let #root::futures::task::Poll::Ready(Some(val)) =
106                        #root::futures::Stream::poll_next(::std::pin::Pin::new(&mut *state), &mut ::std::task::Context::from_waker(&#context.waker()))
107                    {
108                        #root::pusherator::Pusherator::give(&mut out, val)
109                    }
110                });
111
112                let consumer = #root::pusherator::for_each::ForEach::new(|fut| {
113                    #work_fn(|| {
114                        let fut = ::std::boxed::Box::pin(fut);
115                        unsafe {
116                            // SAFETY: handle from `#df_ident.add_state(..)`.
117                            #context.state_ref_unchecked(#futures_ident).borrow_mut()
118                        }.#push_fn(fut);
119                    });
120                    #context.schedule_subgraph(#context.current_subgraph(), true);
121                });
122
123                consumer
124            };
125        }
126    };
127
128    Ok(OperatorWriteOutput {
129        write_prologue,
130        write_iterator,
131        ..Default::default()
132    })
133}