This is an automated email from the ASF dual-hosted git repository.

thisisnic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new d5f80cbe2b ARROW-11699: [R] Implement dplyr::across() for mutate()
d5f80cbe2b is described below

commit d5f80cbe2b2e8801127639b15fd24f829478ea84
Author: Nic Crane <[email protected]>
AuthorDate: Thu Sep 1 16:57:22 2022 +0100

    ARROW-11699: [R] Implement dplyr::across() for mutate()
    
    This PR introduces a partial implementation for `dplyr::across()` when 
called within `dplyr::mutate()`.
    
    ``` r
    arrow_table(iris) %>%
      mutate(across(starts_with("Sepal"), list(round, sqrt)))
    #> Table (query)
    #> Sepal.Length: double
    #> Sepal.Width: double
    #> Petal.Length: double
    #> Petal.Width: double
    #> Species: dictionary<values=string, indices=int8>
    #> Sepal.Length_1: double (round(Sepal.Length, {ndigits=0, 
round_mode=HALF_TO_EVEN}))
    #> Sepal.Length_2: double (sqrt_checked(Sepal.Length))
    #> Sepal.Width_1: double (round(Sepal.Width, {ndigits=0, 
round_mode=HALF_TO_EVEN}))
    #> Sepal.Width_2: double (sqrt_checked(Sepal.Width))
    #>
    #> See $.data for the source Arrow object
    ```
    
    I've opened a number of follow-up tickets for the tasks needed to be done 
