This is an automated email from the ASF dual-hosted git repository.

icexelloss pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new b1e85a6d0c GH-36672: [Python][C++] Add support for vector function UDF 
(#36673)
b1e85a6d0c is described below

commit b1e85a6d0cd7a57f93b97d74bb13e89517e3d92e
Author: Li Jin <[email protected]>
AuthorDate: Tue Aug 8 18:25:06 2023 -0400

    GH-36672: [Python][C++] Add support for vector function UDF (#36673)
    
    
    
    ### Rationale for this change
    In Arrow compute, there are four main types of functions: Scalar, Vector, 
ScalarAggregate and HashAggregate.
    
    Some of the previous work added support for Scalar, 
ScalarAggregate(https://github.com/apache/arrow/issues/35515) and 
HashAggregate(https://github.com/apache/arrow/issues/36252). I think it makes 
sense to add support for vector function as well to complete all 
non-decomposable UDF kernel support.
    
    Internally, we plan to extend Acero to implement a "SegmentVectorNode" 
which would use this API to invoke vector on a segment by segment basis, which 
will allow to use constant memory to compute things like "rank the value across 
all rows per segment using a python UDF".
    
    ### What changes are included in this PR?
    The change includes is very similar to the support for aggregate function, 
which includes code to register the vector UDF, and a kernel that invokes the 
vector UDF on given inputs.
    
    ### Are these changes tested?
    Yes. Added new test.
    
    ### Are there any user-facing changes?
    Yes. This adds an user-facing API to register the vector function.
    
    * Closes: #36672
    
    Authored-by: Li Jin <[email protected]>
    Signed-off-by: Li Jin <[email protected]>
---
 python/pyarrow/_compute.pyx            | 84 +++++++++++++++++++++++++++++++++-
 python/pyarrow/compute.py              |  1 +
 python/pyarrow/includes/libarrow.pxd   |  4 ++
 python/pyarrow/src/arrow/python/udf.cc | 41 ++++++++++-------
 python/pyarrow/src/arrow/python/udf.h  |  5 ++
 python/pyarrow/tests/test_udf.py       | 70 ++++++++++++++++++++++++++++
 6 files changed, 188 insertions(+), 17 deletions(-)

diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index ac7efeff41..bc3b9e8c55 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -1964,7 +1964,7 @@ class CumulativeOptions(_CumulativeOptions):
     Parameters
     ----------
     start : Scalar, default None
-        Starting value for the cumulative operation. If none is given, 
+        Starting value for the cumulative operation. If none is given,
         a default value depending on the operation and input type is used.
     skip_nulls : bool, default False
         When false, the first encountered null is propagated.
@@ -2707,6 +2707,11 @@ cdef get_register_aggregate_function():
     reg.register_func = RegisterAggregateFunction
     return reg
 
+cdef get_register_vector_function():
+    cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf)
+    reg.register_func = RegisterVectorFunction
+    return reg
+
 
 def register_scalar_function(func, function_name, function_doc, in_types, 
out_type,
                              func_registry=None):
@@ -2789,6 +2794,83 @@ def register_scalar_function(func, function_name, 
function_doc, in_types, out_ty
                                            out_type, func_registry)
 
 
+def register_vector_function(func, function_name, function_doc, in_types, 
out_type,
+                             func_registry=None):
+    """
+    Register a user-defined vector function.
+
+    This API is EXPERIMENTAL.
+
+    A vector function is a function that executes vector
+    operations on arrays. Vector function is often used
+    when compute doesn't fit other more specific types of
+    functions (e.g., scalar and aggregate).
+
+    Parameters
+    ----------
+    func : callable
+        A callable implementing the user-defined function.
+        The first argument is the context argument of type
+        UdfContext.
+        Then, it must take arguments equal to the number of
+        in_types defined. It must return an Array or Scalar
+        matching the out_type. It must return a Scalar if
+        all arguments are scalar, else it must return an Array.
+
+        To define a varargs function, pass a callable that takes
+        *args. The last in_type will be the type of all varargs
+        arguments.
+    function_name : str
+        Name of the function. There should only be one function
+        registered with this name in the function registry.
+    function_doc : dict
+        A dictionary object with keys "summary" (str),
+        and "description" (str).
+    in_types : Dict[str, DataType]
+        A dictionary mapping function argument names to
+        their respective DataType.
+        The argument names will be used to generate
+        documentation for the function. The number of
+        arguments specified here determines the function
+        arity.
+    out_type : DataType
+        Output type of the function.
+    func_registry : FunctionRegistry
+        Optional function registry to use instead of the default global one.
+
+    Examples
+    --------
+    >>> import pyarrow as pa
+    >>> import pyarrow.compute as pc
+    >>>
+    >>> func_doc = {}
+    >>> func_doc["summary"] = "percent rank"
+    >>> func_doc["description"] = "compute percent rank"
+    >>>
+    >>> def list_flatten_udf(ctx, x):
+    ...     return pc.list_flatten(x)
+    >>>
+    >>> func_name = "list_flatten_udf"
+    >>> in_types = {"array": pa.list_(pa.int64())}
+    >>> out_type = pa.int64()
+    >>> pc.register_vector_function(list_flatten_udf, func_name, func_doc,
+    ...                   in_types, out_type)
+    >>>
+    >>> answer = pc.call_function(func_name, [pa.array([[1, 2], [3, 4]])])
+    >>> answer
+    <pyarrow.lib.Int64Array object at ...>
+    [
+      1,
+      2,
+      3,
+      4
+    ]
+    """
+    return _register_user_defined_function(get_register_vector_function(),
+                                           func, function_name, function_doc, 
in_types,
+                                           out_type, func_registry)
+
+
 def register_aggregate_function(func, function_name, function_doc, in_types, 
out_type,
                                 func_registry=None):
     """
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index 0fefa18dd1..7b8983cbb9 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -87,6 +87,7 @@ from pyarrow._compute import (  # noqa
     register_scalar_function,
     register_tabular_function,
     register_aggregate_function,
+    register_vector_function,
     UdfContext,
     # Expressions
     Expression,
diff --git a/python/pyarrow/includes/libarrow.pxd 
b/python/pyarrow/includes/libarrow.pxd
index da46cdcb75..f4d6541fa7 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -2815,5 +2815,9 @@ cdef extern from "arrow/python/udf.h" namespace 
"arrow::py" nogil:
                                       function[CallbackUdf] wrapper, const 
CUdfOptions& options,
                                       CFunctionRegistry* registry)
 
+    CStatus RegisterVectorFunction(PyObject* function,
+                                   function[CallbackUdf] wrapper, const 
CUdfOptions& options,
+                                   CFunctionRegistry* registry)
+
     CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction(
         const c_string& func_name, const vector[CDatum]& args, 
CFunctionRegistry* registry)
diff --git a/python/pyarrow/src/arrow/python/udf.cc 
b/python/pyarrow/src/arrow/python/udf.cc
index 435c89f596..f7761a9277 100644
--- a/python/pyarrow/src/arrow/python/udf.cc
+++ b/python/pyarrow/src/arrow/python/udf.cc
@@ -292,14 +292,14 @@ struct PythonUdfHashAggregatorImpl : public 
HashUdfAggregator {
     return out;
   }
 
-  Status Resize(KernelContext* ctx, int64_t new_num_groups) {
+  Status Resize(KernelContext* ctx, int64_t new_num_groups) override {
     // We only need to change num_groups in resize
     // similar to other hash aggregate kernels
     num_groups = new_num_groups;
     return Status::OK();
   }
 
-  Status Consume(KernelContext* ctx, const ExecSpan& batch) {
+  Status Consume(KernelContext* ctx, const ExecSpan& batch) override {
     ARROW_ASSIGN_OR_RAISE(
         std::shared_ptr<RecordBatch> rb,
         batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool()));
@@ -316,7 +316,7 @@ struct PythonUdfHashAggregatorImpl : public 
HashUdfAggregator {
     return Status::OK();
   }
   Status Merge(KernelContext* ctx, KernelState&& other_state,
-               const ArrayData& group_id_mapping) {
+               const ArrayData& group_id_mapping) override {
     // This is similar to GroupedListImpl
     auto& other = checked_cast<PythonUdfHashAggregatorImpl&>(other_state);
     auto& other_values = other.values;
@@ -336,7 +336,7 @@ struct PythonUdfHashAggregatorImpl : public 
HashUdfAggregator {
     return Status::OK();
   }
 
-  Status Finalize(KernelContext* ctx, Datum* out) {
+  Status Finalize(KernelContext* ctx, Datum* out) override {
     // Exclude the last column which is the group id
     const int num_args = input_schema->num_fields() - 1;
 
@@ -484,24 +484,25 @@ Status PythonUdfExec(compute::KernelContext* ctx, const 
compute::ExecSpan& batch
   return SafeCallIntoPython([&]() -> Status { return udf->Exec(ctx, batch, 
out); });
 }
 
-Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init,
-                   UdfWrapperCallback wrapper, const UdfOptions& options,
+template <class Function, class Kernel>
+Status RegisterUdf(PyObject* function, compute::KernelInit kernel_init,
+                   UdfWrapperCallback cb, const UdfOptions& options,
                    compute::FunctionRegistry* registry) {
-  if (!PyCallable_Check(user_function)) {
+  if (!PyCallable_Check(function)) {
     return Status::TypeError("Expected a callable Python object.");
   }
-  auto scalar_func = std::make_shared<compute::ScalarFunction>(
-      options.func_name, options.arity, options.func_doc);
-  Py_INCREF(user_function);
+  auto scalar_func =
+      std::make_shared<Function>(options.func_name, options.arity, 
options.func_doc);
+  Py_INCREF(function);
   std::vector<compute::InputType> input_types;
   for (const auto& in_dtype : options.input_types) {
     input_types.emplace_back(in_dtype);
   }
   compute::OutputType output_type(options.output_type);
   auto udf_data = std::make_shared<PythonUdf>(
-      std::make_shared<OwnedRefNoGIL>(user_function), wrapper,
+      std::make_shared<OwnedRefNoGIL>(function), cb,
       TypeHolder::FromTypes(options.input_types), options.output_type);
-  compute::ScalarKernel kernel(
+  Kernel kernel(
       compute::KernelSignature::Make(std::move(input_types), 
std::move(output_type),
                                      options.arity.is_varargs),
       PythonUdfExec, kernel_init);
@@ -522,9 +523,17 @@ Status RegisterUdf(PyObject* user_function, 
compute::KernelInit kernel_init,
 Status RegisterScalarFunction(PyObject* function, UdfWrapperCallback cb,
                               const UdfOptions& options,
                               compute::FunctionRegistry* registry) {
-  return RegisterUdf(function,
-                     
PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function)}, cb,
-                     options, registry);
+  return RegisterUdf<compute::ScalarFunction, compute::ScalarKernel>(
+      function, 
PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function)}, cb,
+      options, registry);
+}
+
+Status RegisterVectorFunction(PyObject* function, UdfWrapperCallback cb,
+                              const UdfOptions& options,
+                              compute::FunctionRegistry* registry) {
+  return RegisterUdf<compute::VectorFunction, compute::VectorKernel>(
+      function, 
PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function)}, cb,
+      options, registry);
 }
 
 Status RegisterTabularFunction(PyObject* function, UdfWrapperCallback cb,
@@ -536,7 +545,7 @@ Status RegisterTabularFunction(PyObject* function, 
UdfWrapperCallback cb,
   if (options.output_type->id() != Type::type::STRUCT) {
     return Status::Invalid("tabular function with non-struct output");
   }
-  return RegisterUdf(
+  return RegisterUdf<compute::ScalarFunction, compute::ScalarKernel>(
       function, 
PythonTableUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function), cb},
       cb, options, registry);
 }
diff --git a/python/pyarrow/src/arrow/python/udf.h 
b/python/pyarrow/src/arrow/python/udf.h
index 682cbb2ffe..d8c4e430e5 100644
--- a/python/pyarrow/src/arrow/python/udf.h
+++ b/python/pyarrow/src/arrow/python/udf.h
@@ -67,6 +67,11 @@ Status ARROW_PYTHON_EXPORT RegisterAggregateFunction(
     PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& 
options,
     compute::FunctionRegistry* registry = NULLPTR);
 
+/// \brief register a Vector user-defined-function from Python
+Status ARROW_PYTHON_EXPORT RegisterVectorFunction(
+    PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& 
options,
+    compute::FunctionRegistry* registry = NULLPTR);
+
 Result<std::shared_ptr<RecordBatchReader>> ARROW_PYTHON_EXPORT
 CallTabularFunction(const std::string& func_name, const std::vector<Datum>& 
args,
                     compute::FunctionRegistry* registry = NULLPTR);
diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py
index 5631e19455..62d1eb5baf 100644
--- a/python/pyarrow/tests/test_udf.py
+++ b/python/pyarrow/tests/test_udf.py
@@ -299,6 +299,44 @@ def raising_func_fixture():
     return raising_func, func_name
 
 
[email protected](scope="session")
+def unary_vector_func_fixture():
+    """
+    Reigster a vector function
+    """
+    def pct_rank(ctx, x):
+        # copy here to get around pandas 1.0 issue
+        return pa.array(x.to_pandas().copy().rank(pct=True))
+
+    func_name = "y=pct_rank(x)"
+    doc = empty_udf_doc
+    pc.register_vector_function(pct_rank, func_name, doc, {
+                                'x': pa.float64()}, pa.float64())
+
+    return pct_rank, func_name
+
+
[email protected](scope="session")
+def struct_vector_func_fixture():
+    """
+    Reigster a vector function that returns a struct array
+    """
+    def pivot(ctx, k, v, c):
+        df = pa.RecordBatch.from_arrays([k, v, c], names=['k', 'v', 
'c']).to_pandas()
+        df_pivot = df.pivot(columns='c', values='v', index='k').reset_index()
+        return pa.RecordBatch.from_pandas(df_pivot).to_struct_array()
+
+    func_name = "y=pivot(x)"
+    doc = empty_udf_doc
+    pc.register_vector_function(
+        pivot, func_name, doc,
+        {'k': pa.int64(), 'v': pa.float64(), 'c': pa.utf8()},
+        pa.struct([('k', pa.int64()), ('v1', pa.float64()), ('v2', 
pa.float64())])
+    )
+
+    return pivot, func_name
+
+
 def check_scalar_function(func_fixture,
                           inputs, *,
                           run_in_dataset=True,
@@ -797,3 +835,35 @@ def test_hash_agg_random(sum_agg_func_fixture):
         [("value", "sum")]).rename_columns(['id', 'value_sum_udf'])
 
     assert result.sort_by('id') == expected.sort_by('id')
+
+
[email protected]
+def test_vector_basic(unary_vector_func_fixture):
+    arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64())
+    result = pc.call_function("y=pct_rank(x)", [arr])
+    expected = unary_vector_func_fixture[0](None, arr)
+    assert result == expected
+
+
[email protected]
+def test_vector_empty(unary_vector_func_fixture):
+    arr = pa.array([1], pa.float64())
+    result = pc.call_function("y=pct_rank(x)", [arr])
+    expected = unary_vector_func_fixture[0](None, arr)
+    assert result == expected
+
+
[email protected]
+def test_vector_struct(struct_vector_func_fixture):
+    k = pa.array(
+        [1, 1, 2, 2], pa.int64()
+    )
+    v = pa.array(
+        [1.0, 2.0, 3.0, 4.0], pa.float64()
+    )
+    c = pa.array(
+        ['v1', 'v2', 'v1', 'v2']
+    )
+    result = pc.call_function("y=pivot(x)", [k, v, c])
+    expected = struct_vector_func_fixture[0](None, k, v, c)
+    assert result == expected

Reply via email to