This is an automated email from the ASF dual-hosted git repository.
paleolimbot pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 557acf524f ARROW-17689: [R] Implement dplyr::across() inside
group_by() (#14122)
557acf524f is described below
commit 557acf524f6b73d73bdb9464e947b78b9d02fcea
Author: eitsupi <[email protected]>
AuthorDate: Fri Sep 16 10:42:34 2022 +0900
ARROW-17689: [R] Implement dplyr::across() inside group_by() (#14122)
Because the handling of the case `.add = TRUE` and the `add` argument have
been changed, test cases for these are also added.
Authored-by: SHIMA Tatsuya <[email protected]>
Signed-off-by: Dewey Dunnington <[email protected]>
---
r/R/dplyr-group-by.R | 38 +++++-------
r/tests/testthat/test-dplyr-group-by.R | 110 +++++++++++++++++++++++++++++++++
2 files changed, 126 insertions(+), 22 deletions(-)
diff --git a/r/R/dplyr-group-by.R b/r/R/dplyr-group-by.R
index c650799e8d..57cf417c9a 100644
--- a/r/R/dplyr-group-by.R
+++ b/r/R/dplyr-group-by.R
@@ -21,37 +21,31 @@
group_by.arrow_dplyr_query <- function(.data,
...,
.add = FALSE,
- add = .add,
+ add = NULL,
.drop =
dplyr::group_by_drop_default(.data)) {
+ if (!missing(add)) {
+ .Deprecated(
+ msg = paste("The `add` argument of `group_by()` is deprecated. Please
use the `.add` argument instead.")
+ )
+ .add <- add
+ }
+
.data <- as_adq(.data)
- new_groups <- enquos(...)
- # ... can contain expressions (i.e. can add (or rename?) columns) and so we
- # need to identify those and add them on to the query with mutate.
Specifically,
- # we want to mark as new:
- # * expressions (named or otherwise)
- # * variables that have new names
- # All others (i.e. simple references to variables) should not be (re)-added
+ expression_list <- expand_across(.data, quos(...))
+ new_groups <- ensure_named_exprs(expression_list)
- # Identify any groups with names which aren't in names of .data
- new_group_ind <- map_lgl(new_groups, ~ !(quo_name(.x) %in% names(.data)))
- # Identify any groups which don't have names
- named_group_ind <- map_lgl(names(new_groups), nzchar)
- # Retain any new groups identified above
- new_groups <- new_groups[new_group_ind | named_group_ind]
if (length(new_groups)) {
- # now either use the name that was given in ... or if that is "" then use
the expr
- names(new_groups) <- imap_chr(new_groups, ~ ifelse(.y == "", quo_name(.x),
.y))
-
# Add them to the data
.data <- dplyr::mutate(.data, !!!new_groups)
}
- if (".add" %in% names(formals(dplyr::group_by))) {
- # For compatibility with dplyr >= 1.0
- gv <- dplyr::group_by_prepare(.data, ..., .add = .add)$group_names
+
+ if (.add) {
+ gv <- union(dplyr::group_vars(.data), names(new_groups))
} else {
- gv <- dplyr::group_by_prepare(.data, ..., add = add)$group_names
+ gv <- names(new_groups)
}
- .data$group_by_vars <- gv
+
+ .data$group_by_vars <- gv %||% character()
.data$drop_empty_groups <- ifelse(length(gv), .drop,
dplyr::group_by_drop_default(.data))
.data
}
diff --git a/r/tests/testthat/test-dplyr-group-by.R
b/r/tests/testthat/test-dplyr-group-by.R
index c7380e96ec..9bb6aa9600 100644
--- a/r/tests/testthat/test-dplyr-group-by.R
+++ b/r/tests/testthat/test-dplyr-group-by.R
@@ -166,3 +166,113 @@ test_that("group_by() with namespaced functions", {
tbl
)
})
+
+test_that("group_by() with .add", {
+ compare_dplyr_binding(
+ .input %>%
+ group_by(dbl2) %>%
+ group_by(.add = FALSE) %>%
+ collect(),
+ tbl
+ )
+ compare_dplyr_binding(
+ .input %>%
+ group_by(dbl2) %>%
+ group_by(.add = TRUE) %>%
+ collect(),
+ tbl
+ )
+ compare_dplyr_binding(
+ .input %>%
+ group_by(dbl2) %>%
+ group_by(chr, .add = FALSE) %>%
+ collect(),
+ tbl
+ )
+ compare_dplyr_binding(
+ .input %>%
+ group_by(dbl2) %>%
+ group_by(chr, .add = TRUE) %>%
+ collect(),
+ tbl
+ )
+ compare_dplyr_binding(
+ .input %>%
+ group_by(chr, .add = FALSE) %>%
+ collect(),
+ tbl %>%
+ group_by(dbl2)
+ )
+ compare_dplyr_binding(
+ .input %>%
+ group_by(chr, .add = TRUE) %>%
+ collect(),
+ tbl %>%
+ group_by(dbl2)
+ )
+ suppressWarnings(compare_dplyr_binding(
+ .input %>%
+ group_by(dbl2) %>%
+ group_by(add = FALSE) %>%
+ collect(),
+ tbl,
+ warning = "deprecated"
+ ))
+ suppressWarnings(compare_dplyr_binding(
+ .input %>%
+ group_by(dbl2) %>%
+ group_by(add = TRUE) %>%
+ collect(),
+ tbl,
+ warning = "deprecated"
+ ))
+ expect_warning(
+ tbl %>%
+ arrow_table() %>%
+ group_by(add = TRUE) %>%
+ collect(),
+ "The `add` argument of `group_by\\(\\)` is deprecated"
+ )
+ expect_error(
+ suppressWarnings(
+ tbl %>%
+ arrow_table() %>%
+ group_by(add = dbl2) %>%
+ collect()
+ ),
+ "object 'dbl2' not found"
+ )
+})
+
+test_that("Can use across() within group_by()", {
+ test_groups <- c("dbl", "int", "chr")
+ compare_dplyr_binding(
+ .input %>%
+ group_by(across()) %>%
+ collect(),
+ tbl
+ )
+ compare_dplyr_binding(
+ .input %>%
+ group_by(across(starts_with("d"))) %>%
+ collect(),
+ tbl
+ )
+ compare_dplyr_binding(
+ .input %>%
+ group_by(across({{ test_groups }})) %>%
+ collect(),
+ tbl
+ )
+
+ # ARROW-12778 - `where()` is not yet supported
+ expect_error(
+ compare_dplyr_binding(
+ .input %>%
+ group_by(across(where(is.numeric))) %>%
+ collect(),
+ tbl
+ ),
+ "Unsupported selection helper"
+ )
+})