This is an automated email from the ASF dual-hosted git repository.
jonkeane pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 1b3caf6b23 GH-29642: [R] Support for .keep_all = TRUE with distinct()
(#44652)
1b3caf6b23 is described below
commit 1b3caf6b232b7855956d3ec45ee95ede0492e78f
Author: Neal Richardson <[email protected]>
AuthorDate: Sat Dec 7 10:04:08 2024 -0500
GH-29642: [R] Support for .keep_all = TRUE with distinct() (#44652)
### Rationale for this change
Support a missing feature, just wiring up some stuff from R to Acero,
then adding docs and tests.
This is mostly picking up where #13934 started and finishing it out.
Thanks @mopcup for the initial lift.
### What changes are included in this PR?
An aggregation binding, some symbol manipulation, and tests. I also
cleaned up some dplyr test shims from 2022.
### Are these changes tested?
Yes, though if anyone knows of odd corners in `distinct()` that aren't
covered by this, we can add more
### Are there any user-facing changes?
Yes indeed.
* GitHub Issue: #29642
---
r/R/arrow-package.R | 5 +-
r/R/dplyr-distinct.R | 25 +++++++---
r/R/dplyr-funcs-agg.R | 7 +++
r/tests/testthat/test-dplyr-distinct.R | 89 +++++++++++++++++++++++-----------
4 files changed, 90 insertions(+), 36 deletions(-)
diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R
index 4c3b78e085..4b54697d4b 100644
--- a/r/R/arrow-package.R
+++ b/r/R/arrow-package.R
@@ -62,7 +62,10 @@ supported_dplyr_methods <- list(
relocate = NULL,
compute = NULL,
collapse = NULL,
- distinct = "`.keep_all = TRUE` not supported",
+ distinct = c(
+ "`.keep_all = TRUE` returns a non-missing value if present,",
+ "only returning missing values if all are missing."
+ ),
left_join = "the `copy` argument is ignored",
right_join = "the `copy` argument is ignored",
inner_join = "the `copy` argument is ignored",
diff --git a/r/R/dplyr-distinct.R b/r/R/dplyr-distinct.R
index 49948caa01..95fb837bd5 100644
--- a/r/R/dplyr-distinct.R
+++ b/r/R/dplyr-distinct.R
@@ -18,12 +18,6 @@
# The following S3 methods are registered on load if dplyr is present
distinct.arrow_dplyr_query <- function(.data, ..., .keep_all = FALSE) {
- if (.keep_all == TRUE) {
- # TODO(ARROW-14045): the function is called "hash_one" (from ARROW-13993)
- # May need to call it: `summarize(x = one(x), ...)` for x in non-group cols
- arrow_not_supported("`distinct()` with `.keep_all = TRUE`")
- }
-
original_gv <- dplyr::group_vars(.data)
if (length(quos(...))) {
# group_by() calls mutate() if there are any expressions in ...
@@ -33,11 +27,28 @@ distinct.arrow_dplyr_query <- function(.data, ...,
.keep_all = FALSE) {
.data <- dplyr::group_by(.data, !!!syms(names(.data)))
}
- out <- dplyr::summarize(.data, .groups = "drop")
+ if (isTRUE(.keep_all)) {
+ # Note: in regular dplyr, `.keep_all = TRUE` returns the first row's value.
+ # However, Acero's `hash_one` function prefers returning non-null values.
+ # So, you'll get the same shape of data, but the values may differ.
+ keeps <- names(.data)[!(names(.data) %in% .data$group_by_vars)]
+ exprs <- lapply(keeps, function(x) call2("one", sym(x)))
+ names(exprs) <- keeps
+ } else {
+ exprs <- list()
+ }
+
+ out <- dplyr::summarize(.data, !!!exprs, .groups = "drop")
+
# distinct() doesn't modify group by vars, so restore the original ones
if (length(original_gv)) {
out$group_by_vars <- original_gv
}
+ if (isTRUE(.keep_all)) {
+ # Also ensure the column order matches the original
+ # summarize() will put the group_by_vars first
+ out <- dplyr::select(out, !!!syms(names(.data)))
+ }
out
}
diff --git a/r/R/dplyr-funcs-agg.R b/r/R/dplyr-funcs-agg.R
index 340ebe7adc..275fca3654 100644
--- a/r/R/dplyr-funcs-agg.R
+++ b/r/R/dplyr-funcs-agg.R
@@ -150,6 +150,13 @@ register_bindings_aggregate <- function() {
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
+ register_binding("arrow::one", function(...) {
+ set_agg(
+ fun = "one",
+ data = ensure_one_arg(list2(...), "one"),
+ options = list()
+ )
+ })
}
set_agg <- function(...) {
diff --git a/r/tests/testthat/test-dplyr-distinct.R
b/r/tests/testthat/test-dplyr-distinct.R
index 4c7f8894cd..e4d789e8e9 100644
--- a/r/tests/testthat/test-dplyr-distinct.R
+++ b/r/tests/testthat/test-dplyr-distinct.R
@@ -26,11 +26,8 @@ test_that("distinct()", {
compare_dplyr_binding(
.input %>%
distinct(some_grouping, lgl) %>%
- collect() %>%
- # GH-14947: column output order changed in dplyr 1.1.0, so we need
- # to make the column order explicit until dplyr 1.1.0 is on CRAN
- select(some_grouping, lgl) %>%
- arrange(some_grouping, lgl),
+ arrange(some_grouping, lgl) %>%
+ collect(),
tbl
)
})
@@ -60,11 +57,8 @@ test_that("distinct() can retain groups", {
.input %>%
group_by(some_grouping, int) %>%
distinct(lgl) %>%
- collect() %>%
- # GH-14947: column output order changed in dplyr 1.1.0, so we need
- # to make the column order explicit until dplyr 1.1.0 is on CRAN
- select(some_grouping, int, lgl) %>%
- arrange(lgl, int),
+ arrange(lgl, int) %>%
+ collect(),
tbl
)
@@ -73,11 +67,8 @@ test_that("distinct() can retain groups", {
.input %>%
group_by(y = some_grouping, int) %>%
distinct(x = lgl) %>%
- collect() %>%
- # GH-14947: column output order changed in dplyr 1.1.0, so we need
- # to make the column order explicit until dplyr 1.1.0 is on CRAN
- select(y, int, x) %>%
- arrange(int),
+ arrange(int) %>%
+ collect(),
tbl
)
})
@@ -95,11 +86,8 @@ test_that("distinct() can contain expressions", {
.input %>%
group_by(lgl, int) %>%
distinct(x = some_grouping + 1) %>%
- collect() %>%
- # GH-14947: column output order changed in dplyr 1.1.0, so we need
- # to make the column order explicit until dplyr 1.1.0 is on CRAN
- select(lgl, int, x) %>%
- arrange(int),
+ arrange(int) %>%
+ collect(),
tbl
)
})
@@ -115,12 +103,57 @@ test_that("across() works in distinct()", {
})
test_that("distinct() can return all columns", {
- skip("ARROW-14045")
- compare_dplyr_binding(
- .input %>%
- distinct(lgl, .keep_all = TRUE) %>%
- collect() %>%
- arrange(int),
- tbl
- )
+ # hash_one prefers to keep non-null values, which is different from
.keep_all in dplyr
+ # so we can't compare the result directly
+ expected <- tbl %>%
+ # Drop factor because of #44661:
+ # NotImplemented: Function 'hash_one' has no kernel matching input types
+ # (dictionary<values=string, indices=int8, ordered=0>, uint8)
+ select(-fct) %>%
+ distinct(lgl, .keep_all = TRUE) %>%
+ arrange(int)
+
+ with_table <- tbl %>%
+ arrow_table() %>%
+ select(-fct) %>%
+ distinct(lgl, .keep_all = TRUE) %>%
+ arrange(int) %>%
+ collect()
+
+ expect_identical(dim(with_table), dim(expected))
+ expect_identical(names(with_table), names(expected))
+
+ # Test with some mutation in there
+ expected <- tbl %>%
+ select(-fct) %>%
+ distinct(lgl, bigger = int * 10L, .keep_all = TRUE) %>%
+ arrange(int)
+
+ with_table <- tbl %>%
+ arrow_table() %>%
+ select(-fct) %>%
+ distinct(lgl, bigger = int * 10, .keep_all = TRUE) %>%
+ arrange(int) %>%
+ collect()
+
+ expect_identical(dim(with_table), dim(expected))
+ expect_identical(names(with_table), names(expected))
+ expect_identical(with_table$bigger, expected$bigger)
+
+ # Mutation that overwrites
+ expected <- tbl %>%
+ select(-fct) %>%
+ distinct(lgl, int = int * 10L, .keep_all = TRUE) %>%
+ arrange(int)
+
+ with_table <- tbl %>%
+ arrow_table() %>%
+ select(-fct) %>%
+ distinct(lgl, int = int * 10, .keep_all = TRUE) %>%
+ arrange(int) %>%
+ collect()
+
+ expect_identical(dim(with_table), dim(expected))
+ expect_identical(names(with_table), names(expected))
+ expect_identical(with_table$int, expected$int)
})