This is an automated email from the ASF dual-hosted git repository.
npr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new cc875e7 ARROW-6982: [R] Add bindings for compare and boolean kernels
cc875e7 is described below
commit cc875e712bec83046db6dad114e96c9111beb77d
Author: Neal Richardson <[email protected]>
AuthorDate: Wed Jul 22 11:36:08 2020 -0700
ARROW-6982: [R] Add bindings for compare and boolean kernels
The scope of this has grown to something larger than the description. In
addition to adding bindings to boolean kernels, it also changes how the dplyr
filter expressions are generated and evaluated for RecordBatch and Table.
Previously, any R function could be used to `filter()` because evaluation
happened in R by calling `as.vector` on any Arrays referenced. Now, `filter()`
translates R function names to Arrow function names, and evaluation passes the
function and arguments to `call_ [...]
In addition to these improvements, the patch includes some extra
validation, testing, and print method upgrades.
There are a number of less-than-ideal design choices in here. Some are
related to https://issues.apache.org/jira/browse/ARROW-9001 because we have to
track/make a guess as to whether the result of `call_function` should be an
Array, ChunkedArray, etc.
There's also a bit of duplication here between the two Arrow expression
classes, this R-specific parse tree of array/compute expressions and the other
Dataset filter expressions. I think that's unavoidable at this time but we
should and I expect we will rationalize this in the near future.
Closes #7668 from nealrichardson/r-kernels
Authored-by: Neal Richardson <[email protected]>
Signed-off-by: Neal Richardson <[email protected]>
---
r/R/array.R | 10 +-
r/R/arrowExports.R | 4 -
r/R/chunked-array.R | 2 +-
r/R/compute.R | 14 +-
r/R/dplyr.R | 11 +-
r/R/expression.R | 143 +++++++++++++++++++--
r/R/record-batch.R | 1 +
r/src/array.cpp | 16 ---
r/src/arrowExports.cpp | 16 ---
r/tests/testthat/test-Array.R | 12 +-
r/tests/testthat/test-chunked-array.R | 10 +-
.../{test-compute.R => test-compute-aggregate.R} | 7 +-
r/tests/testthat/test-compute-vector.R | 121 +++++++++++++++++
r/tests/testthat/test-dplyr.R | 22 ++++
r/tests/testthat/test-expression.R | 17 +--
15 files changed, 322 insertions(+), 84 deletions(-)
diff --git a/r/R/array.R b/r/R/array.R
index 061c421..8c9a29b 100644
--- a/r/R/array.R
+++ b/r/R/array.R
@@ -270,13 +270,7 @@ FixedSizeListArray <- R6Class("FixedSizeListArray",
inherit = Array,
length.Array <- function(x) x$length()
#' @export
-is.na.Array <- function(x) {
- if (x$type == null()) {
- rep(TRUE, length(x))
- } else {
- !Array__Mask(x)
- }
-}
+is.na.Array <- function(x) shared_ptr(Array, call_function("is_null", x))
#' @export
as.vector.Array <- function(x, mode) x$as_vector()
@@ -287,7 +281,7 @@ filter_rows <- function(x, i, keep_na = TRUE, ...) {
nrows <- x$num_rows %||% x$length() # Depends on whether Array or Table-like
if (inherits(i, "array_expression")) {
# Evaluate it
- i <- as.vector(i)
+ i <- eval_array_expression(i)
}
if (is.logical(i)) {
if (isTRUE(i)) {
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index 5a2c952..a98a6cb 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -60,10 +60,6 @@ Array__View <- function(array, type){
.Call(`_arrow_Array__View` , array, type)
}
-Array__Mask <- function(array){
- .Call(`_arrow_Array__Mask` , array)
-}
-
Array__Validate <- function(array){
invisible(.Call(`_arrow_Array__Validate` , array))
}
diff --git a/r/R/chunked-array.R b/r/R/chunked-array.R
index 5592d54..d2475eb 100644
--- a/r/R/chunked-array.R
+++ b/r/R/chunked-array.R
@@ -128,7 +128,7 @@ length.ChunkedArray <- function(x) x$length()
as.vector.ChunkedArray <- function(x, mode) x$as_vector()
#' @export
-is.na.ChunkedArray <- function(x) unlist(lapply(x$chunks, is.na))
+is.na.ChunkedArray <- function(x) shared_ptr(ChunkedArray,
call_function("is_null", x))
#' @export
`[.ChunkedArray` <- filter_rows
diff --git a/r/R/compute.R b/r/R/compute.R
index f242a58..60a1a46 100644
--- a/r/R/compute.R
+++ b/r/R/compute.R
@@ -19,9 +19,19 @@
#' @include chunked-array.R
#' @include scalar.R
-call_function <- function(function_name, ..., options = list()) {
+call_function <- function(function_name, ..., args = list(...), options =
empty_named_list()) {
assert_that(is.string(function_name))
- compute__CallFunction(function_name, list(...), options)
+ assert_that(is.list(options), !is.null(names(options)))
+
+ datum_classes <- c("Array", "ChunkedArray", "RecordBatch", "Table", "Scalar")
+ valid_args <- map_lgl(args, ~inherits(., datum_classes))
+ if (!all(valid_args)) {
+ # Lame, just pick one to report
+ first_bad <- min(which(!valid_args))
+ stop("Argument ", first_bad, " is of class ",
head(class(args[[first_bad]]), 1), " but it must be one of ",
oxford_paste(datum_classes, "or"), call. = FALSE)
+ }
+
+ compute__CallFunction(function_name, args, options)
}
#' @export
diff --git a/r/R/dplyr.R b/r/R/dplyr.R
index bf5d3c6..4d0cd5f 100644
--- a/r/R/dplyr.R
+++ b/r/R/dplyr.R
@@ -59,7 +59,12 @@ print.arrow_dplyr_query <- function(x, ...) {
cat(fields, "\n", sep = "")
cat("\n")
if (!isTRUE(x$filtered_rows)) {
- cat("* Filter: ", x$filtered_rows$ToString(), "\n", sep = "")
+ if (query_on_dataset(x)) {
+ filter_string <- x$filtered_rows$ToString()
+ } else {
+ filter_string <- .format_array_expression(x$filtered_rows)
+ }
+ cat("* Filter: ", filter_string, "\n", sep = "")
}
if (length(x$group_by_vars)) {
cat("* Grouped by ", paste(x$group_by_vars, collapse = ", "), "\n", sep =
"")
@@ -202,13 +207,13 @@ filter_mask <- function(.data) {
} else {
comp_func <- function(operator) {
force(operator)
- function(e1, e2) array_expression(operator, e1, e2)
+ function(e1, e2) build_array_expression(operator, e1, e2)
}
var_binder <- function(x) .data$.data[[x]]
}
# First add the functions
- func_names <- set_names(c(names(comparison_function_map), "&", "|", "%in%"))
+ func_names <- set_names(names(.array_function_map))
env_bind(f_env, !!!lapply(func_names, comp_func))
# Then add the column references
# Renaming is handled automatically by the named list
diff --git a/r/R/expression.R b/r/R/expression.R
index 338e152..092c7ca 100644
--- a/r/R/expression.R
+++ b/r/R/expression.R
@@ -17,30 +17,148 @@
#' @include arrowExports.R
-array_expression <- function(FUN, ...) {
- structure(list(fun = FUN, args = list(...)), class = "array_expression")
+array_expression <- function(FUN,
+ ...,
+ args = list(...),
+ options = empty_named_list(),
+ result_class = .guess_result_class(args[[1]])) {
+ structure(
+ list(
+ fun = FUN,
+ args = args,
+ options = options,
+ result_class = result_class
+ ),
+ class = "array_expression"
+ )
}
#' @export
-Ops.Array <- function(e1, e2) array_expression(.Generic, e1, e2)
+Ops.Array <- function(e1, e2) {
+ if (.Generic %in% names(.array_function_map)) {
+ expr <- build_array_expression(.Generic, e1, e2, result_class = "Array")
+ eval_array_expression(expr)
+ } else {
+ stop("Unsupported operation on Array: ", .Generic, call. = FALSE)
+ }
+}
#' @export
-Ops.ChunkedArray <- Ops.Array
+Ops.ChunkedArray <- function(e1, e2) {
+ if (.Generic %in% names(.array_function_map)) {
+ expr <- build_array_expression(.Generic, e1, e2, result_class =
"ChunkedArray")
+ eval_array_expression(expr)
+ } else {
+ stop("Unsupported operation on ChunkedArray: ", .Generic, call. = FALSE)
+ }
+}
#' @export
-Ops.array_expression <- Ops.Array
+Ops.array_expression <- function(e1, e2) {
+ if (.Generic == "!") {
+ build_array_expression(.Generic, e1, result_class = e1$result_class)
+ } else {
+ build_array_expression(.Generic, e1, e2, result_class = e1$result_class)
+ }
+}
+
+build_array_expression <- function(.Generic, e1, e2, ...) {
+ if (.Generic %in% names(.unary_function_map)) {
+ expr <- array_expression(.unary_function_map[[.Generic]], e1)
+ } else {
+ e1 <- .wrap_arrow(e1, .Generic, e2$type)
+ e2 <- .wrap_arrow(e2, .Generic, e1$type)
+ expr <- array_expression(.binary_function_map[[.Generic]], e1, e2, ...)
+ }
+ expr
+}
+
+.wrap_arrow <- function(arg, fun, type) {
+ if (!inherits(arg, c("ArrowObject", "array_expression"))) {
+ # TODO: Array$create if lengths are equal?
+ # TODO: these kernels should autocast like the dataset ones do (e.g. int
vs. float)
+ if (fun == "%in%") {
+ arg <- Array$create(arg, type = type)
+ } else {
+ arg <- Scalar$create(arg, type = type)
+ }
+ }
+ arg
+}
+
+.unary_function_map <- list(
+ "!" = "invert",
+ "is.na" = "is_null"
+)
+
+.binary_function_map <- list(
+ "==" = "equal",
+ "!=" = "not_equal",
+ ">" = "greater",
+ ">=" = "greater_equal",
+ "<" = "less",
+ "<=" = "less_equal",
+ "&" = "and_kleene",
+ "|" = "or_kleene",
+ "%in%" = "is_in_meta_binary"
+)
+
+.array_function_map <- c(.unary_function_map, .binary_function_map)
+
+.guess_result_class <- function(arg) {
+ # HACK HACK HACK delete this when call_function returns an ArrowObject itself
+ if (inherits(arg, "ArrowObject")) {
+ return(class(arg)[1])
+ } else if (inherits(arg, "array_expression")) {
+ return(arg$result_class)
+ } else {
+ stop("Not implemented")
+ }
+}
+
+eval_array_expression <- function(x) {
+ x$args <- lapply(x$args, function (a) {
+ if (inherits(a, "array_expression")) {
+ eval_array_expression(a)
+ } else {
+ a
+ }
+ })
+ ptr <- call_function(x$fun, args = x$args, options = x$options %||%
empty_named_list())
+ shared_ptr(get(x$result_class), ptr)
+}
#' @export
is.na.array_expression <- function(x) array_expression("is.na", x)
#' @export
as.vector.array_expression <- function(x, ...) {
- x$args <- lapply(x$args, as.vector)
- do.call(x$fun, x$args)
+ as.vector(eval_array_expression(x))
}
#' @export
-print.array_expression <- function(x, ...) print(as.vector(x))
+print.array_expression <- function(x, ...) {
+ cat(.format_array_expression(x), "\n", sep = "")
+ invisible(x)
+}
+
+.format_array_expression <- function(x) {
+ printed_args <- map_chr(x$args, function(arg) {
+ if (inherits(arg, "Scalar")) {
+ deparse(as.vector(arg))
+ } else if (inherits(arg, "ArrowObject")) {
+ paste0("<", class(arg)[1], ">")
+ } else if (inherits(arg, "array_expression")) {
+ .format_array_expression(arg)
+ } else {
+ # Should not happen
+ deparse(arg)
+ }
+ })
+ # Prune this for readability
+ function_name <- sub("_kleene", "", x$fun)
+ paste0(function_name, "(", paste(printed_args, collapse = ", "), ")")
+}
###########
@@ -130,6 +248,15 @@ make_expression <- function(operator, e1, e2) {
# In doesn't take Scalar, it takes Array
return(Expression$in_(e1, e2))
}
+
+ # Handle unary functions before touching e2
+ if (operator == "is.na") {
+ return(is.na(e1))
+ }
+ if (operator == "!") {
+ return(Expression$not(e1))
+ }
+
# Check for non-expressions and convert to Expressions
if (!inherits(e1, "Expression")) {
e1 <- Expression$scalar(e1)
diff --git a/r/R/record-batch.R b/r/R/record-batch.R
index cc68348..712000a 100644
--- a/r/R/record-batch.R
+++ b/r/R/record-batch.R
@@ -120,6 +120,7 @@ RecordBatch <- R6Class("RecordBatch", inherit = ArrowObject,
if (is.logical(i)) {
i <- Array$create(i)
}
+ assert_that(is.Array(i, "bool"))
shared_ptr(RecordBatch, call_function("filter", self, i, options =
list(keep_na = keep_na)))
},
serialize = function() ipc___SerializeRecordBatch__Raw(self),
diff --git a/r/src/array.cpp b/r/src/array.cpp
index 1ebb157..5879dc9 100644
--- a/r/src/array.cpp
+++ b/r/src/array.cpp
@@ -150,22 +150,6 @@ std::shared_ptr<arrow::Array> Array__View(const
std::shared_ptr<arrow::Array>& a
}
// [[arrow::export]]
-LogicalVector Array__Mask(const std::shared_ptr<arrow::Array>& array) {
- if (array->null_count() == 0) {
- return LogicalVector(array->length(), true);
- }
-
- auto n = array->length();
- LogicalVector res(no_init(n));
- arrow::internal::BitmapReader bitmap_reader(array->null_bitmap()->data(),
- array->offset(), n);
- for (int64_t i = 0; i < n; i++, bitmap_reader.Next()) {
- res[i] = bitmap_reader.IsSet();
- }
- return res;
-}
-
-// [[arrow::export]]
void Array__Validate(const std::shared_ptr<arrow::Array>& array) {
StopIfNotOk(array->Validate());
}
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index 570f126..9d0058b 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -243,21 +243,6 @@ RcppExport SEXP _arrow_Array__View(SEXP array_sexp, SEXP
type_sexp){
// array.cpp
#if defined(ARROW_R_WITH_ARROW)
-LogicalVector Array__Mask(const std::shared_ptr<arrow::Array>& array);
-RcppExport SEXP _arrow_Array__Mask(SEXP array_sexp){
-BEGIN_RCPP
- Rcpp::traits::input_parameter<const
std::shared_ptr<arrow::Array>&>::type array(array_sexp);
- return Rcpp::wrap(Array__Mask(array));
-END_RCPP
-}
-#else
-RcppExport SEXP _arrow_Array__Mask(SEXP array_sexp){
- Rf_error("Cannot call Array__Mask(). Please use arrow::install_arrow()
to install required runtime libraries. ");
-}
-#endif
-
-// array.cpp
-#if defined(ARROW_R_WITH_ARROW)
void Array__Validate(const std::shared_ptr<arrow::Array>& array);
RcppExport SEXP _arrow_Array__Validate(SEXP array_sexp){
BEGIN_RCPP
@@ -5940,7 +5925,6 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_Array__data", (DL_FUNC) &_arrow_Array__data, 1},
{ "_arrow_Array__RangeEquals", (DL_FUNC)
&_arrow_Array__RangeEquals, 5},
{ "_arrow_Array__View", (DL_FUNC) &_arrow_Array__View, 2},
- { "_arrow_Array__Mask", (DL_FUNC) &_arrow_Array__Mask, 1},
{ "_arrow_Array__Validate", (DL_FUNC) &_arrow_Array__Validate,
1},
{ "_arrow_DictionaryArray__indices", (DL_FUNC)
&_arrow_DictionaryArray__indices, 1},
{ "_arrow_DictionaryArray__dictionary", (DL_FUNC)
&_arrow_DictionaryArray__dictionary, 1},
diff --git a/r/tests/testthat/test-Array.R b/r/tests/testthat/test-Array.R
index d60fea4..ce1b5be 100644
--- a/r/tests/testthat/test-Array.R
+++ b/r/tests/testthat/test-Array.R
@@ -25,7 +25,7 @@ expect_array_roundtrip <- function(x, type, as = NULL) {
# TODO: revisit how missingness works with ListArrays
# R list objects don't handle missingness the same way as other vectors.
# Is there some vctrs thing we should do on the roundtrip back to R?
- expect_identical(is.na(a), is.na(x))
+ expect_equal(as.vector(is.na(a)), is.na(x))
}
expect_equivalent(as.vector(a), x)
# Make sure the storage mode is the same on roundtrip (esp. integer vs.
numeric)
@@ -37,7 +37,7 @@ expect_array_roundtrip <- function(x, type, as = NULL) {
expect_type_equal(a_sliced$type, type)
expect_identical(length(a_sliced), length(x_sliced))
if (!inherits(type, c("ListType", "LargeListType"))) {
- expect_identical(is.na(a_sliced), is.na(x_sliced))
+ expect_equal(as.vector(is.na(a_sliced)), is.na(x_sliced))
}
expect_equivalent(as.vector(a_sliced), x_sliced)
}
@@ -182,8 +182,8 @@ test_that("Array supports NA", {
expect_true(x_int$IsNull(10L))
expect_true(x_dbl$IsNull(10))
- expect_equal(is.na(x_int), c(rep(FALSE, 10), TRUE))
- expect_equal(is.na(x_dbl), c(rep(FALSE, 10), TRUE))
+ expect_equal(as.vector(is.na(x_int)), c(rep(FALSE, 10), TRUE))
+ expect_equal(as.vector(is.na(x_dbl)), c(rep(FALSE, 10), TRUE))
# Input validation
expect_error(x_int$IsValid("ten"), class = "Rcpp::not_compatible")
@@ -354,7 +354,7 @@ test_that("integer types casts (ARROW-3741)", {
for (type in c(int_types, uint_types)) {
casted <- a$cast(type)
expect_equal(casted$type, type)
- expect_identical(is.na(casted), c(rep(FALSE, 10), TRUE))
+ expect_identical(as.vector(is.na(casted)), c(rep(FALSE, 10), TRUE))
}
})
@@ -372,7 +372,7 @@ test_that("float types casts (ARROW-3741)", {
for (type in float_types) {
casted <- a$cast(type)
expect_equal(casted$type, type)
- expect_identical(is.na(casted), c(rep(FALSE, 3), TRUE))
+ expect_identical(as.vector(is.na(casted)), c(rep(FALSE, 3), TRUE))
expect_identical(as.vector(casted), x)
}
})
diff --git a/r/tests/testthat/test-chunked-array.R
b/r/tests/testthat/test-chunked-array.R
index 75f27aa..b4695e2 100644
--- a/r/tests/testthat/test-chunked-array.R
+++ b/r/tests/testthat/test-chunked-array.R
@@ -28,7 +28,7 @@ expect_chunked_roundtrip <- function(x, type) {
# TODO: revisit how missingness works with ListArrays
# R list objects don't handle missingness the same way as other vectors.
# Is there some vctrs thing we should do on the roundtrip back to R?
- expect_identical(is.na(a), is.na(flat_x))
+ expect_identical(as.vector(is.na(a)), is.na(flat_x))
}
expect_equal(as.vector(a), flat_x)
expect_equal(as.vector(a$chunk(0)), x[[1]])
@@ -39,7 +39,7 @@ expect_chunked_roundtrip <- function(x, type) {
expect_type_equal(a_sliced$type, type)
expect_identical(length(a_sliced), length(x_sliced))
if (!inherits(type, "ListType")) {
- expect_identical(is.na(a_sliced), is.na(x_sliced))
+ expect_identical(as.vector(is.na(a_sliced)), is.na(x_sliced))
}
expect_equal(as.vector(a_sliced), x_sliced)
}
@@ -117,10 +117,8 @@ test_that("ChunkedArray handles NA", {
expect_equal(as.vector(x), c(1:10, c(NA, 2:10), c(1:3, NA, 5)))
chunks <- x$chunks
- expect_equal(is.na(chunks[[1]]), is.na(data[[1]]))
- expect_equal(is.na(chunks[[2]]), is.na(data[[2]]))
- expect_equal(is.na(chunks[[3]]), is.na(data[[3]]))
- expect_equal(is.na(x), c(is.na(data[[1]]), is.na(data[[2]]),
is.na(data[[3]])))
+ expect_equal(as.vector(is.na(chunks[[2]])), is.na(data[[2]]))
+ expect_equal(as.vector(is.na(x)), c(is.na(data[[1]]), is.na(data[[2]]),
is.na(data[[3]])))
})
test_that("ChunkedArray supports logical vectors (ARROW-3341)", {
diff --git a/r/tests/testthat/test-compute.R
b/r/tests/testthat/test-compute-aggregate.R
similarity index 94%
rename from r/tests/testthat/test-compute.R
rename to r/tests/testthat/test-compute-aggregate.R
index 811c27d..1e5f9a4 100644
--- a/r/tests/testthat/test-compute.R
+++ b/r/tests/testthat/test-compute-aggregate.R
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-context("compute")
+context("compute: aggregation")
test_that("sum.Array", {
ints <- 1:5
@@ -94,7 +94,10 @@ test_that("mean.Scalar", {
})
test_that("Bad input handling of call_function", {
- expect_error(call_function("sum", 2, 3), "to_datum: Not implemented for type
double")
+ expect_error(
+ call_function("sum", 2, 3),
+ 'Argument 1 is of class numeric but it must be one of "Array",
"ChunkedArray", "RecordBatch", "Table", or "Scalar"'
+ )
})
test_that("min/max.Array", {
diff --git a/r/tests/testthat/test-compute-vector.R
b/r/tests/testthat/test-compute-vector.R
new file mode 100644
index 0000000..b9097b6
--- /dev/null
+++ b/r/tests/testthat/test-compute-vector.R
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+context("compute: vector operations")
+
+expect_bool_function_equal <- function(array_exp, r_exp, class = "Array") {
+ # Assert that the Array operation returns a boolean array
+ # and that its contents are equal to expected
+ expect_is(array_exp, class)
+ expect_type_equal(array_exp, bool())
+ expect_identical(as.vector(array_exp), r_exp)
+}
+
+expect_array_compares <- function(r_values, compared_to, Class = Array) {
+ a <- Class$create(r_values)
+ # Iterate over all comparison functions
+ expect_bool_function_equal(a == compared_to, r_values == compared_to,
class(a))
+ expect_bool_function_equal(a != compared_to, r_values != compared_to,
class(a))
+ expect_bool_function_equal(a > compared_to, r_values > compared_to, class(a))
+ expect_bool_function_equal(a >= compared_to, r_values >= compared_to,
class(a))
+ expect_bool_function_equal(a < compared_to, r_values < compared_to, class(a))
+ expect_bool_function_equal(a <= compared_to, r_values <= compared_to,
class(a))
+}
+
+expect_chunked_array_compares <- function(...) expect_array_compares(...,
Class = ChunkedArray)
+
+test_that("compare ops with Array", {
+ expect_array_compares(1:5, 4L)
+ expect_array_compares(1:5, 4) # implicit casting
+ expect_array_compares(c(NA, 1:5), 4)
+ expect_array_compares(as.numeric(c(NA, 1:5)), 4)
+})
+
+test_that("compare ops with ChunkedArray", {
+ expect_chunked_array_compares(1:5, 4L)
+ expect_chunked_array_compares(1:5, 4) # implicit casting
+ expect_chunked_array_compares(c(NA, 1:5), 4)
+ expect_chunked_array_compares(as.numeric(c(NA, 1:5)), 4)
+})
+
+test_that("logic ops with Array", {
+ truth <- expand.grid(left = c(TRUE, FALSE, NA), right = c(TRUE, FALSE, NA))
+ a_left <- Array$create(truth$left)
+ a_right <- Array$create(truth$right)
+ expect_bool_function_equal(a_left & a_right, truth$left & truth$right)
+ expect_bool_function_equal(a_left | a_right, truth$left | truth$right)
+ expect_bool_function_equal(a_left == a_right, truth$left == truth$right)
+ expect_bool_function_equal(a_left != a_right, truth$left != truth$right)
+ expect_bool_function_equal(!a_left, !truth$left)
+
+ # More complexity
+ isEqualTo <- function(x, y) x == y & !is.na(x)
+ expect_bool_function_equal(
+ isEqualTo(a_left, a_right),
+ isEqualTo(truth$left, truth$right)
+ )
+})
+
+test_that("logic ops with ChunkedArray", {
+ truth <- expand.grid(left = c(TRUE, FALSE, NA), right = c(TRUE, FALSE, NA))
+ a_left <- ChunkedArray$create(truth$left)
+ a_right <- ChunkedArray$create(truth$right)
+ expect_bool_function_equal(a_left & a_right, truth$left & truth$right,
"ChunkedArray")
+ expect_bool_function_equal(a_left | a_right, truth$left | truth$right,
"ChunkedArray")
+ expect_bool_function_equal(a_left == a_right, truth$left == truth$right,
"ChunkedArray")
+ expect_bool_function_equal(a_left != a_right, truth$left != truth$right,
"ChunkedArray")
+ expect_bool_function_equal(!a_left, !truth$left, "ChunkedArray")
+
+ # More complexity
+ isEqualTo <- function(x, y) x == y & !is.na(x)
+ expect_bool_function_equal(
+ isEqualTo(a_left, a_right),
+ isEqualTo(truth$left, truth$right),
+ "ChunkedArray"
+ )
+})
+
+test_that("call_function validation", {
+ expect_error(
+ call_function("filter", 4),
+ 'Argument 1 is of class numeric but it must be one of "Array",
"ChunkedArray", "RecordBatch", "Table", or "Scalar"'
+ )
+ expect_error(
+ call_function("filter", Array$create(1:4), 3),
+ 'Argument 2 is of class numeric'
+ )
+ expect_error(
+ call_function("filter",
+ Array$create(1:4),
+ Array$create(c(TRUE, FALSE, TRUE)),
+ options = list(keep_na = TRUE)
+ ),
+ "Array arguments must all be the same length"
+ )
+ expect_error(
+ call_function("filter",
+ record_batch(a = 1:3),
+ Array$create(c(TRUE, FALSE, TRUE)),
+ options = list(keep_na = TRUE)
+ ),
+ NA
+ )
+ expect_error(
+ call_function("filter", options = list(keep_na = TRUE)),
+ "accepts 2 arguments"
+ )
+})
diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R
index 7b4afda..995ba8a 100644
--- a/r/tests/testthat/test-dplyr.R
+++ b/r/tests/testthat/test-dplyr.R
@@ -145,6 +145,24 @@ test_that("More complex select/filter", {
)
})
+test_that("Print method", {
+ expect_output(
+ record_batch(tbl) %>%
+ filter(dbl > 2, chr == "d" | chr == "f") %>%
+ select(chr, int, lgl) %>%
+ filter(int < 5) %>%
+ select(int, chr) %>%
+ print(),
+'RecordBatch (query)
+int: int32
+chr: string
+
+* Filter: and(and(greater(<Array>, 2), or(equal(<Array>, "d"), equal(<Array>,
"f"))), less(<Array>, 5L))
+See $.data for the source Arrow object',
+ fixed = TRUE
+ )
+})
+
test_that("filter() with %in%", {
expect_dplyr_equal(
input %>%
@@ -169,6 +187,10 @@ test_that("filter environment scope", {
# 'could not find function "isEqualTo"'
expect_dplyr_error(filter(batch, isEqualTo(int, 4)))
+ # TODO: fix this: this isEqualTo function is eagerly evaluating; it should
+ # instead yield array_expressions. Probably bc the parent env of the function
+ # has the Ops.Array methods defined; we need to move it so that the parent
+ # env is the data mask we use in the dplyr eval
isEqualTo <- function(x, y) x == y & !is.na(x)
expect_dplyr_equal(
input %>%
diff --git a/r/tests/testthat/test-expression.R
b/r/tests/testthat/test-expression.R
index f75926e..1bf0859 100644
--- a/r/tests/testthat/test-expression.R
+++ b/r/tests/testthat/test-expression.R
@@ -18,25 +18,18 @@
context("Expressions")
test_that("Can create an expression", {
- expect_is(Array$create(1:5) + 4, "array_expression")
-})
-
-test_that("Recursive expression generation", {
- a <- Array$create(1:5)
- expect_is(a == 4 | a == 3, "array_expression")
+ expect_is(build_array_expression(">", Array$create(1:5), 4),
"array_expression")
})
test_that("as.vector(array_expression)", {
- a <- Array$create(1:5)
- expect_equal(as.vector(a + 4), 5:9)
- expect_equal(as.vector(a == 4 | a == 3), c(FALSE, FALSE, TRUE, TRUE, FALSE))
+ expect_equal(as.vector(build_array_expression(">", Array$create(1:5), 4)),
c(FALSE, FALSE, FALSE, FALSE, TRUE))
})
test_that("array_expression print method", {
- a <- Array$create(1:5)
expect_output(
- print(a == 4 | a == 3),
- capture.output(print(c(FALSE, FALSE, TRUE, TRUE, FALSE))),
+ print(build_array_expression(">", Array$create(1:5), 4)),
+ # Not ideal but it is informative
+ "greater(<Array>, 4L)",
fixed = TRUE
)
})