paleolimbot commented on code in PR #13397:
URL: https://github.com/apache/arrow/pull/13397#discussion_r906623430


##########
r/R/compute.R:
##########
@@ -307,3 +307,158 @@ cast_options <- function(safe = TRUE, ...) {
   )
   modifyList(opts, list(...))
 }
+
+#' Register user-defined functions
+#'
+#' These functions support calling R code from query engine execution
+#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]).
+#' Use [arrow_scalar_function()] to define an R function that accepts and
+#' returns R objects; use [arrow_base_scalar_function()] to define a
+#' lower-level function that operates directly on Arrow objects.
+#'
+#' @param name The function name to be used in the dplyr bindings
+#' @param scalar_function An object created with [arrow_scalar_function()]
+#'   or [arrow_base_scalar_function()].
+#' @param registry_name The function name to be used in the Arrow C++
+#'   compute function registry. This may be different from `name`.
+#' @param in_type A [DataType] of the input type or a [schema()]
+#'   for functions with more than one argument. This signature will be used
+#'   to determine if this function is appropriate for a given set of arguments.
+#'   If this function is appropriate for more than one signature, pass a
+#'   `list()` of the above.
+#' @param out_type A [DataType] of the output type or a function accepting
+#'   a single argument (`types`), which is a `list()` of [DataType]s. If a
+#'   function it must return a [DataType].
+#' @param fun An R function or rlang-style lambda expression. This function
+#'   will be called with R objects as arguments and must return an object
+#'   that can be converted to an [Array] using [as_arrow_array()]. Function
+#'   authors must take care to return an array castable to the output data
+#'   type specified by `out_type`.
+#' @param base_fun An R function or rlang-style lambda expression. This
+#'   function will be called with exactly two arguments: `kernel_context`,
+#'   which is a `list()` of objects giving information about the
+#'   execution context and `args`, which is a list of [Array] or [Scalar]
+#'   objects corresponding to the input arguments.
+#'
+#' @return
+#'   - `register_scalar_function()`: `NULL`, invisibly
+#'   - `arrow_scalar_function()`: returns an object of class
+#'     "arrow_base_scalar_function" that can be passed to
+#'     `register_scalar_function()`.
+#' @export
+#'
+#' @examplesIf .Machine$sizeof.pointer >= 8
+#' fun_wrapper <- arrow_scalar_function(
+#'   schema(x = float64(), y = float64(), z = float64()),
+#'   float64(),
+#'   function(x, y, z) x + y + z
+#' )
+#' register_scalar_function("example_add3", fun_wrapper)
+#'
+#' call_function(
+#'   "example_add3",
+#'   Scalar$create(1),
+#'   Scalar$create(2),
+#'   Array$create(3)
+#' )
+#'
+#' # use arrow_base_scalar_function() for a lower-level interface
+#' base_fun_wrapper <- arrow_base_scalar_function(
+#'   schema(x = float64(), y = float64(), z = float64()),
+#'   float64(),
+#'   function(context, args) {
+#'     args[[1]] + args[[2]] + args[[3]]
+#'   }
+#' )
+#' register_scalar_function("example_add3", base_fun_wrapper)
+#'
+#' call_function(
+#'   "example_add3",
+#'   Scalar$create(1),
+#'   Scalar$create(2),
+#'   Array$create(3)
+#' )
+#'
+register_scalar_function <- function(name, scalar_function, registry_name = 
name) {
+  assert_that(
+    is.string(name),
+    is.string(registry_name),
+    inherits(scalar_function, "arrow_base_scalar_function")
+  )
+
+  # register with Arrow C++
+  RegisterScalarUDF(registry_name, scalar_function)
+
+  # register with dplyr bindings
+  register_binding(
+    name,
+    function(...) build_expr(registry_name, ...)
+  )
+
+  # recreate dplyr binding cache
+  create_binding_cache()
+
+  invisible(NULL)
+}
+
+#' @rdname register_scalar_function
+#' @export
+arrow_scalar_function <- function(in_type, out_type, fun) {
+  fun <- rlang::as_function(fun)
+  base_fun <- function(context, args) {
+    args <- lapply(args, as.vector)
+    result <- do.call(fun, args)
+    as_arrow_array(result, type = context$output_type)
+  }
+
+  arrow_base_scalar_function(in_type, out_type, base_fun)
+}
+
+#' @rdname register_scalar_function
+#' @export
+arrow_base_scalar_function <- function(in_type, out_type, base_fun) {

Review Comment:
   There's almost certainly a better name...the 'base' version only deals with 
Arrow objects and gives some additional information from the execution context; 
the non 'base' version converts to R vectors and back automatically and is 
probably what most users want. For the geoarrow use-case, I specifically want 
to avoid R vectors and `do.call()`, which adds some additional overhead (which 
I should probably measure before adding the confusion of two scalar function 
constructors...).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to