This is an automated email from the ASF dual-hosted git repository.

thisisnic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6bd0050811 GH-35445: [R] Behavior something like group_by(foo) |> 
across(everything()) is different from dplyr (#35473)
6bd0050811 is described below

commit 6bd00508116edea5afcdc4e3e11cd9fa789b70a3
Author: eitsupi <[email protected]>
AuthorDate: Thu May 18 22:11:05 2023 +0900

    GH-35445: [R] Behavior something like group_by(foo) |> across(everything()) 
is different from dplyr (#35473)
    
    ### Rationale for this change
    
    The argument `.cols` of the `dplyr::across` function has the following 
description.
    
    > You can't select grouping columns because they are already automatically 
handled by the verb (i.e. summarise() or mutate()).
    
    However, this behavior is currently not reproduced in the `arrow` package 
and an error occurs when selecting the column used for grouping with 
`everything()`.
    
    ``` r
    mtcars |>
      arrow::as_arrow_table() |>
      dplyr::group_by(cyl) |>
      dplyr::summarise(dplyr::across(everything(), sum)) |>
      dplyr::collect()
    #> Error in `compute.arrow_dplyr_query()`:
    #> ! Invalid: Multiple matches for FieldRef.Name(cyl) in mpg: double
    #> cyl: double
    #> disp: double
    #> hp: double
    #> drat: double
    #> wt: double
    #> qsec: double
    #> vs: double
    #> am: double
    #> gear: double
    #> carb: double
    #> cyl: double
    #> Backtrace:
    #>     ▆
    #>  1. ├─dplyr::collect(...)
    #>  2. └─arrow:::collect.arrow_dplyr_query(...)
    #>  3.   └─arrow:::compute.arrow_dplyr_query(x)
    #>  4.     └─base::tryCatch(...)
    #>  5.       └─base (local) tryCatchList(expr, classes, parentenv, handlers)
    #>  6.         └─base (local) tryCatchOne(expr, names, parentenv, 
handlers[[1L]])
    #>  7.           └─value[[3L]](cond)
    #>  8.             └─arrow:::augment_io_error_msg(e, call, schema = 
schema())
    #>  9.               └─rlang::abort(msg, call = call)
    ```
    
    <sup>Created on 2023-05-05 with [reprex 
v2.0.2](https://reprex.tidyverse.org)</sup>
    
    This PR fixes this behavior to match with dplyr's original behavior.
    
    ### What changes are included in this PR?
    
    - Auto exclude grouping columns in `across` in `mutate`, `transmute`, and 
`summarise`.
    - The `.data` argument of internal function `expand_across` should be 
`arrow_dplyr_query`.
      Some tests have been slightly modified to accommodate this change.
    - `mutate`, `transmute`, `arrange`, `filter` always return 
`arrow_dplyr_query`.
      Currently, `arrow_dplyr_query` is not returned in the following cases, 
which was not consistent.
      ```r
      mtcars |> arrow::arrow_table() |> dplyr::mutate()
      ```
    - Correct the order of columns in results of `group_by(foo) |> mutate(.keep 
= "none")`
      Currently, the results of the following query show that the columns used 
for grouping have moved to the tail and differ from the behavior of dplyr.
      ```r
      mtcars |> arrow::arrow_table() |> dplyr::group_by(cyl) |> 
dplyr::mutate(am, .keep = "none") |> dplyr::collect()
      ```
    - Correct the order of columns in results of `group_by(foo) |> transmute()`
      Currently, the results of the following query show that the columns used 
for grouping have moved to the tail and differ from the behavior of dplyr.
      ```r
      mtcars |> arrow::arrow_table() |> dplyr::group_by(cyl) |> 
dplyr::transmute(mpg) |> dplyr::collect()
      ```
      After `transmute`, the group columns should move to the left. (This is a 
different behavior from `mutate(.keep = "none")`, which keeps the original 
position.)
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    Yes.
    
    * Closes: #35445
    
    Authored-by: SHIMA Tatsuya <[email protected]>
    Signed-off-by: Nic Crane <[email protected]>
---
 r/R/dplyr-across.R                      | 10 +++---
 r/R/dplyr-arrange.R                     |  1 +
 r/R/dplyr-filter.R                      |  1 +
 r/R/dplyr-mutate.R                      | 14 ++++----
 r/R/dplyr-summarize.R                   |  2 +-
 r/tests/testthat/test-dplyr-across.R    |  4 +--
 r/tests/testthat/test-dplyr-mutate.R    | 58 +++++++++++++++++++++++++++++++++
 r/tests/testthat/test-dplyr-summarize.R | 21 ++++++++++++
 8 files changed, 98 insertions(+), 13 deletions(-)

diff --git a/r/R/dplyr-across.R b/r/R/dplyr-across.R
index 5b816a0719..da61353b22 100644
--- a/r/R/dplyr-across.R
+++ b/r/R/dplyr-across.R
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-expand_across <- function(.data, quos_in) {
+expand_across <- function(.data, quos_in, exclude_cols = NULL) {
   quos_out <- list()
   # retrieve items using their values to preserve naming of quos other than 
across
   for (quo_i in seq_along(quos_in)) {
@@ -49,7 +49,8 @@ expand_across <- function(.data, quos_in) {
         names = across_call[[".names"]],
         .caller_env = quo_env,
         mask = .data,
-        inline = TRUE
+        inline = TRUE,
+        exclude_cols = exclude_cols
       )
 
       new_quos <- quosures_from_setup(setup, quo_env)
@@ -106,10 +107,11 @@ quosures_from_setup <- function(setup, quo_env) {
   set_names(new_quo_list, setup$names)
 }
 
-across_setup <- function(cols, fns, names, .caller_env, mask, inline = FALSE) {
+across_setup <- function(cols, fns, names, .caller_env, mask, inline = FALSE, 
exclude_cols = NULL) {
   cols <- enquo(cols)
 
-  vars <- names(dplyr::select(mask, !!cols))
+  sim_df <- dplyr::select(as.data.frame(implicit_schema(mask)), 
!(!!exclude_cols))
+  vars <- names(dplyr::select(sim_df, !!cols))
 
   if (is.null(fns)) {
     if (!is.null(names)) {
diff --git a/r/R/dplyr-arrange.R b/r/R/dplyr-arrange.R
index 39388394d5..e3e20f2cb3 100644
--- a/r/R/dplyr-arrange.R
+++ b/r/R/dplyr-arrange.R
@@ -20,6 +20,7 @@
 
 arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) {
   call <- match.call()
+  .data <- as_adq(.data)
   exprs <- expand_across(.data, quos(...))
 
   if (.by_group) {
diff --git a/r/R/dplyr-filter.R b/r/R/dplyr-filter.R
index 1ef2b6d7e5..a864f1e9ce 100644
--- a/r/R/dplyr-filter.R
+++ b/r/R/dplyr-filter.R
@@ -19,6 +19,7 @@
 # The following S3 methods are registered on load if dplyr is present
 
 filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) {
+  .data <- as_adq(.data)
   # TODO something with the .preserve argument
   filts <- expand_across(.data, quos(...))
   if (length(filts) == 0) {
diff --git a/r/R/dplyr-mutate.R b/r/R/dplyr-mutate.R
index b75de1e4db..638a4566ab 100644
--- a/r/R/dplyr-mutate.R
+++ b/r/R/dplyr-mutate.R
@@ -24,8 +24,10 @@ mutate.arrow_dplyr_query <- function(.data,
                                      .before = NULL,
                                      .after = NULL) {
   call <- match.call()
+  .data <- as_adq(.data)
+  grv <- .data$group_by_vars
 
-  expression_list <- expand_across(.data, quos(...))
+  expression_list <- expand_across(.data, quos(...), exclude_cols = grv)
   exprs <- ensure_named_exprs(expression_list)
 
   .keep <- match.arg(.keep)
@@ -37,8 +39,6 @@ mutate.arrow_dplyr_query <- function(.data,
     return(.data)
   }
 
-  .data <- as_adq(.data)
-
   # Restrict the cases we support for now
   has_aggregations <- any(unlist(lapply(exprs, all_funs)) %in% 
names(agg_funcs))
   if (has_aggregations) {
@@ -86,7 +86,7 @@ mutate.arrow_dplyr_query <- function(.data,
   }
 
   # Deduplicate new_vars and remove NULL columns from new_vars
-  new_vars <- intersect(new_vars, names(.data$selected_columns))
+  new_vars <- intersect(union(new_vars, grv), names(.data$selected_columns))
 
   # Respect .before and .after
   if (!quo_is_null(.before) || !quo_is_null(.after)) {
@@ -117,7 +117,9 @@ mutate.Dataset <- mutate.ArrowTabular <- 
mutate.RecordBatchReader <- mutate.arro
 
 transmute.arrow_dplyr_query <- function(.data, ...) {
   dots <- check_transmute_args(...)
-  expression_list <- expand_across(.data, dots)
+  .data <- as_adq(.data)
+  grv <- .data$group_by_vars
+  expression_list <- expand_across(.data, dots, exclude_cols = grv)
 
   has_null <- map_lgl(expression_list, quo_is_null)
   .data <- dplyr::mutate(.data, !!!expression_list, .keep = "none")
@@ -129,7 +131,7 @@ transmute.arrow_dplyr_query <- function(.data, ...) {
   cur_exprs <- map_chr(expression_list, as_label)
   transmute_order <- names(cur_exprs)
   transmute_order[!nzchar(transmute_order)] <- 
cur_exprs[!nzchar(transmute_order)]
-  dplyr::select(.data, all_of(transmute_order))
+  dplyr::select(.data, all_of(c(grv, transmute_order)))
 }
 transmute.Dataset <- transmute.ArrowTabular <- transmute.RecordBatchReader <- 
transmute.arrow_dplyr_query
 
diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R
index 5d943633a8..c02b6ee522 100644
--- a/r/R/dplyr-summarize.R
+++ b/r/R/dplyr-summarize.R
@@ -172,7 +172,7 @@ agg_funcs[["::"]] <- function(lhs, rhs) {
 summarise.arrow_dplyr_query <- function(.data, ..., .groups = NULL) {
   call <- match.call()
   .data <- as_adq(.data)
-  exprs <- expand_across(.data, quos(...))
+  exprs <- expand_across(.data, quos(...), exclude_cols = .data$group_by_vars)
   # Only retain the columns we need to do our aggregations
   vars_to_keep <- unique(c(
     unlist(lapply(exprs, all.vars)), # vars referenced in summarise
diff --git a/r/tests/testthat/test-dplyr-across.R 
b/r/tests/testthat/test-dplyr-across.R
index edf74dcbdb..eebb8a23ea 100644
--- a/r/tests/testthat/test-dplyr-across.R
+++ b/r/tests/testthat/test-dplyr-across.R
@@ -120,7 +120,7 @@ test_that("expand_across correctly expands quosures", {
   # ellipses (...) are a deprecated argument
   expect_error(
     expand_across(
-      example_data,
+      as_adq(example_data),
       quos(across(c(dbl, dbl2), round, digits = -1))
     ),
     regexp = "`...` argument to `across()` is deprecated in dplyr and not 
supported in Arrow",
@@ -206,7 +206,7 @@ test_that("expand_across correctly expands quosures", {
   # dodgy .names specification
   expect_error(
     expand_across(
-      example_data,
+      as_adq(example_data),
       quos(across(c(dbl, dbl2), list(round, "my_exp" = exp), .names = "zarg"))
     ),
     regexp = "`.names` specification must produce (number of columns * number 
of functions) names.",
diff --git a/r/tests/testthat/test-dplyr-mutate.R 
b/r/tests/testthat/test-dplyr-mutate.R
index ab37747458..79554059cb 100644
--- a/r/tests/testthat/test-dplyr-mutate.R
+++ b/r/tests/testthat/test-dplyr-mutate.R
@@ -74,6 +74,17 @@ test_that("transmute", {
   )
 })
 
+test_that("transmute after group_by", {
+  compare_dplyr_binding(
+    .input %>%
+      select(int, dbl, chr) %>%
+      group_by(chr, int) %>%
+      transmute(dbl + 1) %>%
+      collect(),
+    tbl
+  )
+})
+
 test_that("transmute respect bespoke dplyr implementation", {
   ## see: https://github.com/tidyverse/dplyr/issues/6086
   compare_dplyr_binding(
@@ -397,6 +408,15 @@ test_that("Can mutate after group_by as long as there are 
no aggregations", {
       collect(),
     tbl
   )
+  # Check the column order when .keep = "none"
+  compare_dplyr_binding(
+    .input %>%
+      select(chr, int) %>%
+      group_by(chr) %>%
+      mutate(int + 1, .keep = "none") %>%
+      collect(),
+    tbl
+  )
   expect_warning(
     tbl %>%
       Table$create() %>%
@@ -652,3 +672,41 @@ test_that("Can use across() within transmute()", {
     example_data
   )
 })
+
+test_that("across() does not select grouping variables within mutate()", {
+  compare_dplyr_binding(
+    .input %>%
+      select(int, dbl, chr) %>%
+      group_by(chr) %>%
+      mutate(across(everything(), round)) %>%
+      collect(),
+    example_data
+  )
+
+  expect_error(
+    example_data %>%
+      arrow_table() %>%
+      group_by(chr) %>%
+      mutate(across(chr, as.character)),
+    "Column `chr` doesn't exist"
+  )
+})
+
+test_that("across() does not select grouping variables within transmute()", {
+  compare_dplyr_binding(
+    .input %>%
+      select(int, dbl, chr) %>%
+      group_by(chr) %>%
+      transmute(across(everything(), round)) %>%
+      collect(),
+    example_data
+  )
+
+  expect_error(
+    example_data %>%
+      arrow_table() %>%
+      group_by(chr) %>%
+      transmute(across(chr, as.character)),
+    "Column `chr` doesn't exist"
+  )
+})
diff --git a/r/tests/testthat/test-dplyr-summarize.R 
b/r/tests/testthat/test-dplyr-summarize.R
index 3eb1a6ed2b..09f50986d7 100644
--- a/r/tests/testthat/test-dplyr-summarize.R
+++ b/r/tests/testthat/test-dplyr-summarize.R
@@ -1173,3 +1173,24 @@ test_that("Can use across() within summarise()", {
     regexp = "Expression int is not an aggregate expression or is not 
supported in Arrow; pulling data into R"
   )
 })
+
+test_that("across() does not select grouping variables within summarise()", {
+  compare_dplyr_binding(
+    .input %>%
+      select(int, dbl, chr) %>%
+      group_by(chr) %>%
+      summarise(across(everything(), sum)) %>%
+      arrange(chr) %>%
+      collect(),
+    example_data
+  )
+
+  expect_error(
+    example_data %>%
+      select(int, dbl) %>%
+      arrow_table() %>%
+      group_by(int) %>%
+      summarise(across(int, sum)),
+    "Column `int` doesn't exist"
+  )
+})

Reply via email to