dfir_lang/graph/ops/
py_udf.rs

1use proc_macro2::Literal;
2use quote::quote_spanned;
3
4use super::{
5    OperatorCategory, OperatorConstraints, OperatorWriteOutput, WriteContextArgs,
6    RANGE_0, RANGE_1,
7};
8
9/// > Arguments: First, the source code for a python module, second, the name of a unary function
10/// > defined within the module source code.
11///
12/// **Requires the "python" feature to be enabled.**
13///
14/// An operator which allows you to run a python udf. Input arguments must be a stream of tuples
15/// whose items implement [`IntoPy`](https://docs.rs/pyo3/latest/pyo3/conversion/trait.IntoPy.html).
16/// See the [relevant pyo3 docs here](https://pyo3.rs/latest/conversions/tables#mapping-of-rust-types-to-python-types).
17///
18/// Output items are of type `PyResult<Py<PyAny>>`. Rust native types can be extracted using
19/// `.extract()`, see the [relevant pyo3 docs here](https://pyo3.rs/latest/conversions/traits#extract-and-the-frompyobject-trait)
20/// or the examples below.
21///
22/// ```dfir
23/// use pyo3::prelude::*;
24///
25/// source_iter(0..10)
26///     -> map(|x| (x,))
27///     -> py_udf("
28/// def fib(n):
29///     if n < 2:
30///         return n
31///     else:
32///         return fib(n - 2) + fib(n - 1)
33/// ", "fib")
34///     -> map(|x: PyResult<Py<PyAny>>| Python::with_gil(|py| {
35///         usize::extract(x.unwrap().as_ref(py)).unwrap()
36///     }))
37///     -> assert_eq([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
38/// ```
39///
40/// ```dfir
41/// use pyo3::prelude::*;
42///
43/// source_iter([(5,1)])
44///     -> py_udf("
45/// def add(a, b):
46///     return a + b
47/// ", "add")
48///     -> map(|x: PyResult<Py<PyAny>>| Python::with_gil(|py| {
49///         usize::extract(x.unwrap().as_ref(py)).unwrap()
50///     }))
51///     -> assert_eq([6]);
52/// ```
53pub const PY_UDF: OperatorConstraints = OperatorConstraints {
54    name: "py_udf",
55    categories: &[OperatorCategory::Map],
56    hard_range_inn: RANGE_1,
57    soft_range_inn: RANGE_1,
58    hard_range_out: RANGE_1,
59    soft_range_out: RANGE_1,
60    num_args: 2,
61    persistence_args: RANGE_0,
62    type_args: RANGE_0,
63    is_external_input: false,
64    has_singleton_output: false,
65    flo_type: None,
66    ports_inn: None,
67    ports_out: None,
68    input_delaytype_fn: |_| None,
69    write_fn: |wc @ &WriteContextArgs {
70                   root,
71                   op_span,
72                   context,
73                   df_ident,
74                   ident,
75                   inputs,
76                   outputs,
77                   is_pull,
78                   op_name,
79                   arguments,
80                   ..
81               },
82               _| {
83        let py_src = &arguments[0];
84        let py_func_name = &arguments[1];
85
86        let py_func_ident = wc.make_ident("py_func");
87
88        let err_lit = Literal::string(&format!(
89            "`python` feature must be enabled to use `{}`",
90            op_name
91        ));
92
93        let write_prologue = quote_spanned! {op_span=>
94            #root::__python_feature_gate! {
95                {
96                    let #py_func_ident = {
97                        #root::pyo3::prepare_freethreaded_python();
98                        let func = #root::pyo3::Python::with_gil::<_, #root::pyo3::PyResult<#root::pyo3::Py<#root::pyo3::PyAny>>>(|py| {
99                            Ok(#root::pyo3::types::PyModule::from_code(
100                                py,
101                                #py_src,
102                                "_filename",
103                                "_modulename",
104                            )?
105                            .getattr(#py_func_name)?
106                            .into())
107                        }).expect("Failed to compile python.");
108                        #df_ident.add_state(func)
109                    };
110                },
111                {
112                    ::std::compile_error!(#err_lit);
113                }
114            }
115        };
116        let closure = quote_spanned! {op_span=>
117            |x| {
118                #root::__python_feature_gate! {
119                    {
120                        // TODO(mingwei): maybe this can be outside the closure?
121                        let py_func = unsafe {
122                            // SAFETY: handle from `#df_ident.add_state(..)`.
123                            #context.state_ref_unchecked(#py_func_ident)
124                        };
125                        #root::pyo3::Python::with_gil(|py| py_func.call1(py, x))
126                    },
127                    {
128                        panic!()
129                    }
130                }
131            }
132        };
133        let write_iterator = if is_pull {
134            let input = &inputs[0];
135            quote_spanned! {op_span=>
136                let #ident = #input.map(#closure);
137            }
138        } else {
139            let output = &outputs[0];
140            quote_spanned! {op_span=>
141                let #ident = #root::pusherator::map::Map::new(#closure, #output);
142            }
143        };
144        Ok(OperatorWriteOutput {
145            write_prologue,
146            write_iterator,
147            ..Default::default()
148        })
149    },
150};