rtpsw commented on code in PR #14682:
URL: https://github.com/apache/arrow/pull/14682#discussion_r1049561525


##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -105,21 +158,117 @@ Status RegisterScalarFunction(PyObject* user_function, 
ScalarUdfWrapperCallback
   }
   compute::OutputType output_type(options.output_type);
   auto udf_data = std::make_shared<PythonUdf>(
-      wrapper, std::make_shared<OwnedRefNoGIL>(user_function), 
options.output_type);
+      std::make_shared<OwnedRefNoGIL>(user_function), wrapper, 
options.output_type);
   compute::ScalarKernel kernel(
       compute::KernelSignature::Make(std::move(input_types), 
std::move(output_type),
                                      options.arity.is_varargs),
-      PythonUdfExec);
+      PythonUdfExec, kernel_init);
   kernel.data = std::move(udf_data);
 
   kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE;
   kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE;
   RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel)));
-  auto registry = compute::GetFunctionRegistry();
+  if (registry == NULLPTR) {
+    registry = compute::GetFunctionRegistry();
+  }
   RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func)));
   return Status::OK();
 }
 
-}  // namespace py
+}  // namespace
+
+Status RegisterScalarFunction(PyObject* user_function, 
ScalarUdfWrapperCallback wrapper,
+                              const ScalarUdfOptions& options,
+                              compute::FunctionRegistry* registry) {
+  return RegisterScalarLikeFunction(
+      user_function,
+      
PythonScalarUdfKernelInit{std::make_shared<OwnedRefNoGIL>(user_function)}, 
wrapper,
+      options, registry);
+}
+
+Status RegisterTabularFunction(PyObject* user_function, 
ScalarUdfWrapperCallback wrapper,
+                               const ScalarUdfOptions& options,
+                               compute::FunctionRegistry* registry) {
+  if (options.arity.num_args != 0 || options.arity.is_varargs) {
+    return Status::NotImplemented("tabular function of non-null arity");
+  }
+  if (options.output_type->id() != Type::type::STRUCT) {
+    return Status::Invalid("tabular function with non-struct output");
+  }
+  return RegisterScalarLikeFunction(
+      user_function,
+      PythonTableUdfKernelInit{std::make_shared<OwnedRefNoGIL>(user_function), 
wrapper},
+      wrapper, options, registry);
+}
 
+namespace  {
+
+Result<std::shared_ptr<RecordBatch>> RecordBatchFromArray(
+    std::shared_ptr<Schema> schema, std::shared_ptr<Array> array) {
+  auto& data = const_cast<std::shared_ptr<ArrayData>&>(array->data());
+  if (data->child_data.size() != static_cast<size_t>(schema->num_fields())) {
+    return Status::Invalid("UDF result with shape not conforming to schema");
+  }
+  return RecordBatch::Make(std::move(schema), data->length, 
std::move(data->child_data));
+}
+
+}  // namespace
+
+Result<std::shared_ptr<RecordBatchReader>> CallTabularFunction(
+    const std::string& func_name, const std::vector<Datum>& args,
+    compute::FunctionRegistry* registry) {
+  if (args.size() != 0) {
+    return Status::NotImplemented("non-empty arguments to tabular function");
+  }
+  if (registry == NULLPTR) {
+    registry = compute::GetFunctionRegistry();
+  }
+  ARROW_ASSIGN_OR_RAISE(auto func, registry->GetFunction(func_name));
+  if (func->kind() != compute::Function::SCALAR) {
+    return Status::Invalid("tabular function of non-scalar kind");
+  }
+  auto arity = func->arity();
+  if (arity.num_args != 0 || arity.is_varargs) {
+    return Status::NotImplemented("tabular function of non-null arity");
+  }
+  auto kernels =
+      
arrow::internal::checked_pointer_cast<compute::ScalarFunction>(func)->kernels();
+  if (kernels.size() != 1) {
+    return Status::NotImplemented("tabular function with non-single kernel");
+  }
+  const compute::ScalarKernel* kernel = kernels[0];
+  auto out_type = kernel->signature->out_type();
+  if (out_type.kind() != compute::OutputType::FIXED) {
+    return Status::Invalid("tabular kernel of non-fixed kind");
+  }
+  auto datatype = out_type.type();
+  if (datatype->id() != Type::type::STRUCT) {
+    return Status::Invalid("tabular kernel with non-struct output");
+  }
+  auto struct_type = 
arrow::internal::checked_cast<StructType*>(datatype.get());
+  auto schema = ::arrow::schema(struct_type->fields());
+  std::vector<TypeHolder> in_types;
+  ARROW_ASSIGN_OR_RAISE(auto func_exec,
+                        GetFunctionExecutor(func_name, in_types, NULLPTR, 
registry));
+  auto next_func =
+      [schema,
+       func_exec = std::move(func_exec)]() -> 
Result<std::shared_ptr<RecordBatch>> {
+    std::vector<Datum> args;
+    // passed_length of -1 or 0 with args.size() of 0 leads to an empty 
ExecSpanIterator
+    // in exec.cc and to never invoking the source function, so 1 is passed 
instead
+    ARROW_ASSIGN_OR_RAISE(auto datum, func_exec->Execute(args, 
/*passed_length=*/1));
+    if (!datum.is_array()) {
+      return Status::Invalid("UDF result of non-array kind");
+    }
+    std::shared_ptr<Array> array = datum.make_array();
+    if (array->length() == 0) {
+      return IterationTraits<std::shared_ptr<RecordBatch>>::End();
+    }
+    return RecordBatchFromArray(std::move(schema), std::move(array));

Review Comment:
   Reviewing this, I believe I used `StructArray` because the [original 
result-handling 
code](https://github.com/apache/arrow/blob/5c1044fce55ed0e373a622cb8ee3b97a1a34799a/python/pyarrow/src/arrow/python/udf.cc#L68-L83)
 deals with arrays only, so introducing a new type of result should probably be 
done with care and consistency. You should know this original code - WDYT?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to