1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use proc_macro2::Literal;
use quote::quote_spanned;

use super::{
    OperatorCategory, OperatorConstraints, OperatorWriteOutput, WriteContextArgs,
    RANGE_0, RANGE_1,
};

/// > Arguments: First, the source code for a python module, second, the name of a unary function
/// > defined within the module source code.
///
/// **Requires the "python" feature to be enabled.**
///
/// An operator which allows you to run a python udf. Input arguments must be a stream of tuples
/// whose items implement [`IntoPy`](https://docs.rs/pyo3/latest/pyo3/conversion/trait.IntoPy.html).
/// See the [relevant pyo3 docs here](https://pyo3.rs/latest/conversions/tables#mapping-of-rust-types-to-python-types).
///
/// Output items are of type `PyResult<Py<PyAny>>`. Rust native types can be extracted using
/// `.extract()`, see the [relevant pyo3 docs here](https://pyo3.rs/latest/conversions/traits#extract-and-the-frompyobject-trait)
/// or the examples below.
///
/// ```dfir
/// use pyo3::prelude::*;
///
/// source_iter(0..10)
///     -> map(|x| (x,))
///     -> py_udf("
/// def fib(n):
///     if n < 2:
///         return n
///     else:
///         return fib(n - 2) + fib(n - 1)
/// ", "fib")
///     -> map(|x: PyResult<Py<PyAny>>| Python::with_gil(|py| {
///         usize::extract(x.unwrap().as_ref(py)).unwrap()
///     }))
///     -> assert_eq([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
/// ```
///
/// ```dfir
/// use pyo3::prelude::*;
///
/// source_iter([(5,1)])
///     -> py_udf("
/// def add(a, b):
///     return a + b
/// ", "add")
///     -> map(|x: PyResult<Py<PyAny>>| Python::with_gil(|py| {
///         usize::extract(x.unwrap().as_ref(py)).unwrap()
///     }))
///     -> assert_eq([6]);
/// ```
pub const PY_UDF: OperatorConstraints = OperatorConstraints {
    name: "py_udf",
    categories: &[OperatorCategory::Map],
    hard_range_inn: RANGE_1,
    soft_range_inn: RANGE_1,
    hard_range_out: RANGE_1,
    soft_range_out: RANGE_1,
    num_args: 2,
    persistence_args: RANGE_0,
    type_args: RANGE_0,
    is_external_input: false,
    has_singleton_output: false,
    flo_type: None,
    ports_inn: None,
    ports_out: None,
    input_delaytype_fn: |_| None,
    write_fn: |wc @ &WriteContextArgs {
                   root,
                   op_span,
                   context,
                   hydroflow,
                   ident,
                   inputs,
                   outputs,
                   is_pull,
                   op_name,
                   arguments,
                   ..
               },
               _| {
        let py_src = &arguments[0];
        let py_func_name = &arguments[1];

        let py_func_ident = wc.make_ident("py_func");

        let err_lit = Literal::string(&format!(
            "Hydroflow 'python' feature must be enabled to use `{}`",
            op_name
        ));

        let write_prologue = quote_spanned! {op_span=>
            #root::__python_feature_gate! {
                {
                    let #py_func_ident = {
                        #root::pyo3::prepare_freethreaded_python();
                        let func = #root::pyo3::Python::with_gil::<_, #root::pyo3::PyResult<#root::pyo3::Py<#root::pyo3::PyAny>>>(|py| {
                            Ok(#root::pyo3::types::PyModule::from_code(
                                py,
                                #py_src,
                                "_filename",
                                "_modulename",
                            )?
                            .getattr(#py_func_name)?
                            .into())
                        }).expect("Failed to compile python.");
                        #hydroflow.add_state(func)
                    };
                },
                {
                    ::std::compile_error!(#err_lit);
                }
            }
        };
        let closure = quote_spanned! {op_span=>
            |x| {
                #root::__python_feature_gate! {
                    {
                        // TODO(mingwei): maybe this can be outside the closure?
                        let py_func = #context.state_ref(#py_func_ident);
                        #root::pyo3::Python::with_gil(|py| py_func.call1(py, x))
                    },
                    {
                        panic!()
                    }
                }
            }
        };
        let write_iterator = if is_pull {
            let input = &inputs[0];
            quote_spanned! {op_span=>
                let #ident = #input.map(#closure);
            }
        } else {
            let output = &outputs[0];
            quote_spanned! {op_span=>
                let #ident = #root::pusherator::map::Map::new(#closure, #output);
            }
        };
        Ok(OperatorWriteOutput {
            write_prologue,
            write_iterator,
            ..Default::default()
        })
    },
};