Copilot commented on code in PR #1299:
URL:
https://github.com/apache/datafusion-python/pull/1299#discussion_r2705860116
##########
python/datafusion/user_defined.py:
##########
@@ -212,23 +237,25 @@ def _function(
name = func.__qualname__.lower()
else:
name = func.__class__.__name__.lower()
+ input_fields = data_types_or_fields_to_field_list(input_fields)
+ return_field = data_type_or_field_to_field(return_field, "value")
return ScalarUDF(
name=name,
func=func,
- input_types=input_types,
- return_type=return_type,
+ input_fields=input_fields,
+ return_field=return_field,
volatility=volatility,
)
def _decorator(
- input_types: list[pa.DataType],
- return_type: _R,
+ input_fields: list[pa.DataType],
Review Comment:
The type hint for `input_fields` parameter in `_decorator` should be
`Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field` to match the
signature of `_function` and maintain consistency with the actual accepted
types.
```suggestion
input_fields: Sequence[pa.DataType | pa.Field] | pa.DataType |
pa.Field,
```
##########
src/udf.rs:
##########
@@ -15,67 +15,140 @@
// specific language governing permissions and limitations
// under the License.
+use std::any::Any;
+use std::hash::{Hash, Hasher};
use std::sync::Arc;
-use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
+use arrow::datatypes::{Field, FieldRef};
+use arrow::pyarrow::ToPyArrow;
+use datafusion::arrow::array::{make_array, ArrayData};
use datafusion::arrow::datatypes::DataType;
-use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow};
+use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType};
use datafusion::error::DataFusionError;
-use datafusion::logical_expr::function::ScalarFunctionImplementation;
-use datafusion::logical_expr::{create_udf, ColumnarValue, ScalarUDF};
+use datafusion::logical_expr::{
+ ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
ScalarUDFImpl, Signature,
+ Volatility,
+};
use datafusion_ffi::udf::{FFI_ScalarUDF, ForeignScalarUDF};
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple};
+use crate::array::PyArrowArrayExportable;
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
use crate::expr::PyExpr;
use crate::utils::{parse_volatility, validate_pycapsule};
-/// Create a Rust callable function from a python function that expects
pyarrow arrays
-fn pyarrow_function_to_rust(
+/// This struct holds the Python written function that is a
+/// ScalarUDF.
+#[derive(Debug)]
+struct PythonFunctionScalarUDF {
+ name: String,
func: Py<PyAny>,
-) -> impl Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
- move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
+ signature: Signature,
+ return_field: FieldRef,
+}
+
+impl PythonFunctionScalarUDF {
+ fn new(
+ name: String,
+ func: Py<PyAny>,
+ input_fields: Vec<Field>,
+ return_field: Field,
+ volatility: Volatility,
+ ) -> Self {
+ let input_types = input_fields.iter().map(|f|
f.data_type().clone()).collect();
+ let signature = Signature::exact(input_types, volatility);
+ Self {
+ name,
+ func,
+ signature,
+ return_field: Arc::new(return_field),
+ }
+ }
+}
+
+impl Eq for PythonFunctionScalarUDF {}
+impl PartialEq for PythonFunctionScalarUDF {
+ fn eq(&self, other: &Self) -> bool {
+ self.name == other.name
+ && self.signature == other.signature
+ && self.return_field == other.return_field
+ && Python::attach(|py|
self.func.bind(py).eq(other.func.bind(py)).unwrap_or(false))
+ }
+}
+
+impl Hash for PythonFunctionScalarUDF {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ self.name.hash(state);
+ self.signature.hash(state);
+ self.return_field.hash(state);
+
+ Python::attach(|py| {
+ let py_hash = self.func.bind(py).hash().unwrap_or(0); // Handle
unhashable objects
+
+ state.write_isize(py_hash);
+ });
+ }
+}
+
+impl ScalarUDFImpl for PythonFunctionScalarUDF {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) ->
datafusion::common::Result<DataType> {
+ unimplemented!()
Review Comment:
The `return_type` method is unimplemented but may be called by DataFusion
internals. Consider either implementing it by returning
`self.return_field.data_type().clone()` or adding a comment explaining why it's
safe to leave unimplemented.
```suggestion
Ok(self.return_field.data_type().clone())
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]