This is an automated email from the ASF dual-hosted git repository.
icexelloss pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new b1e85a6d0c GH-36672: [Python][C++] Add support for vector function UDF
(#36673)
b1e85a6d0c is described below
commit b1e85a6d0cd7a57f93b97d74bb13e89517e3d92e
Author: Li Jin <[email protected]>
AuthorDate: Tue Aug 8 18:25:06 2023 -0400
GH-36672: [Python][C++] Add support for vector function UDF (#36673)
### Rationale for this change
In Arrow compute, there are four main types of functions: Scalar, Vector,
ScalarAggregate and HashAggregate.
Some of the previous work added support for Scalar,
ScalarAggregate(https://github.com/apache/arrow/issues/35515) and
HashAggregate(https://github.com/apache/arrow/issues/36252). I think it makes
sense to add support for vector function as well to complete all
non-decomposable UDF kernel support.
Internally, we plan to extend Acero to implement a "SegmentVectorNode"
which would use this API to invoke vector on a segment by segment basis, which
will allow to use constant memory to compute things like "rank the value across
all rows per segment using a python UDF".
### What changes are included in this PR?
The change includes is very similar to the support for aggregate function,
which includes code to register the vector UDF, and a kernel that invokes the
vector UDF on given inputs.
### Are these changes tested?
Yes. Added new test.
### Are there any user-facing changes?
Yes. This adds an user-facing API to register the vector function.
* Closes: #36672
Authored-by: Li Jin <[email protected]>
Signed-off-by: Li Jin <[email protected]>
---
python/pyarrow/_compute.pyx | 84 +++++++++++++++++++++++++++++++++-
python/pyarrow/compute.py | 1 +
python/pyarrow/includes/libarrow.pxd | 4 ++
python/pyarrow/src/arrow/python/udf.cc | 41 ++++++++++-------
python/pyarrow/src/arrow/python/udf.h | 5 ++
python/pyarrow/tests/test_udf.py | 70 ++++++++++++++++++++++++++++
6 files changed, 188 insertions(+), 17 deletions(-)
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index ac7efeff41..bc3b9e8c55 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -1964,7 +1964,7 @@ class CumulativeOptions(_CumulativeOptions):
Parameters
----------
start : Scalar, default None
- Starting value for the cumulative operation. If none is given,
+ Starting value for the cumulative operation. If none is given,
a default value depending on the operation and input type is used.
skip_nulls : bool, default False
When false, the first encountered null is propagated.
@@ -2707,6 +2707,11 @@ cdef get_register_aggregate_function():
reg.register_func = RegisterAggregateFunction
return reg
+cdef get_register_vector_function():
+ cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf)
+ reg.register_func = RegisterVectorFunction
+ return reg
+
def register_scalar_function(func, function_name, function_doc, in_types,
out_type,
func_registry=None):
@@ -2789,6 +2794,83 @@ def register_scalar_function(func, function_name,
function_doc, in_types, out_ty
out_type, func_registry)
+def register_vector_function(func, function_name, function_doc, in_types,
out_type,
+ func_registry=None):
+ """
+ Register a user-defined vector function.
+
+ This API is EXPERIMENTAL.
+
+ A vector function is a function that executes vector
+ operations on arrays. Vector function is often used
+ when compute doesn't fit other more specific types of
+ functions (e.g., scalar and aggregate).
+
+ Parameters
+ ----------
+ func : callable
+ A callable implementing the user-defined function.
+ The first argument is the context argument of type
+ UdfContext.
+ Then, it must take arguments equal to the number of
+ in_types defined. It must return an Array or Scalar
+ matching the out_type. It must return a Scalar if
+ all arguments are scalar, else it must return an Array.
+
+ To define a varargs function, pass a callable that takes
+ *args. The last in_type will be the type of all varargs
+ arguments.
+ function_name : str
+ Name of the function. There should only be one function
+ registered with this name in the function registry.
+ function_doc : dict
+ A dictionary object with keys "summary" (str),
+ and "description" (str).
+ in_types : Dict[str, DataType]
+ A dictionary mapping function argument names to
+ their respective DataType.
+ The argument names will be used to generate
+ documentation for the function. The number of
+ arguments specified here determines the function
+ arity.
+ out_type : DataType
+ Output type of the function.
+ func_registry : FunctionRegistry
+ Optional function registry to use instead of the default global one.
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> import pyarrow.compute as pc
+ >>>
+ >>> func_doc = {}
+ >>> func_doc["summary"] = "percent rank"
+ >>> func_doc["description"] = "compute percent rank"
+ >>>
+ >>> def list_flatten_udf(ctx, x):
+ ... return pc.list_flatten(x)
+ >>>
+ >>> func_name = "list_flatten_udf"
+ >>> in_types = {"array": pa.list_(pa.int64())}
+ >>> out_type = pa.int64()
+ >>> pc.register_vector_function(list_flatten_udf, func_name, func_doc,
+ ... in_types, out_type)
+ >>>
+ >>> answer = pc.call_function(func_name, [pa.array([[1, 2], [3, 4]])])
+ >>> answer
+ <pyarrow.lib.Int64Array object at ...>
+ [
+ 1,
+ 2,
+ 3,
+ 4
+ ]
+ """
+ return _register_user_defined_function(get_register_vector_function(),
+ func, function_name, function_doc,
in_types,
+ out_type, func_registry)
+
+
def register_aggregate_function(func, function_name, function_doc, in_types,
out_type,
func_registry=None):
"""
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index 0fefa18dd1..7b8983cbb9 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -87,6 +87,7 @@ from pyarrow._compute import ( # noqa
register_scalar_function,
register_tabular_function,
register_aggregate_function,
+ register_vector_function,
UdfContext,
# Expressions
Expression,
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index da46cdcb75..f4d6541fa7 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -2815,5 +2815,9 @@ cdef extern from "arrow/python/udf.h" namespace
"arrow::py" nogil:
function[CallbackUdf] wrapper, const
CUdfOptions& options,
CFunctionRegistry* registry)
+ CStatus RegisterVectorFunction(PyObject* function,
+ function[CallbackUdf] wrapper, const
CUdfOptions& options,
+ CFunctionRegistry* registry)
+
CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction(
const c_string& func_name, const vector[CDatum]& args,
CFunctionRegistry* registry)
diff --git a/python/pyarrow/src/arrow/python/udf.cc
b/python/pyarrow/src/arrow/python/udf.cc
index 435c89f596..f7761a9277 100644
--- a/python/pyarrow/src/arrow/python/udf.cc
+++ b/python/pyarrow/src/arrow/python/udf.cc
@@ -292,14 +292,14 @@ struct PythonUdfHashAggregatorImpl : public
HashUdfAggregator {
return out;
}
- Status Resize(KernelContext* ctx, int64_t new_num_groups) {
+ Status Resize(KernelContext* ctx, int64_t new_num_groups) override {
// We only need to change num_groups in resize
// similar to other hash aggregate kernels
num_groups = new_num_groups;
return Status::OK();
}
- Status Consume(KernelContext* ctx, const ExecSpan& batch) {
+ Status Consume(KernelContext* ctx, const ExecSpan& batch) override {
ARROW_ASSIGN_OR_RAISE(
std::shared_ptr<RecordBatch> rb,
batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool()));
@@ -316,7 +316,7 @@ struct PythonUdfHashAggregatorImpl : public
HashUdfAggregator {
return Status::OK();
}
Status Merge(KernelContext* ctx, KernelState&& other_state,
- const ArrayData& group_id_mapping) {
+ const ArrayData& group_id_mapping) override {
// This is similar to GroupedListImpl
auto& other = checked_cast<PythonUdfHashAggregatorImpl&>(other_state);
auto& other_values = other.values;
@@ -336,7 +336,7 @@ struct PythonUdfHashAggregatorImpl : public
HashUdfAggregator {
return Status::OK();
}
- Status Finalize(KernelContext* ctx, Datum* out) {
+ Status Finalize(KernelContext* ctx, Datum* out) override {
// Exclude the last column which is the group id
const int num_args = input_schema->num_fields() - 1;
@@ -484,24 +484,25 @@ Status PythonUdfExec(compute::KernelContext* ctx, const
compute::ExecSpan& batch
return SafeCallIntoPython([&]() -> Status { return udf->Exec(ctx, batch,
out); });
}
-Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init,
- UdfWrapperCallback wrapper, const UdfOptions& options,
+template <class Function, class Kernel>
+Status RegisterUdf(PyObject* function, compute::KernelInit kernel_init,
+ UdfWrapperCallback cb, const UdfOptions& options,
compute::FunctionRegistry* registry) {
- if (!PyCallable_Check(user_function)) {
+ if (!PyCallable_Check(function)) {
return Status::TypeError("Expected a callable Python object.");
}
- auto scalar_func = std::make_shared<compute::ScalarFunction>(
- options.func_name, options.arity, options.func_doc);
- Py_INCREF(user_function);
+ auto scalar_func =
+ std::make_shared<Function>(options.func_name, options.arity,
options.func_doc);
+ Py_INCREF(function);
std::vector<compute::InputType> input_types;
for (const auto& in_dtype : options.input_types) {
input_types.emplace_back(in_dtype);
}
compute::OutputType output_type(options.output_type);
auto udf_data = std::make_shared<PythonUdf>(
- std::make_shared<OwnedRefNoGIL>(user_function), wrapper,
+ std::make_shared<OwnedRefNoGIL>(function), cb,
TypeHolder::FromTypes(options.input_types), options.output_type);
- compute::ScalarKernel kernel(
+ Kernel kernel(
compute::KernelSignature::Make(std::move(input_types),
std::move(output_type),
options.arity.is_varargs),
PythonUdfExec, kernel_init);
@@ -522,9 +523,17 @@ Status RegisterUdf(PyObject* user_function,
compute::KernelInit kernel_init,
Status RegisterScalarFunction(PyObject* function, UdfWrapperCallback cb,
const UdfOptions& options,
compute::FunctionRegistry* registry) {
- return RegisterUdf(function,
-
PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function)}, cb,
- options, registry);
+ return RegisterUdf<compute::ScalarFunction, compute::ScalarKernel>(
+ function,
PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function)}, cb,
+ options, registry);
+}
+
+Status RegisterVectorFunction(PyObject* function, UdfWrapperCallback cb,
+ const UdfOptions& options,
+ compute::FunctionRegistry* registry) {
+ return RegisterUdf<compute::VectorFunction, compute::VectorKernel>(
+ function,
PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function)}, cb,
+ options, registry);
}
Status RegisterTabularFunction(PyObject* function, UdfWrapperCallback cb,
@@ -536,7 +545,7 @@ Status RegisterTabularFunction(PyObject* function,
UdfWrapperCallback cb,
if (options.output_type->id() != Type::type::STRUCT) {
return Status::Invalid("tabular function with non-struct output");
}
- return RegisterUdf(
+ return RegisterUdf<compute::ScalarFunction, compute::ScalarKernel>(
function,
PythonTableUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function), cb},
cb, options, registry);
}
diff --git a/python/pyarrow/src/arrow/python/udf.h
b/python/pyarrow/src/arrow/python/udf.h
index 682cbb2ffe..d8c4e430e5 100644
--- a/python/pyarrow/src/arrow/python/udf.h
+++ b/python/pyarrow/src/arrow/python/udf.h
@@ -67,6 +67,11 @@ Status ARROW_PYTHON_EXPORT RegisterAggregateFunction(
PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions&
options,
compute::FunctionRegistry* registry = NULLPTR);
+/// \brief register a Vector user-defined-function from Python
+Status ARROW_PYTHON_EXPORT RegisterVectorFunction(
+ PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions&
options,
+ compute::FunctionRegistry* registry = NULLPTR);
+
Result<std::shared_ptr<RecordBatchReader>> ARROW_PYTHON_EXPORT
CallTabularFunction(const std::string& func_name, const std::vector<Datum>&
args,
compute::FunctionRegistry* registry = NULLPTR);
diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py
index 5631e19455..62d1eb5baf 100644
--- a/python/pyarrow/tests/test_udf.py
+++ b/python/pyarrow/tests/test_udf.py
@@ -299,6 +299,44 @@ def raising_func_fixture():
return raising_func, func_name
[email protected](scope="session")
+def unary_vector_func_fixture():
+ """
+ Reigster a vector function
+ """
+ def pct_rank(ctx, x):
+ # copy here to get around pandas 1.0 issue
+ return pa.array(x.to_pandas().copy().rank(pct=True))
+
+ func_name = "y=pct_rank(x)"
+ doc = empty_udf_doc
+ pc.register_vector_function(pct_rank, func_name, doc, {
+ 'x': pa.float64()}, pa.float64())
+
+ return pct_rank, func_name
+
+
[email protected](scope="session")
+def struct_vector_func_fixture():
+ """
+ Reigster a vector function that returns a struct array
+ """
+ def pivot(ctx, k, v, c):
+ df = pa.RecordBatch.from_arrays([k, v, c], names=['k', 'v',
'c']).to_pandas()
+ df_pivot = df.pivot(columns='c', values='v', index='k').reset_index()
+ return pa.RecordBatch.from_pandas(df_pivot).to_struct_array()
+
+ func_name = "y=pivot(x)"
+ doc = empty_udf_doc
+ pc.register_vector_function(
+ pivot, func_name, doc,
+ {'k': pa.int64(), 'v': pa.float64(), 'c': pa.utf8()},
+ pa.struct([('k', pa.int64()), ('v1', pa.float64()), ('v2',
pa.float64())])
+ )
+
+ return pivot, func_name
+
+
def check_scalar_function(func_fixture,
inputs, *,
run_in_dataset=True,
@@ -797,3 +835,35 @@ def test_hash_agg_random(sum_agg_func_fixture):
[("value", "sum")]).rename_columns(['id', 'value_sum_udf'])
assert result.sort_by('id') == expected.sort_by('id')
+
+
[email protected]
+def test_vector_basic(unary_vector_func_fixture):
+ arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64())
+ result = pc.call_function("y=pct_rank(x)", [arr])
+ expected = unary_vector_func_fixture[0](None, arr)
+ assert result == expected
+
+
[email protected]
+def test_vector_empty(unary_vector_func_fixture):
+ arr = pa.array([1], pa.float64())
+ result = pc.call_function("y=pct_rank(x)", [arr])
+ expected = unary_vector_func_fixture[0](None, arr)
+ assert result == expected
+
+
[email protected]
+def test_vector_struct(struct_vector_func_fixture):
+ k = pa.array(
+ [1, 1, 2, 2], pa.int64()
+ )
+ v = pa.array(
+ [1.0, 2.0, 3.0, 4.0], pa.float64()
+ )
+ c = pa.array(
+ ['v1', 'v2', 'v1', 'v2']
+ )
+ result = pc.call_function("y=pivot(x)", [k, v, c])
+ expected = struct_vector_func_fixture[0](None, k, v, c)
+ assert result == expected