westonpace commented on code in PR #13397:
URL: https://github.com/apache/arrow/pull/13397#discussion_r918175160
##########
r/R/compute.R:
##########
@@ -307,3 +307,157 @@ cast_options <- function(safe = TRUE, ...) {
)
modifyList(opts, list(...))
}
+
+#' Register user-defined functions
+#'
+#' These functions support calling R code from query engine execution
+#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]).
+#' Use [arrow_scalar_function()] to define an R function that accepts and
+#' returns R objects; use [arrow_advanced_scalar_function()] to define a
+#' lower-level function that operates directly on Arrow objects.
Review Comment:
Do you want to document here or somewhere else the restrictions on scalar
functions? In particular:
* They should be stateless
* The size of the output array must match the size of the input array (one
output per row)
##########
r/R/compute.R:
##########
@@ -307,3 +307,157 @@ cast_options <- function(safe = TRUE, ...) {
)
modifyList(opts, list(...))
}
+
+#' Register user-defined functions
+#'
+#' These functions support calling R code from query engine execution
+#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]).
+#' Use [arrow_scalar_function()] to define an R function that accepts and
+#' returns R objects; use [arrow_advanced_scalar_function()] to define a
+#' lower-level function that operates directly on Arrow objects.
+#'
+#' @param name The function name to be used in the dplyr bindings
+#' @param scalar_function An object created with [arrow_scalar_function()]
+#' or [arrow_advanced_scalar_function()].
+#' @param in_type A [DataType] of the input type or a [schema()]
+#' for functions with more than one argument. This signature will be used
+#' to determine if this function is appropriate for a given set of arguments.
+#' If this function is appropriate for more than one signature, pass a
+#' `list()` of the above.
+#' @param out_type A [DataType] of the output type or a function accepting
+#' a single argument (`types`), which is a `list()` of [DataType]s. If a
+#' function it must return a [DataType].
+#' @param fun An R function or rlang-style lambda expression. This function
+#' will be called with R objects as arguments and must return an object
+#' that can be converted to an [Array] using [as_arrow_array()]. Function
+#' authors must take care to return an array castable to the output data
+#' type specified by `out_type`.
+#' @param advanced_fun An R function or rlang-style lambda expression. This
+#' function will be called with exactly two arguments: `kernel_context`,
+#' which is a `list()` of objects giving information about the
+#' execution context and `args`, which is a list of [Array] or [Scalar]
+#' objects corresponding to the input arguments.
Review Comment:
What should this function return?
##########
r/src/compute.cpp:
##########
@@ -574,3 +576,169 @@ SEXP compute__CallFunction(std::string func_name,
cpp11::list args, cpp11::list
std::vector<std::string> compute__GetFunctionNames() {
return arrow::compute::GetFunctionRegistry()->GetFunctionNames();
}
+
+class RScalarUDFKernelState : public arrow::compute::KernelState {
+ public:
+ RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver)
+ : exec_func_(exec_func), resolver_(resolver) {}
+
+ cpp11::function exec_func_;
+ cpp11::function resolver_;
+};
+
+arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
+ arrow::compute::KernelContext* context,
+ const std::vector<arrow::TypeHolder>& input_types) {
+ return SafeCallIntoR<arrow::TypeHolder>(
+ [&]() -> arrow::TypeHolder {
+ auto kernel =
+ reinterpret_cast<const
arrow::compute::ScalarKernel*>(context->kernel());
+ auto state =
std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+ cpp11::writable::list input_types_sexp(input_types.size());
+ for (size_t i = 0; i < input_types.size(); i++) {
+ input_types_sexp[i] =
+ cpp11::to_r6<arrow::DataType>(input_types[i].GetSharedPtr());
+ }
+
+ cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp);
+ if (!Rf_inherits(output_type_sexp, "DataType")) {
+ cpp11::stop("arrow_scalar_function resolver must return a DataType");
Review Comment:
In what situation would a user get this error? Would they understand it?
Or is this an internal error?
##########
r/src/compute.cpp:
##########
@@ -574,3 +576,169 @@ SEXP compute__CallFunction(std::string func_name,
cpp11::list args, cpp11::list
std::vector<std::string> compute__GetFunctionNames() {
return arrow::compute::GetFunctionRegistry()->GetFunctionNames();
}
+
+class RScalarUDFKernelState : public arrow::compute::KernelState {
+ public:
+ RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver)
+ : exec_func_(exec_func), resolver_(resolver) {}
+
+ cpp11::function exec_func_;
+ cpp11::function resolver_;
+};
+
+arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
+ arrow::compute::KernelContext* context,
+ const std::vector<arrow::TypeHolder>& input_types) {
+ return SafeCallIntoR<arrow::TypeHolder>(
+ [&]() -> arrow::TypeHolder {
+ auto kernel =
+ reinterpret_cast<const
arrow::compute::ScalarKernel*>(context->kernel());
+ auto state =
std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+ cpp11::writable::list input_types_sexp(input_types.size());
+ for (size_t i = 0; i < input_types.size(); i++) {
+ input_types_sexp[i] =
+ cpp11::to_r6<arrow::DataType>(input_types[i].GetSharedPtr());
+ }
+
+ cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp);
+ if (!Rf_inherits(output_type_sexp, "DataType")) {
+ cpp11::stop("arrow_scalar_function resolver must return a DataType");
+ }
+
+ return arrow::TypeHolder(
+ cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(output_type_sexp));
+ },
+ "resolve scalar user-defined function output data type");
+}
+
+arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context,
+ const arrow::compute::ExecSpan& span,
+ arrow::compute::ExecResult* result) {
+ if (result->is_array_span()) {
+ return arrow::Status::NotImplemented("ArraySpan result from R scalar UDF");
+ }
+
+ return SafeCallIntoRVoid(
+ [&]() {
+ auto kernel =
+ reinterpret_cast<const
arrow::compute::ScalarKernel*>(context->kernel());
+ auto state =
std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+ cpp11::writable::list args_sexp(span.num_values());
+
+ for (int i = 0; i < span.num_values(); i++) {
+ const arrow::compute::ExecValue& exec_val = span[i];
+ if (exec_val.is_array()) {
+ args_sexp[i] =
cpp11::to_r6<arrow::Array>(exec_val.array.ToArray());
+ } else if (exec_val.is_scalar()) {
+ args_sexp[i] =
cpp11::to_r6<arrow::Scalar>(exec_val.scalar->GetSharedPtr());
+ }
+ }
+
+ cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length);
+
+ std::shared_ptr<arrow::DataType> output_type =
result->type()->GetSharedPtr();
+ cpp11::sexp output_type_sexp =
cpp11::to_r6<arrow::DataType>(output_type);
+ cpp11::writable::list udf_context = {batch_length_sexp,
output_type_sexp};
+ udf_context.names() = {"batch_length", "output_type"};
+
+ cpp11::sexp func_result_sexp = state->exec_func_(udf_context,
args_sexp);
+
+ if (Rf_inherits(func_result_sexp, "Array")) {
+ auto array =
cpp11::as_cpp<std::shared_ptr<arrow::Array>>(func_result_sexp);
+
+ // handle an Array result of the wrong type
+ if (!result->type()->Equals(array->type())) {
+ arrow::Datum out = ValueOrStop(arrow::compute::Cast(array,
result->type()));
Review Comment:
Do we really want to support gracefully casting here? Maybe R is more
generous but in C++ I think I'd want to give an error to the user and force
them to decide if the cast is worth it.
##########
r/src/compute-exec.cpp:
##########
@@ -109,9 +109,50 @@ std::shared_ptr<arrow::RecordBatchReader> ExecPlan_run(
auto kv = strings_to_kvm(metadata);
out_schema = out_schema->WithMetadata(kv);
}
- return compute::MakeGeneratorReader(
+
+ std::pair<std::shared_ptr<compute::ExecPlan>,
std::shared_ptr<arrow::RecordBatchReader>>
+ out;
+ out.first = plan;
+ out.second = compute::MakeGeneratorReader(
out_schema, [stop_producing, plan, sink_gen] { return sink_gen(); },
gc_memory_pool());
+ return out;
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatchReader> ExecPlan_run(
Review Comment:
Streaming execution with R UDFs should someday be possible.
##########
r/src/compute.cpp:
##########
@@ -574,3 +576,169 @@ SEXP compute__CallFunction(std::string func_name,
cpp11::list args, cpp11::list
std::vector<std::string> compute__GetFunctionNames() {
return arrow::compute::GetFunctionRegistry()->GetFunctionNames();
}
+
+class RScalarUDFKernelState : public arrow::compute::KernelState {
+ public:
+ RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver)
+ : exec_func_(exec_func), resolver_(resolver) {}
+
+ cpp11::function exec_func_;
+ cpp11::function resolver_;
+};
+
+arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
+ arrow::compute::KernelContext* context,
+ const std::vector<arrow::TypeHolder>& input_types) {
+ return SafeCallIntoR<arrow::TypeHolder>(
+ [&]() -> arrow::TypeHolder {
+ auto kernel =
+ reinterpret_cast<const
arrow::compute::ScalarKernel*>(context->kernel());
+ auto state =
std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+ cpp11::writable::list input_types_sexp(input_types.size());
+ for (size_t i = 0; i < input_types.size(); i++) {
+ input_types_sexp[i] =
+ cpp11::to_r6<arrow::DataType>(input_types[i].GetSharedPtr());
+ }
+
+ cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp);
+ if (!Rf_inherits(output_type_sexp, "DataType")) {
+ cpp11::stop("arrow_scalar_function resolver must return a DataType");
+ }
+
+ return arrow::TypeHolder(
+ cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(output_type_sexp));
+ },
+ "resolve scalar user-defined function output data type");
+}
+
+arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context,
+ const arrow::compute::ExecSpan& span,
+ arrow::compute::ExecResult* result) {
+ if (result->is_array_span()) {
+ return arrow::Status::NotImplemented("ArraySpan result from R scalar UDF");
+ }
+
+ return SafeCallIntoRVoid(
+ [&]() {
+ auto kernel =
+ reinterpret_cast<const
arrow::compute::ScalarKernel*>(context->kernel());
+ auto state =
std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+ cpp11::writable::list args_sexp(span.num_values());
+
+ for (int i = 0; i < span.num_values(); i++) {
+ const arrow::compute::ExecValue& exec_val = span[i];
+ if (exec_val.is_array()) {
+ args_sexp[i] =
cpp11::to_r6<arrow::Array>(exec_val.array.ToArray());
+ } else if (exec_val.is_scalar()) {
+ args_sexp[i] =
cpp11::to_r6<arrow::Scalar>(exec_val.scalar->GetSharedPtr());
+ }
+ }
+
+ cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length);
+
+ std::shared_ptr<arrow::DataType> output_type =
result->type()->GetSharedPtr();
+ cpp11::sexp output_type_sexp =
cpp11::to_r6<arrow::DataType>(output_type);
+ cpp11::writable::list udf_context = {batch_length_sexp,
output_type_sexp};
+ udf_context.names() = {"batch_length", "output_type"};
+
+ cpp11::sexp func_result_sexp = state->exec_func_(udf_context,
args_sexp);
+
+ if (Rf_inherits(func_result_sexp, "Array")) {
+ auto array =
cpp11::as_cpp<std::shared_ptr<arrow::Array>>(func_result_sexp);
+
+ // handle an Array result of the wrong type
+ if (!result->type()->Equals(array->type())) {
+ arrow::Datum out = ValueOrStop(arrow::compute::Cast(array,
result->type()));
+ std::shared_ptr<arrow::Array> out_array = out.make_array();
+ array.swap(out_array);
+ }
+
+ result->value = std::move(array->data());
+ } else if (Rf_inherits(func_result_sexp, "Scalar")) {
+ auto scalar =
cpp11::as_cpp<std::shared_ptr<arrow::Scalar>>(func_result_sexp);
+
+ // handle a Scalar result of the wrong type
+ if (!result->type()->Equals(scalar->type)) {
+ arrow::Datum out = ValueOrStop(arrow::compute::Cast(scalar,
result->type()));
+ std::shared_ptr<arrow::Scalar> out_scalar = out.scalar();
+ scalar.swap(out_scalar);
+ }
+
+ auto array = ValueOrStop(
+ arrow::MakeArrayFromScalar(*scalar, span.length,
context->memory_pool()));
+ result->value = std::move(array->data());
+ } else {
+ cpp11::stop("arrow_scalar_function must return an Array or Scalar");
+ }
+ },
+ "execute scalar user-defined function");
+}
+
+// [[arrow::export]]
+void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) {
+ cpp11::list in_type_r(func_sexp.attr("in_type"));
+ cpp11::list out_type_r(func_sexp.attr("out_type"));
+ R_xlen_t n_kernels = in_type_r.size();
+
+ if (n_kernels == 0) {
+ cpp11::stop("Can't register user-defined function with zero kernels");
+ }
+
+ // compute the Arity from the list of input kernels
+ std::vector<int64_t> n_args(n_kernels);
+ for (R_xlen_t i = 0; i < n_kernels; i++) {
+ auto in_types =
cpp11::as_cpp<std::shared_ptr<arrow::Schema>>(in_type_r[i]);
+ n_args[i] = in_types->num_fields();
+ }
+
+ const int64_t min_args = *std::min_element(n_args.begin(), n_args.end());
+ const int64_t max_args = *std::max_element(n_args.begin(), n_args.end());
+
+ // We can't currently handle variable numbers of arguments in a user-defined
+ // function and we don't have a mechanism for the user to specify a variable
+ // number of arguments at the end of a signature.
+ if (min_args != max_args) {
+ cpp11::stop(
+ "User-defined function with a variable number of arguments is not
supported");
+ }
Review Comment:
When I think "varargs" or "variable number of arguments" I am thinking of a
single kernel that can take any number of arguments. However, this situation
is more about multiple kernels that don't have the same number of arguments
right? Are we ever going to support something like that?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]