dfir_lang/graph/ops/py_udf.rs
1use proc_macro2::Literal;
2use quote::quote_spanned;
3
4use super::{
5 OperatorCategory, OperatorConstraints, OperatorWriteOutput, WriteContextArgs,
6 RANGE_0, RANGE_1,
7};
8
9/// > Arguments: First, the source code for a python module, second, the name of a unary function
10/// > defined within the module source code.
11///
12/// **Requires the "python" feature to be enabled.**
13///
14/// An operator which allows you to run a python udf. Input arguments must be a stream of tuples
15/// whose items implement [`IntoPy`](https://docs.rs/pyo3/latest/pyo3/conversion/trait.IntoPy.html).
16/// See the [relevant pyo3 docs here](https://pyo3.rs/latest/conversions/tables#mapping-of-rust-types-to-python-types).
17///
18/// Output items are of type `PyResult<Py<PyAny>>`. Rust native types can be extracted using
19/// `.extract()`, see the [relevant pyo3 docs here](https://pyo3.rs/latest/conversions/traits#extract-and-the-frompyobject-trait)
20/// or the examples below.
21///
22/// ```dfir
23/// use pyo3::prelude::*;
24///
25/// source_iter(0..10)
26/// -> map(|x| (x,))
27/// -> py_udf("
28/// def fib(n):
29/// if n < 2:
30/// return n
31/// else:
32/// return fib(n - 2) + fib(n - 1)
33/// ", "fib")
34/// -> map(|x: PyResult<Py<PyAny>>| Python::with_gil(|py| {
35/// usize::extract(x.unwrap().as_ref(py)).unwrap()
36/// }))
37/// -> assert_eq([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
38/// ```
39///
40/// ```dfir
41/// use pyo3::prelude::*;
42///
43/// source_iter([(5,1)])
44/// -> py_udf("
45/// def add(a, b):
46/// return a + b
47/// ", "add")
48/// -> map(|x: PyResult<Py<PyAny>>| Python::with_gil(|py| {
49/// usize::extract(x.unwrap().as_ref(py)).unwrap()
50/// }))
51/// -> assert_eq([6]);
52/// ```
53pub const PY_UDF: OperatorConstraints = OperatorConstraints {
54 name: "py_udf",
55 categories: &[OperatorCategory::Map],
56 hard_range_inn: RANGE_1,
57 soft_range_inn: RANGE_1,
58 hard_range_out: RANGE_1,
59 soft_range_out: RANGE_1,
60 num_args: 2,
61 persistence_args: RANGE_0,
62 type_args: RANGE_0,
63 is_external_input: false,
64 has_singleton_output: false,
65 flo_type: None,
66 ports_inn: None,
67 ports_out: None,
68 input_delaytype_fn: |_| None,
69 write_fn: |wc @ &WriteContextArgs {
70 root,
71 op_span,
72 context,
73 df_ident,
74 ident,
75 inputs,
76 outputs,
77 is_pull,
78 op_name,
79 arguments,
80 ..
81 },
82 _| {
83 let py_src = &arguments[0];
84 let py_func_name = &arguments[1];
85
86 let py_func_ident = wc.make_ident("py_func");
87
88 let err_lit = Literal::string(&format!(
89 "`python` feature must be enabled to use `{}`",
90 op_name
91 ));
92
93 let write_prologue = quote_spanned! {op_span=>
94 #root::__python_feature_gate! {
95 {
96 let #py_func_ident = {
97 #root::pyo3::prepare_freethreaded_python();
98 let func = #root::pyo3::Python::with_gil::<_, #root::pyo3::PyResult<#root::pyo3::Py<#root::pyo3::PyAny>>>(|py| {
99 Ok(#root::pyo3::types::PyModule::from_code(
100 py,
101 #py_src,
102 "_filename",
103 "_modulename",
104 )?
105 .getattr(#py_func_name)?
106 .into())
107 }).expect("Failed to compile python.");
108 #df_ident.add_state(func)
109 };
110 },
111 {
112 ::std::compile_error!(#err_lit);
113 }
114 }
115 };
116 let closure = quote_spanned! {op_span=>
117 |x| {
118 #root::__python_feature_gate! {
119 {
120 // TODO(mingwei): maybe this can be outside the closure?
121 let py_func = unsafe {
122 // SAFETY: handle from `#df_ident.add_state(..)`.
123 #context.state_ref_unchecked(#py_func_ident)
124 };
125 #root::pyo3::Python::with_gil(|py| py_func.call1(py, x))
126 },
127 {
128 panic!()
129 }
130 }
131 }
132 };
133 let write_iterator = if is_pull {
134 let input = &inputs[0];
135 quote_spanned! {op_span=>
136 let #ident = #input.map(#closure);
137 }
138 } else {
139 let output = &outputs[0];
140 quote_spanned! {op_span=>
141 let #ident = #root::pusherator::map::Map::new(#closure, #output);
142 }
143 };
144 Ok(OperatorWriteOutput {
145 write_prologue,
146 write_iterator,
147 ..Default::default()
148 })
149 },
150};