This is an automated email from the ASF dual-hosted git repository.
npr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 0b020a1 ARROW-11683: [R] Support dplyr::mutate()
0b020a1 is described below
commit 0b020a13a9eab8b3c2a4e97720bebe88c06ab4e8
Author: Neal Richardson <[email protected]>
AuthorDate: Thu Feb 25 10:31:08 2021 -0800
ARROW-11683: [R] Support dplyr::mutate()
First steps:
* Rework `selected_columns` to hold field_refs instead of string column
names; add code to back out the string field names where needed (e.g. dataset
`Project()`)
* Create an `array_ref` pseudo-function to do the same as `field_ref` for
`array_expressions`
* Add a `data` argument to `eval_array_expression` in order to bind
`array_ref`s to the actual Arrays before evaluating
* Refactor `filter()` NSE code for reuse in `mutate()`
* Split up dplyr tests because we're going to be adding lots more
Then:
* Basic `mutate()` and `transmute()` (done in
https://github.com/apache/arrow/pull/9521/commits/578d4929264858916b94e8dc632123dfb85816d2)
* Go through the examples in the dplyr::mutate() docs and add tests for all
cases. Where possible they're implemented in arrow fully; where we don't
support the functions, it falls back to the current behavior of pulling the
data into R first.
Followup JIRAs:
* ARROW-11704: Wire up dplyr::mutate() for datasets
* ARROW-16999: Implement dplyr::across() and autosplicing
* ARROW-11700: Internationalize error handling in tidy eval
* ARROW-11701: Implement dplyr::relocate()
* ARROW-11702: Enable ungrouped aggregations in non-Dataset expressions
* ARROW-11658: Handle mutate/rename inside group_by
* ARROW-11705: Support scalar value recycling in RecordBatch/Table$create()
* ARROW-11754: Support dplyr::compute()
* ARROW-11752: Replace usage of testthat::expect_is()
* ARROW-11755: Add tests from dplyr/test-mutate.r
* ARROW-11785: Fallback when filtering Table with if_any() expression fails
Closes #9521 from nealrichardson/mutate
Authored-by: Neal Richardson <[email protected]>
Signed-off-by: Neal Richardson <[email protected]>
---
r/NEWS.md | 7 +
r/R/arrow-package.R | 2 +-
r/R/arrowExports.R | 4 +
r/R/dataset-scan.R | 10 +
r/R/dataset-write.R | 6 +-
r/R/dplyr.R | 287 +++++++++++++++++++++------
r/R/expression.R | 43 +++-
r/src/arrowExports.cpp | 9 +
r/src/expression.cpp | 7 +
r/tests/testthat/helper-expectation.R | 63 ++++++
r/tests/testthat/test-RecordBatch.R | 8 +
r/tests/testthat/test-dplyr-filter.R | 287 +++++++++++++++++++++++++++
r/tests/testthat/test-dplyr-mutate.R | 350 +++++++++++++++++++++++++++++++++
r/tests/testthat/test-dplyr.R | 356 +---------------------------------
r/tests/testthat/test-expression.R | 12 ++
15 files changed, 1030 insertions(+), 421 deletions(-)
diff --git a/r/NEWS.md b/r/NEWS.md
index 65c4e22..a008088 100644
--- a/r/NEWS.md
+++ b/r/NEWS.md
@@ -19,6 +19,13 @@
# arrow 3.0.0.9000
+## dplyr methods
+
+* `dplyr::mutate()` on Arrow `Table` and `RecordBatch` is now supported in
Arrow for many applications. Where not yet supported, the implementation falls
back to pulling data into an R `data.frame` first.
+* String functions `nchar()`, `tolower()`, and `toupper()`, along with their
`stringr` spellings `str_length()`, `str_to_lower()`, and `str_to_upper()`, are
supported in Arrow `dplyr` calls. `str_trim()` is also supported.
+
+## Other improvements
+
* `value_counts()` to tabulate values in an `Array` or `ChunkedArray`, similar
to `base::table()`.
* `StructArray` objects gain data.frame-like methods, including `names()`,
`$`, `[[`, and `dim()`.
* RecordBatch columns can now be added, replaced, or removed by assigning
(`<-`) with either `$` or `[[`
diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R
index 66694a9..818d85c 100644
--- a/r/R/arrow-package.R
+++ b/r/R/arrow-package.R
@@ -30,7 +30,7 @@
"dplyr::",
c(
"select", "filter", "collect", "summarise", "group_by", "groups",
- "group_vars", "ungroup", "mutate", "arrange", "rename", "pull"
+ "group_vars", "ungroup", "mutate", "transmute", "arrange", "rename",
"pull"
)
)
for (cl in c("Dataset", "ArrowTabular", "arrow_dplyr_query")) {
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index 3d0f31c..790232c 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -744,6 +744,10 @@ dataset___expr__field_ref <- function(name){
.Call(`_arrow_dataset___expr__field_ref`, name)
}
+dataset___expr__get_field_ref_name <- function(ref){
+ .Call(`_arrow_dataset___expr__get_field_ref_name`, ref)
+}
+
dataset___expr__scalar <- function(x){
.Call(`_arrow_dataset___expr__scalar`, x)
}
diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R
index 45fc968..ec6f85c 100644
--- a/r/R/dataset-scan.R
+++ b/r/R/dataset-scan.R
@@ -69,6 +69,10 @@ Scanner$create <- function(dataset,
batch_size = NULL,
...) {
if (inherits(dataset, "arrow_dplyr_query")) {
+ if (inherits(dataset$.data, "ArrowTabular")) {
+ # To handle mutate() on Table/RecordBatch, we need to
collect(as_data_frame=FALSE) now
+ dataset <- dplyr::collect(dataset, as_data_frame = FALSE)
+ }
return(Scanner$create(
dataset$.data,
dataset$selected_columns,
@@ -152,6 +156,12 @@ map_batches <- function(X, FUN, ..., .data.frame = TRUE) {
ScannerBuilder <- R6Class("ScannerBuilder", inherit = ArrowObject,
public = list(
Project = function(cols) {
+ # cols is either a character vector or a named list of Expressions
+ if (!is.character(cols)) {
+ # We don't yet support mutate() on datasets, so this is just a list
+ # of FieldRefs, and we need to back out the field names
+ cols <- get_field_names(cols)
+ }
assert_is(cols, "character")
dataset___ScannerBuilder__Project(self, cols)
self
diff --git a/r/R/dataset-write.R b/r/R/dataset-write.R
index c5c9292..5078bc3 100644
--- a/r/R/dataset-write.R
+++ b/r/R/dataset-write.R
@@ -62,8 +62,12 @@ write_dataset <- function(dataset,
hive_style = TRUE,
...) {
if (inherits(dataset, "arrow_dplyr_query")) {
+ if (inherits(dataset$.data, "ArrowTabular")) {
+ # collect() to materialize any mutate/rename
+ dataset <- dplyr::collect(dataset, as_data_frame = FALSE)
+ }
# We can select a subset of columns but we can't rename them
- if (!all(dataset$selected_columns == names(dataset$selected_columns))) {
+ if (!all(get_field_names(dataset) == names(dataset$selected_columns))) {
stop("Renaming columns when writing a dataset is not yet supported",
call. = FALSE)
}
# partitioning vars need to be in the `select` schema
diff --git a/r/R/dplyr.R b/r/R/dplyr.R
index 3271374..2bd8170 100644
--- a/r/R/dplyr.R
+++ b/r/R/dplyr.R
@@ -33,11 +33,11 @@ arrow_dplyr_query <- function(.data) {
structure(
list(
.data = .data$clone(),
- # selected_columns is a named character vector:
- # * vector contents are the names of the columns in the data
- # * vector names are the names they should be in the end (i.e. this
+ # selected_columns is a named list:
+ # * contents are references/expressions pointing to the data
+ # * names are the names they should be in the end (i.e. this
# records any renaming)
- selected_columns = set_names(names(.data)),
+ selected_columns = make_field_refs(names(.data), dataset =
inherits(.data, "Dataset")),
# filtered_rows will be an Expression
filtered_rows = TRUE,
# group_by_vars is a character vector of columns (as renamed)
@@ -51,8 +51,15 @@ arrow_dplyr_query <- function(.data) {
#' @export
print.arrow_dplyr_query <- function(x, ...) {
schm <- x$.data$schema
- cols <- x$selected_columns
- fields <- map_chr(cols, ~schm$GetFieldByName(.)$ToString())
+ cols <- get_field_names(x)
+ # If cols are expressions, they won't be in the schema and will be "" in cols
+ fields <- map_chr(cols, function(name) {
+ if (nzchar(name)) {
+ schm$GetFieldByName(name)$ToString()
+ } else {
+ "expr"
+ }
+ })
# Strip off the field names as they are in the dataset and add the renamed
ones
fields <- paste(names(cols), sub("^.*?: ", "", fields), sep = ": ", collapse
= "\n")
cat(class(x$.data)[1], " (query)\n", sep = "")
@@ -73,6 +80,33 @@ print.arrow_dplyr_query <- function(x, ...) {
invisible(x)
}
+get_field_names <- function(selected_cols) {
+ if (inherits(selected_cols, "arrow_dplyr_query")) {
+ selected_cols <- selected_cols$selected_columns
+ }
+ map_chr(selected_cols, function(x) {
+ if (inherits(x, "Expression")) {
+ out <- x$field_name
+ } else if (inherits(x, "array_expression")) {
+ out <- x$args$field_name
+ } else {
+ out <- NULL
+ }
+ # If x isn't some kind of field reference, out is NULL,
+ # but we always need to return a string
+ out %||% ""
+ })
+}
+
+make_field_refs <- function(field_names, dataset = TRUE) {
+ if (dataset) {
+ out <- lapply(field_names, Expression$field_ref)
+ } else {
+ out <- lapply(field_names, function(x) array_expression("array_ref",
field_name = x))
+ }
+ set_names(out, field_names)
+}
+
# These are the names reflecting all select/rename, not what is in Arrow
#' @export
names.arrow_dplyr_query <- function(x) names(x$selected_columns)
@@ -89,7 +123,7 @@ dim.arrow_dplyr_query <- function(x) {
rows <- NA_integer_
} else {
# Evaluate the filter expression to a BooleanArray and count
- rows <- as.integer(sum(eval_array_expression(x$filtered_rows), na.rm =
TRUE))
+ rows <- as.integer(sum(eval_array_expression(x$filtered_rows, x$.data),
na.rm = TRUE))
}
c(rows, cols)
}
@@ -187,29 +221,8 @@ filter.arrow_dplyr_query <- function(.data, ..., .preserve
= FALSE) {
}
.data <- arrow_dplyr_query(.data)
- # The filter() method works by evaluating the filters to generate Expressions
- # with references to Arrays (if .data is Table/RecordBatch) or Fields (if
- # .data is a Dataset).
- dm <- filter_mask(.data)
- filters <- lapply(filts, function (f) {
- # This should yield an Expression as long as the filter function(s) are
- # implemented in Arrow.
- tryCatch(eval_tidy(f, dm), error = function(e) {
- # Look for the cases where bad input was given, i.e. this would fail
- # in regular dplyr anyway, and let those raise those as errors;
- # else, for things not supported by Arrow return a "try-error",
- # which we'll handle differently
- msg <- conditionMessage(e)
- # TODO: internationalization?
- if (grepl("object '.*'.not.found", msg)) {
- stop(e)
- }
- if (grepl('could not find function ".*"', msg)) {
- stop(e)
- }
- invisible(structure(msg, class = "try-error", condition = e))
- })
- })
+ # tidy-eval the filter expressions inside an Arrow data_mask
+ filters <- lapply(filts, arrow_eval, arrow_mask(.data))
bad_filters <- map_lgl(filters, ~inherits(., "try-error"))
if (any(bad_filters)) {
bads <- oxford_paste(map_chr(filts, as_label)[bad_filters], quote = FALSE)
@@ -238,6 +251,30 @@ filter.arrow_dplyr_query <- function(.data, ..., .preserve
= FALSE) {
}
filter.Dataset <- filter.ArrowTabular <- filter.arrow_dplyr_query
+arrow_eval <- function (expr, mask) {
+ # filter(), mutate(), etc. work by evaluating the quoted `exprs` to generate
Expressions
+ # with references to Arrays (if .data is Table/RecordBatch) or Fields (if
+ # .data is a Dataset).
+
+ # This yields an Expression as long as the `exprs` are implemented in Arrow.
+ # Otherwise, it returns a try-error
+ tryCatch(eval_tidy(expr, mask), error = function(e) {
+ # Look for the cases where bad input was given, i.e. this would fail
+ # in regular dplyr anyway, and let those raise those as errors;
+ # else, for things not supported by Arrow return a "try-error",
+ # which we'll handle differently
+ msg <- conditionMessage(e)
+ # TODO(ARROW-11700): internationalization
+ if (grepl("object '.*'.not.found", msg)) {
+ stop(e)
+ }
+ if (grepl('could not find function ".*"', msg)) {
+ stop(e)
+ }
+ invisible(structure(msg, class = "try-error", condition = e))
+ })
+}
+
# Helper to assemble the functions that go in the NSE data mask
# The only difference between the Dataset and the Table/RecordBatch versions
# is that they use a different wrapping function (FUN) to hold the unevaluated
@@ -271,23 +308,32 @@ build_function_list <- function(FUN) {
dataset_function_list <- build_function_list(build_dataset_expression)
array_function_list <- build_function_list(build_array_expression)
-# Create a data mask for evaluating a filter expression
-filter_mask <- function(.data) {
+# Create a data mask for evaluating a dplyr expression
+arrow_mask <- function(.data) {
if (query_on_dataset(.data)) {
f_env <- new_environment(dataset_function_list)
- var_binder <- function(x) Expression$field_ref(x)
} else {
f_env <- new_environment(array_function_list)
- var_binder <- function(x) .data$.data[[x]]
}
- # Add the column references
- # Renaming is handled automatically by the named list
- data_pronoun <- lapply(.data$selected_columns, var_binder)
- env_bind(f_env, !!!data_pronoun)
- # Then bind the data pronoun
- env_bind(f_env, .data = data_pronoun)
- new_data_mask(f_env)
+ # Add functions that need to error hard and clear.
+ # Some R functions will still try to evaluate on an Expression
+ # and return NA with a warning
+ fail <- function(...) stop("Not implemented")
+ for (f in c("mean")) {
+ f_env[[f]] <- fail
+ }
+
+ # Add the column references and make the mask
+ out <- new_data_mask(
+ new_environment(.data$selected_columns, parent = f_env),
+ f_env
+ )
+ # Then insert the data pronoun
+ # TODO: figure out what rlang::as_data_pronoun does/why we should use it
+ # (because if we do we get `Error: Can't modify the data pronoun` in
mutate())
+ out$.data <- .data$selected_columns
+ out
}
set_filters <- function(.data, expressions) {
@@ -309,8 +355,27 @@ collect.arrow_dplyr_query <- function(x, as_data_frame =
TRUE, ...) {
# See dataset.R for Dataset and Scanner(Builder) classes
tab <- Scanner$create(x)$ToTable()
} else {
- # This is a Table/RecordBatch. See record-batch.R for the [ method
- tab <- x$.data[x$filtered_rows, x$selected_columns, keep_na = FALSE]
+ # This is a Table or RecordBatch
+
+ # Filter and select the data referenced in selected columns
+ if (isTRUE(x$filtered_rows)) {
+ filter <- TRUE
+ } else {
+ filter <- eval_array_expression(x$filtered_rows, x$.data)
+ }
+ # TODO: shortcut if identical(names(x$.data),
find_array_refs(x$selected_columns))?
+ tab <- x$.data[filter, find_array_refs(x$selected_columns), keep_na =
FALSE]
+ # Now evaluate those expressions on the filtered table
+ cols <- lapply(x$selected_columns, eval_array_expression, data = tab)
+ if (length(cols) == 0) {
+ tab <- tab[, integer(0)]
+ } else {
+ if (inherits(x$.data, "Table")) {
+ tab <- Table$create(!!!cols)
+ } else {
+ tab <- RecordBatch$create(!!!cols)
+ }
+ }
}
if (as_data_frame) {
df <- as.data.frame(tab)
@@ -327,7 +392,13 @@ ensure_group_vars <- function(x) {
if (inherits(x, "arrow_dplyr_query")) {
# Before pulling data from Arrow, make sure all group vars are in the
projection
gv <- set_names(setdiff(dplyr::group_vars(x), names(x)))
- x$selected_columns <- c(x$selected_columns, gv)
+ if (length(gv)) {
+ # Add them back
+ x$selected_columns <- c(
+ x$selected_columns,
+ make_field_refs(gv, dataset = query_on_dataset(.data))
+ )
+ }
}
x
}
@@ -337,21 +408,20 @@ restore_dplyr_features <- function(df, query) {
# After calling collect(), make sure these features are carried over
grouped <- length(query$group_by_vars) > 0
- renamed <- !identical(names(df), names(query))
- if (is.data.frame(df)) {
+ renamed <- ncol(df) && !identical(names(df), names(query))
+ if (renamed) {
# In case variables were renamed, apply those names
- if (renamed && ncol(df)) {
- names(df) <- names(query)
- }
+ names(df) <- names(query)
+ }
+ if (grouped) {
# Preserve groupings, if present
- if (grouped) {
+ if (is.data.frame(df)) {
df <- dplyr::grouped_df(df, dplyr::group_vars(query))
+ } else {
+ # This is a Table, via collect(as_data_frame = FALSE)
+ df <- arrow_dplyr_query(df)
+ df$group_by_vars <- query$group_by_vars
}
- } else if (grouped || renamed) {
- # This is a Table, via collect(as_data_frame = FALSE)
- df <- arrow_dplyr_query(df)
- names(df$selected_columns) <- names(query)
- df$group_by_vars <- query$group_by_vars
}
df
}
@@ -423,26 +493,117 @@ ungroup.arrow_dplyr_query <- function(x, ...) {
}
ungroup.Dataset <- ungroup.ArrowTabular <- force
-mutate.arrow_dplyr_query <- function(.data, ...) {
+mutate.arrow_dplyr_query <- function(.data,
+ ...,
+ .keep = c("all", "used", "unused",
"none"),
+ .before = NULL,
+ .after = NULL) {
+ call <- match.call()
+ exprs <- quos(...)
+ if (length(exprs) == 0) {
+ # Nothing to do
+ return(.data)
+ }
+
.data <- arrow_dplyr_query(.data)
if (query_on_dataset(.data)) {
not_implemented_for_dataset("mutate()")
}
- # TODO: see if we can defer evaluating the expressions and not collect here.
- # It's different from filters (as currently implemented) because the basic
- # vector transformation functions aren't yet implemented in Arrow C++.
- dplyr::mutate(dplyr::collect(.data), ...)
+
+ .keep <- match.arg(.keep)
+ .before <- enquo(.before)
+ .after <- enquo(.after)
+ # Restrict the cases we support for now
+ if (!quo_is_null(.before) || !quo_is_null(.after)) {
+ # TODO(ARROW-11701)
+ return(abandon_ship(call, .data, '.before and .after arguments are not
supported in Arrow'))
+ } else if (length(group_vars(.data)) > 0) {
+ # mutate() on a grouped dataset does calculations within groups
+ # This doesn't matter on scalar ops (arithmetic etc.) but it does
+ # for things with aggregations (e.g. subtracting the mean)
+ return(abandon_ship(call, .data, 'mutate() on grouped data not supported
in Arrow'))
+ }
+
+ # Check for unnamed expressions and fix if any
+ unnamed <- !nzchar(names(exprs))
+ # Deparse and take the first element in case they're long expressions
+ names(exprs)[unnamed] <- map_chr(exprs[unnamed], as_label)
+
+ mask <- arrow_mask(.data)
+ results <- list()
+ for (i in seq_along(exprs)) {
+ # Iterate over the indices and not the names because names may be repeated
+ # (which overwrites the previous name)
+ new_var <- names(exprs)[i]
+ results[[new_var]] <- arrow_eval(exprs[[i]], mask)
+ if (inherits(results[[new_var]], "try-error")) {
+ msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in
Arrow')
+ return(abandon_ship(call, .data, msg))
+ }
+ # Put it in the data mask too
+ mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]]
+ }
+
+ # Assign the new columns into the .data$selected_columns, respecting the
.keep param
+ if (.keep == "none") {
+ .data$selected_columns <- results
+ } else {
+ if (.keep != "all") {
+ # "used" or "unused"
+ used_vars <- unlist(lapply(exprs, all.vars), use.names = FALSE)
+ old_vars <- names(.data$selected_columns)
+ if (.keep == "used") {
+ .data$selected_columns <- .data$selected_columns[intersect(old_vars,
used_vars)]
+ } else {
+ # "unused"
+ .data$selected_columns <- .data$selected_columns[setdiff(old_vars,
used_vars)]
+ }
+ }
+ # Note that this is names(exprs) not names(results):
+ # if results$new_var is NULL, that means we are supposed to remove it
+ for (new_var in names(exprs)) {
+ .data$selected_columns[[new_var]] <- results[[new_var]]
+ }
+ }
+ # Even if "none", we still keep group vars
+ ensure_group_vars(.data)
}
mutate.Dataset <- mutate.ArrowTabular <- mutate.arrow_dplyr_query
-# TODO: add transmute() that does what summarise() does (select only the vars
we need)
+
+transmute.arrow_dplyr_query <- function(.data, ...) dplyr::mutate(.data, ...,
.keep = "none")
+transmute.Dataset <- transmute.ArrowTabular <- transmute.arrow_dplyr_query
+
+# Helper to handle unsupported dplyr features
+# * For Table/RecordBatch, we collect() and then call the dplyr method in R
+# * For Dataset, we just error
+abandon_ship <- function(call, .data, msg = NULL) {
+ dplyr_fun_name <- sub("^(.*?)\\..*", "\\1", as.character(call[[1]]))
+ if (query_on_dataset(.data)) {
+ if (is.null(msg)) {
+ # Default message: function not implemented
+ not_implemented_for_dataset(paste0(dplyr_fun_name, "()"))
+ } else {
+ stop(msg, call. = FALSE)
+ }
+ }
+
+ # else, collect and call dplyr method
+ if (!is.null(msg)) {
+ warning(msg, "; pulling data into R", immediate. = TRUE, call. = FALSE)
+ }
+ call$.data <- dplyr::collect(.data)
+ call[[1]] <- get(dplyr_fun_name, envir = asNamespace("dplyr"))
+ eval.parent(call, 2)
+}
arrange.arrow_dplyr_query <- function(.data, ...) {
.data <- arrow_dplyr_query(.data)
if (query_on_dataset(.data)) {
not_implemented_for_dataset("arrange()")
}
-
- dplyr::arrange(dplyr::collect(.data), ...)
+ # TODO(ARROW-11703) move this to Arrow
+ call <- match.call()
+ abandon_ship(call, .data)
}
arrange.Dataset <- arrange.ArrowTabular <- arrange.arrow_dplyr_query
diff --git a/r/R/expression.R b/r/R/expression.R
index 878b800..74c1aef 100644
--- a/r/R/expression.R
+++ b/r/R/expression.R
@@ -143,7 +143,14 @@ cast_array_expression <- function(x, to_type, safe = TRUE,
...) {
.array_function_map <- c(.unary_function_map, .binary_function_map)
-eval_array_expression <- function(x) {
+eval_array_expression <- function(x, data = NULL) {
+ if (!is.null(data)) {
+ x <- bind_array_refs(x, data)
+ }
+ if (!inherits(x, "array_expression")) {
+ # Nothing to evaluate
+ return(x)
+ }
x$args <- lapply(x$args, function (a) {
if (inherits(a, "array_expression")) {
eval_array_expression(a)
@@ -154,6 +161,27 @@ eval_array_expression <- function(x) {
call_function(x$fun, args = x$args, options = x$options %||%
empty_named_list())
}
+find_array_refs <- function(x) {
+ if (identical(x$fun, "array_ref")) {
+ out <- x$args$field_name
+ } else {
+ out <- lapply(x$args, find_array_refs)
+ }
+ unlist(out)
+}
+
+# Take an array_expression and replace array_refs with arrays/chunkedarrays
from data
+bind_array_refs <- function(x, data) {
+ if (inherits(x, "array_expression")) {
+ if (identical(x$fun, "array_ref")) {
+ x <- data[[x$args$field_name]]
+ } else {
+ x$args <- lapply(x$args, bind_array_refs, data)
+ }
+ }
+ x
+}
+
#' @export
is.na.array_expression <- function(x) array_expression("is.na", x)
@@ -181,9 +209,13 @@ print.array_expression <- function(x, ...) {
deparse(arg)
}
})
- # Prune this for readability
- function_name <- sub("_kleene", "", x$fun)
- paste0(function_name, "(", paste(printed_args, collapse = ", "), ")")
+ if (identical(x$fun, "array_ref")) {
+ x$args$field_name
+ } else {
+ # Prune this for readability
+ function_name <- sub("_kleene", "", x$fun)
+ paste0(function_name, "(", paste(printed_args, collapse = ", "), ")")
+ }
}
###########
@@ -217,6 +249,9 @@ Expression <- R6Class("Expression", inherit = ArrowObject,
)
Expression$create("cast", self, options = modifyList(opts, list(...)))
}
+ ),
+ active = list(
+ field_name = function() dataset___expr__get_field_ref_name(self)
)
)
Expression$create <- function(function_name,
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index 839c9d6..73ee648 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -1569,6 +1569,14 @@ BEGIN_CPP11
END_CPP11
}
// expression.cpp
+std::string dataset___expr__get_field_ref_name(const
std::shared_ptr<ds::Expression>& ref);
+extern "C" SEXP _arrow_dataset___expr__get_field_ref_name(SEXP ref_sexp){
+BEGIN_CPP11
+ arrow::r::Input<const std::shared_ptr<ds::Expression>&>::type
ref(ref_sexp);
+ return cpp11::as_sexp(dataset___expr__get_field_ref_name(ref));
+END_CPP11
+}
+// expression.cpp
std::shared_ptr<ds::Expression> dataset___expr__scalar(const
std::shared_ptr<arrow::Scalar>& x);
extern "C" SEXP _arrow_dataset___expr__scalar(SEXP x_sexp){
BEGIN_CPP11
@@ -3702,6 +3710,7 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_FixedSizeListType__list_size", (DL_FUNC)
&_arrow_FixedSizeListType__list_size, 1},
{ "_arrow_dataset___expr__call", (DL_FUNC)
&_arrow_dataset___expr__call, 3},
{ "_arrow_dataset___expr__field_ref", (DL_FUNC)
&_arrow_dataset___expr__field_ref, 1},
+ { "_arrow_dataset___expr__get_field_ref_name", (DL_FUNC)
&_arrow_dataset___expr__get_field_ref_name, 1},
{ "_arrow_dataset___expr__scalar", (DL_FUNC)
&_arrow_dataset___expr__scalar, 1},
{ "_arrow_dataset___expr__ToString", (DL_FUNC)
&_arrow_dataset___expr__ToString, 1},
{ "_arrow_ipc___WriteFeather__Table", (DL_FUNC)
&_arrow_ipc___WriteFeather__Table, 6},
diff --git a/r/src/expression.cpp b/r/src/expression.cpp
index ddb1e72..76d8222 100644
--- a/r/src/expression.cpp
+++ b/r/src/expression.cpp
@@ -48,6 +48,13 @@ std::shared_ptr<ds::Expression>
dataset___expr__field_ref(std::string name) {
}
// [[arrow::export]]
+std::string dataset___expr__get_field_ref_name(
+ const std::shared_ptr<ds::Expression>& ref) {
+ auto refname = ref->field_ref()->name();
+ return *refname;
+}
+
+// [[arrow::export]]
std::shared_ptr<ds::Expression> dataset___expr__scalar(
const std::shared_ptr<arrow::Scalar>& x) {
return std::make_shared<ds::Expression>(ds::literal(std::move(x)));
diff --git a/r/tests/testthat/helper-expectation.R
b/r/tests/testthat/helper-expectation.R
index ce0f9de..76edea6 100644
--- a/r/tests/testthat/helper-expectation.R
+++ b/r/tests/testthat/helper-expectation.R
@@ -59,3 +59,66 @@ verify_output <- function(...) {
}
testthat::verify_output(...)
}
+
+expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its
start
+ tbl, # A tbl/df as reference, will make
RB/Table with
+ skip_record_batch = NULL, # Msg, if should skip
RB test
+ skip_table = NULL, # Msg, if should skip
Table test
+ ...) {
+ expr <- rlang::enquo(expr)
+ expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input =
tbl)))
+
+ skip_msg <- NULL
+
+ if (is.null(skip_record_batch)) {
+ via_batch <- rlang::eval_tidy(
+ expr,
+ rlang::new_data_mask(rlang::env(input = record_batch(tbl)))
+ )
+ expect_equivalent(via_batch, expected, ...)
+ } else {
+ skip_msg <- c(skip_msg, skip_record_batch)
+ }
+
+ if (is.null(skip_table)) {
+ via_table <- rlang::eval_tidy(
+ expr,
+ rlang::new_data_mask(rlang::env(input = Table$create(tbl)))
+ )
+ expect_equivalent(via_table, expected, ...)
+ } else {
+ skip_msg <- c(skip_msg, skip_table)
+ }
+
+ if (!is.null(skip_msg)) {
+ skip(paste(skip_msg, collpase = "\n"))
+ }
+}
+
+expect_dplyr_error <- function(expr, # A dplyr pipeline with `input` as its
start
+ tbl, # A tbl/df as reference, will make
RB/Table with
+ ...) {
+ expr <- rlang::enquo(expr)
+ msg <- tryCatch(
+ rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))),
+ error = function (e) conditionMessage(e)
+ )
+ expect_is(msg, "character", label = "dplyr on data.frame did not error")
+
+ expect_error(
+ rlang::eval_tidy(
+ expr,
+ rlang::new_data_mask(rlang::env(input = record_batch(tbl)))
+ ),
+ msg,
+ ...
+ )
+ expect_error(
+ rlang::eval_tidy(
+ expr,
+ rlang::new_data_mask(rlang::env(input = Table$create(tbl)))
+ ),
+ msg,
+ ...
+ )
+}
\ No newline at end of file
diff --git a/r/tests/testthat/test-RecordBatch.R
b/r/tests/testthat/test-RecordBatch.R
index aeee66d..a017823 100644
--- a/r/tests/testthat/test-RecordBatch.R
+++ b/r/tests/testthat/test-RecordBatch.R
@@ -416,6 +416,14 @@ test_that("record_batch() handles null type (ARROW-7064)",
{
expect_equivalent(batch$schema, schema(a = int32(), n = null()))
})
+test_that("record_batch() scalar recycling", {
+ skip("Not implemented (ARROW-11705)")
+ expect_data_frame(
+ record_batch(a = 1:10, b = 5),
+ tibble::tibble(a = 1:10, b = 5)
+ )
+})
+
test_that("RecordBatch$Equals", {
df <- tibble::tibble(x = 1:10, y = letters[1:10])
a <- record_batch(df)
diff --git a/r/tests/testthat/test-dplyr-filter.R
b/r/tests/testthat/test-dplyr-filter.R
new file mode 100644
index 0000000..f735894
--- /dev/null
+++ b/r/tests/testthat/test-dplyr-filter.R
@@ -0,0 +1,287 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+library(dplyr)
+library(stringr)
+
+tbl <- example_data
+# Add some better string data
+tbl$verses <- verses[[1]]
+# c(" a ", " b ", " c ", ...) increasing padding
+# nchar = 3 5 7 9 11 13 15 17 19 21
+tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2*(1:10)+1, side
= "both")
+
+test_that("filter() on is.na()", {
+ expect_dplyr_equal(
+ input %>%
+ filter(is.na(lgl)) %>%
+ select(chr, int, lgl) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("filter() with NAs in selection", {
+ expect_dplyr_equal(
+ input %>%
+ filter(lgl) %>%
+ select(chr, int, lgl) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("Filter returning an empty Table should not segfault (ARROW-8354)", {
+ expect_dplyr_equal(
+ input %>%
+ filter(false) %>%
+ select(chr, int, lgl) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("filtering with expression", {
+ char_sym <- "b"
+ expect_dplyr_equal(
+ input %>%
+ filter(chr == char_sym) %>%
+ select(string = chr, int) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("filtering with arithmetic", {
+ expect_dplyr_equal(
+ input %>%
+ filter(dbl + 1 > 3) %>%
+ select(string = chr, int, dbl) %>%
+ collect(),
+ tbl
+ )
+
+ expect_dplyr_equal(
+ input %>%
+ filter(dbl / 2 > 3) %>%
+ select(string = chr, int, dbl) %>%
+ collect(),
+ tbl
+ )
+
+ expect_dplyr_equal(
+ input %>%
+ filter(dbl / 2L > 3) %>%
+ select(string = chr, int, dbl) %>%
+ collect(),
+ tbl
+ )
+
+ expect_dplyr_equal(
+ input %>%
+ filter(int / 2 > 3) %>%
+ select(string = chr, int, dbl) %>%
+ collect(),
+ tbl
+ )
+
+ expect_dplyr_equal(
+ input %>%
+ filter(int / 2L > 3) %>%
+ select(string = chr, int, dbl) %>%
+ collect(),
+ tbl
+ )
+
+ expect_dplyr_equal(
+ input %>%
+ filter(dbl %/% 2 > 3) %>%
+ select(string = chr, int, dbl) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("filtering with expression + autocasting", {
+ expect_dplyr_equal(
+ input %>%
+ filter(dbl + 1 > 3L) %>% # test autocasting with comparison to 3L
+ select(string = chr, int, dbl) %>%
+ collect(),
+ tbl
+ )
+
+ expect_dplyr_equal(
+ input %>%
+ filter(int + 1 > 3) %>%
+ select(string = chr, int, dbl) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("More complex select/filter", {
+ expect_dplyr_equal(
+ input %>%
+ filter(dbl > 2, chr == "d" | chr == "f") %>%
+ select(chr, int, lgl) %>%
+ filter(int < 5) %>%
+ select(int, chr) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("filter() with %in%", {
+ expect_dplyr_equal(
+ input %>%
+ filter(dbl > 2, chr %in% c("d", "f")) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("filter() with string ops", {
+ # Extra instrumentation to ensure that we're calling Arrow compute here
+ # because many base R string functions implicitly call as.character,
+ # which means they still work on Arrays but actually force data into R
+ # 1) wrapper that raises a warning if as.character is called. Can't wrap
+ # the whole test because as.character apparently gets called in other
+ # (presumably legitimate) places
+ # 2) Wrap the test in expect_warning(expr, NA) to catch the warning
+
+ with_no_as_character <- function(expr) {
+ trace(
+ "as.character",
+ tracer = quote(warning("as.character was called")),
+ print = FALSE,
+ where = toupper
+ )
+ on.exit(untrace("as.character", where = toupper))
+ force(expr)
+ }
+
+ expect_warning(
+ expect_dplyr_equal(
+ input %>%
+ filter(dbl > 2, with_no_as_character(toupper(chr)) %in% c("D", "F"))
%>%
+ collect(),
+ tbl
+ ),
+ NA)
+
+ expect_dplyr_equal(
+ input %>%
+ filter(dbl > 2, str_length(verses) > 25) %>%
+ collect(),
+ tbl
+ )
+
+ expect_dplyr_equal(
+ input %>%
+ filter(dbl > 2, str_length(str_trim(padded_strings, "left")) > 5) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("filter environment scope", {
+ # "object 'b_var' not found"
+ expect_dplyr_error(input %>% filter(batch, chr == b_var))
+
+ b_var <- "b"
+ expect_dplyr_equal(
+ input %>%
+ filter(chr == b_var) %>%
+ collect(),
+ tbl
+ )
+ # Also for functions
+ # 'could not find function "isEqualTo"' because we haven't defined it yet
+ expect_dplyr_error(filter(batch, isEqualTo(int, 4)))
+
+ skip("Need to substitute in user defined function too")
+ # TODO: fix this: this isEqualTo function is eagerly evaluating; it should
+ # instead yield array_expressions. Probably bc the parent env of the function
+ # has the Ops.Array methods defined; we need to move it so that the parent
+ # env is the data mask we use in the dplyr eval
+ isEqualTo <- function(x, y) x == y & !is.na(x)
+ expect_dplyr_equal(
+ input %>%
+ select(-fct) %>% # factor levels aren't identical
+ filter(isEqualTo(int, 4)) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("Filtering on a column that doesn't exist errors correctly", {
+ skip("Error handling in arrow_eval() needs to be internationalized
(ARROW-11700)")
+ expect_error(
+ batch %>% filter(not_a_col == 42) %>% collect(),
+ "object 'not_a_col' not found"
+ )
+})
+
+test_that("Filtering with a function that doesn't have an Array/expr method
still works", {
+ expect_warning(
+ expect_dplyr_equal(
+ input %>%
+ filter(int > 2, pnorm(dbl) > .99) %>%
+ collect(),
+ tbl
+ ),
+ 'Filter expression not implemented in Arrow: pnorm(dbl) > 0.99; pulling
data into R',
+ fixed = TRUE
+ )
+})
+
+test_that("filter() with .data pronoun", {
+ expect_dplyr_equal(
+ input %>%
+ filter(.data$dbl > 4) %>%
+ select(.data$chr, .data$int, .data$lgl) %>%
+ collect(),
+ tbl
+ )
+
+ expect_dplyr_equal(
+ input %>%
+ filter(is.na(.data$lgl)) %>%
+ select(.data$chr, .data$int, .data$lgl) %>%
+ collect(),
+ tbl
+ )
+
+ # and the .env pronoun too!
+ chr <- 4
+ expect_dplyr_equal(
+ input %>%
+ filter(.data$dbl > .env$chr) %>%
+ select(.data$chr, .data$int, .data$lgl) %>%
+ collect(),
+ tbl
+ )
+
+ # but there is an error if we don't override the masking with `.env`
+ expect_dplyr_error(
+ tbl %>%
+ filter(.data$dbl > chr) %>%
+ select(.data$chr, .data$int, .data$lgl) %>%
+ collect()
+ )
+})
diff --git a/r/tests/testthat/test-dplyr-mutate.R
b/r/tests/testthat/test-dplyr-mutate.R
new file mode 100644
index 0000000..56d7e36
--- /dev/null
+++ b/r/tests/testthat/test-dplyr-mutate.R
@@ -0,0 +1,350 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+library(dplyr)
+library(stringr)
+
+tbl <- example_data
+# Add some better string data
+tbl$verses <- verses[[1]]
+# c(" a ", " b ", " c ", ...) increasing padding
+# nchar = 3 5 7 9 11 13 15 17 19 21
+tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2*(1:10)+1, side
= "both")
+
+test_that("mutate() is lazy", {
+ expect_is(
+ tbl %>% record_batch() %>% mutate(int = int + 6L),
+ "arrow_dplyr_query"
+ )
+})
+
+test_that("basic mutate", {
+ expect_dplyr_equal(
+ input %>%
+ select(int, chr) %>%
+ filter(int > 5) %>%
+ mutate(int = int + 6L) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("transmute", {
+ expect_dplyr_equal(
+ input %>%
+ select(int, chr) %>%
+ filter(int > 5) %>%
+ transmute(int = int + 6L) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("mutate and refer to previous mutants", {
+ expect_dplyr_equal(
+ input %>%
+ select(int, padded_strings) %>%
+ mutate(
+ line_lengths = nchar(padded_strings),
+ longer = line_lengths * 10
+ ) %>%
+ filter(line_lengths > 15) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("mutate with .data pronoun", {
+ expect_dplyr_equal(
+ input %>%
+ select(int, padded_strings) %>%
+ mutate(
+ line_lengths = nchar(padded_strings),
+ longer = .data$line_lengths * 10
+ ) %>%
+ filter(line_lengths > 15) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("mutate with unnamed expressions", {
+ expect_dplyr_equal(
+ input %>%
+ select(int, padded_strings) %>%
+ mutate(
+ int, # bare column name
+ nchar(padded_strings) # expression
+ ) %>%
+ filter(int > 5) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("mutate with reassigning same name", {
+ expect_dplyr_equal(
+ input %>%
+ transmute(
+ new = lgl,
+ new = chr
+ ) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("mutate with single value for recycling", {
+ skip("Not implemented (ARROW-11705")
+ expect_dplyr_equal(
+ input %>%
+ select(int, padded_strings) %>%
+ mutate(
+ dr_bronner = 1 # ALL ONE!
+ ) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("dplyr::mutate's examples", {
+ # Newly created variables are available immediately
+ expect_dplyr_equal(
+ input %>%
+ select(name, mass) %>%
+ mutate(
+ mass2 = mass * 2,
+ mass2_squared = mass2 * mass2
+ ) %>%
+ collect(),
+ starwars # this is a test dataset that ships with dplyr
+ )
+
+ # As well as adding new variables, you can use mutate() to
+ # remove variables and modify existing variables.
+ expect_dplyr_equal(
+ input %>%
+ select(name, height, mass, homeworld) %>%
+ mutate(
+ mass = NULL,
+ height = height * 0.0328084 # convert to feet
+ ) %>%
+ collect(),
+ starwars
+ )
+
+ # Examples we don't support should succeed
+ # but warn that they're pulling data into R to do so
+
+ # across + autosplicing: ARROW-11699
+ expect_warning(
+ expect_dplyr_equal(
+ input %>%
+ select(name, homeworld, species) %>%
+ mutate(across(!name, as.factor)) %>%
+ collect(),
+ starwars
+ ),
+ "Expression across.*not supported in Arrow"
+ )
+
+ # group_by then mutate
+ expect_warning(
+ expect_dplyr_equal(
+ input %>%
+ select(name, mass, homeworld) %>%
+ group_by(homeworld) %>%
+ mutate(rank = min_rank(desc(mass))) %>%
+ collect(),
+ starwars
+ ),
+ "not supported in Arrow"
+ )
+
+ # `.before` and `.after` experimental args: ARROW-11701
+ df <- tibble(x = 1, y = 2)
+ expect_dplyr_equal(
+ input %>% mutate(z = x + y) %>% collect(),
+ df
+ )
+ #> # A tibble: 1 x 3
+ #> x y z
+ #> <dbl> <dbl> <dbl>
+ #> 1 1 2 3
+ expect_warning(
+ expect_dplyr_equal(
+ input %>% mutate(z = x + y, .before = 1) %>% collect(),
+ df
+ ),
+ "not supported in Arrow"
+ )
+ #> # A tibble: 1 x 3
+ #> z x y
+ #> <dbl> <dbl> <dbl>
+ #> 1 3 1 2
+ expect_warning(
+ expect_dplyr_equal(
+ input %>% mutate(z = x + y, .after = x) %>% collect(),
+ df
+ ),
+ "not supported in Arrow"
+ )
+ #> # A tibble: 1 x 3
+ #> x z y
+ #> <dbl> <dbl> <dbl>
+ #> 1 1 3 2
+
+ # By default, mutate() keeps all columns from the input data.
+ # Experimental: You can override with `.keep`
+ df <- tibble(x = 1, y = 2, a = "a", b = "b")
+ expect_dplyr_equal(
+ input %>% mutate(z = x + y, .keep = "all") %>% collect(), # the default
+ df
+ )
+ #> # A tibble: 1 x 5
+ #> x y a b z
+ #> <dbl> <dbl> <chr> <chr> <dbl>
+ #> 1 1 2 a b 3
+ expect_dplyr_equal(
+ input %>% mutate(z = x + y, .keep = "used") %>% collect(),
+ df
+ )
+ #> # A tibble: 1 x 3
+ #> x y z
+ #> <dbl> <dbl> <dbl>
+ #> 1 1 2 3
+ expect_dplyr_equal(
+ input %>% mutate(z = x + y, .keep = "unused") %>% collect(),
+ df
+ )
+ #> # A tibble: 1 x 3
+ #> a b z
+ #> <chr> <chr> <dbl>
+ #> 1 a b 3
+ expect_dplyr_equal(
+ input %>% mutate(z = x + y, .keep = "none") %>% collect(), # same as
transmute()
+ df
+ )
+ #> # A tibble: 1 x 1
+ #> z
+ #> <dbl>
+ #> 1 3
+
+ # Grouping ----------------------------------------
+ # The mutate operation may yield different results on grouped
+ # tibbles because the expressions are computed within groups.
+ # The following normalises `mass` by the global average:
+ # TODO(ARROW-11702)
+ expect_warning(
+ expect_dplyr_equal(
+ input %>%
+ select(name, mass, species) %>%
+ mutate(mass_norm = mass / mean(mass, na.rm = TRUE)) %>%
+ collect(),
+ starwars
+ ),
+ "not supported in Arrow"
+ )
+})
+
+test_that("handle bad expressions", {
+ # TODO: search for functions other than mean() (see above test)
+ # that need to be forced to fail because they error ambiguously
+
+ skip("Error handling in arrow_eval() needs to be internationalized
(ARROW-11700)")
+ expect_error(
+ Table$create(tbl) %>% mutate(newvar = NOTAVAR + 2),
+ "object 'NOTAVAR' not found"
+ )
+})
+
+test_that("print a mutated dataset", {
+ expect_output(
+ Table$create(tbl) %>%
+ select(int) %>%
+ mutate(twice = int * 2) %>%
+ print(),
+'Table (query)
+int: int32
+twice: expr
+
+See $.data for the source Arrow object',
+ fixed = TRUE)
+
+ # Handling non-expressions/edge cases
+ expect_output(
+ Table$create(tbl) %>%
+ select(int) %>%
+ mutate(again = 1:10) %>%
+ print(),
+'Table (query)
+int: int32
+again: expr
+
+See $.data for the source Arrow object',
+ fixed = TRUE)
+})
+
+test_that("mutate and write_dataset", {
+ # See related test in test-dataset.R
+
+ skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-9651
+
+ first_date <- lubridate::ymd_hms("2015-04-29 03:12:39")
+ df1 <- tibble(
+ int = 1:10,
+ dbl = as.numeric(1:10),
+ lgl = rep(c(TRUE, FALSE, NA, TRUE, FALSE), 2),
+ chr = letters[1:10],
+ fct = factor(LETTERS[1:10]),
+ ts = first_date + lubridate::days(1:10)
+ )
+
+ second_date <- lubridate::ymd_hms("2017-03-09 07:01:02")
+ df2 <- tibble(
+ int = 101:110,
+ dbl = c(as.numeric(51:59), NaN),
+ lgl = rep(c(TRUE, FALSE, NA, TRUE, FALSE), 2),
+ chr = letters[10:1],
+ fct = factor(LETTERS[10:1]),
+ ts = second_date + lubridate::days(10:1)
+ )
+
+ dst_dir <- tempfile()
+ stacked <- record_batch(rbind(df1, df2))
+ stacked %>%
+ mutate(twice = int * 2) %>%
+ group_by(int) %>%
+ write_dataset(dst_dir, format = "feather")
+ expect_true(dir.exists(dst_dir))
+ expect_identical(dir(dst_dir), sort(paste("int", c(1:10, 101:110), sep =
"=")))
+
+ new_ds <- open_dataset(dst_dir, format = "feather")
+
+ expect_equivalent(
+ new_ds %>%
+ select(string = chr, integer = int, twice) %>%
+ filter(integer > 6 & integer < 11) %>%
+ collect() %>%
+ summarize(mean = mean(integer)),
+ df1 %>%
+ select(string = chr, integer = int) %>%
+ mutate(twice = integer * 2) %>%
+ filter(integer > 6) %>%
+ summarize(mean = mean(integer))
+ )
+})
\ No newline at end of file
diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R
index 6d9945a..13610f1 100644
--- a/r/tests/testthat/test-dplyr.R
+++ b/r/tests/testthat/test-dplyr.R
@@ -15,74 +15,9 @@
# specific language governing permissions and limitations
# under the License.
-context("dplyr verbs")
-
library(dplyr)
library(stringr)
-expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its
start
- tbl, # A tbl/df as reference, will make
RB/Table with
- skip_record_batch = NULL, # Msg, if should skip
RB test
- skip_table = NULL, # Msg, if should skip
Table test
- ...) {
- expr <- rlang::enquo(expr)
- expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input =
tbl)))
-
- skip_msg <- NULL
-
- if (is.null(skip_record_batch)) {
- via_batch <- rlang::eval_tidy(
- expr,
- rlang::new_data_mask(rlang::env(input = record_batch(tbl)))
- )
- expect_equivalent(via_batch, expected, ...)
- } else {
- skip_msg <- c(skip_msg, skip_record_batch)
- }
-
- if (is.null(skip_table)) {
- via_table <- rlang::eval_tidy(
- expr,
- rlang::new_data_mask(rlang::env(input = Table$create(tbl)))
- )
- expect_equivalent(via_table, expected, ...)
- } else {
- skip_msg <- c(skip_msg, skip_table)
- }
-
- if (!is.null(skip_msg)) {
- skip(paste(skip_msg, collpase = "\n"))
- }
-}
-
-expect_dplyr_error <- function(expr, # A dplyr pipeline with `input` as its
start
- tbl, # A tbl/df as reference, will make
RB/Table with
- ...) {
- expr <- rlang::enquo(expr)
- msg <- tryCatch(
- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))),
- error = function (e) conditionMessage(e)
- )
- expect_is(msg, "character", label = "dplyr on data.frame did not error")
-
- expect_error(
- rlang::eval_tidy(
- expr,
- rlang::new_data_mask(rlang::env(input = record_batch(tbl)))
- ),
- msg,
- ...
- )
- expect_error(
- rlang::eval_tidy(
- expr,
- rlang::new_data_mask(rlang::env(input = Table$create(tbl)))
- ),
- msg,
- ...
- )
-}
-
tbl <- example_data
# Add some better string data
tbl$verses <- verses[[1]]
@@ -104,127 +39,6 @@ test_that("basic select/filter/collect", {
expect_identical(collect(batch), tbl)
})
-test_that("filter() on is.na()", {
- expect_dplyr_equal(
- input %>%
- filter(is.na(lgl)) %>%
- select(chr, int, lgl) %>%
- collect(),
- tbl
- )
-})
-
-test_that("filter() with NAs in selection", {
- expect_dplyr_equal(
- input %>%
- filter(lgl) %>%
- select(chr, int, lgl) %>%
- collect(),
- tbl
- )
-})
-
-test_that("Filter returning an empty Table should not segfault (ARROW-8354)", {
- expect_dplyr_equal(
- input %>%
- filter(false) %>%
- select(chr, int, lgl) %>%
- collect(),
- tbl
- )
-})
-
-test_that("filtering with expression", {
- char_sym <- "b"
- expect_dplyr_equal(
- input %>%
- filter(chr == char_sym) %>%
- select(string = chr, int) %>%
- collect(),
- tbl
- )
-})
-
-test_that("filtering with arithmetic", {
- expect_dplyr_equal(
- input %>%
- filter(dbl + 1 > 3) %>%
- select(string = chr, int, dbl) %>%
- collect(),
- tbl
- )
-
- expect_dplyr_equal(
- input %>%
- filter(dbl / 2 > 3) %>%
- select(string = chr, int, dbl) %>%
- collect(),
- tbl
- )
-
- expect_dplyr_equal(
- input %>%
- filter(dbl / 2L > 3) %>%
- select(string = chr, int, dbl) %>%
- collect(),
- tbl
- )
-
- expect_dplyr_equal(
- input %>%
- filter(int / 2 > 3) %>%
- select(string = chr, int, dbl) %>%
- collect(),
- tbl
- )
-
- expect_dplyr_equal(
- input %>%
- filter(int / 2L > 3) %>%
- select(string = chr, int, dbl) %>%
- collect(),
- tbl
- )
-
- expect_dplyr_equal(
- input %>%
- filter(dbl %/% 2 > 3) %>%
- select(string = chr, int, dbl) %>%
- collect(),
- tbl
- )
-})
-
-test_that("filtering with expression + autocasting", {
- expect_dplyr_equal(
- input %>%
- filter(dbl + 1 > 3L) %>% # test autocasting with comparison to 3L
- select(string = chr, int, dbl) %>%
- collect(),
- tbl
- )
-
- expect_dplyr_equal(
- input %>%
- filter(int + 1 > 3) %>%
- select(string = chr, int, dbl) %>%
- collect(),
- tbl
- )
-})
-
-test_that("More complex select/filter", {
- expect_dplyr_equal(
- input %>%
- filter(dbl > 2, chr == "d" | chr == "f") %>%
- select(chr, int, lgl) %>%
- filter(int < 5) %>%
- select(int, chr) %>%
- collect(),
- tbl
- )
-})
-
test_that("dim() on query", {
expect_dplyr_equal(
input %>%
@@ -247,151 +61,12 @@ test_that("Print method", {
int: int32
chr: string
-* Filter: and(and(greater(<Array>, 2), or(equal(<Array>, "d"), equal(<Array>,
"f"))), less(<Array>, 5))
+* Filter: and(and(greater(dbl, 2), or(equal(chr, "d"), equal(chr, "f"))),
less(int, 5))
See $.data for the source Arrow object',
fixed = TRUE
)
})
-test_that("filter() with %in%", {
- expect_dplyr_equal(
- input %>%
- filter(dbl > 2, chr %in% c("d", "f")) %>%
- collect(),
- tbl
- )
-})
-
-test_that("filter() with string ops", {
- # Extra instrumentation to ensure that we're calling Arrow compute here
- # because many base R string functions implicitly call as.character,
- # which means they still work on Arrays but actually force data into R
- # 1) wrapper that raises a warning if as.character is called. Can't wrap
- # the whole test because as.character apparently gets called in other
- # (presumably legitimate) places
- # 2) Wrap the test in expect_warning(expr, NA) to catch the warning
-
- with_no_as_character <- function(expr) {
- trace(
- "as.character",
- tracer = quote(warning("as.character was called")),
- print = FALSE,
- where = toupper
- )
- on.exit(untrace("as.character", where = toupper))
- force(expr)
- }
-
- expect_warning(
- expect_dplyr_equal(
- input %>%
- filter(dbl > 2, with_no_as_character(toupper(chr)) %in% c("D", "F"))
%>%
- collect(),
- tbl
- ),
- NA)
-
- expect_dplyr_equal(
- input %>%
- filter(dbl > 2, str_length(verses) > 25) %>%
- collect(),
- tbl
- )
-
- expect_dplyr_equal(
- input %>%
- filter(dbl > 2, str_length(str_trim(padded_strings, "left")) > 5) %>%
- collect(),
- tbl
- )
-})
-
-test_that("filter environment scope", {
- # "object 'b_var' not found"
- expect_dplyr_error(input %>% filter(batch, chr == b_var))
-
- b_var <- "b"
- expect_dplyr_equal(
- input %>%
- filter(chr == b_var) %>%
- collect(),
- tbl
- )
- # Also for functions
- # 'could not find function "isEqualTo"'
- expect_dplyr_error(filter(batch, isEqualTo(int, 4)))
-
- # TODO: fix this: this isEqualTo function is eagerly evaluating; it should
- # instead yield array_expressions. Probably bc the parent env of the function
- # has the Ops.Array methods defined; we need to move it so that the parent
- # env is the data mask we use in the dplyr eval
- isEqualTo <- function(x, y) x == y & !is.na(x)
- expect_dplyr_equal(
- input %>%
- select(-fct) %>% # factor levels aren't identical
- filter(isEqualTo(int, 4)) %>%
- collect(),
- tbl
- )
-})
-
-test_that("Filtering on a column that doesn't exist errors correctly", {
- skip("Error handling in filter() needs to be internationalized")
- expect_error(
- batch %>% filter(not_a_col == 42) %>% collect(),
- "object 'not_a_col' not found"
- )
-})
-
-test_that("Filtering with a function that doesn't have an Array/expr method
still works", {
- expect_warning(
- expect_dplyr_equal(
- input %>%
- filter(int > 2, pnorm(dbl) > .99) %>%
- collect(),
- tbl
- ),
- 'Filter expression not implemented in Arrow: pnorm(dbl) > 0.99; pulling
data into R',
- fixed = TRUE
- )
-})
-
-test_that("filter() with .data pronoun", {
- expect_dplyr_equal(
- input %>%
- filter(.data$dbl > 4) %>%
- select(.data$chr, .data$int, .data$lgl) %>%
- collect(),
- tbl
- )
-
- expect_dplyr_equal(
- input %>%
- filter(is.na(.data$lgl)) %>%
- select(.data$chr, .data$int, .data$lgl) %>%
- collect(),
- tbl
- )
-
- # and the .env pronoun too!
- chr <- 4
- expect_dplyr_equal(
- input %>%
- filter(.data$dbl > .env$chr) %>%
- select(.data$chr, .data$int, .data$lgl) %>%
- collect(),
- tbl
- )
-
- # but there is an error if we don't override the masking with `.env`
- expect_dplyr_error(
- tbl %>%
- filter(.data$dbl > chr) %>%
- select(.data$chr, .data$int, .data$lgl) %>%
- collect()
- )
-})
-
test_that("summarize", {
expect_dplyr_equal(
input %>%
@@ -410,29 +85,6 @@ test_that("summarize", {
)
})
-test_that("mutate", {
- expect_dplyr_equal(
- input %>%
- select(int, chr) %>%
- filter(int > 5) %>%
- mutate(int = int + 6L) %>%
- summarize(min_int = min(int)),
- tbl
- )
-})
-
-test_that("transmute", {
- skip("TODO: reimplement transmute (with dplyr 1.0, it no longer just works
via mutate)")
- expect_dplyr_equal(
- input %>%
- select(int, chr) %>%
- filter(int > 5) %>%
- transmute(int = int + 6L) %>%
- summarize(min_int = min(int)),
- tbl
- )
-})
-
test_that("group_by groupings are recorded", {
expect_dplyr_equal(
input %>%
@@ -599,7 +251,7 @@ test_that("collect(as_data_frame=FALSE)", {
select(int, strng = chr) %>%
filter(int > 5) %>%
collect(as_data_frame = FALSE)
- expect_is(b3, "arrow_dplyr_query")
+ expect_is(b3, "RecordBatch")
expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng")))
b4 <- batch %>%
@@ -632,7 +284,7 @@ test_that("head", {
select(int, strng = chr) %>%
filter(int > 5) %>%
head(2)
- expect_is(b3, "arrow_dplyr_query")
+ expect_is(b3, "RecordBatch")
expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng")))
b4 <- batch %>%
@@ -665,7 +317,7 @@ test_that("tail", {
select(int, strng = chr) %>%
filter(int > 5) %>%
tail(2)
- expect_is(b3, "arrow_dplyr_query")
+ expect_is(b3, "RecordBatch")
expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng")))
b4 <- batch %>%
diff --git a/r/tests/testthat/test-expression.R
b/r/tests/testthat/test-expression.R
index 3c10081..3df7270 100644
--- a/r/tests/testthat/test-expression.R
+++ b/r/tests/testthat/test-expression.R
@@ -34,8 +34,20 @@ test_that("array_expression print method", {
)
})
+test_that("array_refs", {
+ tab <- Table$create(a = 1:5)
+ ex <- build_array_expression(">", array_expression("array_ref", field_name =
"a"), 4)
+ expect_is(ex, "array_expression")
+ expect_identical(ex$args[[1]]$args$field_name, "a")
+ expect_identical(find_array_refs(ex), "a")
+ out <- eval_array_expression(ex, tab)
+ expect_is(out, "ChunkedArray")
+ expect_equal(as.vector(out), c(FALSE, FALSE, FALSE, FALSE, TRUE))
+})
+
test_that("C++ expressions", {
f <- Expression$field_ref("f")
+ expect_identical(f$field_name, "f")
g <- Expression$field_ref("g")
date <- Expression$scalar(as.Date("2020-01-15"))
ts <- Expression$scalar(as.POSIXct("2020-01-17 11:11:11"))