This is an automated email from the ASF dual-hosted git repository.

npr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 0368e410be GH-33892: [R] Map `dplyr::n()` to `count_all` kernel 
(#33917)
0368e410be is described below

commit 0368e410be4dac30eada13d307b415165aedc6a7
Author: Ian Cook <[email protected]>
AuthorDate: Mon Feb 13 10:16:03 2023 -0500

    GH-33892: [R] Map `dplyr::n()` to `count_all` kernel (#33917)
    
    ### Rationale for this change
    
    This PR is a follow-up to #15083. It allows the R package to register 
bindings to nullary aggregation functions, and it remaps `dplyr::n()` to the 
nullary aggregation function `count_all`.
    
    This PR also:
    - Prepares the R bindings to support aggregation functions with 2+ 
arguments, although none yet exist in the C++ library
    - Removes the heuristics that were used to infer the data types of 
aggregates, replacing that with actual type determination
    
    ### Are these changes tested?
    
    Yes, through existing tests.
    
    ### Are there any user-facing changes?
    
    No.
    * Closes: #33892
    * Closes: #33960
    
    Authored-by: Ian Cook <[email protected]>
    Signed-off-by: Neal Richardson <[email protected]>
---
 r/R/dplyr-collect.R                     |  18 +++---
 r/R/dplyr-funcs.R                       |   2 +-
 r/R/dplyr-summarize.R                   | 102 +++++++++++++++++++++-----------
 r/R/query-engine.R                      |  12 ++--
 r/man/register_binding.Rd               |   2 +-
 r/src/compute-exec.cpp                  |   8 ++-
 r/tests/testthat/test-dplyr-collapse.R  |   4 +-
 r/tests/testthat/test-dplyr-summarize.R |  10 +++-
 8 files changed, 103 insertions(+), 55 deletions(-)

diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R
index 395026ce78..f45a9886ea 100644
--- a/r/R/dplyr-collect.R
+++ b/r/R/dplyr-collect.R
@@ -179,19 +179,15 @@ implicit_schema <- function(.data) {
       new_fields <- c(left_fields, right_fields)
     }
   } else {
-    # The output schema is based on the aggregations and any group_by vars
-    new_fields <- map(summarize_projection(.data), ~ .$type(old_schm))
-    # * Put group_by_vars first (this can't be done by summarize,
-    #   they have to be last per the aggregate node signature,
-    #   and they get projected to this order after aggregation)
-    # * Infer the output types from the aggregations
-    group_fields <- new_fields[.data$group_by_vars]
     hash <- length(.data$group_by_vars) > 0
-    agg_fields <- imap(
-      new_fields[setdiff(names(new_fields), .data$group_by_vars)],
-      ~ agg_fun_output_type(.data$aggregations[[.y]][["fun"]], .x, hash)
+    # The output schema is based on the aggregations and any group_by vars.
+    # The group_by vars come first (this can't be done by summarize; they have
+    # to be last per the aggregate node signature, and they get projected to
+    # this order after aggregation)
+    new_fields <- c(
+      group_types(.data, old_schm),
+      aggregate_types(.data, hash, old_schm)
     )
-    new_fields <- c(group_fields, agg_fields)
   }
   schema(!!!new_fields)
 }
diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R
index ce88e25bcb..2728a64539 100644
--- a/r/R/dplyr-funcs.R
+++ b/r/R/dplyr-funcs.R
@@ -49,7 +49,7 @@ NULL
 #'   aggregate function. This function must accept `Expression` objects as
 #'   arguments and return a `list()` with components:
 #'   - `fun`: string function name
-#'   - `data`: `Expression` (these are all currently a single field)
+#'   - `data`: list of 0 or more `Expression`s
 #'   - `options`: list of function options, as passed to call_function
 #' @param update_cache Update .cache$functions at the time of registration.
 #'   the default is FALSE because the majority of usage is to register
diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R
index 5e670538f6..184c0aade4 100644
--- a/r/R/dplyr-summarize.R
+++ b/r/R/dplyr-summarize.R
@@ -18,7 +18,7 @@
 # Aggregation functions
 # These all return a list of:
 # @param fun string function name
-# @param data Expression (these are all currently a single field)
+# @param data list of 0 or more Expressions
 # @param options list of function options, as passed to call_function
 # For group-by aggregation, `hash_` gets prepended to the function name.
 # So to see a list of available hash aggregation functions,
@@ -31,28 +31,7 @@ ensure_one_arg <- function(args, fun) {
   } else if (length(args) > 1) {
     arrow_not_supported(paste0("Multiple arguments to ", fun, "()"))
   }
-  args[[1]]
-}
-
-agg_fun_output_type <- function(fun, input_type, hash) {
-  # These are quick and dirty heuristics.
-  if (fun %in% c("any", "all")) {
-    bool()
-  } else if (fun %in% "sum") {
-    # It may upcast to a bigger type but this is close enough
-    input_type
-  } else if (fun %in% c("mean", "stddev", "variance", "approximate_median")) {
-    float64()
-  } else if (fun %in% "tdigest") {
-    if (hash) {
-      fixed_size_list_of(float64(), 1L)
-    } else {
-      float64()
-    }
-  } else {
-    # Just so things don't error, assume the resulting type is the same
-    input_type
-  }
+  args
 }
 
 register_bindings_aggregate <- function() {
@@ -80,21 +59,21 @@ register_bindings_aggregate <- function() {
   register_binding_agg("base::mean", function(x, na.rm = FALSE) {
     list(
       fun = "mean",
-      data = x,
+      data = list(x),
       options = list(skip_nulls = na.rm, min_count = 0L)
     )
   })
   register_binding_agg("stats::sd", function(x, na.rm = FALSE, ddof = 1) {
     list(
       fun = "stddev",
-      data = x,
+      data = list(x),
       options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof)
     )
   })
   register_binding_agg("stats::var", function(x, na.rm = FALSE, ddof = 1) {
     list(
       fun = "variance",
-      data = x,
+      data = list(x),
       options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof)
     )
   })
