This is an automated email from the ASF dual-hosted git repository.

jonkeane pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 5845556  ARROW-14226: [R] Handle n_distinct() (and others) with args 
!= 1
5845556 is described below

commit 58455564fbcf67219947a3b0a9806e11c54c7318
Author: Neal Richardson <[email protected]>
AuthorDate: Thu Oct 14 15:07:46 2021 -0500

    ARROW-14226: [R] Handle n_distinct() (and others) with args != 1
    
    Turns out that we had a bunch of other aggregation functions that needed 
the same validation
    
    Closes #11407 from nealrichardson/n-distinct-args
    
    Authored-by: Neal Richardson <[email protected]>
    Signed-off-by: Jonathan Keane <[email protected]>
---
 r/R/dplyr-functions.R                   | 37 +++++++++++-----------
 r/tests/testthat/test-dplyr-summarize.R | 54 +++++++++++++++++++++++++++++++++
 2 files changed, 73 insertions(+), 18 deletions(-)

diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R
index d10132e..1d68092 100644
--- a/r/R/dplyr-functions.R
+++ b/r/R/dplyr-functions.R
@@ -941,24 +941,24 @@ nse_funcs$case_when <- function(...) {
 # So to see a list of available hash aggregation functions,
 # you can use list_compute_functions("^hash_")
 agg_funcs <- list()
-agg_funcs$sum <- function(x, na.rm = FALSE) {
+agg_funcs$sum <- function(..., na.rm = FALSE) {
   list(
     fun = "sum",
-    data = x,
+    data = ensure_one_arg(list2(...), "sum"),
     options = list(skip_nulls = na.rm, min_count = 0L)
   )
 }
-agg_funcs$any <- function(x, na.rm = FALSE) {
+agg_funcs$any <- function(..., na.rm = FALSE) {
   list(
     fun = "any",
-    data = x,
+    data = ensure_one_arg(list2(...), "any"),
     options = list(skip_nulls = na.rm, min_count = 0L)
   )
 }
-agg_funcs$all <- function(x, na.rm = FALSE) {
+agg_funcs$all <- function(..., na.rm = FALSE) {
   list(
     fun = "all",
-    data = x,
+    data = ensure_one_arg(list2(...), "all"),
     options = list(skip_nulls = na.rm, min_count = 0L)
   )
 }
@@ -1014,10 +1014,10 @@ agg_funcs$median <- function(x, na.rm = FALSE) {
     options = list(skip_nulls = na.rm)
   )
 }
-agg_funcs$n_distinct <- function(x, na.rm = FALSE) {
+agg_funcs$n_distinct <- function(..., na.rm = FALSE) {
   list(
     fun = "count_distinct",
-    data = x,
+    data = ensure_one_arg(list2(...), "n_distinct"),
     options = list(na.rm = na.rm)
   )
 }
@@ -1029,28 +1029,29 @@ agg_funcs$n <- function() {
   )
 }
 agg_funcs$min <- function(..., na.rm = FALSE) {
-  args <- list2(...)
-  if (length(args) > 1) {
-    arrow_not_supported("Multiple arguments to min()")
-  }
   list(
     fun = "min",
-    data = args[[1]],
+    data = ensure_one_arg(list2(...), "min"),
     options = list(skip_nulls = na.rm, min_count = 0L)
   )
 }
 agg_funcs$max <- function(..., na.rm = FALSE) {
-  args <- list2(...)
-  if (length(args) > 1) {
-    arrow_not_supported("Multiple arguments to max()")
-  }
   list(
     fun = "max",
-    data = args[[1]],
+    data = ensure_one_arg(list2(...), "max"),
     options = list(skip_nulls = na.rm, min_count = 0L)
   )
 }
 
+ensure_one_arg <- function(args, fun) {
+  if (length(args) == 0) {
+    arrow_not_supported(paste0(fun, "() with 0 arguments"))
+  } else if (length(args) > 1) {
+    arrow_not_supported(paste0("Multiple arguments to ", fun, "()"))
+  }
+  args[[1]]
+}
+
 output_type <- function(fun, input_type, hash) {
   # These are quick and dirty heuristics.
   if (fun %in% c("any", "all")) {
diff --git a/r/tests/testthat/test-dplyr-summarize.R 
b/r/tests/testthat/test-dplyr-summarize.R
index aa2bf23..c13aa4f 100644
--- a/r/tests/testthat/test-dplyr-summarize.R
+++ b/r/tests/testthat/test-dplyr-summarize.R
@@ -241,6 +241,53 @@ test_that("n_distinct() on dataset", {
       collect(),
     tbl
   )
+
+  expect_dplyr_equal(
+    input %>%
+      summarize(distinct = n_distinct(int, lgl)) %>%
+      collect(),
+    tbl,
+    warning = "Multiple arguments"
+  )
+  expect_dplyr_equal(
+    input %>%
+      group_by(some_grouping) %>%
+      summarize(distinct = n_distinct(int, lgl)) %>%
+      collect(),
+    tbl,
+    warning = "Multiple arguments"
+  )
+})
+
+test_that("Functions that take ... but we only accept a single arg", {
+  expect_dplyr_equal(
+    input %>%
+      summarize(distinct = n_distinct()) %>%
+      collect(),
+    tbl,
+    warning = "0 arguments"
+  )
+  expect_dplyr_equal(
+    input %>%
+      summarize(distinct = n_distinct(int, lgl)) %>%
+      collect(),
+    tbl,
+    warning = "Multiple arguments"
+  )
+  # Now that we've demonstrated that the whole machinery works, let's test
+  # the agg_funcs directly
+  expect_error(agg_funcs$n_distinct(), "n_distinct() with 0 arguments", fixed 
= TRUE)
+  expect_error(agg_funcs$sum(), "sum() with 0 arguments", fixed = TRUE)
+  expect_error(agg_funcs$any(), "any() with 0 arguments", fixed = TRUE)
+  expect_error(agg_funcs$all(), "all() with 0 arguments", fixed = TRUE)
+  expect_error(agg_funcs$min(), "min() with 0 arguments", fixed = TRUE)
+  expect_error(agg_funcs$max(), "max() with 0 arguments", fixed = TRUE)
+  expect_error(agg_funcs$n_distinct(1, 2), "Multiple arguments to 
n_distinct()")
+  expect_error(agg_funcs$sum(1, 2), "Multiple arguments to sum")
+  expect_error(agg_funcs$any(1, 2), "Multiple arguments to any()")
+  expect_error(agg_funcs$all(1, 2), "Multiple arguments to all()")
+  expect_error(agg_funcs$min(1, 2), "Multiple arguments to min()")
+  expect_error(agg_funcs$max(1, 2), "Multiple arguments to max()")
 })
 
 test_that("median()", {
@@ -249,6 +296,10 @@ test_that("median()", {
   # output of type float64. The calls to median(int, ...) in the tests below
   # are enclosed in as.double() to work around this known difference.
 
+  # Use old testthat behavior here so we don't have to assert the same warning
+  # over and over
+  local_edition(2)
+
   # with groups
   expect_dplyr_equal(
     input %>%
@@ -290,6 +341,7 @@ test_that("median()", {
     tbl,
     warning = "median\\(\\) currently returns an approximate median in Arrow"
   )
+  local_edition(3)
 })
 
 test_that("quantile()", {
@@ -315,6 +367,7 @@ test_that("quantile()", {
   # return output of type float64. The calls to quantile(int, ...) in the tests
   # below are enclosed in as.double() to work around this known difference.
 
+  local_edition(2)
   # with groups
   expect_warning(
     expect_equal(
@@ -378,6 +431,7 @@ test_that("quantile()", {
     "quantile() currently returns an approximate quantile in Arrow",
     fixed = TRUE
   )
+  local_edition(3)
 
   # with a vector of 2+ probs
   expect_warning(

Reply via email to