westonpace commented on code in PR #14320:
URL: https://github.com/apache/arrow/pull/14320#discussion_r1060111846
##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -99,22 +99,33 @@ Status RegisterScalarFunction(PyObject* user_function,
ScalarUdfWrapperCallback
auto scalar_func = std::make_shared<compute::ScalarFunction>(
options.func_name, options.arity, options.func_doc);
Py_INCREF(user_function);
- std::vector<compute::InputType> input_types;
- for (const auto& in_dtype : options.input_types) {
- input_types.emplace_back(in_dtype);
+
+ const size_t num_kernels = options.input_arg_types.size();
+ // number of input_type variations and output_types must be
+ // equal in size
+ if(num_kernels != options.output_types.size()) {
+ return Status::Invalid("input_arg_types and output_types should be equal
in size");
+ }
+ // adding kernels
+ for(size_t idx=0 ; idx < num_kernels; idx++) {
+ const auto& opt_input_types = options.input_arg_types[idx];
+ std::vector<compute::InputType> input_types;
+ for (const auto& in_dtype : opt_input_types) {
+ input_types.emplace_back(in_dtype);
+ }
+ const auto opts_out_type = options.output_types[idx];
Review Comment:
```suggestion
const auto& opts_out_type = options.output_types[idx];
```
##########
python/pyarrow/_compute.pyx:
##########
@@ -2603,15 +2606,29 @@ def register_scalar_function(func, function_name,
function_doc, in_types,
func_spec = inspect.getfullargspec(func)
num_args = -1
- if isinstance(in_types, dict):
- for in_type in in_types.values():
- c_in_types.push_back(
- pyarrow_unwrap_data_type(ensure_type(in_type)))
- function_doc["arg_names"] = in_types.keys()
- num_args = len(in_types)
+ if not isinstance(in_arg_types, list):
+ raise TypeError(
+ "in_arg_types must be a list of dictionaries of DataType")
+ if not isinstance(out_types, list):
+ raise TypeError("out_types must be a list of DataType")
+ # each input_type dict in input_types list must
+ # have same arg_names
+ if isinstance(in_arg_types[0], dict):
+ function_doc["arg_names"] = in_arg_types[0].keys()
+ num_args = len(in_arg_types[0])
else:
raise TypeError(
- "in_types must be a dictionary of DataType")
+ "Elements in in_arg_types must be a dictionary of DataType")
+
+ for in_types in in_arg_types:
+ if isinstance(in_types, dict):
+ for in_type in in_types.values():
+ c_in_types.push_back(
+ pyarrow_unwrap_data_type(ensure_type(in_type)))
+ else:
+ raise TypeError(
+ "Elements in in_arg_types must be a dictionary of DataType")
Review Comment:
```suggestion
"Elements in in_arg_types must be a dictionary of
str:DataType")
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]