This is an automated email from the ASF dual-hosted git repository.
thisisnic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new d5f80cbe2b ARROW-11699: [R] Implement dplyr::across() for mutate()
d5f80cbe2b is described below
commit d5f80cbe2b2e8801127639b15fd24f829478ea84
Author: Nic Crane <[email protected]>
AuthorDate: Thu Sep 1 16:57:22 2022 +0100
ARROW-11699: [R] Implement dplyr::across() for mutate()
This PR introduces a partial implementation for `dplyr::across()` when
called within `dplyr::mutate()`.
``` r
arrow_table(iris) %>%
mutate(across(starts_with("Sepal"), list(round, sqrt)))
#> Table (query)
#> Sepal.Length: double
#> Sepal.Width: double
#> Petal.Length: double
#> Petal.Width: double
#> Species: dictionary<values=string, indices=int8>
#> Sepal.Length_1: double (round(Sepal.Length, {ndigits=0,
round_mode=HALF_TO_EVEN}))
#> Sepal.Length_2: double (sqrt_checked(Sepal.Length))
#> Sepal.Width_1: double (round(Sepal.Width, {ndigits=0,
round_mode=HALF_TO_EVEN}))
#> Sepal.Width_2: double (sqrt_checked(Sepal.Width))
#>
#> See $.data for the source Arrow object
```
I've opened a number of follow-up tickets for the tasks needed to be done
to provide a more complete implementation:
* [ARROW-17362: [R] Implement dplyr::across() inside
summarise()](https://issues.apache.org/jira/browse/ARROW-17362)
* [ARROW-17387: [R] Implement dplyr::across() inside
filter()](https://issues.apache.org/jira/browse/ARROW-17387)
* ~[ARROW-17364: [R] Implement .names argument inside
across()](https://issues.apache.org/jira/browse/ARROW-17364)~ (now done in this
PR, will close it once this is merged)
* [ARROW-17366: [R] Support purrr-style lambda functions in .fns argument
to across()](https://issues.apache.org/jira/browse/ARROW-17366)
Closes #13786 from thisisnic/ARROW-11699_across
Authored-by: Nic Crane <[email protected]>
Signed-off-by: Nic Crane <[email protected]>
---
r/DESCRIPTION | 2 +
r/NAMESPACE | 7 ++
r/R/arrow-package.R | 4 +-
r/R/dplyr-across.R | 177 ++++++++++++++++++++++++++
r/R/dplyr-mutate.R | 4 +-
r/tests/testthat/helper-expectation.R | 4 +
r/tests/testthat/test-dplyr-across.R | 226 ++++++++++++++++++++++++++++++++++
r/tests/testthat/test-dplyr-mutate.R | 61 ++++++++-
8 files changed, 479 insertions(+), 6 deletions(-)
diff --git a/r/DESCRIPTION b/r/DESCRIPTION
index 95c1405869..a728be3773 100644
--- a/r/DESCRIPTION
+++ b/r/DESCRIPTION
@@ -31,6 +31,7 @@ Biarch: true
Imports:
assertthat,
bit64 (>= 0.9-7),
+ glue,
methods,
purrr,
R6,
@@ -91,6 +92,7 @@ Collate:
'dataset-scan.R'
'dataset-write.R'
'dictionary.R'
+ 'dplyr-across.R'
'dplyr-arrange.R'
'dplyr-collect.R'
'dplyr-count.R'
diff --git a/r/NAMESPACE b/r/NAMESPACE
index c4c18ba16d..49db309b8e 100644
--- a/r/NAMESPACE
+++ b/r/NAMESPACE
@@ -390,6 +390,7 @@ importFrom(assertthat,assert_that)
importFrom(assertthat,is.string)
importFrom(bit64,print.integer64)
importFrom(bit64,str.integer64)
+importFrom(glue,glue)
importFrom(methods,as)
importFrom(purrr,as_mapper)
importFrom(purrr,flatten)
@@ -413,6 +414,7 @@ importFrom(rlang,as_function)
importFrom(rlang,as_label)
importFrom(rlang,as_quosure)
importFrom(rlang,call2)
+importFrom(rlang,call_args)
importFrom(rlang,caller_env)
importFrom(rlang,dots_n)
importFrom(rlang,enexpr)
@@ -425,20 +427,25 @@ importFrom(rlang,eval_tidy)
importFrom(rlang,exec)
importFrom(rlang,expr)
importFrom(rlang,is_bare_character)
+importFrom(rlang,is_call)
importFrom(rlang,is_character)
importFrom(rlang,is_empty)
importFrom(rlang,is_false)
+importFrom(rlang,is_formula)
importFrom(rlang,is_integerish)
importFrom(rlang,is_interactive)
importFrom(rlang,is_list)
importFrom(rlang,is_quosure)
+importFrom(rlang,is_symbol)
importFrom(rlang,list2)
importFrom(rlang,new_data_mask)
importFrom(rlang,new_environment)
importFrom(rlang,quo_get_env)
importFrom(rlang,quo_get_expr)
+importFrom(rlang,quo_is_call)
importFrom(rlang,quo_is_null)
importFrom(rlang,quo_name)
+importFrom(rlang,quo_set_env)
importFrom(rlang,quo_set_expr)
importFrom(rlang,quos)
importFrom(rlang,seq2)
diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R
index f3e0b817d5..e8aa93f953 100644
--- a/r/R/arrow-package.R
+++ b/r/R/arrow-package.R
@@ -23,8 +23,10 @@
#' @importFrom rlang eval_tidy new_data_mask syms env new_environment env_bind
set_names exec
#' @importFrom rlang is_bare_character quo_get_expr quo_get_env quo_set_expr
.data seq2 is_interactive
#' @importFrom rlang expr caller_env is_character quo_name is_quosure enexpr
enexprs as_quosure
-#' @importFrom rlang is_list call2 is_empty as_function as_label arg_match
+#' @importFrom rlang is_list call2 is_empty as_function as_label arg_match
is_symbol is_call call_args
+#' @importFrom rlang quo_set_env quo_get_env is_formula quo_is_call
#' @importFrom tidyselect vars_pull vars_rename vars_select eval_select
+#' @importFrom glue glue
#' @useDynLib arrow, .registration = TRUE
#' @keywords internal
"_PACKAGE"
diff --git a/r/R/dplyr-across.R b/r/R/dplyr-across.R
new file mode 100644
index 0000000000..01a9262b81
--- /dev/null
+++ b/r/R/dplyr-across.R
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+expand_across <- function(.data, quos_in) {
+ quos_out <- list()
+ # retrieve items using their values to preserve naming of quos other than
across
+ for (quo_i in seq_along(quos_in)) {
+ quo_in <- quos_in[quo_i]
+ quo_expr <- quo_get_expr(quo_in[[1]])
+ quo_env <- quo_get_env(quo_in[[1]])
+
+ if (is_call(quo_expr, "across")) {
+ new_quos <- list()
+
+ across_call <- match.call(
+ definition = dplyr::across,
+ call = quo_expr,
+ expand.dots = FALSE,
+ envir = quo_env
+ )
+
+ if (!all(names(across_call[-1]) %in% c(".cols", ".fns", ".names"))) {
+ abort("`...` argument to `across()` is deprecated in dplyr and not
supported in Arrow")
+ }
+
+ if (!is.null(across_call[[".cols"]])) {
+ cols <- across_call[[".cols"]]
+ } else {
+ cols <- quote(everything())
+ }
+
+ setup <- across_setup(
+ cols = !!as_quosure(cols, quo_env),
+ fns = across_call[[".fns"]],
+ names = across_call[[".names"]],
+ .caller_env = quo_env,
+ mask = .data,
+ inline = TRUE
+ )
+
+ if (!is_list(setup$fns) && !is.null(setup$fns) &&
as.character(setup$fns)[[1]] == "~") {
+ abort(
+ paste(
+ "purrr-style lambda functions as `.fns` argument to `across()`",
+ "not yet supported in Arrow"
+ )
+ )
+ }
+
+ new_quos <- quosures_from_setup(setup, quo_env)
+
+ quos_out <- append(quos_out, new_quos)
+ } else {
+ quos_out <- append(quos_out, quo_in)
+ }
+ }
+
+ quos_out
+}
+
+# given a named list of functions and column names, create a list of new
quosures
+quosures_from_setup <- function(setup, quo_env) {
+ if (!is.null(setup$fns)) {
+ func_list_full <- rep(setup$fns, length(setup$vars))
+ cols_list_full <- rep(setup$vars, each = length(setup$fns))
+
+ # get new quosures
+ new_quo_list <- map2(
+ func_list_full, cols_list_full,
+ ~ quo(!!call2(.x, sym(.y)))
+ )
+ } else {
+ # if there's no functions, just map to variables themselves
+ new_quo_list <- map(
+ setup$vars,
+ ~ quo(!!sym(.x))
+ )
+ }
+
+ quosures <- set_names(new_quo_list, setup$names)
+ map(quosures, ~ quo_set_env(.x, quo_env))
+}
+
+across_setup <- function(cols, fns, names, .caller_env, mask, inline = FALSE) {
+ cols <- enquo(cols)
+
+ vars <- names(dplyr::select(mask, !!cols))
+
+ if (is.null(fns)) {
+ if (!is.null(names)) {
+ glue_mask <- across_glue_mask(.caller_env, .col = vars, .fn = "1")
+ names <- vctrs::vec_as_names(glue::glue(names, .envir = glue_mask),
repair = "check_unique")
+ } else {
+ names <- vars
+ }
+
+ value <- list(vars = vars, fns = fns, names = names)
+ return(value)
+ }
+
+ # apply `.names` smart default
+ if (is.function(fns) || is_formula(fns) || is.name(fns)) {
+ names <- names %||% "{.col}"
+ fns <- list("1" = fns)
+ } else {
+ names <- names %||% "{.col}_{.fn}"
+ fns <- call_args(fns)
+ }
+
+ if (any(map_lgl(fns, is_formula))) {
+ abort(
+ paste(
+ "purrr-style lambda functions as `.fns` argument to `across()`",
+ "not yet supported in Arrow"
+ )
+ )
+ }
+
+ if (!is.list(fns)) {
+ msg <- c("`.fns` must be NULL, a function, a formula, or a list of
functions/formulas.")
+ abort(msg)
+ }
+
+ # make sure fns has names, use number to replace unnamed
+ if (is.null(names(fns))) {
+ names_fns <- seq_along(fns)
+ } else {
+ names_fns <- names(fns)
+ empties <- which(names_fns == "")
+ if (length(empties)) {
+ names_fns[empties] <- empties
+ }
+ }
+
+ glue_mask <- across_glue_mask(.caller_env,
+ .col = rep(vars, each = length(fns)),
+ .fn = rep(names_fns, length(vars))
+ )
+ names <- vctrs::vec_as_names(glue::glue(names, .envir = glue_mask), repair =
"check_unique")
+
+ if (!inline) {
+ fns <- map(fns, as_function)
+ }
+
+ # ensure .names argument has resulted in
+ if (length(names) != (length(vars) * length(fns))) {
+ abort(
+ c(
+ "`.names` specification must produce (number of columns * number of
functions) names.",
+ x = paste0(
+ length(vars) * length(fns), " names required (", length(vars), "
columns * ", length(fns), " functions)\n ",
+ length(names), " name(s) produced: ", paste(names, collapse = ",")
+ )
+ )
+ )
+ }
+
+ list(vars = vars, fns = fns, names = names)
+}
+
+across_glue_mask <- function(.col, .fn, .caller_env) {
+ env(.caller_env, .col = .col, .fn = .fn, col = .col, fn = .fn)
+}
diff --git a/r/R/dplyr-mutate.R b/r/R/dplyr-mutate.R
index 653c1e6f25..ac555fafe0 100644
--- a/r/R/dplyr-mutate.R
+++ b/r/R/dplyr-mutate.R
@@ -24,7 +24,9 @@ mutate.arrow_dplyr_query <- function(.data,
.before = NULL,
.after = NULL) {
call <- match.call()
- exprs <- ensure_named_exprs(quos(...))
+
+ expression_list <- expand_across(.data, quos(...))
+ exprs <- ensure_named_exprs(expression_list)
.keep <- match.arg(.keep)
.before <- enquo(.before)
diff --git a/r/tests/testthat/helper-expectation.R
b/r/tests/testthat/helper-expectation.R
index eb2e6b0219..ba11700ab6 100644
--- a/r/tests/testthat/helper-expectation.R
+++ b/r/tests/testthat/helper-expectation.R
@@ -321,3 +321,7 @@ split_vector_as_list <- function(vec) {
vec2 <- vec[seq(from = min(length(vec), vec_split + 1), to = length(vec), by
= 1)]
list(vec1, vec2)
}
+
+expect_across_equal <- function(actual, expected, tbl) {
+ expect_identical(expand_across(tbl, actual), as.list(expected))
+}
diff --git a/r/tests/testthat/test-dplyr-across.R
b/r/tests/testthat/test-dplyr-across.R
new file mode 100644
index 0000000000..8945c2a5f3
--- /dev/null
+++ b/r/tests/testthat/test-dplyr-across.R
@@ -0,0 +1,226 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+library(dplyr, warn.conflicts = FALSE)
+
+test_that("expand_across correctly expands quosures", {
+
+ # single unnamed function
+ expect_across_equal(
+ quos(across(c(dbl, dbl2), round)),
+ quos(
+ dbl = round(dbl),
+ dbl2 = round(dbl2)
+ ),
+ example_data
+ )
+
+ # multiple unnamed functions
+ expect_across_equal(
+ quos(across(c(dbl, dbl2), list(exp, sqrt))),
+ quos(
+ dbl_1 = exp(dbl),
+ dbl_2 = sqrt(dbl),
+ dbl2_1 = exp(dbl2),
+ dbl2_2 = sqrt(dbl2)
+ ),
+ example_data
+ )
+
+ # single named function
+ expect_across_equal(
+ quos(across(c(dbl, dbl2), list("fun1" = round))),
+ quos(
+ dbl_fun1 = round(dbl),
+ dbl2_fun1 = round(dbl2)
+ ),
+ example_data
+ )
+
+ # multiple named functions
+ expect_across_equal(
+ quos(across(c(dbl, dbl2), list("fun1" = round, "fun2" = sqrt))),
+ quos(
+ dbl_fun1 = round(dbl),
+ dbl_fun2 = sqrt(dbl),
+ dbl2_fun1 = round(dbl2),
+ dbl2_fun2 = sqrt(dbl2)
+ ),
+ example_data
+ )
+
+ # mix of named and unnamed functions
+ expect_across_equal(
+ quos(across(c(dbl, dbl2), list(round, "fun2" = sqrt))),
+ quos(
+ dbl_1 = round(dbl),
+ dbl_fun2 = sqrt(dbl),
+ dbl2_1 = round(dbl2),
+ dbl2_fun2 = sqrt(dbl2)
+ ),
+ example_data
+ )
+
+ # across() with no functions returns columns unchanged
+ expect_across_equal(
+ quos(across(starts_with("dbl"))),
+ quos(
+ dbl = dbl,
+ dbl2 = dbl2
+ ),
+ example_data
+ )
+
+ # across() arguments not in default order
+ expect_across_equal(
+ quos(across(.fns = round, c(dbl, dbl2))),
+ quos(
+ dbl = round(dbl),
+ dbl2 = round(dbl2)
+ ),
+ example_data
+ )
+
+ # across() with no columns named
+ expect_across_equal(
+ quos(across(.fns = round)),
+ quos(
+ int = round(int),
+ dbl = round(dbl),
+ dbl2 = round(dbl2)
+ ),
+ example_data %>% select(int, dbl, dbl2)
+ )
+
+ # column selection via dynamic variable name
+ int <- c("dbl", "dbl2")
+ expect_across_equal(
+ quos(across(all_of(int), sqrt)),
+ quos(
+ dbl = sqrt(dbl),
+ dbl2 = sqrt(dbl2)
+ ),
+ example_data
+ )
+
+ # ellipses (...) are a deprecated argument
+ expect_error(
+ expand_across(
+ example_data,
+ quos(across(c(dbl, dbl2), round, digits = -1))
+ ),
+ regexp = "`...` argument to `across()` is deprecated in dplyr and not
supported in Arrow",
+ fixed = TRUE
+ )
+
+ # alternative ways of specifying .fns - as a list
+ expect_across_equal(
+ quos(across(1:dbl2, list(round))),
+ quos(
+ int_1 = round(int),
+ dbl_1 = round(dbl),
+ dbl2_1 = round(dbl2)
+ ),
+ example_data
+ )
+
+ # supply .fns as a one-item vector
+ expect_across_equal(
+ quos(across(1:dbl2, c(round))),
+ quos(
+ int_1 = round(int),
+ dbl_1 = round(dbl),
+ dbl2_1 = round(dbl2)
+ ),
+ example_data
+ )
+
+ # ARROW-17366: purrr-style lambda functions not yet supported
+ expect_error(
+ expand_across(
+ example_data,
+ quos(across(1:dbl2, ~ round(.x, digits = -1)))
+ ),
+ regexp = "purrr-style lambda functions as `.fns` argument to `across()`
not yet supported in Arrow",
+ fixed = TRUE
+ )
+
+ # .names argument
+ expect_across_equal(
+ quos(across(c(dbl, dbl2), round, .names = "{.col}.{.fn}")),
+ quos(
+ dbl.1 = round(dbl),
+ dbl2.1 = round(dbl2)
+ ),
+ example_data
+ )
+
+ # names argument with custom text
+ expect_across_equal(
+ quos(across(c(dbl, dbl2), round, .names = "round_{.col}")),
+ quos(
+ round_dbl = round(dbl),
+ round_dbl2 = round(dbl2)
+ ),
+ example_data
+ )
+
+ # names argument supplied but no functions
+ expect_across_equal(
+ quos(across(starts_with("dbl"), .names = "new_{.col}")),
+ quos(
+ new_dbl = dbl,
+ new_dbl2 = dbl2
+ ),
+ example_data
+ )
+
+ # .names argument and functions named
+ expect_across_equal(
+ quos(across(c(dbl, dbl2), list("my_round" = round, "my_exp" = exp), .names
= "{.col}.{.fn}")),
+ quos(
+ dbl.my_round = round(dbl),
+ dbl.my_exp = exp(dbl),
+ dbl2.my_round = round(dbl2),
+ dbl2.my_exp = exp(dbl2)
+ ),
+ example_data
+ )
+
+ # .names argument and mix of named and unnamed functions
+ expect_across_equal(
+ quos(across(c(dbl, dbl2), list(round, "my_exp" = exp), .names =
"{.col}.{.fn}")),
+ quos(
+ dbl.1 = round(dbl),
+ dbl.my_exp = exp(dbl),
+ dbl2.1 = round(dbl2),
+ dbl2.my_exp = exp(dbl2)
+ ),
+ example_data
+ )
+
+ # dodgy .names specification
+ expect_error(
+ expand_across(
+ example_data,
+ quos(across(c(dbl, dbl2), list(round, "my_exp" = exp), .names = "zarg"))
+ ),
+ regexp = "`.names` specification must produce (number of columns * number
of functions) names.",
+ fixed = TRUE
+ )
+
+})
diff --git a/r/tests/testthat/test-dplyr-mutate.R
b/r/tests/testthat/test-dplyr-mutate.R
index 66e3b4edf0..f1de5c7045 100644
--- a/r/tests/testthat/test-dplyr-mutate.R
+++ b/r/tests/testthat/test-dplyr-mutate.R
@@ -279,14 +279,13 @@ test_that("dplyr::mutate's examples", {
# Examples we don't support should succeed
# but warn that they're pulling data into R to do so
- # across and autosplicing: ARROW-11699
+ # test modified from version in dplyr::mutate due to ARROW-12632
compare_dplyr_binding(
.input %>%
- select(name, homeworld, species) %>%
- mutate(across(!name, as.factor)) %>%
+ select(name, height, mass) %>%
+ mutate(across(!name, as.character)) %>%
collect(),
starwars,
- warning = "Expression across.*not supported in Arrow"
)
# group_by then mutate
@@ -589,3 +588,57 @@ test_that("mutate() and transmute() with namespaced
functions", {
tbl
)
})
+
+test_that("Can use across() within mutate()", {
+
+ # expressions work in the right order
+ compare_dplyr_binding(
+ .input %>%
+ mutate(
+ dbl2 = dbl * 2,
+ across(c(dbl, dbl2), round),
+ int2 = int * 2,
+ dbl = dbl + 3
+ ) %>%
+ collect(),
+ example_data
+ )
+
+ # this is valid is neither R nor Arrow
+ expect_error(
+ expect_warning(
+ compare_dplyr_binding(
+ .input %>%
+ arrow_table() %>%
+ mutate(across(c(dbl, dbl2), list("fun1" = round(sqrt(dbl))))) %>%
+ collect(),
+ example_data,
+ warning = TRUE
+ )
+ )
+ )
+
+ # ARROW-12778 - `where()` is not yet supported
+ expect_error(
+ compare_dplyr_binding(
+ .input %>%
+ mutate(across(where(is.double))) %>%
+ collect(),
+ example_data
+ ),
+ "Unsupported selection helper"
+ )
+
+ # gives the right error with window functions
+ expect_warning(
+ arrow_table(example_data) %>%
+ mutate(
+ x = int + 2,
+ across(c("int", "dbl"), list(mean = mean, sd = sd, round)),
+ exp(dbl2)
+ ) %>%
+ collect(),
+ "window functions not currently supported in Arrow; pulling data into R",
+ fixed = TRUE
+ )
+})