vibhatha commented on code in PR #13397:
URL: https://github.com/apache/arrow/pull/13397#discussion_r901703174
##########
r/src/compute.cpp:
##########
@@ -574,3 +576,90 @@ SEXP compute__CallFunction(std::string func_name,
cpp11::list args, cpp11::list
std::vector<std::string> compute__GetFunctionNames() {
return arrow::compute::GetFunctionRegistry()->GetFunctionNames();
}
+
+class RScalarUDFCallable : public arrow::compute::ArrayKernelExec {
+ public:
+ RScalarUDFCallable(const std::shared_ptr<arrow::Schema>& input_types,
+ const std::shared_ptr<arrow::DataType>& output_type,
cpp11::sexp fun)
+ : input_types_(input_types), output_type_(output_type), fun_(fun) {}
+
+ arrow::Status operator()(arrow::compute::KernelContext* context,
+ const arrow::compute::ExecSpan& span,
+ arrow::compute::ExecResult* result) {
+ std::vector<std::shared_ptr<arrow::Array>> array_args;
+ for (int64_t i = 0; i < span.num_values(); i++) {
+ const arrow::compute::ExecValue& v = span[i];
+ if (v.is_array()) {
+ array_args.push_back(v.array.ToArray());
+ } else if (v.is_scalar()) {
+ auto array = ValueOrStop(arrow::MakeArrayFromScalar(*v.scalar,
span.length));
+ array_args.push_back(array);
+ }
+ }
+
+ auto batch = arrow::RecordBatch::Make(input_types_, span.length,
array_args);
+
+ auto fun_result = SafeCallIntoR<std::shared_ptr<arrow::Array>>([&]() {
+ cpp11::sexp batch_sexp = cpp11::to_r6<arrow::RecordBatch>(batch);
+ cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length);
+
+ cpp11::writable::list udf_context = {batch_length_sexp};
+ udf_context.names() = {"batch_length"};
+
+ cpp11::sexp fun_result_sexp = fun_(udf_context, batch_sexp);
+ if (!Rf_inherits(fun_result_sexp, "Array")) {
+ cpp11::stop("arrow_scalar_function must return an Array");
+ }
+
+ return cpp11::as_cpp<std::shared_ptr<arrow::Array>>(fun_result_sexp);
+ });
+
+ if (!fun_result.ok()) {
+ return fun_result.status();
+ }
+
+ result->value.emplace<std::shared_ptr<arrow::ArrayData>>(
+ fun_result.ValueUnsafe()->data());
+ return arrow::Status::OK();
+ }
+
+ private:
+ std::shared_ptr<arrow::Schema> input_types_;
+ std::shared_ptr<arrow::DataType> output_type_;
+ cpp11::function fun_;
+};
+
+// [[arrow::export]]
+void RegisterScalarUDF(std::string name, cpp11::sexp fun) {
+ const arrow::compute::FunctionDoc dummy_function_doc{
+ "A user-defined R function", "returns something", {"..."}};
+
+ auto func = std::make_shared<arrow::compute::ScalarFunction>(
+ name, arrow::compute::Arity::VarArgs(), dummy_function_doc);
Review Comment:
@paleolimbot I am not quite sure about extracting the function doc in R as
well. I was merely referring to a string passed as the function docs explicitly
if required. In Python I was looking into the `inspect` API and it can help
with this. For Python UDFs, I did a little bit of a check if we can use it to
extract all the required values to register a UDF. It seems possible. I was
just curious how it is planned to handle such things in R API.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]