paleolimbot commented on code in PR #13397: URL: https://github.com/apache/arrow/pull/13397#discussion_r925955886
########## r/R/compute.R: ########## @@ -306,3 +306,145 @@ cast_options <- function(safe = TRUE, ...) { ) modifyList(opts, list(...)) } + +#' Register user-defined functions +#' +#' These functions support calling R code from query engine execution +#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). +#' Use [register_scalar_function()] attach Arrow input and output types to an +#' R function and make it available for use in the dplyr interface and/or +#' [call_function()]. Scalar functions are currently the only type of +#' user-defined function supported. In Arrow, scalar functions must be +#' stateless and return output with the same shape (i.e., the same number +#' of rows) as the input. +#' +#' @param name The function name to be used in the dplyr bindings +#' @param in_type A [DataType] of the input type or a [schema()] +#' for functions with more than one argument. This signature will be used +#' to determine if this function is appropriate for a given set of arguments. +#' If this function is appropriate for more than one signature, pass a +#' `list()` of the above. +#' @param out_type A [DataType] of the output type or a function accepting +#' a single argument (`types`), which is a `list()` of [DataType]s. If a +#' function it must return a [DataType]. +#' @param fun An R function or rlang-style lambda expression. The function +#' will be called with a first argument `context` which is a `list()` +#' with elements `batch_size` (the expected length of the output) and +#' `output_type` (the required [DataType] of the output). Subsequent +#' arguments are passed by position as specified by `in_types`. If +#' `auto_convert` is `TRUE`, subsequent arguments are converted to +#' R vectors before being passed to `fun` and the output is automatically +#' constructed with the expected output type via [as_arrow_array()]. +#' @param auto_convert Use `TRUE` to convert inputs before passing to `fun` +#' and construct an Array of the correct type from the output. Use this +#' option to write functions of R objects as opposed to functions of +#' Arrow R6 objects. +#' +#' @return `NULL`, invisibly +#' @export +#' +#' @examplesIf arrow_with_dataset() +#' library(dplyr, warn.conflicts = FALSE) +#' +#' some_model <- lm(mpg ~ disp + cyl, data = mtcars) +#' register_scalar_function( +#' "mtcars_predict_mpg", +#' function(context, disp, cyl) { +#' predict(some_model, newdata = data.frame(disp, cyl)) +#' }, +#' in_type = schema(disp = float64(), cyl = float64()), +#' out_type = float64(), +#' auto_convert = TRUE +#' ) +#' +#' as_arrow_table(mtcars) %>% +#' transmute(mpg, mpg_predicted = mtcars_predict_mpg(disp, cyl)) %>% +#' collect() %>% +#' head() +#' +register_scalar_function <- function(name, fun, in_type, out_type, + auto_convert = FALSE) { + assert_that(is.string(name)) + + scalar_function <- arrow_scalar_function( + fun, + in_type, + out_type, + auto_convert = auto_convert + ) + + # register with Arrow C++ function registry (enables its use in + # call_function() and Expression$create()) + RegisterScalarUDF(name, scalar_function) + + # register with dplyr binding (enables its use in mutate(), filter(), etc.) + register_binding( + name, + function(...) build_expr(name, ...), + update_cache = TRUE + ) + + invisible(NULL) +} + +arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) { + fun <- as_function(fun) + + # Create a small wrapper function that is easier to call from C++. + # This wrapper could be implemented in C/C++ to reduce evaluation + # overhead and generate prettier backtraces when errors occur + # (probably using a similar approach to purrr). + if (auto_convert) { + wrapper_fun <- function(context, args) { + args <- lapply(args, as.vector) + result <- do.call(fun, c(list(context), args)) + as_arrow_array(result, type = context$output_type) + } + } else { + wrapper_fun <- function(context, args) { + do.call(fun, c(list(context), args)) + } + } + + if (is.list(in_type)) { + in_type <- lapply(in_type, as_scalar_function_in_type) + } else { + in_type <- list(as_scalar_function_in_type(in_type)) + } + + if (is.list(out_type)) { + out_type <- lapply(out_type, as_scalar_function_out_type) + } else { + out_type <- list(as_scalar_function_out_type(out_type)) + } + + out_type <- rep_len(out_type, length(in_type)) + + structure( + list( + wrapper_fun = wrapper_fun, + in_type = in_type, + out_type = out_type + ), + class = "arrow_scalar_function" + ) +} + +as_scalar_function_in_type <- function(x) { Review Comment: It's a good point...I hesitate to add `as_schema.(DataType|Field)()` because I don't know that there's anywhere else that a `DataType` *should* be interpreted as a `Schema`. `as_schema()` might get used to sanitize arguments, in which case I would expect an error for something that can't be interpreted in this way (as opposed to, say, a substrait schema, which should be coerced and used). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org