to provide a more complete implementation:
    * [ARROW-17362: [R] Implement dplyr::across() inside 
summarise()](https://issues.apache.org/jira/browse/ARROW-17362)
    * [ARROW-17387: [R] Implement dplyr::across() inside 
filter()](https://issues.apache.org/jira/browse/ARROW-17387)
    * ~[ARROW-17364: [R] Implement .names argument inside 
across()](https://issues.apache.org/jira/browse/ARROW-17364)~ (now done in this 
PR, will close it once this is merged)
    * [ARROW-17366: [R] Support purrr-style lambda functions in .fns argument 
to across()](https://issues.apache.org/jira/browse/ARROW-17366)
    
    Closes #13786 from thisisnic/ARROW-11699_across
    
    Authored-by: Nic Crane <[email protected]>
    Signed-off-by: Nic Crane <[email protected]>
---
 r/DESCRIPTION                         |   2 +
 r/NAMESPACE                           |   7 ++
 r/R/arrow-package.R                   |   4 +-
 r/R/dplyr-across.R                    | 177 ++++++++++++++++++++++++++
 r/R/dplyr-mutate.R                    |   4 +-
 r/tests/testthat/helper-expectation.R |   4 +
 r/tests/testthat/test-dplyr-across.R  | 226 ++++++++++++++++++++++++++++++++++
 r/tests/testthat/test-dplyr-mutate.R  |  61 ++++++++-
 8 files changed, 479 insertions(+), 6 deletions(-)

diff --git a/r/DESCRIPTION b/r/DESCRIPTION
index 95c1405869..a728be3773 100644
--- a/r/DESCRIPTION
+++ b/r/DESCRIPTION
@@ -31,6 +31,7 @@ Biarch: true
 Imports:
     assertthat,
     bit64 (>= 0.9-7),
+    glue,
     methods,
     purrr,
     R6,
@@ -91,6 +92,7 @@ Collate:
     'dataset-scan.R'
     'dataset-write.R'
     'dictionary.R'
+    'dplyr-across.R'
     'dplyr-arrange.R'
     'dplyr-collect.R'
     'dplyr-count.R'
diff --git a/r/NAMESPACE b/r/NAMESPACE
index c4c18ba16d..49db309b8e 100644
--- a/r/NAMESPACE
+++ b/r/NAMESPACE
@@ -390,6 +390,7 @@ importFrom(assertthat,assert_that)
 importFrom(assertthat,is.string)
 importFrom(bit64,print.integer64)
 importFrom(bit64,str.integer64)
+importFrom(glue,glue)
 importFrom(methods,as)
 importFrom(purrr,as_mapper)
 importFrom(purrr,flatten)
@@ -413,6 +414,7 @@ importFrom(rlang,as_function)
 importFrom(rlang,as_label)
 importFrom(rlang,as_quosure)
 importFrom(rlang,call2)
+importFrom(rlang,call_args)
 importFrom(rlang,caller_env)
 importFrom(rlang,dots_n)
 importFrom(rlang,enexpr)
@@ -425,20 +427,25 @@ importFrom(rlang,eval_tidy)
 importFrom(rlang,exec)
 importFrom(rlang,expr)
 importFrom(rlang,is_bare_character)
+importFrom(rlang,is_call)
 importFrom(rlang,is_character)
 importFrom(rlang,is_empty)
 importFrom(rlang,is_false)
+importFrom(rlang,is_formula)
 importFrom(rlang,is_integerish)
 importFrom(rlang,is_interactive)
 importFrom(rlang,is_list)
 importFrom(rlang,is_quosure)
+importFrom(rlang,is_symbol)
 importFrom(rlang,list2)
 importFrom(rlang,new_data_mask)
 importFrom(rlang,new_environment)
 importFrom(rlang,quo_get_env)
 importFrom(rlang,quo_get_expr)
+importFrom(rlang,quo_is_call)
 importFrom(rlang,quo_is_null)
 importFrom(rlang,quo_name)
+importFrom(rlang,quo_set_env)
 importFrom(rlang,quo_set_expr)
 importFrom(rlang,quos)
 importFrom(rlang,seq2)
diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R
index f3e0b817d5..e8aa93f953 100644
--- a/r/R/arrow-package.R
+++ b/r/R/arrow-package.R
@@ -23,8 +23,10 @@
 #' @importFrom rlang eval_tidy new_data_mask syms env new_environment env_bind 
set_names exec
 #' @importFrom rlang is_bare_character quo_get_expr quo_get_env quo_set_expr 
.data seq2 is_interactive
 #' @importFrom rlang expr caller_env is_character quo_name is_quosure enexpr 
enexprs as_quosure
-#' @importFrom rlang is_list call2 is_empty as_function as_label arg_match
+#' @importFrom rlang is_list call2 is_empty as_function as_label arg_match 
is_symbol is_call call_args
+#' @importFrom rlang quo_set_env quo_get_env is_formula quo_is_call
 #' @importFrom tidyselect vars_pull vars_rename vars_select eval_select
+#' @importFrom glue glue
 #' @useDynLib arrow, .registration = TRUE
 #' @keywords internal
 "_PACKAGE"
diff --git a/r/R/dplyr-across.R b/r/R/dplyr-across.R
new file mode 100644
index 0000000000..01a9262b81
--- /dev/null
+++ b/r/R/dplyr-across.R
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+expand_across <- function(.data, quos_in) {
+  quos_out <- list()
+  # retrieve items using their values to preserve naming of quos other than 
across
+  for (quo_i in seq_along(quos_in)) {
+    quo_in <- quos_in[quo_i]
+    quo_expr <- quo_get_expr(quo_in[[1]])
+    quo_env <- quo_get_env(quo_in[[1]])
+
+    if (is_call(quo_expr, "across")) {
+      new_quos <- list()
+
+      across_call <- match.call(
+        definition = dplyr::across,
+        call = quo_expr,
+        expand.dots = FALSE,
+        envir = quo_env
+      )
+
+      if (!all(names(across_call[-1]) %in% c(".cols", ".fns", ".names"))) {
+        abort("`...` argument to `across()` is deprecated in dplyr and not 
supported in Arrow")
+      }
+
+      if (!is.null(across_call[[".cols"]])) {
+        cols <- across_call[[".cols"]]
+      } else {
+        cols <- quote(everything())
+      }
+
+      setup <- across_setup(
+        cols = !!as_quosure(cols, quo_env),
+        fns = across_call[[".fns"]],
+        names = across_call[[".names"]],
+        .caller_env = quo_env,
+        mask = .data,
+        inline = TRUE
+      )
+
+      if (!is_list(setup$fns) && !is.null(setup$fns) && 
as.character(setup$fns)[[1]] == "~") {
+        abort(
+          paste(
+            "purrr-style lambda functions as `.fns` argument to `across()`",
+            "not yet supported in Arrow"
+          )
+        )
+      }
+
+      new_quos <- quosures_from_setup(setup, quo_env)
+
+      quos_out <- append(quos_out, new_quos)
+    } else {
+      quos_out <- append(quos_out, quo_in)
+    }
+  }
+
+  quos_out
+}
+
+# given a named list of functions and column names, create a list of new 
quosures
+quosures_from_setup <- function(setup, quo_env) {
+  if (!is.null(setup$fns)) {
+    func_list_full <- rep(setup$fns, length(setup$vars))
+    cols_list_full <- rep(setup$vars, each = length(setup$fns))
+
+    # get new quosures
+    new_quo_list <- map2(
+      func_list_full, cols_list_full,
+      ~ quo(!!call2(.x, sym(.y)))
+    )
+  } else {
+    # if there's no functions, just map to variables themselves
+    new_quo_list <- map(
+      setup$vars,
+      ~ quo(!!sym(.x))
+    )
+  }
+
+  quosures <- set_names(new_quo_list, setup$names)
+  map(quosures, ~ quo_set_env(.x, quo_env))
+}
+
+across_setup <- function(cols, fns, names, .caller_env, mask, inline = FALSE) {
+  cols <- enquo(cols)
+
+  vars <- names(dplyr::select(mask, !!cols))
+
+  if (is.null(fns)) {
+    if (!is.null(names)) {
+      glue_mask <- across_glue_mask(.caller_env, .col = vars, .fn = "1")
+      names <- vctrs::vec_as_names(glue::glue(names, .envir = glue_mask), 
repair = "check_unique")
+    } else {
+      names <- vars
+    }
+
+    value <- list(vars = vars, fns = fns, names = names)
+    return(value)
+  }
+
+  # apply `.names` smart default
+  if (is.function(fns) || is_formula(fns) || is.name(fns)) {
+    names <- names %||% "{.col}"
+    fns <- list("1" = fns)
+  } else {
+    names <- names %||% "{.col}_{.fn}"
+    fns <- call_args(fns)
+  }
+
+  if (any(map_lgl(fns, is_formula))) {
+    abort(
+      paste(
+        "purrr-style lambda functions as `.fns` argument to `across()`",
+        "not yet supported in Arrow"
+      )
+    )
+  }
+
+  if (!is.list(fns)) {
+    msg <- c("`.fns` must be NULL, a function, a formula, or a list of 
functions/formulas.")
+    abort(msg)
+  }
+
+  # make sure fns has names, use number to replace unnamed
+  if (is.null(names(fns))) {
+    names_fns <- seq_along(fns)
+  } else {
+    names_fns <- names(fns)
+    empties <- which(names_fns == "")
+    if (length(empties)) {
+      names_fns[empties] <- empties
+    }
+  }
+
+  glue_mask <- across_glue_mask(.caller_env,
+    .col = rep(vars, each = length(fns)),
+    .fn  = rep(names_fns, length(vars))
+  )
+  names <- vctrs::vec_as_names(glue::glue(names, .envir = glue_mask), repair = 
"check_unique")
+
+  if (!inline) {
+    fns <- map(fns, as_function)
+  }
+
+  # ensure .names argument has resulted in
+  if (length(names) != (length(vars) * length(fns))) {
+    abort(
+      c(
+        "`.names` specification must produce (number of columns * number of 
functions) names.",
+        x = paste0(
+          length(vars) * length(fns), " names required (", length(vars), " 
columns * ", length(fns), " functions)\n  ",
+          length(names), " name(s) produced: ", paste(names, collapse = ",")
+        )
+      )
+    )
+  }
+
+  list(vars = vars, fns = fns, names = names)
+}
+
+across_glue_mask <- function(.col, .fn, .caller_env) {
+  env(.caller_env, .col = .col, .fn = .fn, col = .col, fn = .fn)
+}
diff --git a/r/R/dplyr-mutate.R b/r/R/dplyr-mutate.R
index 653c1e6f25..ac555fafe0 100644
--- a/r/R/dplyr-mutate.R
+++ b/r/R/dplyr-mutate.R
@@ -24,7 +24,9 @@ mutate.arrow_dplyr_query <- function(.data,
                                      .before = NULL,
                                      .after = NULL) {
   call <- match.call()
-  exprs <- ensure_named_exprs(quos(...))
+
+  expression_list <- expand_across(.data, quos(...))
+  exprs <- ensure_named_exprs(expression_list)
 
   .keep <- match.arg(.keep)
   .before <- enquo(.before)
diff --git a/r/tests/testthat/helper-expectation.R 
b/r/tests/testthat/helper-expectation.R
index eb2e6b0219..ba11700ab6 100644
--- a/r/tests/testthat/helper-expectation.R
+++ b/r/tests/testthat/helper-expectation.R
@@ -321,3 +321,7 @@ split_vector_as_list <- function(vec) {
   vec2 <- vec[seq(from = min(length(vec), vec_split + 1), to = length(vec), by 
= 1)]
   list(vec1, vec2)
 }
+
+expect_across_equal <- function(actual, expected, tbl) {
+  expect_identical(expand_across(tbl, actual), as.list(expected))
+}
diff --git a/r/tests/testthat/test-dplyr-across.R 
b/r/tests/testthat/test-dplyr-across.R
new file mode 100644
index 0000000000..8945c2a5f3
--- /dev/null
+++ b/r/tests/testthat/test-dplyr-across.R
@@ -0,0 +1,226 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+library(dplyr, warn.conflicts = FALSE)
+
+test_that("expand_across correctly expands quosures", {
+
+  # single unnamed function
+  expect_across_equal(
+    quos(across(c(dbl, dbl2), round)),
+    quos(
+      dbl = round(dbl),
+      dbl2 = round(dbl2)
+    ),
+    example_data
+  )
+
+  # multiple unnamed functions
+  expect_across_equal(
+    quos(across(c(dbl, dbl2), list(exp, sqrt))),
+    quos(
+      dbl_1 = exp(dbl),
+      dbl_2 = sqrt(dbl),
+      dbl2_1 = exp(dbl2),
+      dbl2_2 = sqrt(dbl2)
+    ),
+    example_data
+  )
+
+  # single named function
+  expect_across_equal(
+    quos(across(c(dbl, dbl2), list("fun1" = round))),
+    quos(
+      dbl_fun1 = round(dbl),
+      dbl2_fun1 = round(dbl2)
+    ),
+    example_data
+  )
+
+  # multiple named functions
+  expect_across_equal(
+    quos(across(c(dbl, dbl2), list("fun1" = round, "fun2" = sqrt))),
+    quos(
+      dbl_fun1 = round(dbl),
+      dbl_fun2 = sqrt(dbl),
+      dbl2_fun1 = round(dbl2),
+      dbl2_fun2 = sqrt(dbl2)
+    ),
+    example_data
+  )
+
+  # mix of named and unnamed functions
+  expect_across_equal(
+    quos(across(c(dbl, dbl2), list(round, "fun2" = sqrt))),
+    quos(
+      dbl_1 = round(dbl),
+      dbl_fun2 = sqrt(dbl),
+      dbl2_1 = round(dbl2),
+      dbl2_fun2 = sqrt(dbl2)
+    ),
+    example_data
+  )
+
+  # across() with no functions returns columns unchanged
+  expect_across_equal(
+    quos(across(starts_with("dbl"))),
+    quos(
+      dbl = dbl,
+      dbl2 = dbl2
+    ),
+    example_data
+  )
+
+  # across() arguments not in default order
+  expect_across_equal(
+    quos(across(.fns = round, c(dbl, dbl2))),
+    quos(
+      dbl = round(dbl),
+      dbl2 = round(dbl2)
+    ),
+    example_data
+  )
+
+  # across() with no columns named
+  expect_across_equal(
+    quos(across(.fns = round)),
+    quos(
+      int = round(int),
+      dbl = round(dbl),
+      dbl2 = round(dbl2)
+    ),
+    example_data %>% select(int, dbl, dbl2)
+  )
+
+  # column selection via dynamic variable name
+  int <- c("dbl", "dbl2")
+  expect_across_equal(
+    quos(across(all_of(int), sqrt)),
+    quos(
+      dbl = sqrt(dbl),
+      dbl2 = sqrt(dbl2)
+    ),
+    example_data
+  )
+
+  # ellipses (...) are a deprecated argument
+  expect_error(
+    expand_across(
+      example_data,
+      quos(across(c(dbl, dbl2), round, digits = -1))
+    ),
+    regexp = "`...` argument to `across()` is deprecated in dplyr and not 
supported in Arrow",
+    fixed = TRUE
+  )
+
+  # alternative ways of specifying .fns - as a list
+  expect_across_equal(
+    quos(across(1:dbl2, list(round))),
+    quos(
+      int_1 = round(int),
+      dbl_1 = round(dbl),
+      dbl2_1 = round(dbl2)
+    ),
+    example_data
+  )
+
+  # supply .fns as a one-item vector
+  expect_across_equal(
+    quos(across(1:dbl2, c(round))),
+    quos(
+      int_1 = round(int),
+      dbl_1 = round(dbl),
+      dbl2_1 = round(dbl2)
+    ),
+    example_data
+  )
+
+  # ARROW-17366: purrr-style lambda functions not yet supported
+  expect_error(
+    expand_across(
+      example_data,
+      quos(across(1:dbl2, ~ round(.x, digits = -1)))
+    ),
+    regexp = "purrr-style lambda functions as `.fns` argument to `across()` 
not yet supported in Arrow",
+    fixed = TRUE
+  )
+
+  # .names argument
+  expect_across_equal(
+    quos(across(c(dbl, dbl2), round, .names = "{.col}.{.fn}")),
+    quos(
+      dbl.1 = round(dbl),
+      dbl2.1 = round(dbl2)
+    ),
+    example_data
+  )
+
+  # names argument with custom text
+  expect_across_equal(
+    quos(across(c(dbl, dbl2), round, .names = "round_{.col}")),
+    quos(
+      round_dbl = round(dbl),
+      round_dbl2 = round(dbl2)
+    ),
+    example_data
+  )
+
+  # names argument supplied but no functions
+  expect_across_equal(
+    quos(across(starts_with("dbl"), .names = "new_{.col}")),
+    quos(
+      new_dbl = dbl,
+      new_dbl2 = dbl2
+    ),
+    example_data
+  )
+
+  # .names argument and functions named
+  expect_across_equal(
+    quos(across(c(dbl, dbl2), list("my_round" = round, "my_exp" = exp), .names 
= "{.col}.{.fn}")),
+    quos(
+      dbl.my_round = round(dbl),
+      dbl.my_exp = exp(dbl),
+      dbl2.my_round = round(dbl2),
+      dbl2.my_exp = exp(dbl2)
+    ),
+    example_data
+  )
+
+  # .names argument and mix of named and unnamed functions
+  expect_across_equal(
+    quos(across(c(dbl, dbl2), list(round, "my_exp" = exp), .names = 
"{.col}.{.fn}")),
+    quos(
+      dbl.1 = round(dbl),
+      dbl.my_exp = exp(dbl),
+      dbl2.1 = round(dbl2),
+      dbl2.my_exp = exp(dbl2)
+    ),
+    example_data
+  )
+
+  # dodgy .names specification
+  expect_error(
+    expand_across(
+      example_data,
+      quos(across(c(dbl, dbl2), list(round, "my_exp" = exp), .names = "zarg"))
+    ),
+    regexp = "`.names` specification must produce (number of columns * number 
of functions) names.",
+    fixed = TRUE
+  )
+
+})
diff --git a/r/tests/testthat/test-dplyr-mutate.R 
b/r/tests/testthat/test-dplyr-mutate.R
index 66e3b4edf0..f1de5c7045 100644
--- a/r/tests/testthat/test-dplyr-mutate.R
+++ b/r/tests/testthat/test-dplyr-mutate.R
@@ -279,14 +279,13 @@ test_that("dplyr::mutate's examples", {
   # Examples we don't support should succeed
   # but warn that they're pulling data into R to do so
 
-  # across and autosplicing: ARROW-11699
+  # test modified from version in dplyr::mutate due to ARROW-12632
   compare_dplyr_binding(
     .input %>%
-      select(name, homeworld, species) %>%
-      mutate(across(!name, as.factor)) %>%
+      select(name, height, mass) %>%
+      mutate(across(!name, as.character)) %>%
       collect(),
     starwars,
-    warning = "Expression across.*not supported in Arrow"
   )
 
   # group_by then mutate
@@ -589,3 +588,57 @@ test_that("mutate() and transmute() with namespaced 
functions", {
     tbl
   )
 })
+
+test_that("Can use across() within mutate()", {
+
+  # expressions work in the right order
+  compare_dplyr_binding(
+    .input %>%
+      mutate(
+        dbl2 = dbl * 2,
+        across(c(dbl, dbl2), round),
+        int2 = int * 2,
+        dbl = dbl + 3
+      ) %>%
+      collect(),
+    example_data
+  )
+
+  # this is valid is neither R nor Arrow
+  expect_error(
+    expect_warning(
+      compare_dplyr_binding(
+        .input %>%
+          arrow_table() %>%
+          mutate(across(c(dbl, dbl2), list("fun1" = round(sqrt(dbl))))) %>%
+          collect(),
+        example_data,
+        warning = TRUE
+      )
+    )
+  )
+
+  # ARROW-12778 - `where()` is not yet supported
+  expect_error(
+    compare_dplyr_binding(
+      .input %>%
+        mutate(across(where(is.double))) %>%
+        collect(),
+      example_data
+    ),
+    "Unsupported selection helper"
+  )
+
+  # gives the right error with window functions
+  expect_warning(
+    arrow_table(example_data) %>%
+      mutate(
+        x = int + 2,
+        across(c("int", "dbl"), list(mean = mean, sd = sd, round)),
+        exp(dbl2)
+      ) %>%
+      collect(),
+    "window functions not currently supported in Arrow; pulling data into R",
+    fixed = TRUE
+  )
+})

Reply via email to