1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
use proc_macro2::Literal;
use quote::quote_spanned;
use super::{
OperatorCategory, OperatorConstraints, OperatorWriteOutput, WriteContextArgs,
RANGE_0, RANGE_1,
};
/// > Arguments: First, the source code for a python module, second, the name of a unary function
/// > defined within the module source code.
///
/// **Requires the "python" feature to be enabled.**
///
/// An operator which allows you to run a python udf. Input arguments must be a stream of tuples
/// whose items implement [`IntoPy`](https://docs.rs/pyo3/latest/pyo3/conversion/trait.IntoPy.html).
/// See the [relevant pyo3 docs here](https://pyo3.rs/latest/conversions/tables#mapping-of-rust-types-to-python-types).
///
/// Output items are of type `PyResult<Py<PyAny>>`. Rust native types can be extracted using
/// `.extract()`, see the [relevant pyo3 docs here](https://pyo3.rs/latest/conversions/traits#extract-and-the-frompyobject-trait)
/// or the examples below.
///
/// ```dfir
/// use pyo3::prelude::*;
///
/// source_iter(0..10)
/// -> map(|x| (x,))
/// -> py_udf("
/// def fib(n):
/// if n < 2:
/// return n
/// else:
/// return fib(n - 2) + fib(n - 1)
/// ", "fib")
/// -> map(|x: PyResult<Py<PyAny>>| Python::with_gil(|py| {
/// usize::extract(x.unwrap().as_ref(py)).unwrap()
/// }))
/// -> assert_eq([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
/// ```
///
/// ```dfir
/// use pyo3::prelude::*;
///
/// source_iter([(5,1)])
/// -> py_udf("
/// def add(a, b):
/// return a + b
/// ", "add")
/// -> map(|x: PyResult<Py<PyAny>>| Python::with_gil(|py| {
/// usize::extract(x.unwrap().as_ref(py)).unwrap()
/// }))
/// -> assert_eq([6]);
/// ```
pub const PY_UDF: OperatorConstraints = OperatorConstraints {
name: "py_udf",
categories: &[OperatorCategory::Map],
hard_range_inn: RANGE_1,
soft_range_inn: RANGE_1,
hard_range_out: RANGE_1,
soft_range_out: RANGE_1,
num_args: 2,
persistence_args: RANGE_0,
type_args: RANGE_0,
is_external_input: false,
has_singleton_output: false,
flo_type: None,
ports_inn: None,
ports_out: None,
input_delaytype_fn: |_| None,
write_fn: |wc @ &WriteContextArgs {
root,
op_span,
context,
hydroflow,
ident,
inputs,
outputs,
is_pull,
op_name,
arguments,
..
},
_| {
let py_src = &arguments[0];
let py_func_name = &arguments[1];
let py_func_ident = wc.make_ident("py_func");
let err_lit = Literal::string(&format!(
"Hydroflow 'python' feature must be enabled to use `{}`",
op_name
));
let write_prologue = quote_spanned! {op_span=>
#root::__python_feature_gate! {
{
let #py_func_ident = {
#root::pyo3::prepare_freethreaded_python();
let func = #root::pyo3::Python::with_gil::<_, #root::pyo3::PyResult<#root::pyo3::Py<#root::pyo3::PyAny>>>(|py| {
Ok(#root::pyo3::types::PyModule::from_code(
py,
#py_src,
"_filename",
"_modulename",
)?
.getattr(#py_func_name)?
.into())
}).expect("Failed to compile python.");
#hydroflow.add_state(func)
};
},
{
::std::compile_error!(#err_lit);
}
}
};
let closure = quote_spanned! {op_span=>
|x| {
#root::__python_feature_gate! {
{
// TODO(mingwei): maybe this can be outside the closure?
let py_func = #context.state_ref(#py_func_ident);
#root::pyo3::Python::with_gil(|py| py_func.call1(py, x))
},
{
panic!()
}
}
}
};
let write_iterator = if is_pull {
let input = &inputs[0];
quote_spanned! {op_span=>
let #ident = #input.map(#closure);
}
} else {
let output = &outputs[0];
quote_spanned! {op_span=>
let #ident = #root::pusherator::map::Map::new(#closure, #output);
}
};
Ok(OperatorWriteOutput {
write_prologue,
write_iterator,
..Default::default()
})
},
};