This is an automated email from the ASF dual-hosted git repository.
thisisnic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 6bd0050811 GH-35445: [R] Behavior something like group_by(foo) |>
across(everything()) is different from dplyr (#35473)
6bd0050811 is described below
commit 6bd00508116edea5afcdc4e3e11cd9fa789b70a3
Author: eitsupi <[email protected]>
AuthorDate: Thu May 18 22:11:05 2023 +0900
GH-35445: [R] Behavior something like group_by(foo) |> across(everything())
is different from dplyr (#35473)
### Rationale for this change
The argument `.cols` of the `dplyr::across` function has the following
description.
> You can't select grouping columns because they are already automatically
handled by the verb (i.e. summarise() or mutate()).
However, this behavior is currently not reproduced in the `arrow` package
and an error occurs when selecting the column used for grouping with
`everything()`.
``` r
mtcars |>
arrow::as_arrow_table() |>
dplyr::group_by(cyl) |>
dplyr::summarise(dplyr::across(everything(), sum)) |>
dplyr::collect()
#> Error in `compute.arrow_dplyr_query()`:
#> ! Invalid: Multiple matches for FieldRef.Name(cyl) in mpg: double
#> cyl: double
#> disp: double
#> hp: double
#> drat: double
#> wt: double
#> qsec: double
#> vs: double
#> am: double
#> gear: double
#> carb: double
#> cyl: double
#> Backtrace:
#> ▆
#> 1. ├─dplyr::collect(...)
#> 2. └─arrow:::collect.arrow_dplyr_query(...)
#> 3. └─arrow:::compute.arrow_dplyr_query(x)
#> 4. └─base::tryCatch(...)
#> 5. └─base (local) tryCatchList(expr, classes, parentenv, handlers)
#> 6. └─base (local) tryCatchOne(expr, names, parentenv,
handlers[[1L]])
#> 7. └─value[[3L]](cond)
#> 8. └─arrow:::augment_io_error_msg(e, call, schema =
schema())
#> 9. └─rlang::abort(msg, call = call)
```
<sup>Created on 2023-05-05 with [reprex
v2.0.2](https://reprex.tidyverse.org)</sup>
This PR fixes this behavior to match with dplyr's original behavior.
### What changes are included in this PR?
- Auto exclude grouping columns in `across` in `mutate`, `transmute`, and
`summarise`.
- The `.data` argument of internal function `expand_across` should be
`arrow_dplyr_query`.
Some tests have been slightly modified to accommodate this change.
- `mutate`, `transmute`, `arrange`, `filter` always return
`arrow_dplyr_query`.
Currently, `arrow_dplyr_query` is not returned in the following cases,
which was not consistent.
```r
mtcars |> arrow::arrow_table() |> dplyr::mutate()
```
- Correct the order of columns in results of `group_by(foo) |> mutate(.keep
= "none")`
Currently, the results of the following query show that the columns used
for grouping have moved to the tail and differ from the behavior of dplyr.
```r
mtcars |> arrow::arrow_table() |> dplyr::group_by(cyl) |>
dplyr::mutate(am, .keep = "none") |> dplyr::collect()
```
- Correct the order of columns in results of `group_by(foo) |> transmute()`
Currently, the results of the following query show that the columns used
for grouping have moved to the tail and differ from the behavior of dplyr.
```r
mtcars |> arrow::arrow_table() |> dplyr::group_by(cyl) |>
dplyr::transmute(mpg) |> dplyr::collect()
```
After `transmute`, the group columns should move to the left. (This is a
different behavior from `mutate(.keep = "none")`, which keeps the original
position.)
### Are these changes tested?
Yes.
### Are there any user-facing changes?
Yes.
* Closes: #35445
Authored-by: SHIMA Tatsuya <[email protected]>
Signed-off-by: Nic Crane <[email protected]>
---
r/R/dplyr-across.R | 10 +++---
r/R/dplyr-arrange.R | 1 +
r/R/dplyr-filter.R | 1 +
r/R/dplyr-mutate.R | 14 ++++----
r/R/dplyr-summarize.R | 2 +-
r/tests/testthat/test-dplyr-across.R | 4 +--
r/tests/testthat/test-dplyr-mutate.R | 58 +++++++++++++++++++++++++++++++++
r/tests/testthat/test-dplyr-summarize.R | 21 ++++++++++++
8 files changed, 98 insertions(+), 13 deletions(-)
diff --git a/r/R/dplyr-across.R b/r/R/dplyr-across.R
index 5b816a0719..da61353b22 100644
--- a/r/R/dplyr-across.R
+++ b/r/R/dplyr-across.R
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-expand_across <- function(.data, quos_in) {
+expand_across <- function(.data, quos_in, exclude_cols = NULL) {
quos_out <- list()
# retrieve items using their values to preserve naming of quos other than
across
for (quo_i in seq_along(quos_in)) {
@@ -49,7 +49,8 @@ expand_across <- function(.data, quos_in) {
names = across_call[[".names"]],
.caller_env = quo_env,
mask = .data,
- inline = TRUE
+ inline = TRUE,
+ exclude_cols = exclude_cols
)
new_quos <- quosures_from_setup(setup, quo_env)
@@ -106,10 +107,11 @@ quosures_from_setup <- function(setup, quo_env) {
set_names(new_quo_list, setup$names)
}
-across_setup <- function(cols, fns, names, .caller_env, mask, inline = FALSE) {
+across_setup <- function(cols, fns, names, .caller_env, mask, inline = FALSE,
exclude_cols = NULL) {
cols <- enquo(cols)
- vars <- names(dplyr::select(mask, !!cols))
+ sim_df <- dplyr::select(as.data.frame(implicit_schema(mask)),
!(!!exclude_cols))
+ vars <- names(dplyr::select(sim_df, !!cols))
if (is.null(fns)) {
if (!is.null(names)) {
diff --git a/r/R/dplyr-arrange.R b/r/R/dplyr-arrange.R
index 39388394d5..e3e20f2cb3 100644
--- a/r/R/dplyr-arrange.R
+++ b/r/R/dplyr-arrange.R
@@ -20,6 +20,7 @@
arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) {
call <- match.call()
+ .data <- as_adq(.data)
exprs <- expand_across(.data, quos(...))
if (.by_group) {
diff --git a/r/R/dplyr-filter.R b/r/R/dplyr-filter.R
index 1ef2b6d7e5..a864f1e9ce 100644
--- a/r/R/dplyr-filter.R
+++ b/r/R/dplyr-filter.R
@@ -19,6 +19,7 @@
# The following S3 methods are registered on load if dplyr is present
filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) {
+ .data <- as_adq(.data)
# TODO something with the .preserve argument
filts <- expand_across(.data, quos(...))
if (length(filts) == 0) {
diff --git a/r/R/dplyr-mutate.R b/r/R/dplyr-mutate.R
index b75de1e4db..638a4566ab 100644
--- a/r/R/dplyr-mutate.R
+++ b/r/R/dplyr-mutate.R
@@ -24,8 +24,10 @@ mutate.arrow_dplyr_query <- function(.data,
.before = NULL,
.after = NULL) {
call <- match.call()
+ .data <- as_adq(.data)
+ grv <- .data$group_by_vars
- expression_list <- expand_across(.data, quos(...))
+ expression_list <- expand_across(.data, quos(...), exclude_cols = grv)
exprs <- ensure_named_exprs(expression_list)
.keep <- match.arg(.keep)
@@ -37,8 +39,6 @@ mutate.arrow_dplyr_query <- function(.data,
return(.data)
}
- .data <- as_adq(.data)
-
# Restrict the cases we support for now
has_aggregations <- any(unlist(lapply(exprs, all_funs)) %in%
names(agg_funcs))
if (has_aggregations) {
@@ -86,7 +86,7 @@ mutate.arrow_dplyr_query <- function(.data,
}
# Deduplicate new_vars and remove NULL columns from new_vars
- new_vars <- intersect(new_vars, names(.data$selected_columns))
+ new_vars <- intersect(union(new_vars, grv), names(.data$selected_columns))
# Respect .before and .after
if (!quo_is_null(.before) || !quo_is_null(.after)) {
@@ -117,7 +117,9 @@ mutate.Dataset <- mutate.ArrowTabular <-
mutate.RecordBatchReader <- mutate.arro
transmute.arrow_dplyr_query <- function(.data, ...) {
dots <- check_transmute_args(...)
- expression_list <- expand_across(.data, dots)
+ .data <- as_adq(.data)
+ grv <- .data$group_by_vars
+ expression_list <- expand_across(.data, dots, exclude_cols = grv)
has_null <- map_lgl(expression_list, quo_is_null)
.data <- dplyr::mutate(.data, !!!expression_list, .keep = "none")
@@ -129,7 +131,7 @@ transmute.arrow_dplyr_query <- function(.data, ...) {
cur_exprs <- map_chr(expression_list, as_label)
transmute_order <- names(cur_exprs)
transmute_order[!nzchar(transmute_order)] <-
cur_exprs[!nzchar(transmute_order)]
- dplyr::select(.data, all_of(transmute_order))
+ dplyr::select(.data, all_of(c(grv, transmute_order)))
}
transmute.Dataset <- transmute.ArrowTabular <- transmute.RecordBatchReader <-
transmute.arrow_dplyr_query
diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R
index 5d943633a8..c02b6ee522 100644
--- a/r/R/dplyr-summarize.R
+++ b/r/R/dplyr-summarize.R
@@ -172,7 +172,7 @@ agg_funcs[["::"]] <- function(lhs, rhs) {
summarise.arrow_dplyr_query <- function(.data, ..., .groups = NULL) {
call <- match.call()
.data <- as_adq(.data)
- exprs <- expand_across(.data, quos(...))
+ exprs <- expand_across(.data, quos(...), exclude_cols = .data$group_by_vars)
# Only retain the columns we need to do our aggregations
vars_to_keep <- unique(c(
unlist(lapply(exprs, all.vars)), # vars referenced in summarise
diff --git a/r/tests/testthat/test-dplyr-across.R
b/r/tests/testthat/test-dplyr-across.R
index edf74dcbdb..eebb8a23ea 100644
--- a/r/tests/testthat/test-dplyr-across.R
+++ b/r/tests/testthat/test-dplyr-across.R
@@ -120,7 +120,7 @@ test_that("expand_across correctly expands quosures", {
# ellipses (...) are a deprecated argument
expect_error(
expand_across(
- example_data,
+ as_adq(example_data),
quos(across(c(dbl, dbl2), round, digits = -1))
),
regexp = "`...` argument to `across()` is deprecated in dplyr and not
supported in Arrow",
@@ -206,7 +206,7 @@ test_that("expand_across correctly expands quosures", {
# dodgy .names specification
expect_error(
expand_across(
- example_data,
+ as_adq(example_data),
quos(across(c(dbl, dbl2), list(round, "my_exp" = exp), .names = "zarg"))
),
regexp = "`.names` specification must produce (number of columns * number
of functions) names.",
diff --git a/r/tests/testthat/test-dplyr-mutate.R
b/r/tests/testthat/test-dplyr-mutate.R
index ab37747458..79554059cb 100644
--- a/r/tests/testthat/test-dplyr-mutate.R
+++ b/r/tests/testthat/test-dplyr-mutate.R
@@ -74,6 +74,17 @@ test_that("transmute", {
)
})
+test_that("transmute after group_by", {
+ compare_dplyr_binding(
+ .input %>%
+ select(int, dbl, chr) %>%
+ group_by(chr, int) %>%
+ transmute(dbl + 1) %>%
+ collect(),
+ tbl
+ )
+})
+
test_that("transmute respect bespoke dplyr implementation", {
## see: https://github.com/tidyverse/dplyr/issues/6086
compare_dplyr_binding(
@@ -397,6 +408,15 @@ test_that("Can mutate after group_by as long as there are
no aggregations", {
collect(),
tbl
)
+ # Check the column order when .keep = "none"
+ compare_dplyr_binding(
+ .input %>%
+ select(chr, int) %>%
+ group_by(chr) %>%
+ mutate(int + 1, .keep = "none") %>%
+ collect(),
+ tbl
+ )
expect_warning(
tbl %>%
Table$create() %>%
@@ -652,3 +672,41 @@ test_that("Can use across() within transmute()", {
example_data
)
})
+
+test_that("across() does not select grouping variables within mutate()", {
+ compare_dplyr_binding(
+ .input %>%
+ select(int, dbl, chr) %>%
+ group_by(chr) %>%
+ mutate(across(everything(), round)) %>%
+ collect(),
+ example_data
+ )
+
+ expect_error(
+ example_data %>%
+ arrow_table() %>%
+ group_by(chr) %>%
+ mutate(across(chr, as.character)),
+ "Column `chr` doesn't exist"
+ )
+})
+
+test_that("across() does not select grouping variables within transmute()", {
+ compare_dplyr_binding(
+ .input %>%
+ select(int, dbl, chr) %>%
+ group_by(chr) %>%
+ transmute(across(everything(), round)) %>%
+ collect(),
+ example_data
+ )
+
+ expect_error(
+ example_data %>%
+ arrow_table() %>%
+ group_by(chr) %>%
+ transmute(across(chr, as.character)),
+ "Column `chr` doesn't exist"
+ )
+})
diff --git a/r/tests/testthat/test-dplyr-summarize.R
b/r/tests/testthat/test-dplyr-summarize.R
index 3eb1a6ed2b..09f50986d7 100644
--- a/r/tests/testthat/test-dplyr-summarize.R
+++ b/r/tests/testthat/test-dplyr-summarize.R
@@ -1173,3 +1173,24 @@ test_that("Can use across() within summarise()", {
regexp = "Expression int is not an aggregate expression or is not
supported in Arrow; pulling data into R"
)
})
+
+test_that("across() does not select grouping variables within summarise()", {
+ compare_dplyr_binding(
+ .input %>%
+ select(int, dbl, chr) %>%
+ group_by(chr) %>%
+ summarise(across(everything(), sum)) %>%
+ arrange(chr) %>%
+ collect(),
+ example_data
+ )
+
+ expect_error(
+ example_data %>%
+ select(int, dbl) %>%
+ arrow_table() %>%
+ group_by(int) %>%
+ summarise(across(int, sum)),
+ "Column `int` doesn't exist"
+ )
+})