wjones127 commented on code in PR #14682:
URL: https://github.com/apache/arrow/pull/14682#discussion_r1057794296


##########
python/pyarrow/tests/test_udf.py:
##########
@@ -504,3 +504,112 @@ def test_input_lifetime(unary_func_fixture):
     # Calling a UDF should not have kept `v` alive longer than required
     v = None
     assert proxy_pool.bytes_allocated() == 0
+
+
+def _record_batch_from_iters(schema, *iters):
+    arrays = [pa.array(list(v), type=schema[i].type)
+              for i, v in enumerate(iters)]
+    return pa.RecordBatch.from_arrays(arrays=arrays, schema=schema)
+
+
+def _record_batch_for_range(schema, n):
+    return _record_batch_from_iters(schema,
+                                    range(n, n + 10),
+                                    range(n + 1, n + 11))
+
+
+def make_udt_func(schema, batch_gen):
+    def udf_func(ctx):
+        class UDT:
+            def __init__(self):
+                self.caller = None
+
+            def __call__(self, ctx):
+                try:
+                    if self.caller is None:
+                        self.caller, ctx = batch_gen(ctx).send, None
+                    batch = self.caller(ctx)
+                except StopIteration:
+                    arrays = [pa.array([], type=field.type)
+                              for field in schema]
+                    batch = pa.RecordBatch.from_arrays(
+                        arrays=arrays, schema=schema)

Review Comment:
   Is there a good reason why we have to return an empty batch? It seems like 
it would be a lot easier to just have users write `throw StopIteration()` (or 
use the default behavior from a generator).



##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -105,21 +192,109 @@ Status RegisterScalarFunction(PyObject* user_function, 
ScalarUdfWrapperCallback
   }
   compute::OutputType output_type(options.output_type);
   auto udf_data = std::make_shared<PythonUdf>(
-      wrapper, std::make_shared<OwnedRefNoGIL>(user_function), 
options.output_type);
+      std::make_shared<OwnedRefNoGIL>(user_function), wrapper,
+      TypeHolder::FromTypes(options.input_types), options.output_type);
   compute::ScalarKernel kernel(
       compute::KernelSignature::Make(std::move(input_types), 
std::move(output_type),
                                      options.arity.is_varargs),
-      PythonUdfExec);
+      PythonUdfExec, kernel_init);
   kernel.data = std::move(udf_data);
 
   kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE;
   kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE;
   RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel)));
-  auto registry = compute::GetFunctionRegistry();
+  if (registry == NULLPTR) {
+    registry = compute::GetFunctionRegistry();
+  }
   RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func)));
   return Status::OK();
 }
 
-}  // namespace py
+}  // namespace
 
+Status RegisterScalarFunction(PyObject* user_function, UdfWrapperCallback 
wrapper,
+                              const UdfOptions& options,
+                              compute::FunctionRegistry* registry) {
+  return RegisterUdf(
+      user_function,
+      PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(user_function)}, 
wrapper,
+      options, registry);
+}
+
+Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback 
wrapper,
+                               const UdfOptions& options,
+                               compute::FunctionRegistry* registry) {
+  if (options.arity.num_args != 0 || options.arity.is_varargs) {
+    return Status::NotImplemented("tabular function of non-null arity");
+  }
+  if (options.output_type->id() != Type::type::STRUCT) {
+    return Status::Invalid("tabular function with non-struct output");
+  }
+  return RegisterUdf(
+      user_function,
+      PythonTableUdfKernelInit{std::make_shared<OwnedRefNoGIL>(user_function), 
wrapper},
+      wrapper, options, registry);
+}
+
+Result<std::shared_ptr<RecordBatchReader>> CallTabularFunction(
+    const std::string& func_name, const std::vector<Datum>& args,
+    compute::FunctionRegistry* registry) {
+  if (args.size() != 0) {
+    return Status::NotImplemented("non-empty arguments to tabular function");
+  }
+  if (registry == NULLPTR) {
+    registry = compute::GetFunctionRegistry();
+  }
+  ARROW_ASSIGN_OR_RAISE(auto func, registry->GetFunction(func_name));
+  if (func->kind() != compute::Function::SCALAR) {
+    return Status::Invalid("tabular function of non-scalar kind");
+  }
+  auto arity = func->arity();
+  if (arity.num_args != 0 || arity.is_varargs) {
+    return Status::NotImplemented("tabular function of non-null arity");
+  }
+  auto kernels =
+      
arrow::internal::checked_pointer_cast<compute::ScalarFunction>(func)->kernels();
+  if (kernels.size() != 1) {
+    return Status::NotImplemented("tabular function with non-single kernel");
+  }
+  const compute::ScalarKernel* kernel = kernels[0];
+  auto out_type = kernel->signature->out_type();
+  if (out_type.kind() != compute::OutputType::FIXED) {
+    return Status::Invalid("tabular kernel of non-fixed kind");
+  }
+  auto datatype = out_type.type();
+  if (datatype->id() != Type::type::STRUCT) {
+    return Status::Invalid("tabular kernel with non-struct output");
+  }
+  auto struct_type = 
arrow::internal::checked_cast<StructType*>(datatype.get());
+  auto schema = ::arrow::schema(struct_type->fields());
+  std::vector<TypeHolder> in_types;
+  ARROW_ASSIGN_OR_RAISE(auto func_exec,
+                        GetFunctionExecutor(func_name, in_types, NULLPTR, 
registry));
+  auto next_func =
+      [schema,
+       func_exec = std::move(func_exec)]() -> 
Result<std::shared_ptr<RecordBatch>> {
+    std::vector<Datum> args;
+    // passed_length of -1 or 0 with args.size() of 0 leads to an empty 
ExecSpanIterator
+    // in exec.cc and to never invoking the source function, so 1 is passed 
instead
+    ARROW_ASSIGN_OR_RAISE(auto datum, func_exec->Execute(args, 
/*passed_length=*/1));

Review Comment:
   We should document then that the `batch_size` on the `ctx` is meaningless 
for UDTFs, which is implied by this, right?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to