@@ -114,7 +93,7 @@ register_bindings_aggregate <- function() {
       )
       list(
         fun = "tdigest",
-        data = x,
+        data = list(x),
         options = list(skip_nulls = na.rm, q = probs)
       )
     },
@@ -136,7 +115,7 @@ register_bindings_aggregate <- function() {
       )
       list(
         fun = "approximate_median",
-        data = x,
+        data = list(x),
         options = list(skip_nulls = na.rm)
       )
     },
@@ -151,8 +130,8 @@ register_bindings_aggregate <- function() {
   })
   register_binding_agg("dplyr::n", function() {
     list(
-      fun = "sum",
-      data = Expression$scalar(1L),
+      fun = "count_all",
+      data = list(),
       options = list()
     )
   })
@@ -322,15 +301,72 @@ arrow_eval_or_stop <- function(expr, mask) {
   out
 }
 
+# This function returns a list of expressions which is used to project the data
+# before an aggregation. This list includes the fields used in the aggregation
+# expressions (the "targets") and the group fields. The names of the returned
+# list are used to ensure that the projection node is wired up correctly to the
+# aggregation node.
 summarize_projection <- function(.data) {
   c(
-    map(.data$aggregations, ~ .$data),
+    unlist(unname(imap(
+      .data$aggregations,
+      ~set_names(
+        .x$data,
+        aggregate_target_names(.x$data, .y)
+      )
+    ))),
     .data$selected_columns[.data$group_by_vars]
   )
 }
 
+# This function determines what names to give to the fields used in an
+# aggregation expression (the "targets"). When an aggregate function takes 2 or
+# more fields as targets, this function gives the fields unique names by
+# appending `..1`, `..2`, etc. When an aggregate function is nullary, this
+# function returns a zero-length character vector.
+aggregate_target_names <- function(data, name) {
+  if (length(data) > 1) {
+    paste(name, seq_along(data), sep = "..")
+  } else if (length(data) > 0) {
+    name
+  } else {
+    character(0)
+  }
+}
+
+# This function returns a named list of the data types of the aggregate columns
+# returned by an aggregation
+aggregate_types <- function(.data, hash, schema = NULL) {
+  if (hash) dummy_groups <- Scalar$create(1L, uint32())
+  map(
+    .data$aggregations,
+    ~if (hash) {
+      Expression$create(
+        paste0("hash_", .$fun),
+        # hash aggregate kernels must be passed an additional argument
+        # representing the groups, so we pass in a dummy scalar, since the
+        # groups will not affect the type that an aggregation returns
+        args = c(.$data, dummy_groups),
+        options = .$options
+      )$type(schema)
+    } else {
+      Expression$create(
+        .$fun,
+        args = .$data,
+        options = .$options
+      )$type(schema)
+    }
+  )
+}
+
+# This function returns a named list of the data types of the group columns
+# returned by an aggregation
+group_types <- function(.data, schema = NULL) {
+  map(.data$selected_columns[.data$group_by_vars], ~.$type(schema))
+}
+
 format_aggregation <- function(x) {
-  paste0(x$fun, "(", x$data$ToString(), ")")
+  paste0(x$fun, "(", paste(map(x$data, ~.$ToString()), collapse = ","), ")")
 }
 
 # This function handles each summarize expression and turns it into the
