This is an automated email from the ASF dual-hosted git repository.

jonkeane pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1b3caf6b23 GH-29642: [R] Support for .keep_all = TRUE with distinct() 
(#44652)
1b3caf6b23 is described below

commit 1b3caf6b232b7855956d3ec45ee95ede0492e78f
Author: Neal Richardson <[email protected]>
AuthorDate: Sat Dec 7 10:04:08 2024 -0500

    GH-29642: [R] Support for .keep_all = TRUE with distinct() (#44652)
    
    ### Rationale for this change
    
    Support a missing feature, just wiring up some stuff from R to Acero,
    then adding docs and tests.
    
    This is mostly picking up where #13934 started and finishing it out.
    Thanks @mopcup for the initial lift.
    
    ### What changes are included in this PR?
    
    An aggregation binding, some symbol manipulation, and tests. I also
    cleaned up some dplyr test shims from 2022.
    
    ### Are these changes tested?
    
    Yes, though if anyone knows of odd corners in `distinct()` that aren't
    covered by this, we can add more
    
    ### Are there any user-facing changes?
    
    Yes indeed.
    * GitHub Issue: #29642
---
 r/R/arrow-package.R                    |  5 +-
 r/R/dplyr-distinct.R                   | 25 +++++++---
 r/R/dplyr-funcs-agg.R                  |  7 +++
 r/tests/testthat/test-dplyr-distinct.R | 89 +++++++++++++++++++++++-----------
 4 files changed, 90 insertions(+), 36 deletions(-)

diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R
index 4c3b78e085..4b54697d4b 100644
--- a/r/R/arrow-package.R
+++ b/r/R/arrow-package.R
@@ -62,7 +62,10 @@ supported_dplyr_methods <- list(
   relocate = NULL,
   compute = NULL,
   collapse = NULL,
-  distinct = "`.keep_all = TRUE` not supported",
+  distinct = c(
+    "`.keep_all = TRUE` returns a non-missing value if present,",
+    "only returning missing values if all are missing."
+  ),
   left_join = "the `copy` argument is ignored",
   right_join = "the `copy` argument is ignored",
   inner_join = "the `copy` argument is ignored",
diff --git a/r/R/dplyr-distinct.R b/r/R/dplyr-distinct.R
index 49948caa01..95fb837bd5 100644
--- a/r/R/dplyr-distinct.R
+++ b/r/R/dplyr-distinct.R
@@ -18,12 +18,6 @@
 # The following S3 methods are registered on load if dplyr is present
 
 distinct.arrow_dplyr_query <- function(.data, ..., .keep_all = FALSE) {
-  if (.keep_all == TRUE) {
-    # TODO(ARROW-14045): the function is called "hash_one" (from ARROW-13993)
-    # May need to call it: `summarize(x = one(x), ...)` for x in non-group cols
-    arrow_not_supported("`distinct()` with `.keep_all = TRUE`")
-  }
-
   original_gv <- dplyr::group_vars(.data)
   if (length(quos(...))) {
     # group_by() calls mutate() if there are any expressions in ...
@@ -33,11 +27,28 @@ distinct.arrow_dplyr_query <- function(.data, ..., 
.keep_all = FALSE) {
     .data <- dplyr::group_by(.data, !!!syms(names(.data)))
   }
 
-  out <- dplyr::summarize(.data, .groups = "drop")
+  if (isTRUE(.keep_all)) {
+    # Note: in regular dplyr, `.keep_all = TRUE` returns the first row's value.
+    # However, Acero's `hash_one` function prefers returning non-null values.
+    # So, you'll get the same shape of data, but the values may differ.
+    keeps <- names(.data)[!(names(.data) %in% .data$group_by_vars)]
+    exprs <- lapply(keeps, function(x) call2("one", sym(x)))
+    names(exprs) <- keeps
+  } else {
+    exprs <- list()
+  }
+
+  out <- dplyr::summarize(.data, !!!exprs, .groups = "drop")
+
   # distinct() doesn't modify group by vars, so restore the original ones
   if (length(original_gv)) {
     out$group_by_vars <- original_gv
   }
+  if (isTRUE(.keep_all)) {
+    # Also ensure the column order matches the original
+    # summarize() will put the group_by_vars first
+    out <- dplyr::select(out, !!!syms(names(.data)))
+  }
   out
 }
 
diff --git a/r/R/dplyr-funcs-agg.R b/r/R/dplyr-funcs-agg.R
index 340ebe7adc..275fca3654 100644
--- a/r/R/dplyr-funcs-agg.R
+++ b/r/R/dplyr-funcs-agg.R
@@ -150,6 +150,13 @@ register_bindings_aggregate <- function() {
       options = list(skip_nulls = na.rm, min_count = 0L)
     )
   })
+  register_binding("arrow::one", function(...) {
+    set_agg(
+      fun = "one",
+      data = ensure_one_arg(list2(...), "one"),
+      options = list()
+    )
+  })
 }
 
 set_agg <- function(...) {
diff --git a/r/tests/testthat/test-dplyr-distinct.R 
b/r/tests/testthat/test-dplyr-distinct.R
index 4c7f8894cd..e4d789e8e9 100644
--- a/r/tests/testthat/test-dplyr-distinct.R
+++ b/r/tests/testthat/test-dplyr-distinct.R
@@ -26,11 +26,8 @@ test_that("distinct()", {
   compare_dplyr_binding(
     .input %>%
       distinct(some_grouping, lgl) %>%
-      collect() %>%
-      # GH-14947: column output order changed in dplyr 1.1.0, so we need
-      # to make the column order explicit until dplyr 1.1.0 is on CRAN
-      select(some_grouping, lgl) %>%
-      arrange(some_grouping, lgl),
+      arrange(some_grouping, lgl) %>%
+      collect(),
     tbl
   )
 })
@@ -60,11 +57,8 @@ test_that("distinct() can retain groups", {
     .input %>%
       group_by(some_grouping, int) %>%
       distinct(lgl) %>%
-      collect() %>%
-      # GH-14947: column output order changed in dplyr 1.1.0, so we need
-      # to make the column order explicit until dplyr 1.1.0 is on CRAN
-      select(some_grouping, int, lgl) %>%
-      arrange(lgl, int),
+      arrange(lgl, int) %>%
+      collect(),
     tbl
   )
 
@@ -73,11 +67,8 @@ test_that("distinct() can retain groups", {
     .input %>%
       group_by(y = some_grouping, int) %>%
       distinct(x = lgl) %>%
-      collect() %>%
-      # GH-14947: column output order changed in dplyr 1.1.0, so we need
-      # to make the column order explicit until dplyr 1.1.0 is on CRAN
-      select(y, int, x) %>%
-      arrange(int),
+      arrange(int) %>%
+      collect(),
     tbl
   )
 })
@@ -95,11 +86,8 @@ test_that("distinct() can contain expressions", {
     .input %>%
       group_by(lgl, int) %>%
       distinct(x = some_grouping + 1) %>%
-      collect() %>%
-      # GH-14947: column output order changed in dplyr 1.1.0, so we need
-      # to make the column order explicit until dplyr 1.1.0 is on CRAN
-      select(lgl, int, x) %>%
-      arrange(int),
+      arrange(int) %>%
+      collect(),
     tbl
   )
 })
@@ -115,12 +103,57 @@ test_that("across() works in distinct()", {
 })
 
 test_that("distinct() can return all columns", {
-  skip("ARROW-14045")
-  compare_dplyr_binding(
-    .input %>%
-      distinct(lgl, .keep_all = TRUE) %>%
-      collect() %>%
-      arrange(int),
-    tbl
-  )
+  # hash_one prefers to keep non-null values, which is different from 
.keep_all in dplyr
+  # so we can't compare the result directly
+  expected <- tbl %>%
+    # Drop factor because of #44661:
+    # NotImplemented: Function 'hash_one' has no kernel matching input types
+    #   (dictionary<values=string, indices=int8, ordered=0>, uint8)
+    select(-fct) %>%
+    distinct(lgl, .keep_all = TRUE) %>%
+    arrange(int)
+
+  with_table <- tbl %>%
+    arrow_table() %>%
+    select(-fct) %>%
+    distinct(lgl, .keep_all = TRUE) %>%
+    arrange(int) %>%
+    collect()
+
+  expect_identical(dim(with_table), dim(expected))
+  expect_identical(names(with_table), names(expected))
+
+  # Test with some mutation in there
+  expected <- tbl %>%
+    select(-fct) %>%
+    distinct(lgl, bigger = int * 10L, .keep_all = TRUE) %>%
+    arrange(int)
+
+  with_table <- tbl %>%
+    arrow_table() %>%
+    select(-fct) %>%
+    distinct(lgl, bigger = int * 10, .keep_all = TRUE) %>%
+    arrange(int) %>%
+    collect()
+
+  expect_identical(dim(with_table), dim(expected))
+  expect_identical(names(with_table), names(expected))
+  expect_identical(with_table$bigger, expected$bigger)
+
+  # Mutation that overwrites
+  expected <- tbl %>%
+    select(-fct) %>%
+    distinct(lgl, int = int * 10L, .keep_all = TRUE) %>%
+    arrange(int)
+
+  with_table <- tbl %>%
+    arrow_table() %>%
+    select(-fct) %>%
+    distinct(lgl, int = int * 10, .keep_all = TRUE) %>%
+    arrange(int) %>%
+    collect()
+
+  expect_identical(dim(with_table), dim(expected))
+  expect_identical(names(with_table), names(expected))
+  expect_identical(with_table$int, expected$int)
 })

Reply via email to