This is an automated email from the ASF dual-hosted git repository.
jonkeane pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 5845556 ARROW-14226: [R] Handle n_distinct() (and others) with args
!= 1
5845556 is described below
commit 58455564fbcf67219947a3b0a9806e11c54c7318
Author: Neal Richardson <[email protected]>
AuthorDate: Thu Oct 14 15:07:46 2021 -0500
ARROW-14226: [R] Handle n_distinct() (and others) with args != 1
Turns out that we had a bunch of other aggregation functions that needed
the same validation
Closes #11407 from nealrichardson/n-distinct-args
Authored-by: Neal Richardson <[email protected]>
Signed-off-by: Jonathan Keane <[email protected]>
---
r/R/dplyr-functions.R | 37 +++++++++++-----------
r/tests/testthat/test-dplyr-summarize.R | 54 +++++++++++++++++++++++++++++++++
2 files changed, 73 insertions(+), 18 deletions(-)
diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R
index d10132e..1d68092 100644
--- a/r/R/dplyr-functions.R
+++ b/r/R/dplyr-functions.R
@@ -941,24 +941,24 @@ nse_funcs$case_when <- function(...) {
# So to see a list of available hash aggregation functions,
# you can use list_compute_functions("^hash_")
agg_funcs <- list()
-agg_funcs$sum <- function(x, na.rm = FALSE) {
+agg_funcs$sum <- function(..., na.rm = FALSE) {
list(
fun = "sum",
- data = x,
+ data = ensure_one_arg(list2(...), "sum"),
options = list(skip_nulls = na.rm, min_count = 0L)
)
}
-agg_funcs$any <- function(x, na.rm = FALSE) {
+agg_funcs$any <- function(..., na.rm = FALSE) {
list(
fun = "any",
- data = x,
+ data = ensure_one_arg(list2(...), "any"),
options = list(skip_nulls = na.rm, min_count = 0L)
)
}
-agg_funcs$all <- function(x, na.rm = FALSE) {
+agg_funcs$all <- function(..., na.rm = FALSE) {
list(
fun = "all",
- data = x,
+ data = ensure_one_arg(list2(...), "all"),
options = list(skip_nulls = na.rm, min_count = 0L)
)
}
@@ -1014,10 +1014,10 @@ agg_funcs$median <- function(x, na.rm = FALSE) {
options = list(skip_nulls = na.rm)
)
}
-agg_funcs$n_distinct <- function(x, na.rm = FALSE) {
+agg_funcs$n_distinct <- function(..., na.rm = FALSE) {
list(
fun = "count_distinct",
- data = x,
+ data = ensure_one_arg(list2(...), "n_distinct"),
options = list(na.rm = na.rm)
)
}
@@ -1029,28 +1029,29 @@ agg_funcs$n <- function() {
)
}
agg_funcs$min <- function(..., na.rm = FALSE) {
- args <- list2(...)
- if (length(args) > 1) {
- arrow_not_supported("Multiple arguments to min()")
- }
list(
fun = "min",
- data = args[[1]],
+ data = ensure_one_arg(list2(...), "min"),
options = list(skip_nulls = na.rm, min_count = 0L)
)
}
agg_funcs$max <- function(..., na.rm = FALSE) {
- args <- list2(...)
- if (length(args) > 1) {
- arrow_not_supported("Multiple arguments to max()")
- }
list(
fun = "max",
- data = args[[1]],
+ data = ensure_one_arg(list2(...), "max"),
options = list(skip_nulls = na.rm, min_count = 0L)
)
}
+ensure_one_arg <- function(args, fun) {
+ if (length(args) == 0) {
+ arrow_not_supported(paste0(fun, "() with 0 arguments"))
+ } else if (length(args) > 1) {
+ arrow_not_supported(paste0("Multiple arguments to ", fun, "()"))
+ }
+ args[[1]]
+}
+
output_type <- function(fun, input_type, hash) {
# These are quick and dirty heuristics.
if (fun %in% c("any", "all")) {
diff --git a/r/tests/testthat/test-dplyr-summarize.R
b/r/tests/testthat/test-dplyr-summarize.R
index aa2bf23..c13aa4f 100644
--- a/r/tests/testthat/test-dplyr-summarize.R
+++ b/r/tests/testthat/test-dplyr-summarize.R
@@ -241,6 +241,53 @@ test_that("n_distinct() on dataset", {
collect(),
tbl
)
+
+ expect_dplyr_equal(
+ input %>%
+ summarize(distinct = n_distinct(int, lgl)) %>%
+ collect(),
+ tbl,
+ warning = "Multiple arguments"
+ )
+ expect_dplyr_equal(
+ input %>%
+ group_by(some_grouping) %>%
+ summarize(distinct = n_distinct(int, lgl)) %>%
+ collect(),
+ tbl,
+ warning = "Multiple arguments"
+ )
+})
+
+test_that("Functions that take ... but we only accept a single arg", {
+ expect_dplyr_equal(
+ input %>%
+ summarize(distinct = n_distinct()) %>%
+ collect(),
+ tbl,
+ warning = "0 arguments"
+ )
+ expect_dplyr_equal(
+ input %>%
+ summarize(distinct = n_distinct(int, lgl)) %>%
+ collect(),
+ tbl,
+ warning = "Multiple arguments"
+ )
+ # Now that we've demonstrated that the whole machinery works, let's test
+ # the agg_funcs directly
+ expect_error(agg_funcs$n_distinct(), "n_distinct() with 0 arguments", fixed
= TRUE)
+ expect_error(agg_funcs$sum(), "sum() with 0 arguments", fixed = TRUE)
+ expect_error(agg_funcs$any(), "any() with 0 arguments", fixed = TRUE)
+ expect_error(agg_funcs$all(), "all() with 0 arguments", fixed = TRUE)
+ expect_error(agg_funcs$min(), "min() with 0 arguments", fixed = TRUE)
+ expect_error(agg_funcs$max(), "max() with 0 arguments", fixed = TRUE)
+ expect_error(agg_funcs$n_distinct(1, 2), "Multiple arguments to
n_distinct()")
+ expect_error(agg_funcs$sum(1, 2), "Multiple arguments to sum")
+ expect_error(agg_funcs$any(1, 2), "Multiple arguments to any()")
+ expect_error(agg_funcs$all(1, 2), "Multiple arguments to all()")
+ expect_error(agg_funcs$min(1, 2), "Multiple arguments to min()")
+ expect_error(agg_funcs$max(1, 2), "Multiple arguments to max()")
})
test_that("median()", {
@@ -249,6 +296,10 @@ test_that("median()", {
# output of type float64. The calls to median(int, ...) in the tests below
# are enclosed in as.double() to work around this known difference.
+ # Use old testthat behavior here so we don't have to assert the same warning
+ # over and over
+ local_edition(2)
+
# with groups
expect_dplyr_equal(
input %>%
@@ -290,6 +341,7 @@ test_that("median()", {
tbl,
warning = "median\\(\\) currently returns an approximate median in Arrow"
)
+ local_edition(3)
})
test_that("quantile()", {
@@ -315,6 +367,7 @@ test_that("quantile()", {
# return output of type float64. The calls to quantile(int, ...) in the tests
# below are enclosed in as.double() to work around this known difference.
+ local_edition(2)
# with groups
expect_warning(
expect_equal(
@@ -378,6 +431,7 @@ test_that("quantile()", {
"quantile() currently returns an approximate quantile in Arrow",
fixed = TRUE
)
+ local_edition(3)
# with a vector of 2+ probs
expect_warning(