lidavidm commented on code in PR #12590:
URL: https://github.com/apache/arrow/pull/12590#discussion_r849387881


##########
python/pyarrow/_compute.pyx:
##########
@@ -199,6 +203,87 @@ FunctionDoc = namedtuple(
      "options_required"))
 
 
+cdef wrap_input_type(const CInputType c_input_type):
+    """
+    Wrap a C++ InputType in an InputType object.
+    """
+    cdef InputType input_type = InputType.__new__(InputType)
+    input_type.init(c_input_type)
+    return input_type
+
+
+cdef class InputType(_Weakrefable):
+    """
+    An input type specification for a user-defined function.
+    """
+
+    def __init__(self):
+        raise TypeError("Do not call {}'s constructor directly"
+                        .format(self.__class__.__name__))
+
+    cdef void init(self, const CInputType &input_type):
+        self.input_type = input_type
+
+    @staticmethod
+    def scalar(data_type):
+        """
+        Create a scalar input type of the given data type.
+
+        Arguments to a UDF have both a data type and a shape,
+        either array or scalar. A scalar InputType means that
+        this argument must be passed a Scalar.  
+
+        Parameter
+        ---------
+        data_type: DataType
+
+        Examples
+        --------
+
+        >>> import pyarrow as pa
+        >>> from pyarrow.compute import InputType
+        >>> in_type = InputType.scalar(pa.int32())
+        scalar[int32]
+        """
+        cdef:
+            shared_ptr[CDataType] c_data_type
+            CInputType c_input_type
+        c_data_type = pyarrow_unwrap_data_type(data_type)
+        c_input_type = CInputType.Scalar(c_data_type)
+        return wrap_input_type(c_input_type)
+
+    @staticmethod
+    def array(data_type):
+        """
+        Create an array input type of the given data type.
+
+        Arguments to a UDF have both a data type and a shape,
+        either array or scalar. An array InputType means that
+        this argument must be passed an Array.
+
+        Parameter
+        ---------
+        data_type: DataType
+
+        Examples
+        --------
+
+        >>> import pyarrow as pa
+        >>> from pyarrow.compute import InputType
+        >>> in_type = InputType.array(pa.int32())
+        <pyarrow._compute.InputType object at 0x102ba4850>

Review Comment:
   Is this correct?



##########
cpp/src/arrow/python/udf.cc:
##########
@@ -0,0 +1,162 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/python/udf.h"
+
+#include <cstddef>
+#include <memory>
+#include <sstream>
+
+#include "arrow/compute/function.h"
+#include "arrow/python/common.h"
+
+namespace arrow {
+
+namespace py {
+
+Status VerifyArrayInput(const compute::ExecBatch& batch) {
+  for (auto value : batch.values) {
+    if (!value.is_array()) {
+      return Status::Invalid("Expected array input, but got ", value.type());
+    }
+  }
+  return Status::OK();
+}
+
+Status VerifyScalarInput(const compute::ExecBatch& batch) {
+  for (auto value : batch.values) {
+    if (!value.is_scalar()) {
+      return Status::Invalid("Expected scalar input, but got ", value.type());
+    }
+  }
+  return Status::OK();
+}
+
+Status VerifyArityAndInput(compute::Arity arity, const compute::ExecBatch& 
batch) {
+  if (!arity.is_varargs) {
+    bool match = static_cast<uint64_t>(arity.num_args) == batch.values.size();
+    if (!match) {
+      return Status::Invalid(
+          "Function Arity and Input data shape doesn't match, expected ", 
arity.num_args,
+          ", got ", batch.values.size());
+    }
+  } else {
+    bool match = static_cast<uint64_t>(arity.num_args) <= batch.values.size();
+    if (!match) {
+      return Status::Invalid("Required minimum number of arguments", 
arity.num_args,
+                             " in VarArgs function is not met.", ", Received ",
+                             batch.values.size());
+    }
+  }
+  return Status::OK();
+}
+
+Status ExecFunctionScalar(const compute::ExecBatch& batch, PyObject* function,
+                          const compute::Arity& arity, Datum* out) {
+  // num_args for arity varargs is arity.num_args, and for other arities,
+  // it is equal to the number of values in the batch
+  int64_t num_args =
+      arity.is_varargs ? static_cast<int64_t>(batch.values.size()) : 
arity.num_args;
+  PyObject* arg_tuple = PyTuple_New(num_args);
+  for (int arg_id = 0; arg_id < num_args; arg_id++) {
+    if (!batch[arg_id].is_scalar()) {
+      return Status::Invalid("Input type and data type doesn't match");
+    }
+    auto c_data = batch[arg_id].scalar();
+    PyObject* data = wrap_scalar(c_data);
+    PyTuple_SetItem(arg_tuple, arg_id, data);
+  }
+  PyObject* result = PyObject_CallObject(function, arg_tuple);
+  if (result == NULL) {
+    return Status::ExecutionError("Output is null, but expected a scalar");
+  }
+  if (!is_scalar(result)) {
+    return Status::Invalid("Output from function is not a scalar");
+  }
+  ARROW_ASSIGN_OR_RAISE(auto unwrapped_result, unwrap_scalar(result));
+  *out = unwrapped_result;
+  return Status::OK();
+}
+
+Status ExecFunctionArray(const compute::ExecBatch& batch, PyObject* function,
+                         const compute::Arity& arity, Datum* out) {
+  // num_args for arity varargs is arity.num_args, and for other arities,
+  // it is equal to the number of values in the batch
+  int num_args =
+      arity.is_varargs ? static_cast<int64_t>(batch.values.size()) : 
arity.num_args;
+  PyObject* arg_tuple = PyTuple_New(num_args);
+  for (int arg_id = 0; arg_id < num_args; arg_id++) {
+    if (!batch[arg_id].is_array()) {
+      return Status::Invalid("Input type and data type doesn't match");
+    }
+    auto c_data = batch[arg_id].make_array();
+    PyObject* data = wrap_array(c_data);
+    PyTuple_SetItem(arg_tuple, arg_id, data);
+  }
+  PyObject* result = PyObject_CallObject(function, arg_tuple);
+  if (result == NULL) {
+    return Status::ExecutionError("Output is null, but expected an array");
+  }
+  if (!is_array(result)) {
+    return Status::Invalid("Output from function is not an array");
+  }
+  return unwrap_array(result).Value(out);
+}
+
+Status ScalarUdfBuilder::MakeFunction(PyObject* function, ScalarUdfOptions* 
options) {
+  if (function == NULL) {
+    return Status::Invalid("python function cannot be null");
+  }
+  Py_INCREF(function);
+  function_.reset(function);
+  if (!PyCallable_Check(function_.obj())) {
+    return Status::TypeError("Expected a callable python object.");
+  }
+  auto doc = options->doc();
+  auto arity = options->arity();
+  scalar_func_ = std::make_shared<compute::ScalarFunction>(options->name(), 
arity, doc);
+  auto func = function_.obj();
+  auto exec = [func, arity](compute::KernelContext* ctx, const 
compute::ExecBatch& batch,
+                            Datum* out) -> Status {
+    PyAcquireGIL lock;
+    RETURN_NOT_OK(VerifyArityAndInput(arity, batch));
+    if (VerifyArrayInput(batch).ok()) {  // checke 0-th element to select 
array callable
+      RETURN_NOT_OK(ExecFunctionArray(batch, func, arity, out));
+    } else if (VerifyScalarInput(batch)
+                   .ok()) {  // check 0-th element to select scalar callable
+      RETURN_NOT_OK(ExecFunctionScalar(batch, func, arity, out));
+    } else {
+      return Status::Invalid("Unexpected input type, scalar or array type 
expected.");
+    }
+    return Status::OK();
+  };
+
+  compute::ScalarKernel kernel(
+      compute::KernelSignature::Make(options->input_types(), 
options->output_type(),
+                                     arity.is_varargs),
+      exec);

Review Comment:
   Will you file the JIRA?



##########
cpp/src/arrow/python/udf.cc:
##########
@@ -0,0 +1,126 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/python/udf.h"
+
+#include <cstddef>
+#include <memory>
+#include <sstream>
+
+#include "arrow/compute/function.h"
+#include "arrow/python/common.h"
+
+namespace arrow {
+
+namespace py {
+
+Status ExecuteFunction(const compute::ExecBatch& batch, PyObject* function,
+                       const compute::OutputType& exp_out_type, Datum* out) {
+  int num_args = static_cast<int64_t>(batch.values.size());
+  PyObject* arg_tuple = PyTuple_New(num_args);
+  // wrap exec_batch objects into Python objects based on the datum type
+  for (int arg_id = 0; arg_id < num_args; arg_id++) {
+    switch (batch[arg_id].kind()) {
+      case Datum::SCALAR: {
+        auto c_data = batch[arg_id].scalar();
+        PyObject* data = wrap_scalar(c_data);
+        PyTuple_SetItem(arg_tuple, arg_id, data);
+        break;
+      }
+      case Datum::ARRAY: {
+        auto c_data = batch[arg_id].make_array();
+        PyObject* data = wrap_array(c_data);
+        PyTuple_SetItem(arg_tuple, arg_id, data);
+        break;
+      }
+      default:
+        return Status::NotImplemented(
+            "User-defined-functions are not supported for the datum kind ",
+            batch[arg_id].kind());
+    }
+  }
+  // call to Python executing the function
+  PyObject* result;
+  auto st = SafeCallIntoPython([&]() -> Status {
+    result = PyObject_CallObject(function, arg_tuple);
+    return CheckPyError();
+  });
+  RETURN_NOT_OK(st);
+  if (result == nullptr) {
+    return Status::ExecutionError("Output is null, but expected an array");

Review Comment:
   When is it possible for result to be None? If the Python function returns 
None, result will be Py_None, not nullptr, right?



##########
cpp/src/arrow/python/udf.h:
##########
@@ -0,0 +1,90 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/python/platform.h"
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/function.h"
+#include "arrow/compute/registry.h"
+#include "arrow/datum.h"
+#include "arrow/util/cpu_info.h"
+#include "arrow/util/logging.h"
+
+#include "arrow/python/common.h"
+#include "arrow/python/pyarrow.h"
+#include "arrow/python/visibility.h"
+
+namespace arrow {
+
+namespace py {
+
+/// TODO: TODO(ARROW-16041): UDF Options are not exposed to the Python
+/// users. This feature will be included when extending to provide advanced
+/// options for the users.
+class ARROW_PYTHON_EXPORT ScalarUdfOptions {
+ public:
+  ScalarUdfOptions(const std::string func_name, const compute::Arity arity,
+                   const compute::FunctionDoc func_doc,
+                   const std::vector<compute::InputType> in_types,
+                   const compute::OutputType out_type)
+      : func_name_(func_name),
+        kind_(compute::Function::SCALAR),
+        arity_(arity),
+        func_doc_(std::move(func_doc)),
+        in_types_(std::move(in_types)),
+        out_type_(out_type) {}
+
+  const std::string& name() const { return func_name_; }
+
+  compute::Function::Kind kind() { return kind_; }
+
+  const compute::Arity& arity() const { return arity_; }
+
+  const compute::FunctionDoc& doc() const { return func_doc_; }
+
+  const std::vector<compute::InputType>& input_types() const { return 
in_types_; }
+
+  const compute::OutputType& output_type() const { return out_type_; }
+
+ private:
+  std::string func_name_;
+  compute::Function::Kind kind_;
+  compute::Arity arity_;
+  const compute::FunctionDoc func_doc_;
+  std::vector<compute::InputType> in_types_;
+  compute::OutputType out_type_;
+};
+
+class ARROW_PYTHON_EXPORT UdfBuilder {
+ public:
+  UdfBuilder() {}
+};

Review Comment:
   This is completely unnecessary now right?



##########
python/pyarrow/_compute.pyx:
##########
@@ -2251,3 +2336,162 @@ cdef CExpression _bind(Expression filter, Schema 
schema) except *:
 
     return GetResultValue(filter.unwrap().Bind(
         deref(pyarrow_unwrap_schema(schema).get())))
+
+
+cdef CFunctionDoc _make_function_doc(dict func_doc) except *:
+    """
+    Helper function to generate the FunctionDoc
+    This function accepts a dictionary and expect the 
+    summary(str), description(str) and arg_names(List[str]) keys. 
+    """
+    cdef:
+        CFunctionDoc f_doc
+        vector[c_string] c_arg_names
+
+    if len(func_doc) <= 1:
+        raise ValueError(
+            "Function doc must contain a summary, a description and arg_names")
+
+    if not "summary" in func_doc.keys():
+        raise ValueError("Function doc must contain a summary")
+
+    if not "description" in func_doc.keys():
+        raise ValueError("Function doc must contain a description")
+
+    if not "arg_names" in func_doc.keys():
+        raise ValueError("Function doc must contain arg_names")
+
+    f_doc.summary = tobytes(func_doc["summary"])
+    f_doc.description = tobytes(func_doc["description"])
+    for arg_name in func_doc["arg_names"]:
+        c_arg_names.push_back(tobytes(arg_name))
+    f_doc.arg_names = c_arg_names
+    # UDFOptions integration:
+    # TODO: https://issues.apache.org/jira/browse/ARROW-16041
+    f_doc.options_class = tobytes("None")
+    f_doc.options_required = False
+    return f_doc
+
+
+def register_scalar_function(func_name, function_doc, in_types,
+                             out_type, function):
+    """
+    Register a user-defined scalar function. 
+
+    A scalar function is a function that executes elementwise
+    operations on arrays or scalars, and therefore whose results
+    generally do not depend on the order of the values in the
+    arguments. Accepts and returns arrays that are all of the
+    same size. These functions roughly correspond to the functions
+    used in SQL expressions.
+
+    Parameters
+    ----------
+    func_name : str
+        Name of the function. This name must be globally unique. 
+    function_doc : dict
+        A dictionary object with keys "summary" (str),
+        and "description" (str).
+    in_types : Dict[str, InputType]
+        Dictionary containing items with input label, InputType
+        objects which defines the arguments to the function.
+        The input label is a str that will be used to generate
+        documentation for the function. The number of arguments
+        specified here determines the function arity.
+    out_type : DataType
+        Output type of the function.
+    function : callable
+        A callable implementing the user-defined function.
+        It must take arguments equal to the number of
+        in_types defined. It must return an Array or Scalar
+        matching the out_type. It must return a Scalar if
+        all arguments are scalar, else it must return an array.
+
+        To define a varargs function, pass a callable that takes
+        varargs. The last in_type will be the type of the all
+        varargs arguments.
+
+    Example
+    -------
+
+    >>> import pyarrow.compute as pc
+    >>> from pyarrow.compute import register_scalar_function
+    >>> from pyarrow.compute import InputType
+    >>> 
+    >>> func_doc = {}
+    >>> func_doc["summary"] = "simple udf"
+    >>> func_doc["description"] = "add a constant to a scalar"
+    >>> 
+    >>> def add_constant(array):
+    ...     return pc.call_function("add", [array, 1])
+    ... 
+    >>> 
+    >>> func_name = "py_add_func"
+    >>> in_types = {"array": InputType.array(pa.int64())}
+    >>> out_type = pa.int64()
+    >>> register_function(func_name, func_doc,
+    ...                   in_types, out_type, add_constant)
+    >>> 
+    >>> func = pc.get_function(func_name)
+    >>> func.name
+    'py_add_func'
+    >>> ans = pc.call_function(func_name, [pa.array([20])])
+    >>> ans
+    <pyarrow.lib.Int64Array object at 0x10c22e700>
+    [
+    21
+    ]
+    """
+    cdef:
+        c_string c_func_name
+        CArity c_arity
+        CFunctionDoc c_func_doc
+        CInputType in_tmp
+        vector[CInputType] c_in_types
+        PyObject* c_function
+        shared_ptr[CDataType] c_type
+        COutputType* c_out_type
+        CStatus st
+        CScalarUdfOptions* c_options
+
+    c_func_name = tobytes(func_name)
+
+    if callable(function):
+        c_function = <PyObject*>function
+    else:
+        raise TypeError("Object must be a callable")
+
+    func_spec = inspect.getfullargspec(function)
+    num_args = -1
+    if isinstance(in_types, dict):
+        for in_type in in_types.values():
+            in_tmp = (<InputType> in_type).input_type
+            c_in_types.push_back(in_tmp)
+        function_doc["arg_names"] = in_types.keys()
+        num_args = len(in_types)
+    else:
+        if num_args == -1:

Review Comment:
   This check is redundant



##########
python/pyarrow/tests/test_udf.py:
##########
@@ -0,0 +1,467 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import pytest
+
+import pyarrow as pa
+from pyarrow import compute as pc
+from pyarrow.compute import register_scalar_function
+from pyarrow.compute import InputType
+
+
+unary_doc = {"summary": "add function",
+             "description": "test add function"}
+
+
+def unary_function(scalar1):
+    return pc.call_function("add", [scalar1, 1])
+
+
+binary_doc = {"summary": "y=mx",
+              "description": "find y from y = mx"}
+
+
+def binary_function(m, x):
+    return pc.call_function("multiply", [m, x])
+
+
+ternary_doc = {"summary": "y=mx+c",
+               "description": "find y from y = mx + c"}
+
+
+def ternary_function(m, x, c):
+    mx = pc.call_function("multiply", [m, x])
+    return pc.call_function("add", [mx, c])
+
+
+varargs_doc = {"summary": "z=ax+by+c",
+               "description": "find z from z = ax + by + c"
+               }
+
+
+def varargs_function(*values):
+    base_val = values[:2]
+    res = pc.call_function("add", base_val)
+    for other_val in values[2:]:
+        res = pc.call_function("add", [res, other_val])
+    return res
+
+
+def test_scalar_udf_function_with_scalar_valued_functions():
+    function_names = [
+        "scalar_y=x+k",
+        "scalar_y=mx",
+        "scalar_y=mx+c",
+        "scalar_z=ax+by+c",
+    ]
+
+    function_input_types = [
+        {
+            "scalar": InputType.scalar(pa.int64()),
+        },
+        {
+            "scalar1": InputType.scalar(pa.int64()),
+            "scalar2": InputType.scalar(pa.int64()),
+        },
+        {
+            "scalar1": InputType.scalar(pa.int64()),
+            "scalar2": InputType.scalar(pa.int64()),
+            "scalar3": InputType.scalar(pa.int64()),
+        },
+        {
+            "scalar1": InputType.scalar(pa.int64()),
+            "scalar2": InputType.scalar(pa.int64()),
+            "scalar3": InputType.scalar(pa.int64()),
+            "scalar4": InputType.scalar(pa.int64()),
+            "scalar5": InputType.scalar(pa.int64()),
+        },
+    ]
+
+    function_output_types = [
+        pa.int64(),
+        pa.int64(),
+        pa.int64(),
+        pa.int64(),
+    ]
+
+    function_docs = [
+        unary_doc,
+        binary_doc,
+        ternary_doc,
+        varargs_doc
+    ]
+
+    functions = [
+        unary_function,
+        binary_function,
+        ternary_function,
+        varargs_function
+    ]
+
+    function_inputs = [
+        [
+            pa.scalar(10, pa.int64())
+        ],
+        [
+            pa.scalar(10, pa.int64()),
+            pa.scalar(2, pa.int64())
+        ],
+        [
+            pa.scalar(10, pa.int64()),
+            pa.scalar(2, pa.int64()),
+            pa.scalar(5, pa.int64())
+        ],
+        [
+            pa.scalar(2, pa.int64()),
+            pa.scalar(10, pa.int64()),
+            pa.scalar(3, pa.int64()),
+            pa.scalar(20, pa.int64()),
+            pa.scalar(5, pa.int64())
+        ],
+    ]
+
+    expected_outputs = [
+        unary_function(function_inputs[0][0]),
+        binary_function(function_inputs[1][0], function_inputs[1][1]),
+        ternary_function(function_inputs[2][0], function_inputs[2][1],
+                         function_inputs[2][2]),
+        varargs_function(*function_inputs[3])
+    ]

Review Comment:
   This is redundant, you can compute it in the loop below as `function(*input)`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to