@@ -414,7 +450,7 @@ summarize_eval <- function(name, quosure, ctx, hash) {
     # Something like: fun(agg(x), agg(y))
     # So based on the aggregations that have been extracted, mutate after
     agg_field_refs <- make_field_refs(names(ctx$aggregations))
-    agg_field_types <- lapply(ctx$aggregations, function(x) x$data$type())
+    agg_field_types <- aggregate_types(ctx$aggregations, hash)
 
     mutate_mask <- arrow_mask(
       list(
diff --git a/r/R/query-engine.R b/r/R/query-engine.R
index 7a336b7a07..ea5a3f1c57 100644
--- a/r/R/query-engine.R
+++ b/r/R/query-engine.R
@@ -101,7 +101,11 @@ ExecPlan <- R6Class("ExecPlan",
         # plus group_by_vars (last)
         # TODO: validate that none of names(aggregations) are the same as 
names(group_by_vars)
         # dplyr does not error on this but the result it gives isn't great
-        node <- node$Project(summarize_projection(.data))
+        projection <- summarize_projection(.data)
+        # skip projection if no grouping and all aggregate functions are 
nullary
+        if (length(projection)) {
+          node <- node$Project(projection)
+        }
 
         if (grouped) {
           # We need to prefix all of the aggregation function names with 
"hash_"
@@ -112,9 +116,9 @@ ExecPlan <- R6Class("ExecPlan",
         }
 
         .data$aggregations <- imap(.data$aggregations, function(x, name) {
-          # Embed the name inside the aggregation objects. `target` and `name`
-          # are the same because we just Project()ed the data that way above
-          x[["name"]] <- x[["target"]] <- name
+          # Embed `name` and `targets` inside the aggregation objects
+          x[["name"]] <- name
+          x[["targets"]] <- aggregate_target_names(x$data, name)
           x
         })
 
diff --git a/r/man/register_binding.Rd b/r/man/register_binding.Rd
index c526ee138c..fd857f5c67 100644
--- a/r/man/register_binding.Rd
+++ b/r/man/register_binding.Rd
@@ -40,7 +40,7 @@ aggregate function. This function must accept 
\code{Expression} objects as
 arguments and return a \code{list()} with components:
 \itemize{
 \item \code{fun}: string function name
-\item \code{data}: \code{Expression} (these are all currently a single field)
+\item \code{data}: list of 0 or more \code{Expression}s
 \item \code{options}: list of function options, as passed to call_function
 }}
 }
diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp
index e616e3d824..5dafc21273 100644
--- a/r/src/compute-exec.cpp
+++ b/r/src/compute-exec.cpp
@@ -393,11 +393,15 @@ std::shared_ptr<compute::ExecNode> ExecNode_Aggregate(
   for (cpp11::list name_opts : options) {
     auto function = cpp11::as_cpp<std::string>(name_opts["fun"]);
     auto opts = make_compute_options(function, name_opts["options"]);
-    auto target = cpp11::as_cpp<std::string>(name_opts["target"]);
+    auto target_names = 
cpp11::as_cpp<std::vector<std::string>>(name_opts["targets"]);
     auto name = cpp11::as_cpp<std::string>(name_opts["name"]);
 
+    std::vector<arrow::FieldRef> targets;
+    for (auto&& target : target_names) {
+      targets.emplace_back(std::move(target));
+    }
     aggregates.push_back(arrow::compute::Aggregate{std::move(function), opts,
-                                                   std::move(target), 
std::move(name)});
+                                                   std::move(targets), 
std::move(name)});
   }
 
   std::vector<arrow::FieldRef> keys;
diff --git a/r/tests/testthat/test-dplyr-collapse.R 
b/r/tests/testthat/test-dplyr-collapse.R
index 6c5f4c1991..cca8412178 100644
--- a/r/tests/testthat/test-dplyr-collapse.R
+++ b/r/tests/testthat/test-dplyr-collapse.R
@@ -162,8 +162,8 @@ test_that("Properties of collapsed query", {
     print(q),
     "Table (query)
 lgl: bool
-total: int32
-extra: int32 (multiply_checked(total, 5))
+total: int64
+extra: int64 (multiply_checked(total, 5))
 
 See $.data for the source Arrow object",
     fixed = TRUE
diff --git a/r/tests/testthat/test-dplyr-summarize.R 
b/r/tests/testthat/test-dplyr-summarize.R
index e54e57c836..6ee8982cc2 100644
--- a/r/tests/testthat/test-dplyr-summarize.R
+++ b/r/tests/testthat/test-dplyr-summarize.R
@@ -1119,7 +1119,7 @@ test_that("We don't add unnecessary ProjectNodes when 
aggregating", {
     1
   )
 
-  # 0 Projections only if
+  # 0 Projections if
   # (a) input only contains the col you're aggregating, and
   # (b) the output col name is the same as the input name, and
   # (c) no grouping
@@ -1128,6 +1128,14 @@ test_that("We don't add unnecessary ProjectNodes when 
aggregating", {
     0
   )
 
+  # 0 Projections if
+  # (a) only nullary functions in summarize()
+  # (b) no grouping
+  expect_project_nodes(
+    tab[, "int"] %>% summarize(n()),
+    0
+  )
+
   # 2 projections: one before, and one after in order to put grouping cols 
first
   expect_project_nodes(
     tab %>% group_by(lgl) %>% summarize(mean(int)),

Reply via email to