This is an automated email from the ASF dual-hosted git repository.
paleolimbot pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 010b5921bd ARROW-16444: [R] Implement user-defined scalar functions in
R bindings (#13397)
010b5921bd is described below
commit 010b5921bd33e9849820f981647561f19936b899
Author: Dewey Dunnington <[email protected]>
AuthorDate: Fri Jul 22 09:24:00 2022 -0300
ARROW-16444: [R] Implement user-defined scalar functions in R bindings
(#13397)
Authored-by: Dewey Dunnington <[email protected]>
Signed-off-by: Dewey Dunnington <[email protected]>
---
r/NAMESPACE | 3 +
r/R/arrowExports.R | 20 +-
r/R/compute.R | 176 ++++++++++++++++++
r/R/dplyr-collect.R | 2 +-
r/R/dplyr-funcs.R | 47 ++++-
r/R/feather.R | 4 +-
r/R/query-engine.R | 28 ++-
r/R/table.R | 15 ++
r/_pkgdown.yml | 1 +
r/man/as_arrow_table.Rd | 6 +
r/man/register_binding.Rd | 10 +-
r/man/register_scalar_function.Rd | 70 +++++++
r/src/arrowExports.cpp | 50 ++++-
r/src/compute-exec.cpp | 54 +++++-
r/src/compute.cpp | 167 +++++++++++++++++
r/src/csv.cpp | 11 +-
r/src/extension-impl.cpp | 25 +--
r/src/feather.cpp | 39 +---
r/src/io.cpp | 68 ++++---
r/src/safe-call-into-r-impl.cpp | 15 ++
r/src/safe-call-into-r.h | 131 +++++++++----
r/tests/testthat/_snaps/compute.md | 4 +
r/tests/testthat/test-compute.R | 305 +++++++++++++++++++++++++++++++
r/tests/testthat/test-csv.R | 4 +-
r/tests/testthat/test-dplyr-funcs.R | 7 +-
r/tests/testthat/test-extension.R | 1 +
r/tests/testthat/test-feather.R | 6 +-
r/tests/testthat/test-safe-call-into-r.R | 8 +-
28 files changed, 1103 insertions(+), 174 deletions(-)
diff --git a/r/NAMESPACE b/r/NAMESPACE
index 750a815f9f..0a120dc97a 100644
--- a/r/NAMESPACE
+++ b/r/NAMESPACE
@@ -45,7 +45,9 @@ S3method(as_arrow_array,data.frame)
S3method(as_arrow_array,default)
S3method(as_arrow_array,pyarrow.lib.Array)
S3method(as_arrow_table,RecordBatch)
+S3method(as_arrow_table,RecordBatchReader)
S3method(as_arrow_table,Table)
+S3method(as_arrow_table,arrow_dplyr_query)
S3method(as_arrow_table,data.frame)
S3method(as_arrow_table,default)
S3method(as_arrow_table,pyarrow.lib.RecordBatch)
@@ -344,6 +346,7 @@ export(read_schema)
export(read_tsv_arrow)
export(record_batch)
export(register_extension_type)
+export(register_scalar_function)
export(reregister_extension_type)
export(s3_bucket)
export(schema)
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index 84f6ee54fc..dfe0db614a 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -408,6 +408,10 @@ ExecPlan_run <- function(plan, final_node, sort_options,
metadata, head) {
.Call(`_arrow_ExecPlan_run`, plan, final_node, sort_options, metadata, head)
}
+ExecPlan_read_table <- function(plan, final_node, sort_options, metadata,
head) {
+ .Call(`_arrow_ExecPlan_read_table`, plan, final_node, sort_options,
metadata, head)
+}
+
ExecPlan_StopProducing <- function(plan) {
invisible(.Call(`_arrow_ExecPlan_StopProducing`, plan))
}
@@ -480,6 +484,10 @@ compute__GetFunctionNames <- function() {
.Call(`_arrow_compute__GetFunctionNames`)
}
+RegisterScalarUDF <- function(name, func_sexp) {
+ invisible(.Call(`_arrow_RegisterScalarUDF`, name, func_sexp))
+}
+
build_info <- function() {
.Call(`_arrow_build_info`)
}
@@ -1108,12 +1116,12 @@ ipc___feather___Reader__version <- function(reader) {
.Call(`_arrow_ipc___feather___Reader__version`, reader)
}
-ipc___feather___Reader__Read <- function(reader, columns, on_old_windows) {
- .Call(`_arrow_ipc___feather___Reader__Read`, reader, columns, on_old_windows)
+ipc___feather___Reader__Read <- function(reader, columns) {
+ .Call(`_arrow_ipc___feather___Reader__Read`, reader, columns)
}
-ipc___feather___Reader__Open <- function(stream, on_old_windows) {
- .Call(`_arrow_ipc___feather___Reader__Open`, stream, on_old_windows)
+ipc___feather___Reader__Open <- function(stream) {
+ .Call(`_arrow_ipc___feather___Reader__Open`, stream)
}
ipc___feather___Reader__schema <- function(reader) {
@@ -1792,6 +1800,10 @@ InitializeMainRThread <- function() {
invisible(.Call(`_arrow_InitializeMainRThread`))
}
+CanRunWithCapturedR <- function() {
+ .Call(`_arrow_CanRunWithCapturedR`)
+}
+
TestSafeCallIntoR <- function(r_fun_that_returns_a_string, opt) {
.Call(`_arrow_TestSafeCallIntoR`, r_fun_that_returns_a_string, opt)
}
diff --git a/r/R/compute.R b/r/R/compute.R
index 1cd12f2e29..0985e73a5f 100644
--- a/r/R/compute.R
+++ b/r/R/compute.R
@@ -306,3 +306,179 @@ cast_options <- function(safe = TRUE, ...) {
)
modifyList(opts, list(...))
}
+
+#' Register user-defined functions
+#'
+#' These functions support calling R code from query engine execution
+#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]).
+#' Use [register_scalar_function()] attach Arrow input and output types to an
+#' R function and make it available for use in the dplyr interface and/or
+#' [call_function()]. Scalar functions are currently the only type of
+#' user-defined function supported. In Arrow, scalar functions must be
+#' stateless and return output with the same shape (i.e., the same number
+#' of rows) as the input.
+#'
+#' @param name The function name to be used in the dplyr bindings
+#' @param in_type A [DataType] of the input type or a [schema()]
+#' for functions with more than one argument. This signature will be used
+#' to determine if this function is appropriate for a given set of arguments.
+#' If this function is appropriate for more than one signature, pass a
+#' `list()` of the above.
+#' @param out_type A [DataType] of the output type or a function accepting
+#' a single argument (`types`), which is a `list()` of [DataType]s. If a
+#' function it must return a [DataType].
+#' @param fun An R function or rlang-style lambda expression. The function
+#' will be called with a first argument `context` which is a `list()`
+#' with elements `batch_size` (the expected length of the output) and
+#' `output_type` (the required [DataType] of the output) that may be used
+#' to ensure that the output has the correct type and length. Subsequent
+#' arguments are passed by position as specified by `in_types`. If
+#' `auto_convert` is `TRUE`, subsequent arguments are converted to
+#' R vectors before being passed to `fun` and the output is automatically
+#' constructed with the expected output type via [as_arrow_array()].
+#' @param auto_convert Use `TRUE` to convert inputs before passing to `fun`
+#' and construct an Array of the correct type from the output. Use this
+#' option to write functions of R objects as opposed to functions of
+#' Arrow R6 objects.
+#'
+#' @return `NULL`, invisibly
+#' @export
+#'
+#' @examplesIf arrow_with_dataset()
+#' library(dplyr, warn.conflicts = FALSE)
+#'
+#' some_model <- lm(mpg ~ disp + cyl, data = mtcars)
+#' register_scalar_function(
+#' "mtcars_predict_mpg",
+#' function(context, disp, cyl) {
+#' predict(some_model, newdata = data.frame(disp, cyl))
+#' },
+#' in_type = schema(disp = float64(), cyl = float64()),
+#' out_type = float64(),
+#' auto_convert = TRUE
+#' )
+#'
+#' as_arrow_table(mtcars) %>%
+#' transmute(mpg, mpg_predicted = mtcars_predict_mpg(disp, cyl)) %>%
+#' collect() %>%
+#' head()
+#'
+register_scalar_function <- function(name, fun, in_type, out_type,
+ auto_convert = FALSE) {
+ assert_that(is.string(name))
+
+ scalar_function <- arrow_scalar_function(
+ fun,
+ in_type,
+ out_type,
+ auto_convert = auto_convert
+ )
+
+ # register with Arrow C++ function registry (enables its use in
+ # call_function() and Expression$create())
+ RegisterScalarUDF(name, scalar_function)
+
+ # register with dplyr binding (enables its use in mutate(), filter(), etc.)
+ register_binding(
+ name,
+ function(...) build_expr(name, ...),
+ update_cache = TRUE
+ )
+
+ invisible(NULL)
+}
+
+arrow_scalar_function <- function(fun, in_type, out_type, auto_convert =
FALSE) {
+ assert_that(is.function(fun))
+
+ # Create a small wrapper function that is easier to call from C++.
+ # TODO(ARROW-17148): This wrapper could be implemented in C/C++ to
+ # reduce evaluation overhead and generate prettier backtraces when
+ # errors occur (probably using a similar approach to purrr).
+ if (auto_convert) {
+ wrapper_fun <- function(context, args) {
+ args <- lapply(args, as.vector)
+ result <- do.call(fun, c(list(context), args))
+ as_arrow_array(result, type = context$output_type)
+ }
+ } else {
+ wrapper_fun <- function(context, args) {
+ do.call(fun, c(list(context), args))
+ }
+ }
+
+ # in_type can be a list() if registering multiple kernels at once
+ if (is.list(in_type)) {
+ in_type <- lapply(in_type, in_type_as_schema)
+ } else {
+ in_type <- list(in_type_as_schema(in_type))
+ }
+
+ # out_type can be a list() if registering multiple kernels at once
+ if (is.list(out_type)) {
+ out_type <- lapply(out_type, out_type_as_function)
+ } else {
+ out_type <- list(out_type_as_function(out_type))
+ }
+
+ # recycle out_type (which is frequently length 1 even if multiple kernels
+ # are being registered at once)
+ out_type <- rep_len(out_type, length(in_type))
+
+ # check n_kernels and number of args in fun
+ n_kernels <- length(in_type)
+ if (n_kernels == 0) {
+ abort("Can't register user-defined scalar function with 0 kernels")
+ }
+
+ expected_n_args <- in_type[[1]]$num_fields + 1L
+ fun_formals_have_dots <- any(names(formals(fun)) == "...")
+ if (!fun_formals_have_dots && length(formals(fun)) != expected_n_args) {
+ abort(
+ sprintf(
+ paste0(
+ "Expected `fun` to accept %d argument(s)\n",
+ "but found a function that acccepts %d argument(s)\n",
+ "Did you forget to include `context` as the first argument?"
+ ),
+ expected_n_args,
+ length(formals(fun))
+ )
+ )
+ }
+
+ structure(
+ list(
+ wrapper_fun = wrapper_fun,
+ in_type = in_type,
+ out_type = out_type
+ ),
+ class = "arrow_scalar_function"
+ )
+}
+
+# This function sanitizes the in_type argument for arrow_scalar_function(),
+# which can be a data type (e.g., int32()), a field for a unary function
+# or a schema() for functions accepting more than one argument. C++ expects
+# a schema().
+in_type_as_schema <- function(x) {
+ if (inherits(x, "Field")) {
+ schema(x)
+ } else if (inherits(x, "DataType")) {
+ schema(field("", x))
+ } else {
+ as_schema(x)
+ }
+}
+
+# This function sanitizes the out_type argument for arrow_scalar_function(),
+# which can be a data type (e.g., int32()) or a function of the input types.
+# C++ currently expects a function.
+out_type_as_function <- function(x) {
+ if (is.function(x)) {
+ x
+ } else {
+ x <- as_data_type(x)
+ function(types) x
+ }
+}
diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R
index 7f10ed307e..3e83475a8c 100644
--- a/r/R/dplyr-collect.R
+++ b/r/R/dplyr-collect.R
@@ -20,7 +20,7 @@
collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) {
tryCatch(
- out <- as_record_batch_reader(x)$read_table(),
+ out <- as_arrow_table(x),
# n = 4 because we want the error to show up as being from collect()
# and not handle_csv_read_error()
error = function(e, call = caller_env(n = 4)) {
diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R
index 7c4ed99e2e..c1dcdd1774 100644
--- a/r/R/dplyr-funcs.R
+++ b/r/R/dplyr-funcs.R
@@ -50,6 +50,13 @@ NULL
#' - `fun`: string function name
#' - `data`: `Expression` (these are all currently a single field)
#' - `options`: list of function options, as passed to call_function
+#' @param update_cache Update .cache$functions at the time of registration.
+#' the default is FALSE because the majority of usage is to register
+#' bindings at package load, after which we create the cache once. The
+#' reason why .cache$functions is needed in addition to nse_funcs for
+#' non-aggregate functions could be revisited...it is currently used
+#' as the data mask in mutate, filter, and aggregate (but not
+#' summarise) because the data mask has to be a list.
#' @param registry An environment in which the functions should be
#' assigned.
#'
@@ -57,13 +64,14 @@ NULL
#' registered function existed.
#' @keywords internal
#'
-register_binding <- function(fun_name, fun, registry = nse_funcs) {
+register_binding <- function(fun_name, fun, registry = nse_funcs,
+ update_cache = FALSE) {
unqualified_name <- sub("^.*?:{+}", "", fun_name)
previous_fun <- registry[[unqualified_name]]
# if the unqualified name exists in the registry, warn
- if (!is.null(fun) && !is.null(previous_fun)) {
+ if (!is.null(previous_fun)) {
warn(
paste0(
"A \"",
@@ -73,11 +81,36 @@ register_binding <- function(fun_name, fun, registry =
nse_funcs) {
}
# register both as `pkg::fun` and as `fun` if `qualified_name` is prefixed
- if (grepl("::", fun_name)) {
- registry[[unqualified_name]] <- fun
- registry[[fun_name]] <- fun
- } else {
- registry[[unqualified_name]] <- fun
+ # unqualified_name and fun_name will be the same if not prefixed
+ registry[[unqualified_name]] <- fun
+ registry[[fun_name]] <- fun
+
+ if (update_cache) {
+ fun_cache <- .cache$functions
+ fun_cache[[unqualified_name]] <- fun
+ fun_cache[[fun_name]] <- fun
+ .cache$functions <- fun_cache
+ }
+
+ invisible(previous_fun)
+}
+
+unregister_binding <- function(fun_name, registry = nse_funcs,
+ update_cache = FALSE) {
+ unqualified_name <- sub("^.*?:{+}", "", fun_name)
+ previous_fun <- registry[[unqualified_name]]
+
+ rm(
+ list = unique(c(fun_name, unqualified_name)),
+ envir = registry,
+ inherits = FALSE
+ )
+
+ if (update_cache) {
+ fun_cache <- .cache$functions
+ fun_cache[[unqualified_name]] <- NULL
+ fun_cache[[fun_name]] <- NULL
+ .cache$functions <- fun_cache
}
invisible(previous_fun)
diff --git a/r/R/feather.R b/r/R/feather.R
index 46863c98a1..03c8a7b5f0 100644
--- a/r/R/feather.R
+++ b/r/R/feather.R
@@ -222,7 +222,7 @@ FeatherReader <- R6Class("FeatherReader",
inherit = ArrowObject,
public = list(
Read = function(columns) {
- ipc___feather___Reader__Read(self, columns, on_old_windows())
+ ipc___feather___Reader__Read(self, columns)
},
print = function(...) {
cat("FeatherReader:\n")
@@ -243,5 +243,5 @@ names.FeatherReader <- function(x) x$column_names
FeatherReader$create <- function(file) {
assert_is(file, "RandomAccessFile")
- ipc___feather___Reader__Open(file, on_old_windows())
+ ipc___feather___Reader__Open(file)
}
diff --git a/r/R/query-engine.R b/r/R/query-engine.R
index 511bf3dbc2..e63fa75ebf 100644
--- a/r/R/query-engine.R
+++ b/r/R/query-engine.R
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# nolint start: cyclocomp_linter,
ExecPlan <- R6Class("ExecPlan",
inherit = ArrowObject,
public = list(
@@ -191,7 +192,7 @@ ExecPlan <- R6Class("ExecPlan",
}
node
},
- Run = function(node) {
+ Run = function(node, as_table = FALSE) {
assert_is(node, "ExecNode")
# Sorting and head/tail (if sorted) are handled in the SinkNode,
@@ -209,7 +210,14 @@ ExecPlan <- R6Class("ExecPlan",
sorting$orders <- as.integer(sorting$orders)
}
- out <- ExecPlan_run(
+ # If we are going to return a Table anyway, we do this in one step and
+ # entirely in one C++ call to ensure that we can execute user-defined
+ # functions from the worker threads spawned by the ExecPlan. If not, we
+ # use ExecPlan_run which returns a RecordBatchReader that can be
+ # manipulated in R code (but that right now won't work with
+ # user-defined functions).
+ exec_fun <- if (as_table) ExecPlan_read_table else ExecPlan_run
+ out <- exec_fun(
self,
node,
sorting,
@@ -232,10 +240,12 @@ ExecPlan <- R6Class("ExecPlan",
} else if (!is.null(node$extras$tail)) {
# TODO(ARROW-16630): proper BottomK support
# Reverse the row order to get back what we expect
- out <- out$read_table()
+ out <- as_arrow_table(out)
out <- out[rev(seq_len(nrow(out))), , drop = FALSE]
# Put back into RBR
- out <- as_record_batch_reader(out)
+ if (!as_table) {
+ out <- as_record_batch_reader(out)
+ }
}
# If arrange() created $temp_columns, make sure to omit them from the
result
@@ -243,9 +253,13 @@ ExecPlan <- R6Class("ExecPlan",
# happens in the end (SinkNode) so nothing comes after it.
# TODO(ARROW-16631): move into ExecPlan
if (length(node$extras$sort$temp_columns) > 0) {
- tab <- out$read_table()
+ tab <- as_arrow_table(out)
tab <- tab[, setdiff(names(tab), node$extras$sort$temp_columns), drop
= FALSE]
- out <- as_record_batch_reader(tab)
+ if (!as_table) {
+ out <- as_record_batch_reader(tab)
+ } else {
+ out <- tab
+ }
}
out
@@ -262,6 +276,8 @@ ExecPlan <- R6Class("ExecPlan",
Stop = function() ExecPlan_StopProducing(self)
)
)
+# nolint end.
+
ExecPlan$create <- function(use_threads = option_use_threads()) {
ExecPlan_create(use_threads)
}
diff --git a/r/R/table.R b/r/R/table.R
index 305f305129..5579c676d5 100644
--- a/r/R/table.R
+++ b/r/R/table.R
@@ -318,3 +318,18 @@ as_arrow_table.RecordBatch <- function(x, ..., schema =
NULL) {
as_arrow_table.data.frame <- function(x, ..., schema = NULL) {
Table$create(x, schema = schema)
}
+
+#' @rdname as_arrow_table
+#' @export
+as_arrow_table.RecordBatchReader <- function(x, ...) {
+ x$read_table()
+}
+
+#' @rdname as_arrow_table
+#' @export
+as_arrow_table.arrow_dplyr_query <- function(x, ...) {
+ # See query-engine.R for ExecPlan/Nodes
+ plan <- ExecPlan$create()
+ final_node <- plan$Build(x)
+ plan$Run(final_node, as_table = TRUE)
+}
diff --git a/r/_pkgdown.yml b/r/_pkgdown.yml
index c0f599fb8a..b04cab8195 100644
--- a/r/_pkgdown.yml
+++ b/r/_pkgdown.yml
@@ -219,6 +219,7 @@ reference:
- match_arrow
- value_counts
- list_compute_functions
+ - register_scalar_function
- title: Connections to other systems
contents:
- to_arrow
diff --git a/r/man/as_arrow_table.Rd b/r/man/as_arrow_table.Rd
index 0ba563f581..aac4495e7c 100644
--- a/r/man/as_arrow_table.Rd
+++ b/r/man/as_arrow_table.Rd
@@ -6,6 +6,8 @@
\alias{as_arrow_table.Table}
\alias{as_arrow_table.RecordBatch}
\alias{as_arrow_table.data.frame}
+\alias{as_arrow_table.RecordBatchReader}
+\alias{as_arrow_table.arrow_dplyr_query}
\title{Convert an object to an Arrow Table}
\usage{
as_arrow_table(x, ..., schema = NULL)
@@ -17,6 +19,10 @@ as_arrow_table(x, ..., schema = NULL)
\method{as_arrow_table}{RecordBatch}(x, ..., schema = NULL)
\method{as_arrow_table}{data.frame}(x, ..., schema = NULL)
+
+\method{as_arrow_table}{RecordBatchReader}(x, ...)
+
+\method{as_arrow_table}{arrow_dplyr_query}(x, ...)
}
\arguments{
\item{x}{An object to convert to an Arrow Table}
diff --git a/r/man/register_binding.Rd b/r/man/register_binding.Rd
index e776e7b3f5..c53df70751 100644
--- a/r/man/register_binding.Rd
+++ b/r/man/register_binding.Rd
@@ -4,7 +4,7 @@
\alias{register_binding}
\title{Register compute bindings}
\usage{
-register_binding(fun_name, fun, registry = nse_funcs)
+register_binding(fun_name, fun, registry = nse_funcs, update_cache = FALSE)
}
\arguments{
\item{fun_name}{A string containing a function name in the form
\code{"function"} or
@@ -18,6 +18,14 @@ This function must accept \code{Expression} objects as
arguments and return
\item{registry}{An environment in which the functions should be
assigned.}
+\item{update_cache}{Update .cache$functions at the time of registration.
+the default is FALSE because the majority of usage is to register
+bindings at package load, after which we create the cache once. The
+reason why .cache$functions is needed in addition to nse_funcs for
+non-aggregate functions could be revisited...it is currently used
+as the data mask in mutate, filter, and aggregate (but not
+summarise) because the data mask has to be a list.}
+
\item{agg_fun}{An aggregate function or \code{NULL} to un-register a previous
aggregate function. This function must accept \code{Expression} objects as
arguments and return a \code{list()} with components:
diff --git a/r/man/register_scalar_function.Rd
b/r/man/register_scalar_function.Rd
new file mode 100644
index 0000000000..4da8f54f64
--- /dev/null
+++ b/r/man/register_scalar_function.Rd
@@ -0,0 +1,70 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/compute.R
+\name{register_scalar_function}
+\alias{register_scalar_function}
+\title{Register user-defined functions}
+\usage{
+register_scalar_function(name, fun, in_type, out_type, auto_convert = FALSE)
+}
+\arguments{
+\item{name}{The function name to be used in the dplyr bindings}
+
+\item{fun}{An R function or rlang-style lambda expression. The function
+will be called with a first argument \code{context} which is a \code{list()}
+with elements \code{batch_size} (the expected length of the output) and
+\code{output_type} (the required \link{DataType} of the output) that may be
used
+to ensure that the output has the correct type and length. Subsequent
+arguments are passed by position as specified by \code{in_types}. If
+\code{auto_convert} is \code{TRUE}, subsequent arguments are converted to
+R vectors before being passed to \code{fun} and the output is automatically
+constructed with the expected output type via
\code{\link[=as_arrow_array]{as_arrow_array()}}.}
+
+\item{in_type}{A \link{DataType} of the input type or a
\code{\link[=schema]{schema()}}
+for functions with more than one argument. This signature will be used
+to determine if this function is appropriate for a given set of arguments.
+If this function is appropriate for more than one signature, pass a
+\code{list()} of the above.}
+
+\item{out_type}{A \link{DataType} of the output type or a function accepting
+a single argument (\code{types}), which is a \code{list()} of
\link{DataType}s. If a
+function it must return a \link{DataType}.}
+
+\item{auto_convert}{Use \code{TRUE} to convert inputs before passing to
\code{fun}
+and construct an Array of the correct type from the output. Use this
+option to write functions of R objects as opposed to functions of
+Arrow R6 objects.}
+}
+\value{
+\code{NULL}, invisibly
+}
+\description{
+These functions support calling R code from query engine execution
+(i.e., a \code{\link[dplyr:mutate]{dplyr::mutate()}} or
\code{\link[dplyr:filter]{dplyr::filter()}} on a \link{Table} or
\link{Dataset}).
+Use \code{\link[=register_scalar_function]{register_scalar_function()}} attach
Arrow input and output types to an
+R function and make it available for use in the dplyr interface and/or
+\code{\link[=call_function]{call_function()}}. Scalar functions are currently
the only type of
+user-defined function supported. In Arrow, scalar functions must be
+stateless and return output with the same shape (i.e., the same number
+of rows) as the input.
+}
+\examples{
+\dontshow{if (arrow_with_dataset()) (if (getRversion() >= "3.4") withAutoprint
else force)(\{ # examplesIf}
+library(dplyr, warn.conflicts = FALSE)
+
+some_model <- lm(mpg ~ disp + cyl, data = mtcars)
+register_scalar_function(
+ "mtcars_predict_mpg",
+ function(context, disp, cyl) {
+ predict(some_model, newdata = data.frame(disp, cyl))
+ },
+ in_type = schema(disp = float64(), cyl = float64()),
+ out_type = float64(),
+ auto_convert = TRUE
+)
+
+as_arrow_table(mtcars) \%>\%
+ transmute(mpg, mpg_predicted = mtcars_predict_mpg(disp, cyl)) \%>\%
+ collect() \%>\%
+ head()
+\dontshow{\}) # examplesIf}
+}
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index e89718144a..fd9f92e5d1 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -881,6 +881,18 @@ BEGIN_CPP11
END_CPP11
}
// compute-exec.cpp
+std::shared_ptr<arrow::Table> ExecPlan_read_table(const
std::shared_ptr<compute::ExecPlan>& plan, const
std::shared_ptr<compute::ExecNode>& final_node, cpp11::list sort_options,
cpp11::strings metadata, int64_t head);
+extern "C" SEXP _arrow_ExecPlan_read_table(SEXP plan_sexp, SEXP
final_node_sexp, SEXP sort_options_sexp, SEXP metadata_sexp, SEXP head_sexp){
+BEGIN_CPP11
+ arrow::r::Input<const std::shared_ptr<compute::ExecPlan>&>::type
plan(plan_sexp);
+ arrow::r::Input<const std::shared_ptr<compute::ExecNode>&>::type
final_node(final_node_sexp);
+ arrow::r::Input<cpp11::list>::type sort_options(sort_options_sexp);
+ arrow::r::Input<cpp11::strings>::type metadata(metadata_sexp);
+ arrow::r::Input<int64_t>::type head(head_sexp);
+ return cpp11::as_sexp(ExecPlan_read_table(plan, final_node,
sort_options, metadata, head));
+END_CPP11
+}
+// compute-exec.cpp
void ExecPlan_StopProducing(const std::shared_ptr<compute::ExecPlan>& plan);
extern "C" SEXP _arrow_ExecPlan_StopProducing(SEXP plan_sexp){
BEGIN_CPP11
@@ -1099,6 +1111,16 @@ BEGIN_CPP11
return cpp11::as_sexp(compute__GetFunctionNames());
END_CPP11
}
+// compute.cpp
+void RegisterScalarUDF(std::string name, cpp11::list func_sexp);
+extern "C" SEXP _arrow_RegisterScalarUDF(SEXP name_sexp, SEXP func_sexp_sexp){
+BEGIN_CPP11
+ arrow::r::Input<std::string>::type name(name_sexp);
+ arrow::r::Input<cpp11::list>::type func_sexp(func_sexp_sexp);
+ RegisterScalarUDF(name, func_sexp);
+ return R_NilValue;
+END_CPP11
+}
// config.cpp
std::vector<std::string> build_info();
extern "C" SEXP _arrow_build_info(){
@@ -2788,22 +2810,20 @@ BEGIN_CPP11
END_CPP11
}
// feather.cpp
-std::shared_ptr<arrow::Table> ipc___feather___Reader__Read(const
std::shared_ptr<arrow::ipc::feather::Reader>& reader, cpp11::sexp columns, bool
on_old_windows);
-extern "C" SEXP _arrow_ipc___feather___Reader__Read(SEXP reader_sexp, SEXP
columns_sexp, SEXP on_old_windows_sexp){
+std::shared_ptr<arrow::Table> ipc___feather___Reader__Read(const
std::shared_ptr<arrow::ipc::feather::Reader>& reader, cpp11::sexp columns);
+extern "C" SEXP _arrow_ipc___feather___Reader__Read(SEXP reader_sexp, SEXP
columns_sexp){
BEGIN_CPP11
arrow::r::Input<const
std::shared_ptr<arrow::ipc::feather::Reader>&>::type reader(reader_sexp);
arrow::r::Input<cpp11::sexp>::type columns(columns_sexp);
- arrow::r::Input<bool>::type on_old_windows(on_old_windows_sexp);
- return cpp11::as_sexp(ipc___feather___Reader__Read(reader, columns,
on_old_windows));
+ return cpp11::as_sexp(ipc___feather___Reader__Read(reader, columns));
END_CPP11
}
// feather.cpp
-std::shared_ptr<arrow::ipc::feather::Reader>
ipc___feather___Reader__Open(const
std::shared_ptr<arrow::io::RandomAccessFile>& stream, bool on_old_windows);
-extern "C" SEXP _arrow_ipc___feather___Reader__Open(SEXP stream_sexp, SEXP
on_old_windows_sexp){
+std::shared_ptr<arrow::ipc::feather::Reader>
ipc___feather___Reader__Open(const
std::shared_ptr<arrow::io::RandomAccessFile>& stream);
+extern "C" SEXP _arrow_ipc___feather___Reader__Open(SEXP stream_sexp){
BEGIN_CPP11
arrow::r::Input<const
std::shared_ptr<arrow::io::RandomAccessFile>&>::type stream(stream_sexp);
- arrow::r::Input<bool>::type on_old_windows(on_old_windows_sexp);
- return cpp11::as_sexp(ipc___feather___Reader__Open(stream,
on_old_windows));
+ return cpp11::as_sexp(ipc___feather___Reader__Open(stream));
END_CPP11
}
// feather.cpp
@@ -4601,6 +4621,13 @@ BEGIN_CPP11
END_CPP11
}
// safe-call-into-r-impl.cpp
+bool CanRunWithCapturedR();
+extern "C" SEXP _arrow_CanRunWithCapturedR(){
+BEGIN_CPP11
+ return cpp11::as_sexp(CanRunWithCapturedR());
+END_CPP11
+}
+// safe-call-into-r-impl.cpp
std::string TestSafeCallIntoR(cpp11::function r_fun_that_returns_a_string,
std::string opt);
extern "C" SEXP _arrow_TestSafeCallIntoR(SEXP
r_fun_that_returns_a_string_sexp, SEXP opt_sexp){
BEGIN_CPP11
@@ -5240,6 +5267,7 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_io___CompressedInputStream__Make", (DL_FUNC)
&_arrow_io___CompressedInputStream__Make, 2},
{ "_arrow_ExecPlan_create", (DL_FUNC) &_arrow_ExecPlan_create,
1},
{ "_arrow_ExecPlan_run", (DL_FUNC) &_arrow_ExecPlan_run, 5},
+ { "_arrow_ExecPlan_read_table", (DL_FUNC)
&_arrow_ExecPlan_read_table, 5},
{ "_arrow_ExecPlan_StopProducing", (DL_FUNC)
&_arrow_ExecPlan_StopProducing, 1},
{ "_arrow_ExecNode_output_schema", (DL_FUNC)
&_arrow_ExecNode_output_schema, 1},
{ "_arrow_ExecNode_Scan", (DL_FUNC) &_arrow_ExecNode_Scan, 4},
@@ -5258,6 +5286,7 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_Table__cast", (DL_FUNC) &_arrow_Table__cast, 3},
{ "_arrow_compute__CallFunction", (DL_FUNC)
&_arrow_compute__CallFunction, 3},
{ "_arrow_compute__GetFunctionNames", (DL_FUNC)
&_arrow_compute__GetFunctionNames, 0},
+ { "_arrow_RegisterScalarUDF", (DL_FUNC)
&_arrow_RegisterScalarUDF, 2},
{ "_arrow_build_info", (DL_FUNC) &_arrow_build_info, 0},
{ "_arrow_runtime_info", (DL_FUNC) &_arrow_runtime_info, 0},
{ "_arrow_set_timezone_database", (DL_FUNC)
&_arrow_set_timezone_database, 1},
@@ -5415,8 +5444,8 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_arrow__UnregisterRExtensionType", (DL_FUNC)
&_arrow_arrow__UnregisterRExtensionType, 1},
{ "_arrow_ipc___WriteFeather__Table", (DL_FUNC)
&_arrow_ipc___WriteFeather__Table, 6},
{ "_arrow_ipc___feather___Reader__version", (DL_FUNC)
&_arrow_ipc___feather___Reader__version, 1},
- { "_arrow_ipc___feather___Reader__Read", (DL_FUNC)
&_arrow_ipc___feather___Reader__Read, 3},
- { "_arrow_ipc___feather___Reader__Open", (DL_FUNC)
&_arrow_ipc___feather___Reader__Open, 2},
+ { "_arrow_ipc___feather___Reader__Read", (DL_FUNC)
&_arrow_ipc___feather___Reader__Read, 2},
+ { "_arrow_ipc___feather___Reader__Open", (DL_FUNC)
&_arrow_ipc___feather___Reader__Open, 1},
{ "_arrow_ipc___feather___Reader__schema", (DL_FUNC)
&_arrow_ipc___feather___Reader__schema, 1},
{ "_arrow_Field__initialize", (DL_FUNC)
&_arrow_Field__initialize, 3},
{ "_arrow_Field__ToString", (DL_FUNC) &_arrow_Field__ToString,
1},
@@ -5586,6 +5615,7 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_ipc___RecordBatchFileWriter__Open", (DL_FUNC)
&_arrow_ipc___RecordBatchFileWriter__Open, 4},
{ "_arrow_ipc___RecordBatchStreamWriter__Open", (DL_FUNC)
&_arrow_ipc___RecordBatchStreamWriter__Open, 4},
{ "_arrow_InitializeMainRThread", (DL_FUNC)
&_arrow_InitializeMainRThread, 0},
+ { "_arrow_CanRunWithCapturedR", (DL_FUNC)
&_arrow_CanRunWithCapturedR, 0},
{ "_arrow_TestSafeCallIntoR", (DL_FUNC)
&_arrow_TestSafeCallIntoR, 2},
{ "_arrow_Array__GetScalar", (DL_FUNC)
&_arrow_Array__GetScalar, 2},
{ "_arrow_Scalar__ToString", (DL_FUNC)
&_arrow_Scalar__ToString, 1},
diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp
index 76112b4cef..e348675fc1 100644
--- a/r/src/compute-exec.cpp
+++ b/r/src/compute-exec.cpp
@@ -16,6 +16,7 @@
// under the License.
#include "./arrow_types.h"
+#include "./safe-call-into-r.h"
#include <arrow/compute/api.h>
#include <arrow/compute/exec/exec_plan.h>
@@ -55,11 +56,10 @@ std::shared_ptr<compute::ExecNode> MakeExecNodeOrStop(
});
}
-// [[arrow::export]]
-std::shared_ptr<arrow::RecordBatchReader> ExecPlan_run(
- const std::shared_ptr<compute::ExecPlan>& plan,
- const std::shared_ptr<compute::ExecNode>& final_node, cpp11::list
sort_options,
- cpp11::strings metadata, int64_t head = -1) {
+std::pair<std::shared_ptr<compute::ExecPlan>,
std::shared_ptr<arrow::RecordBatchReader>>
+ExecPlan_prepare(const std::shared_ptr<compute::ExecPlan>& plan,
+ const std::shared_ptr<compute::ExecNode>& final_node,
+ cpp11::list sort_options, cpp11::strings metadata, int64_t
head = -1) {
// For now, don't require R to construct SinkNodes.
// Instead, just pass the node we should collect as an argument.
arrow::AsyncGenerator<arrow::util::optional<compute::ExecBatch>> sink_gen;
@@ -89,7 +89,6 @@ std::shared_ptr<arrow::RecordBatchReader> ExecPlan_run(
}
StopIfNotOk(plan->Validate());
- StopIfNotOk(plan->StartProducing());
// If the generator is destroyed before being completely drained, inform plan
std::shared_ptr<void> stop_producing{nullptr, [plan](...) {
@@ -109,9 +108,40 @@ std::shared_ptr<arrow::RecordBatchReader> ExecPlan_run(
auto kv = strings_to_kvm(metadata);
out_schema = out_schema->WithMetadata(kv);
}
- return compute::MakeGeneratorReader(
+
+ std::pair<std::shared_ptr<compute::ExecPlan>,
std::shared_ptr<arrow::RecordBatchReader>>
+ out;
+ out.first = plan;
+ out.second = compute::MakeGeneratorReader(
out_schema, [stop_producing, plan, sink_gen] { return sink_gen(); },
gc_memory_pool());
+ return out;
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatchReader> ExecPlan_run(
+ const std::shared_ptr<compute::ExecPlan>& plan,
+ const std::shared_ptr<compute::ExecNode>& final_node, cpp11::list
sort_options,
+ cpp11::strings metadata, int64_t head = -1) {
+ auto prepared_plan = ExecPlan_prepare(plan, final_node, sort_options,
metadata, head);
+ StopIfNotOk(prepared_plan.first->StartProducing());
+ return prepared_plan.second;
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::Table> ExecPlan_read_table(
+ const std::shared_ptr<compute::ExecPlan>& plan,
+ const std::shared_ptr<compute::ExecNode>& final_node, cpp11::list
sort_options,
+ cpp11::strings metadata, int64_t head = -1) {
+ auto prepared_plan = ExecPlan_prepare(plan, final_node, sort_options,
metadata, head);
+
+ auto result = RunWithCapturedRIfPossible<std::shared_ptr<arrow::Table>>(
+ [&]() -> arrow::Result<std::shared_ptr<arrow::Table>> {
+ ARROW_RETURN_NOT_OK(prepared_plan.first->StartProducing());
+ return prepared_plan.second->ToTable();
+ });
+
+ return ValueOrStop(result);
}
// [[arrow::export]]
@@ -196,8 +226,14 @@ void ExecPlan_Write(
ds::WriteNodeOptions{std::move(opts), std::move(kv)});
StopIfNotOk(plan->Validate());
- StopIfNotOk(plan->StartProducing());
- StopIfNotOk(plan->finished().status());
+
+ arrow::Status result = RunWithCapturedRIfPossibleVoid([&]() {
+ RETURN_NOT_OK(plan->StartProducing());
+ RETURN_NOT_OK(plan->finished().status());
+ return arrow::Status::OK();
+ });
+
+ StopIfNotOk(result);
}
#endif
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index 885af3f7ab..1ed949e729 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -16,7 +16,9 @@
// under the License.
#include "./arrow_types.h"
+#include "./safe-call-into-r.h"
+#include <arrow/array/util.h>
#include <arrow/compute/api.h>
#include <arrow/record_batch.h>
#include <arrow/table.h>
@@ -603,3 +605,168 @@ SEXP compute__CallFunction(std::string func_name,
cpp11::list args, cpp11::list
std::vector<std::string> compute__GetFunctionNames() {
return arrow::compute::GetFunctionRegistry()->GetFunctionNames();
}
+
+class RScalarUDFKernelState : public arrow::compute::KernelState {
+ public:
+ RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver)
+ : exec_func_(exec_func), resolver_(resolver) {}
+
+ cpp11::function exec_func_;
+ cpp11::function resolver_;
+};
+
+arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
+ arrow::compute::KernelContext* context,
+ const std::vector<arrow::TypeHolder>& input_types) {
+ return SafeCallIntoR<arrow::TypeHolder>(
+ [&]() -> arrow::TypeHolder {
+ auto kernel =
+ reinterpret_cast<const
arrow::compute::ScalarKernel*>(context->kernel());
+ auto state =
std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+ cpp11::writable::list input_types_sexp(input_types.size());
+ for (size_t i = 0; i < input_types.size(); i++) {
+ input_types_sexp[i] =
+ cpp11::to_r6<arrow::DataType>(input_types[i].GetSharedPtr());
+ }
+
+ cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp);
+ if (!Rf_inherits(output_type_sexp, "DataType")) {
+ cpp11::stop(
+ "Function specified as arrow_scalar_function() out_type argument
must "
+ "return a DataType");
+ }
+
+ return arrow::TypeHolder(
+ cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(output_type_sexp));
+ },
+ "resolve scalar user-defined function output data type");
+}
+
+arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context,
+ const arrow::compute::ExecSpan& span,
+ arrow::compute::ExecResult* result) {
+ if (result->is_array_span()) {
+ return arrow::Status::NotImplemented("ArraySpan result from R scalar UDF");
+ }
+
+ return SafeCallIntoRVoid(
+ [&]() {
+ auto kernel =
+ reinterpret_cast<const
arrow::compute::ScalarKernel*>(context->kernel());
+ auto state =
std::dynamic_pointer_cast<RScalarUDFKernelState>(kernel->data);
+
+ cpp11::writable::list args_sexp(span.num_values());
+
+ for (int i = 0; i < span.num_values(); i++) {
+ const arrow::compute::ExecValue& exec_val = span[i];
+ if (exec_val.is_array()) {
+ args_sexp[i] =
cpp11::to_r6<arrow::Array>(exec_val.array.ToArray());
+ } else if (exec_val.is_scalar()) {
+ args_sexp[i] =
cpp11::to_r6<arrow::Scalar>(exec_val.scalar->GetSharedPtr());
+ }
+ }
+
+ cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length);
+
+ std::shared_ptr<arrow::DataType> output_type =
result->type()->GetSharedPtr();
+ cpp11::sexp output_type_sexp =
cpp11::to_r6<arrow::DataType>(output_type);
+ cpp11::writable::list udf_context = {batch_length_sexp,
output_type_sexp};
+ udf_context.names() = {"batch_length", "output_type"};
+
+ cpp11::sexp func_result_sexp = state->exec_func_(udf_context,
args_sexp);
+
+ if (Rf_inherits(func_result_sexp, "Array")) {
+ auto array =
cpp11::as_cpp<std::shared_ptr<arrow::Array>>(func_result_sexp);
+
+ // Error for an Array result of the wrong type
+ if (!result->type()->Equals(array->type())) {
+ return cpp11::stop(
+ "Expected return Array or Scalar with type '%s' from
user-defined "
+ "function but got Array with type '%s'",
+ result->type()->ToString().c_str(),
array->type()->ToString().c_str());
+ }
+
+ result->value = std::move(array->data());
+ } else if (Rf_inherits(func_result_sexp, "Scalar")) {
+ auto scalar =
cpp11::as_cpp<std::shared_ptr<arrow::Scalar>>(func_result_sexp);
+
+ // handle a Scalar result of the wrong type
+ if (!result->type()->Equals(scalar->type)) {
+ return cpp11::stop(
+ "Expected return Array or Scalar with type '%s' from
user-defined "
+ "function but got Scalar with type '%s'",
+ result->type()->ToString().c_str(),
scalar->type->ToString().c_str());
+ }
+
+ auto array = ValueOrStop(
+ arrow::MakeArrayFromScalar(*scalar, span.length,
context->memory_pool()));
+ result->value = std::move(array->data());
+ } else {
+ cpp11::stop("arrow_scalar_function must return an Array or Scalar");
+ }
+ },
+ "execute scalar user-defined function");
+}
+
+// [[arrow::export]]
+void RegisterScalarUDF(std::string name, cpp11::list func_sexp) {
+ cpp11::list in_type_r(func_sexp["in_type"]);
+ cpp11::list out_type_r(func_sexp["out_type"]);
+ R_xlen_t n_kernels = in_type_r.size();
+
+ if (n_kernels == 0) {
+ cpp11::stop("Can't register user-defined function with zero kernels");
+ }
+
+ // Compute the Arity from the list of input kernels. We don't currently
handle
+ // variable numbers of arguments in a user-defined function.
+ int64_t n_args =
+
cpp11::as_cpp<std::shared_ptr<arrow::Schema>>(in_type_r[0])->num_fields();
+ for (R_xlen_t i = 1; i < n_kernels; i++) {
+ auto in_types =
cpp11::as_cpp<std::shared_ptr<arrow::Schema>>(in_type_r[i]);
+ if (in_types->num_fields() != n_args) {
+ cpp11::stop(
+ "Kernels for user-defined function must accept the same number of
arguments");
+ }
+ }
+
+ arrow::compute::Arity arity(n_args, false);
+
+ // The function documentation isn't currently accessible from R but is
required
+ // for the C++ function constructor.
+ std::vector<std::string> dummy_argument_names(n_args);
+ for (int64_t i = 0; i < n_args; i++) {
+ dummy_argument_names[i] = "arg";
+ }
+ const arrow::compute::FunctionDoc dummy_function_doc{
+ "A user-defined R function", "returns something",
std::move(dummy_argument_names)};
+
+ auto func =
+ std::make_shared<arrow::compute::ScalarFunction>(name, arity,
dummy_function_doc);
+
+ for (R_xlen_t i = 0; i < n_kernels; i++) {
+ auto in_types =
cpp11::as_cpp<std::shared_ptr<arrow::Schema>>(in_type_r[i]);
+ cpp11::sexp out_type_func = out_type_r[i];
+
+ std::vector<arrow::compute::InputType>
compute_in_types(in_types->num_fields());
+ for (int64_t j = 0; j < in_types->num_fields(); j++) {
+ compute_in_types[j] =
arrow::compute::InputType(in_types->field(j)->type());
+ }
+
+ arrow::compute::OutputType out_type((&ResolveScalarUDFOutputType));
+
+ auto signature = std::make_shared<arrow::compute::KernelSignature>(
+ std::move(compute_in_types), std::move(out_type), true);
+ arrow::compute::ScalarKernel kernel(signature, &CallRScalarUDF);
+ kernel.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE;
+ kernel.null_handling =
arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.data =
+ std::make_shared<RScalarUDFKernelState>(func_sexp["wrapper_fun"],
out_type_func);
+
+ StopIfNotOk(func->AddKernel(std::move(kernel)));
+ }
+
+ auto registry = arrow::compute::GetFunctionRegistry();
+ StopIfNotOk(registry->AddFunction(std::move(func), true));
+}
diff --git a/r/src/csv.cpp b/r/src/csv.cpp
index d031cc87ca..7ce55feb5f 100644
--- a/r/src/csv.cpp
+++ b/r/src/csv.cpp
@@ -162,16 +162,9 @@ std::shared_ptr<arrow::csv::TableReader>
csv___TableReader__Make(
// [[arrow::export]]
std::shared_ptr<arrow::Table> csv___TableReader__Read(
const std::shared_ptr<arrow::csv::TableReader>& table_reader) {
-#if !defined(HAS_SAFE_CALL_INTO_R)
- return ValueOrStop(table_reader->Read());
-#else
- const auto& io_context = arrow::io::default_io_context();
- auto result = RunWithCapturedR<std::shared_ptr<arrow::Table>>([&]() {
- return DeferNotOk(
- io_context.executor()->Submit([&]() { return table_reader->Read(); }));
- });
+ auto result = RunWithCapturedRIfPossible<std::shared_ptr<arrow::Table>>(
+ [&]() { return table_reader->Read(); });
return ValueOrStop(result);
-#endif
}
// [[arrow::export]]
diff --git a/r/src/extension-impl.cpp b/r/src/extension-impl.cpp
index efb9f0f467..e6efcf3647 100644
--- a/r/src/extension-impl.cpp
+++ b/r/src/extension-impl.cpp
@@ -38,18 +38,19 @@ bool RExtensionType::ExtensionEquals(const
arrow::ExtensionType& other) const {
// With any ambiguity, we need to materialize the R6 instance and call its
// ExtensionEquals method. We can't do this on the non-R thread.
- // After ARROW-15841, we can use SafeCallIntoR.
- arrow::Result<bool> result = SafeCallIntoR<bool>([&]() {
- cpp11::environment instance = r6_instance();
- cpp11::function instance_ExtensionEquals(instance["ExtensionEquals"]);
-
- std::shared_ptr<DataType> other_shared =
- ValueOrStop(other.Deserialize(other.storage_type(),
other.Serialize()));
- cpp11::sexp other_r6 = cpp11::to_r6<DataType>(other_shared,
"ExtensionType");
-
- cpp11::logicals result(instance_ExtensionEquals(other_r6));
- return cpp11::as_cpp<bool>(result);
- });
+ arrow::Result<bool> result = SafeCallIntoR<bool>(
+ [&]() {
+ cpp11::environment instance = r6_instance();
+ cpp11::function instance_ExtensionEquals(instance["ExtensionEquals"]);
+
+ std::shared_ptr<DataType> other_shared =
+ ValueOrStop(other.Deserialize(other.storage_type(),
other.Serialize()));
+ cpp11::sexp other_r6 = cpp11::to_r6<DataType>(other_shared,
"ExtensionType");
+
+ cpp11::logicals result(instance_ExtensionEquals(other_r6));
+ return cpp11::as_cpp<bool>(result);
+ },
+ "RExtensionType$ExtensionEquals()");
if (!result.ok()) {
throw std::runtime_error(result.status().message());
diff --git a/r/src/feather.cpp b/r/src/feather.cpp
index debabe4968..cf68faef1b 100644
--- a/r/src/feather.cpp
+++ b/r/src/feather.cpp
@@ -49,8 +49,7 @@ int ipc___feather___Reader__version(
// [[arrow::export]]
std::shared_ptr<arrow::Table> ipc___feather___Reader__Read(
- const std::shared_ptr<arrow::ipc::feather::Reader>& reader, cpp11::sexp
columns,
- bool on_old_windows) {
+ const std::shared_ptr<arrow::ipc::feather::Reader>& reader, cpp11::sexp
columns) {
bool use_names = columns != R_NilValue;
std::vector<std::string> names;
if (use_names) {
@@ -61,7 +60,7 @@ std::shared_ptr<arrow::Table> ipc___feather___Reader__Read(
}
}
- auto read_table = [&]() {
+ auto result =
RunWithCapturedRIfPossible<std::shared_ptr<arrow::Table>>([&]() {
std::shared_ptr<arrow::Table> table;
arrow::Status read_result;
if (use_names) {
@@ -75,39 +74,17 @@ std::shared_ptr<arrow::Table> ipc___feather___Reader__Read(
} else {
return arrow::Result<std::shared_ptr<arrow::Table>>(read_result);
}
- };
+ });
-#if !defined(HAS_SAFE_CALL_INTO_R)
- return ValueOrStop(read_table());
-#else
- if (!on_old_windows) {
- const auto& io_context = arrow::io::default_io_context();
- auto result = RunWithCapturedR<std::shared_ptr<arrow::Table>>(
- [&]() { return DeferNotOk(io_context.executor()->Submit(read_table));
});
- return ValueOrStop(result);
- } else {
- return ValueOrStop(read_table());
- }
-#endif
+ return ValueOrStop(result);
}
// [[arrow::export]]
std::shared_ptr<arrow::ipc::feather::Reader> ipc___feather___Reader__Open(
- const std::shared_ptr<arrow::io::RandomAccessFile>& stream, bool
on_old_windows) {
-#if !defined(HAS_SAFE_CALL_INTO_R)
- return ValueOrStop(arrow::ipc::feather::Reader::Open(stream));
-#else
- if (!on_old_windows) {
- const auto& io_context = arrow::io::default_io_context();
- auto result =
RunWithCapturedR<std::shared_ptr<arrow::ipc::feather::Reader>>([&]() {
- return DeferNotOk(io_context.executor()->Submit(
- [&]() { return arrow::ipc::feather::Reader::Open(stream); }));
- });
- return ValueOrStop(result);
- } else {
- return ValueOrStop(arrow::ipc::feather::Reader::Open(stream));
- }
-#endif
+ const std::shared_ptr<arrow::io::RandomAccessFile>& stream) {
+ auto result =
RunWithCapturedRIfPossible<std::shared_ptr<arrow::ipc::feather::Reader>>(
+ [&]() { return arrow::ipc::feather::Reader::Open(stream); });
+ return ValueOrStop(result);
}
// [[arrow::export]]
diff --git a/r/src/io.cpp b/r/src/io.cpp
index 42766ddd2f..321b1b17fe 100644
--- a/r/src/io.cpp
+++ b/r/src/io.cpp
@@ -223,8 +223,8 @@ class RConnectionFileInterface : public virtual
arrow::io::FileInterface {
closed_ = true;
- return SafeCallIntoRVoid(
- [&]() { cpp11::package("base")["close"](connection_sexp_); });
+ return SafeCallIntoRVoid([&]() {
cpp11::package("base")["close"](connection_sexp_); },
+ "close() on R connection");
}
arrow::Result<int64_t> Tell() const {
@@ -232,10 +232,12 @@ class RConnectionFileInterface : public virtual
arrow::io::FileInterface {
return arrow::Status::IOError("R connection is closed");
}
- return SafeCallIntoR<int64_t>([&]() {
- cpp11::sexp result = cpp11::package("base")["seek"](connection_sexp_);
- return cpp11::as_cpp<int64_t>(result);
- });
+ return SafeCallIntoR<int64_t>(
+ [&]() {
+ cpp11::sexp result =
cpp11::package("base")["seek"](connection_sexp_);
+ return cpp11::as_cpp<int64_t>(result);
+ },
+ "tell() on R connection");
}
bool closed() const { return closed_; }
@@ -251,17 +253,19 @@ class RConnectionFileInterface : public virtual
arrow::io::FileInterface {
return arrow::Status::IOError("R connection is closed");
}
- return SafeCallIntoR<int64_t>([&] {
- cpp11::function read_bin = cpp11::package("base")["readBin"];
- cpp11::writable::raws ptype((R_xlen_t)0);
- cpp11::integers n = cpp11::as_sexp<int>(nbytes);
+ return SafeCallIntoR<int64_t>(
+ [&] {
+ cpp11::function read_bin = cpp11::package("base")["readBin"];
+ cpp11::writable::raws ptype((R_xlen_t)0);
+ cpp11::integers n = cpp11::as_sexp<int>(nbytes);
- cpp11::sexp result = read_bin(connection_sexp_, ptype, n);
+ cpp11::sexp result = read_bin(connection_sexp_, ptype, n);
- int64_t result_size = cpp11::safe[Rf_xlength](result);
- memcpy(out, cpp11::safe[RAW](result), result_size);
- return result_size;
- });
+ int64_t result_size = cpp11::safe[Rf_xlength](result);
+ memcpy(out, cpp11::safe[RAW](result), result_size);
+ return result_size;
+ },
+ "readBin() on R connection");
}
arrow::Result<std::shared_ptr<arrow::Buffer>> ReadBase(int64_t nbytes) {
@@ -278,13 +282,15 @@ class RConnectionFileInterface : public virtual
arrow::io::FileInterface {
return arrow::Status::IOError("R connection is closed");
}
- return SafeCallIntoRVoid([&]() {
- cpp11::writable::raws data_raw(nbytes);
- memcpy(cpp11::safe[RAW](data_raw), data, nbytes);
-
- cpp11::function write_bin = cpp11::package("base")["writeBin"];
- write_bin(data_raw, connection_sexp_);
- });
+ return SafeCallIntoRVoid(
+ [&]() {
+ cpp11::writable::raws data_raw(nbytes);
+ memcpy(cpp11::safe[RAW](data_raw), data, nbytes);
+
+ cpp11::function write_bin = cpp11::package("base")["writeBin"];
+ write_bin(data_raw, connection_sexp_);
+ },
+ "writeBin() on R connection");
}
arrow::Status SeekBase(int64_t pos) {
@@ -292,9 +298,11 @@ class RConnectionFileInterface : public virtual
arrow::io::FileInterface {
return arrow::Status::IOError("R connection is closed");
}
- return SafeCallIntoRVoid([&]() {
- cpp11::package("base")["seek"](connection_sexp_,
cpp11::as_sexp<double>(pos));
- });
+ return SafeCallIntoRVoid(
+ [&]() {
+ cpp11::package("base")["seek"](connection_sexp_,
cpp11::as_sexp<double>(pos));
+ },
+ "seek() on R connection");
}
private:
@@ -305,10 +313,12 @@ class RConnectionFileInterface : public virtual
arrow::io::FileInterface {
return true;
}
- auto is_open_result = SafeCallIntoR<bool>([&]() {
- cpp11::sexp result = cpp11::package("base")["isOpen"](connection_sexp_);
- return cpp11::as_cpp<bool>(result);
- });
+ auto is_open_result = SafeCallIntoR<bool>(
+ [&]() {
+ cpp11::sexp result =
cpp11::package("base")["isOpen"](connection_sexp_);
+ return cpp11::as_cpp<bool>(result);
+ },
+ "isOpen() on R connection");
if (!is_open_result.ok()) {
closed_ = true;
diff --git a/r/src/safe-call-into-r-impl.cpp b/r/src/safe-call-into-r-impl.cpp
index 7c5e75b788..7318c81bb5 100644
--- a/r/src/safe-call-into-r-impl.cpp
+++ b/r/src/safe-call-into-r-impl.cpp
@@ -29,6 +29,21 @@ MainRThread& GetMainRThread() {
// [[arrow::export]]
void InitializeMainRThread() { GetMainRThread().Initialize(); }
+// [[arrow::export]]
+bool CanRunWithCapturedR() {
+#if defined(HAS_UNWIND_PROTECT)
+ static int on_old_windows = -1;
+ if (on_old_windows == -1) {
+ cpp11::function on_old_windows_fun =
cpp11::package("arrow")["on_old_windows"];
+ on_old_windows = on_old_windows_fun();
+ }
+
+ return !on_old_windows;
+#else
+ return false;
+#endif
+}
+
// [[arrow::export]]
std::string TestSafeCallIntoR(cpp11::function r_fun_that_returns_a_string,
std::string opt) {
diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h
index 0555628d7d..937163a05d 100644
--- a/r/src/safe-call-into-r.h
+++ b/r/src/safe-call-into-r.h
@@ -20,6 +20,7 @@
#include "./arrow_types.h"
+#include <arrow/io/interfaces.h>
#include <arrow/util/future.h>
#include <arrow/util/thread_pool.h>
@@ -27,11 +28,11 @@
#include <thread>
// Unwind protection was added in R 3.5 and some calls here use it
-// and crash R in older versions (ARROW-16201). We use this define
-// to make sure we don't crash on R 3.4 and lower.
-#if defined(HAS_UNWIND_PROTECT)
-#define HAS_SAFE_CALL_INTO_R
-#endif
+// and crash R in older versions (ARROW-16201). Crashes also occur
+// on 32-bit R builds on R 3.6 and lower. Implementation provided
+// in safe-call-into-r-impl.cpp so that we can skip some tests
+// when this feature is not provided.
+bool CanRunWithCapturedR();
// The MainRThread class keeps track of the thread on which it is safe
// to call the R API to facilitate its safe use (or erroring
@@ -48,7 +49,7 @@ class MainRThread {
void Initialize() {
thread_id_ = std::this_thread::get_id();
initialized_ = true;
- SetError(R_NilValue);
+ ResetError();
}
bool IsInitialized() { return initialized_; }
@@ -56,33 +57,34 @@ class MainRThread {
// Check if the current thread is the main R thread
bool IsMainThread() { return initialized_ && std::this_thread::get_id() ==
thread_id_; }
+ // Check if a SafeCallIntoR call is able to execute
+ bool CanExecuteSafeCallIntoR() { return IsMainThread() || executor_ !=
nullptr; }
+
// The Executor that is running on the main R thread, if it exists
arrow::internal::Executor*& Executor() { return executor_; }
- // Save an error token generated from a cpp11::unwind_exception
- // so that it can be properly handled after some cleanup code
- // has run (e.g., cancelling some futures or waiting for them
- // to finish).
- void SetError(cpp11::sexp token) { error_token_ = token; }
+ // Save an error (possibly with an error token generated from
+ // a cpp11::unwind_exception) so that it can be properly handled
+ // after some cleanup code has run (e.g., cancelling some futures
+ // or waiting for them to finish).
+ void SetError(arrow::Status status) { status_ = status; }
- void ResetError() { error_token_ = R_NilValue; }
+ void ResetError() { status_ = arrow::Status::OK(); }
// Check if there is a saved error
- bool HasError() { return error_token_ != R_NilValue; }
+ bool HasError() { return !status_.ok(); }
- // Throw a cpp11::unwind_exception() with the saved token if it exists
+ // Throw a cpp11::unwind_exception() if
void ClearError() {
- if (HasError()) {
- cpp11::unwind_exception e(error_token_);
- ResetError();
- throw e;
- }
+ arrow::Status maybe_error_status = status_;
+ ResetError();
+ arrow::StopIfNotOk(maybe_error_status);
}
private:
bool initialized_;
std::thread::id thread_id_;
- cpp11::sexp error_token_;
+ arrow::Status status_;
arrow::internal::Executor* executor_;
};
@@ -93,55 +95,76 @@ MainRThread& GetMainRThread();
// a SEXP (use cpp11::as_cpp<T> to convert it to a C++ type inside
// `fun`).
template <typename T>
-arrow::Future<T> SafeCallIntoRAsync(std::function<arrow::Result<T>(void)> fun)
{
+arrow::Future<T> SafeCallIntoRAsync(std::function<arrow::Result<T>(void)> fun,
+ std::string reason = "unspecified") {
MainRThread& main_r_thread = GetMainRThread();
if (main_r_thread.IsMainThread()) {
// If we're on the main thread, run the task immediately and let
// the cpp11::unwind_exception be thrown since it will be caught
// at the top level.
return fun();
- } else if (main_r_thread.Executor() != nullptr) {
+ } else if (main_r_thread.CanExecuteSafeCallIntoR()) {
// If we are not on the main thread and have an Executor,
// use it to run the task on the main R thread. We can't throw
// a cpp11::unwind_exception here, so we need to propagate it back
// to RunWithCapturedR through the MainRThread singleton.
- return DeferNotOk(main_r_thread.Executor()->Submit([fun]() {
+ return DeferNotOk(main_r_thread.Executor()->Submit([fun, reason]() {
+ // This occurs when some other R code that was previously scheduled to
run
+ // has errored, in which case we skip execution and let the original
+ // error surface.
if (GetMainRThread().HasError()) {
- return arrow::Result<T>(arrow::Status::UnknownError("R code execution
error"));
+ return arrow::Result<T>(
+ arrow::Status::Cancelled("Previous R code execution error (",
reason, ")"));
}
try {
return fun();
} catch (cpp11::unwind_exception& e) {
- GetMainRThread().SetError(e.token);
- return arrow::Result<T>(arrow::Status::UnknownError("R code execution
error"));
+ // Here we save the token and set the main R thread to an error state
+ GetMainRThread().SetError(arrow::StatusUnwindProtect(e.token));
+
+ // We also return an error although this should not surface because
+ // main_r_thread.ClearError() will get called before this value can be
+ // returned and will StopIfNotOk(). We don't save the error token here
+ // to ensure that it will only get thrown once.
+ return arrow::Result<T>(
+ arrow::Status::UnknownError("R code execution error (", reason,
")"));
}
}));
} else {
return arrow::Status::NotImplemented(
- "Call to R from a non-R thread without calling RunWithCapturedR");
+ "Call to R (", reason, ") from a non-R thread from an unsupported
context");
}
}
template <typename T>
-arrow::Result<T> SafeCallIntoR(std::function<T(void)> fun) {
- arrow::Future<T> future = SafeCallIntoRAsync<T>(std::move(fun));
+arrow::Result<T> SafeCallIntoR(std::function<T(void)> fun,
+ std::string reason = "unspecified") {
+ arrow::Future<T> future = SafeCallIntoRAsync<T>(std::move(fun), reason);
return future.result();
}
-static inline arrow::Status SafeCallIntoRVoid(std::function<void(void)> fun) {
- arrow::Future<bool> future = SafeCallIntoRAsync<bool>([&fun]() {
- fun();
- return true;
- });
+static inline arrow::Status SafeCallIntoRVoid(std::function<void(void)> fun,
+ std::string reason =
"unspecified") {
+ arrow::Future<bool> future = SafeCallIntoRAsync<bool>(
+ [&fun]() {
+ fun();
+ return true;
+ },
+ reason);
return future.status();
}
+// Performs an Arrow call (e.g., run an exec plan) in such a way that
background threads
+// can use SafeCallIntoR(). This version is useful for Arrow calls that already
+// return a Future<>.
template <typename T>
arrow::Result<T> RunWithCapturedR(std::function<arrow::Future<T>()>
make_arrow_call) {
-#if !defined(HAS_SAFE_CALL_INTO_R)
- return arrow::Status::NotImplemented("RunWithCapturedR() without
UnwindProtect");
-#else
+ if (!CanRunWithCapturedR()) {
+ return arrow::Status::NotImplemented(
+ "RunWithCapturedR() without UnwindProtect or on 32-bit Windows + R <=
3.6");
+ }
+
if (GetMainRThread().Executor() != nullptr) {
return arrow::Status::AlreadyExists("Attempt to use more than one R
Executor()");
}
@@ -158,7 +181,39 @@ arrow::Result<T>
RunWithCapturedR(std::function<arrow::Future<T>()> make_arrow_c
GetMainRThread().ClearError();
return result;
-#endif
+}
+
+// Performs an Arrow call (e.g., run an exec plan) in such a way that
background threads
+// can use SafeCallIntoR(). This version is useful for Arrow calls that do not
already
+// return a Future<>(). If it is not possible to use RunWithCapturedR() (i.e.,
+// CanRunWithCapturedR() returns false), this will run make_arrow_call on the
main
+// R thread (which will cause background threads that try to SafeCallIntoR() to
+// error).
+template <typename T>
+arrow::Result<T> RunWithCapturedRIfPossible(
+ std::function<arrow::Result<T>()> make_arrow_call) {
+ if (CanRunWithCapturedR()) {
+ // Note that the use of the io_context here is arbitrary (i.e. we could use
+ // any construct that launches a background thread).
+ const auto& io_context = arrow::io::default_io_context();
+ return RunWithCapturedR<T>([&]() {
+ return
DeferNotOk(io_context.executor()->Submit(std::move(make_arrow_call)));
+ });
+ } else {
+ return make_arrow_call();
+ }
+}
+
+// Like RunWithCapturedRIfPossible<>() but for arrow calls that don't return
+// a Result.
+static inline arrow::Status RunWithCapturedRIfPossibleVoid(
+ std::function<arrow::Status()> make_arrow_call) {
+ auto result = RunWithCapturedRIfPossible<bool>([&]() -> arrow::Result<bool> {
+ ARROW_RETURN_NOT_OK(make_arrow_call());
+ return true;
+ });
+ ARROW_RETURN_NOT_OK(result);
+ return arrow::Status::OK();
}
#endif
diff --git a/r/tests/testthat/_snaps/compute.md
b/r/tests/testthat/_snaps/compute.md
new file mode 100644
index 0000000000..89506a7fbc
--- /dev/null
+++ b/r/tests/testthat/_snaps/compute.md
@@ -0,0 +1,4 @@
+# arrow_scalar_function() works
+
+ fun is not a function
+
diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R
new file mode 100644
index 0000000000..946583ae00
--- /dev/null
+++ b/r/tests/testthat/test-compute.R
@@ -0,0 +1,305 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+test_that("list_compute_functions() works", {
+ expect_type(list_compute_functions(), "character")
+ expect_true(all(!grepl("^hash_", list_compute_functions())))
+})
+
+
+test_that("arrow_scalar_function() works", {
+ # check in/out type as schema/data type
+ fun <- arrow_scalar_function(
+ function(context, x) x$cast(int64()),
+ schema(x = int32()), int64()
+ )
+ expect_equal(fun$in_type[[1]], schema(x = int32()))
+ expect_equal(fun$out_type[[1]](), int64())
+
+ # check in/out type as data type/data type
+ fun <- arrow_scalar_function(
+ function(context, x) x$cast(int64()),
+ int32(), int64()
+ )
+ expect_equal(fun$in_type[[1]][[1]], field("", int32()))
+ expect_equal(fun$out_type[[1]](), int64())
+
+ # check in/out type as field/data type
+ fun <- arrow_scalar_function(
+ function(context, a_name) x$cast(int64()),
+ field("a_name", int32()),
+ int64()
+ )
+ expect_equal(fun$in_type[[1]], schema(a_name = int32()))
+ expect_equal(fun$out_type[[1]](), int64())
+
+ # check in/out type as lists
+ fun <- arrow_scalar_function(
+ function(context, x) x,
+ list(int32(), int64()),
+ list(int64(), int32()),
+ auto_convert = TRUE
+ )
+
+ expect_equal(fun$in_type[[1]][[1]], field("", int32()))
+ expect_equal(fun$in_type[[2]][[1]], field("", int64()))
+ expect_equal(fun$out_type[[1]](), int64())
+ expect_equal(fun$out_type[[2]](), int32())
+
+ expect_snapshot_error(arrow_scalar_function(NULL, int32(), int32()))
+})
+
+test_that("arrow_scalar_function() works with auto_convert = TRUE", {
+ times_32_wrapper <- arrow_scalar_function(
+ function(context, x) x * 32,
+ float64(),
+ float64(),
+ auto_convert = TRUE
+ )
+
+ dummy_kernel_context <- list()
+
+ expect_equal(
+ times_32_wrapper$wrapper_fun(dummy_kernel_context, list(Scalar$create(2))),
+ Array$create(2 * 32)
+ )
+})
+
+test_that("register_scalar_function() adds a compute function to the
registry", {
+ skip_if_not(CanRunWithCapturedR())
+
+ register_scalar_function(
+ "times_32",
+ function(context, x) x * 32.0,
+ int32(), float64(),
+ auto_convert = TRUE
+ )
+ on.exit(unregister_binding("times_32", update_cache = TRUE))
+
+ expect_true("times_32" %in% names(asNamespace("arrow")$.cache$functions))
+ expect_true("times_32" %in% list_compute_functions())
+
+ expect_equal(
+ call_function("times_32", Array$create(1L, int32())),
+ Array$create(32L, float64())
+ )
+
+ expect_equal(
+ call_function("times_32", Scalar$create(1L, int32())),
+ Scalar$create(32L, float64())
+ )
+
+ expect_identical(
+ record_batch(a = 1L) %>%
+ dplyr::mutate(b = times_32(a)) %>%
+ dplyr::collect(),
+ tibble::tibble(a = 1L, b = 32.0)
+ )
+})
+
+test_that("arrow_scalar_function() with bad return type errors", {
+ skip_if_not(CanRunWithCapturedR())
+
+ register_scalar_function(
+ "times_32_bad_return_type_array",
+ function(context, x) Array$create(x, int32()),
+ int32(),
+ float64()
+ )
+ on.exit(
+ unregister_binding("times_32_bad_return_type_array", update_cache = TRUE)
+ )
+
+ expect_error(
+ call_function("times_32_bad_return_type_array", Array$create(1L)),
+ "Expected return Array or Scalar with type 'double'"
+ )
+
+ register_scalar_function(
+ "times_32_bad_return_type_scalar",
+ function(context, x) Scalar$create(x, int32()),
+ int32(),
+ float64()
+ )
+ on.exit(
+ unregister_binding("times_32_bad_return_type_scalar", update_cache = TRUE)
+ )
+
+ expect_error(
+ call_function("times_32_bad_return_type_scalar", Array$create(1L)),
+ "Expected return Array or Scalar with type 'double'"
+ )
+})
+
+test_that("register_user_defined_function() can register multiple kernels", {
+ skip_if_not(CanRunWithCapturedR())
+
+ register_scalar_function(
+ "times_32",
+ function(context, x) x * 32L,
+ in_type = list(int32(), int64(), float64()),
+ out_type = function(in_types) in_types[[1]],
+ auto_convert = TRUE
+ )
+ on.exit(unregister_binding("times_32", update_cache = TRUE))
+
+ expect_equal(
+ call_function("times_32", Scalar$create(1L, int32())),
+ Scalar$create(32L, int32())
+ )
+
+ expect_equal(
+ call_function("times_32", Scalar$create(1L, int64())),
+ Scalar$create(32L, int64())
+ )
+
+ expect_equal(
+ call_function("times_32", Scalar$create(1L, float64())),
+ Scalar$create(32L, float64())
+ )
+})
+
+test_that("register_user_defined_function() errors for unsupported
specifications", {
+ expect_error(
+ register_scalar_function(
+ "no_kernels",
+ function(...) NULL,
+ list(),
+ list()
+ ),
+ "Can't register user-defined scalar function with 0 kernels"
+ )
+
+ expect_error(
+ register_scalar_function(
+ "wrong_n_args",
+ function(x) NULL,
+ int32(),
+ int32()
+ ),
+ "Expected `fun` to accept 2 argument\\(s\\)"
+ )
+
+ expect_error(
+ register_scalar_function(
+ "var_kernels",
+ function(...) NULL,
+ list(float64(), schema(x = float64(), y = float64())),
+ float64()
+ ),
+ "Kernels for user-defined function must accept the same number of
arguments"
+ )
+})
+
+test_that("user-defined functions work during multi-threaded execution", {
+ skip_if_not(CanRunWithCapturedR())
+ skip_if_not_available("dataset")
+
+ n_rows <- 10000
+ n_partitions <- 10
+ example_df <- expand.grid(
+ part = letters[seq_len(n_partitions)],
+ value = seq_len(n_rows),
+ stringsAsFactors = FALSE
+ )
+
+ # make sure values are different for each partition and
+ example_df$row_num <- seq_len(nrow(example_df))
+ example_df$value <- example_df$value + match(example_df$part, letters)
+
+ tf_dataset <- tempfile()
+ tf_dest <- tempfile()
+ on.exit(unlink(c(tf_dataset, tf_dest)))
+ write_dataset(example_df, tf_dataset, partitioning = "part")
+
+ register_scalar_function(
+ "times_32",
+ function(context, x) x * 32.0,
+ int32(),
+ float64(),
+ auto_convert = TRUE
+ )
+ on.exit(unregister_binding("times_32", update_cache = TRUE))
+
+ # check a regular collect()
+ result <- open_dataset(tf_dataset) %>%
+ dplyr::mutate(fun_result = times_32(value)) %>%
+ dplyr::collect() %>%
+ dplyr::arrange(row_num)
+
+ expect_identical(result$fun_result, example_df$value * 32)
+
+ # check a write_dataset()
+ open_dataset(tf_dataset) %>%
+ dplyr::mutate(fun_result = times_32(value)) %>%
+ write_dataset(tf_dest)
+
+ result2 <- dplyr::collect(open_dataset(tf_dest)) %>%
+ dplyr::arrange(row_num) %>%
+ dplyr::collect()
+
+ expect_identical(result2$fun_result, example_df$value * 32)
+})
+
+test_that("user-defined error when called from an unsupported context", {
+ skip_if_not_available("dataset")
+ skip_if_not(CanRunWithCapturedR())
+
+ register_scalar_function(
+ "times_32",
+ function(context, x) x * 32.0,
+ int32(),
+ float64(),
+ auto_convert = TRUE
+ )
+ on.exit(unregister_binding("times_32", update_cache = TRUE))
+
+ stream_plan_with_udf <- function() {
+ record_batch(a = 1:1000) %>%
+ dplyr::mutate(b = times_32(a)) %>%
+ as_record_batch_reader() %>%
+ as_arrow_table()
+ }
+
+ collect_plan_with_head <- function() {
+ record_batch(a = 1:1000) %>%
+ dplyr::mutate(fun_result = times_32(a)) %>%
+ head(11) %>%
+ dplyr::collect()
+ }
+
+ if (identical(tolower(Sys.info()[["sysname"]]), "windows")) {
+ expect_equal(
+ stream_plan_with_udf(),
+ record_batch(a = 1:1000) %>%
+ dplyr::mutate(b = times_32(a)) %>%
+ dplyr::collect(as_data_frame = FALSE)
+ )
+
+ result <- collect_plan_with_head()
+ expect_equal(nrow(result), 11)
+ } else {
+ expect_error(
+ stream_plan_with_udf(),
+ "Call to R \\(.*?\\) from a non-R thread from an unsupported context"
+ )
+ expect_error(
+ collect_plan_with_head(),
+ "Call to R \\(.*?\\) from a non-R thread from an unsupported context"
+ )
+ }
+})
diff --git a/r/tests/testthat/test-csv.R b/r/tests/testthat/test-csv.R
index fca717cc05..d4878e6d67 100644
--- a/r/tests/testthat/test-csv.R
+++ b/r/tests/testthat/test-csv.R
@@ -293,9 +293,7 @@ test_that("more informative error when reading a CSV with
headers and schema", {
})
test_that("read_csv_arrow() and write_csv_arrow() accept connection objects", {
- # connections with csv need RunWithCapturedR, which is not available
- # in R <= 3.4.4
- skip_on_r_older_than("3.5")
+ skip_if_not(CanRunWithCapturedR())
tf <- tempfile()
on.exit(unlink(tf))
diff --git a/r/tests/testthat/test-dplyr-funcs.R
b/r/tests/testthat/test-dplyr-funcs.R
index 2156ad9af0..86f984dd32 100644
--- a/r/tests/testthat/test-dplyr-funcs.R
+++ b/r/tests/testthat/test-dplyr-funcs.R
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-test_that("register_binding() works", {
+test_that("register_binding()/unregister_binding() works", {
fake_registry <- new.env(parent = emptyenv())
fun1 <- function() NULL
fun2 <- function() "Hello"
@@ -24,8 +24,9 @@ test_that("register_binding() works", {
expect_identical(fake_registry$some_fun, fun1)
expect_identical(fake_registry$`some.pkg::some_fun`, fun1)
- expect_identical(register_binding("some.pkg::some_fun", NULL,
fake_registry), fun1)
- expect_silent(expect_null(register_binding("some.pkg::some_fun", NULL,
fake_registry)))
+ expect_identical(unregister_binding("some.pkg::some_fun", fake_registry),
fun1)
+ expect_false("some.pkg::some_fun" %in% names(fake_registry))
+ expect_false("some_fun" %in% names(fake_registry))
expect_null(register_binding("somePkg::some_fun", fun1, fake_registry))
expect_identical(fake_registry$some_fun, fun1)
diff --git a/r/tests/testthat/test-extension.R
b/r/tests/testthat/test-extension.R
index 638869dc8c..55a1f8d21e 100644
--- a/r/tests/testthat/test-extension.R
+++ b/r/tests/testthat/test-extension.R
@@ -312,6 +312,7 @@ test_that("Table can roundtrip extension types", {
test_that("Dataset/arrow_dplyr_query can roundtrip extension types", {
skip_if_not_available("dataset")
+ skip_if_not(CanRunWithCapturedR())
tf <- tempfile()
on.exit(unlink(tf, recursive = TRUE))
diff --git a/r/tests/testthat/test-feather.R b/r/tests/testthat/test-feather.R
index 2120f6ac72..1ef2ecf3e9 100644
--- a/r/tests/testthat/test-feather.R
+++ b/r/tests/testthat/test-feather.R
@@ -208,11 +208,7 @@ test_that("read_feather requires RandomAccessFile and
errors nicely otherwise (A
})
test_that("read_feather() and write_feather() accept connection objects", {
- # connection object don't work on Windows i386 before R 4.0
- skip_if(on_old_windows())
- # connections with feather need RunWithCapturedR, which is not available
- # in R <= 3.4.4
- skip_on_r_older_than("3.5")
+ skip_if_not(CanRunWithCapturedR())
tf <- tempfile()
on.exit(unlink(tf))
diff --git a/r/tests/testthat/test-safe-call-into-r.R
b/r/tests/testthat/test-safe-call-into-r.R
index a8027ac423..c07d90433f 100644
--- a/r/tests/testthat/test-safe-call-into-r.R
+++ b/r/tests/testthat/test-safe-call-into-r.R
@@ -32,7 +32,7 @@ test_that("SafeCallIntoR works from the main R thread", {
})
test_that("SafeCallIntoR works within RunWithCapturedR", {
- skip_on_r_older_than("3.5")
+ skip_if_not(CanRunWithCapturedR())
skip_on_cran()
expect_identical(
@@ -47,16 +47,16 @@ test_that("SafeCallIntoR works within RunWithCapturedR", {
})
test_that("SafeCallIntoR errors from the non-R thread", {
- skip_on_r_older_than("3.5")
+ skip_if_not(CanRunWithCapturedR())
skip_on_cran()
expect_error(
TestSafeCallIntoR(function() "string one!", opt =
"async_without_executor"),
- "Call to R from a non-R thread"
+ "Call to R \\(unspecified\\) from a non-R thread"
)
expect_error(
TestSafeCallIntoR(function() stop("an error!"), opt =
"async_without_executor"),
- "Call to R from a non-R thread"
+ "Call to R \\(unspecified\\) from a non-R thread"
)
})