This is an automated email from the ASF dual-hosted git repository.
npr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 0368e410be GH-33892: [R] Map `dplyr::n()` to `count_all` kernel
(#33917)
0368e410be is described below
commit 0368e410be4dac30eada13d307b415165aedc6a7
Author: Ian Cook <[email protected]>
AuthorDate: Mon Feb 13 10:16:03 2023 -0500
GH-33892: [R] Map `dplyr::n()` to `count_all` kernel (#33917)
### Rationale for this change
This PR is a follow-up to #15083. It allows the R package to register
bindings to nullary aggregation functions, and it remaps `dplyr::n()` to the
nullary aggregation function `count_all`.
This PR also:
- Prepares the R bindings to support aggregation functions with 2+
arguments, although none yet exist in the C++ library
- Removes the heuristics that were used to infer the data types of
aggregates, replacing that with actual type determination
### Are these changes tested?
Yes, through existing tests.
### Are there any user-facing changes?
No.
* Closes: #33892
* Closes: #33960
Authored-by: Ian Cook <[email protected]>
Signed-off-by: Neal Richardson <[email protected]>
---
r/R/dplyr-collect.R | 18 +++---
r/R/dplyr-funcs.R | 2 +-
r/R/dplyr-summarize.R | 102 +++++++++++++++++++++-----------
r/R/query-engine.R | 12 ++--
r/man/register_binding.Rd | 2 +-
r/src/compute-exec.cpp | 8 ++-
r/tests/testthat/test-dplyr-collapse.R | 4 +-
r/tests/testthat/test-dplyr-summarize.R | 10 +++-
8 files changed, 103 insertions(+), 55 deletions(-)
diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R
index 395026ce78..f45a9886ea 100644
--- a/r/R/dplyr-collect.R
+++ b/r/R/dplyr-collect.R
@@ -179,19 +179,15 @@ implicit_schema <- function(.data) {
new_fields <- c(left_fields, right_fields)
}
} else {
- # The output schema is based on the aggregations and any group_by vars
- new_fields <- map(summarize_projection(.data), ~ .$type(old_schm))
- # * Put group_by_vars first (this can't be done by summarize,
- # they have to be last per the aggregate node signature,
- # and they get projected to this order after aggregation)
- # * Infer the output types from the aggregations
- group_fields <- new_fields[.data$group_by_vars]
hash <- length(.data$group_by_vars) > 0
- agg_fields <- imap(
- new_fields[setdiff(names(new_fields), .data$group_by_vars)],
- ~ agg_fun_output_type(.data$aggregations[[.y]][["fun"]], .x, hash)
+ # The output schema is based on the aggregations and any group_by vars.
+ # The group_by vars come first (this can't be done by summarize; they have
+ # to be last per the aggregate node signature, and they get projected to
+ # this order after aggregation)
+ new_fields <- c(
+ group_types(.data, old_schm),
+ aggregate_types(.data, hash, old_schm)
)
- new_fields <- c(group_fields, agg_fields)
}
schema(!!!new_fields)
}
diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R
index ce88e25bcb..2728a64539 100644
--- a/r/R/dplyr-funcs.R
+++ b/r/R/dplyr-funcs.R
@@ -49,7 +49,7 @@ NULL
#' aggregate function. This function must accept `Expression` objects as
#' arguments and return a `list()` with components:
#' - `fun`: string function name
-#' - `data`: `Expression` (these are all currently a single field)
+#' - `data`: list of 0 or more `Expression`s
#' - `options`: list of function options, as passed to call_function
#' @param update_cache Update .cache$functions at the time of registration.
#' the default is FALSE because the majority of usage is to register
diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R
index 5e670538f6..184c0aade4 100644
--- a/r/R/dplyr-summarize.R
+++ b/r/R/dplyr-summarize.R
@@ -18,7 +18,7 @@
# Aggregation functions
# These all return a list of:
# @param fun string function name
-# @param data Expression (these are all currently a single field)
+# @param data list of 0 or more Expressions
# @param options list of function options, as passed to call_function
# For group-by aggregation, `hash_` gets prepended to the function name.
# So to see a list of available hash aggregation functions,
@@ -31,28 +31,7 @@ ensure_one_arg <- function(args, fun) {
} else if (length(args) > 1) {
arrow_not_supported(paste0("Multiple arguments to ", fun, "()"))
}
- args[[1]]
-}
-
-agg_fun_output_type <- function(fun, input_type, hash) {
- # These are quick and dirty heuristics.
- if (fun %in% c("any", "all")) {
- bool()
- } else if (fun %in% "sum") {
- # It may upcast to a bigger type but this is close enough
- input_type
- } else if (fun %in% c("mean", "stddev", "variance", "approximate_median")) {
- float64()
- } else if (fun %in% "tdigest") {
- if (hash) {
- fixed_size_list_of(float64(), 1L)
- } else {
- float64()
- }
- } else {
- # Just so things don't error, assume the resulting type is the same
- input_type
- }
+ args
}
register_bindings_aggregate <- function() {
@@ -80,21 +59,21 @@ register_bindings_aggregate <- function() {
register_binding_agg("base::mean", function(x, na.rm = FALSE) {
list(
fun = "mean",
- data = x,
+ data = list(x),
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
register_binding_agg("stats::sd", function(x, na.rm = FALSE, ddof = 1) {
list(
fun = "stddev",
- data = x,
+ data = list(x),
options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof)
)
})
register_binding_agg("stats::var", function(x, na.rm = FALSE, ddof = 1) {
list(
fun = "variance",
- data = x,
+ data = list(x),
options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof)
)
})
@@ -114,7 +93,7 @@ register_bindings_aggregate <- function() {
)
list(
fun = "tdigest",
- data = x,
+ data = list(x),
options = list(skip_nulls = na.rm, q = probs)
)
},
@@ -136,7 +115,7 @@ register_bindings_aggregate <- function() {
)
list(
fun = "approximate_median",
- data = x,
+ data = list(x),
options = list(skip_nulls = na.rm)
)
},
@@ -151,8 +130,8 @@ register_bindings_aggregate <- function() {
})
register_binding_agg("dplyr::n", function() {
list(
- fun = "sum",
- data = Expression$scalar(1L),
+ fun = "count_all",
+ data = list(),
options = list()
)
})
@@ -322,15 +301,72 @@ arrow_eval_or_stop <- function(expr, mask) {
out
}
+# This function returns a list of expressions which is used to project the data
+# before an aggregation. This list includes the fields used in the aggregation
+# expressions (the "targets") and the group fields. The names of the returned
+# list are used to ensure that the projection node is wired up correctly to the
+# aggregation node.
summarize_projection <- function(.data) {
c(
- map(.data$aggregations, ~ .$data),
+ unlist(unname(imap(
+ .data$aggregations,
+ ~set_names(
+ .x$data,
+ aggregate_target_names(.x$data, .y)
+ )
+ ))),
.data$selected_columns[.data$group_by_vars]
)
}
+# This function determines what names to give to the fields used in an
+# aggregation expression (the "targets"). When an aggregate function takes 2 or
+# more fields as targets, this function gives the fields unique names by
+# appending `..1`, `..2`, etc. When an aggregate function is nullary, this
+# function returns a zero-length character vector.
+aggregate_target_names <- function(data, name) {
+ if (length(data) > 1) {
+ paste(name, seq_along(data), sep = "..")
+ } else if (length(data) > 0) {
+ name
+ } else {
+ character(0)
+ }
+}
+
+# This function returns a named list of the data types of the aggregate columns
+# returned by an aggregation
+aggregate_types <- function(.data, hash, schema = NULL) {
+ if (hash) dummy_groups <- Scalar$create(1L, uint32())
+ map(
+ .data$aggregations,
+ ~if (hash) {
+ Expression$create(
+ paste0("hash_", .$fun),
+ # hash aggregate kernels must be passed an additional argument
+ # representing the groups, so we pass in a dummy scalar, since the
+ # groups will not affect the type that an aggregation returns
+ args = c(.$data, dummy_groups),
+ options = .$options
+ )$type(schema)
+ } else {
+ Expression$create(
+ .$fun,
+ args = .$data,
+ options = .$options
+ )$type(schema)
+ }
+ )
+}
+
+# This function returns a named list of the data types of the group columns
+# returned by an aggregation
+group_types <- function(.data, schema = NULL) {
+ map(.data$selected_columns[.data$group_by_vars], ~.$type(schema))
+}
+
format_aggregation <- function(x) {
- paste0(x$fun, "(", x$data$ToString(), ")")
+ paste0(x$fun, "(", paste(map(x$data, ~.$ToString()), collapse = ","), ")")
}
# This function handles each summarize expression and turns it into the
@@ -414,7 +450,7 @@ summarize_eval <- function(name, quosure, ctx, hash) {
# Something like: fun(agg(x), agg(y))
# So based on the aggregations that have been extracted, mutate after
agg_field_refs <- make_field_refs(names(ctx$aggregations))
- agg_field_types <- lapply(ctx$aggregations, function(x) x$data$type())
+ agg_field_types <- aggregate_types(ctx$aggregations, hash)
mutate_mask <- arrow_mask(
list(
diff --git a/r/R/query-engine.R b/r/R/query-engine.R
index 7a336b7a07..ea5a3f1c57 100644
--- a/r/R/query-engine.R
+++ b/r/R/query-engine.R
@@ -101,7 +101,11 @@ ExecPlan <- R6Class("ExecPlan",
# plus group_by_vars (last)
# TODO: validate that none of names(aggregations) are the same as
names(group_by_vars)
# dplyr does not error on this but the result it gives isn't great
- node <- node$Project(summarize_projection(.data))
+ projection <- summarize_projection(.data)
+ # skip projection if no grouping and all aggregate functions are
nullary
+ if (length(projection)) {
+ node <- node$Project(projection)
+ }
if (grouped) {
# We need to prefix all of the aggregation function names with
"hash_"
@@ -112,9 +116,9 @@ ExecPlan <- R6Class("ExecPlan",
}
.data$aggregations <- imap(.data$aggregations, function(x, name) {
- # Embed the name inside the aggregation objects. `target` and `name`
- # are the same because we just Project()ed the data that way above
- x[["name"]] <- x[["target"]] <- name
+ # Embed `name` and `targets` inside the aggregation objects
+ x[["name"]] <- name
+ x[["targets"]] <- aggregate_target_names(x$data, name)
x
})
diff --git a/r/man/register_binding.Rd b/r/man/register_binding.Rd
index c526ee138c..fd857f5c67 100644
--- a/r/man/register_binding.Rd
+++ b/r/man/register_binding.Rd
@@ -40,7 +40,7 @@ aggregate function. This function must accept
\code{Expression} objects as
arguments and return a \code{list()} with components:
\itemize{
\item \code{fun}: string function name
-\item \code{data}: \code{Expression} (these are all currently a single field)
+\item \code{data}: list of 0 or more \code{Expression}s
\item \code{options}: list of function options, as passed to call_function
}}
}
diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp
index e616e3d824..5dafc21273 100644
--- a/r/src/compute-exec.cpp
+++ b/r/src/compute-exec.cpp
@@ -393,11 +393,15 @@ std::shared_ptr<compute::ExecNode> ExecNode_Aggregate(
for (cpp11::list name_opts : options) {
auto function = cpp11::as_cpp<std::string>(name_opts["fun"]);
auto opts = make_compute_options(function, name_opts["options"]);
- auto target = cpp11::as_cpp<std::string>(name_opts["target"]);
+ auto target_names =
cpp11::as_cpp<std::vector<std::string>>(name_opts["targets"]);
auto name = cpp11::as_cpp<std::string>(name_opts["name"]);
+ std::vector<arrow::FieldRef> targets;
+ for (auto&& target : target_names) {
+ targets.emplace_back(std::move(target));
+ }
aggregates.push_back(arrow::compute::Aggregate{std::move(function), opts,
- std::move(target),
std::move(name)});
+ std::move(targets),
std::move(name)});
}
std::vector<arrow::FieldRef> keys;
diff --git a/r/tests/testthat/test-dplyr-collapse.R
b/r/tests/testthat/test-dplyr-collapse.R
index 6c5f4c1991..cca8412178 100644
--- a/r/tests/testthat/test-dplyr-collapse.R
+++ b/r/tests/testthat/test-dplyr-collapse.R
@@ -162,8 +162,8 @@ test_that("Properties of collapsed query", {
print(q),
"Table (query)
lgl: bool
-total: int32
-extra: int32 (multiply_checked(total, 5))
+total: int64
+extra: int64 (multiply_checked(total, 5))
See $.data for the source Arrow object",
fixed = TRUE
diff --git a/r/tests/testthat/test-dplyr-summarize.R
b/r/tests/testthat/test-dplyr-summarize.R
index e54e57c836..6ee8982cc2 100644
--- a/r/tests/testthat/test-dplyr-summarize.R
+++ b/r/tests/testthat/test-dplyr-summarize.R
@@ -1119,7 +1119,7 @@ test_that("We don't add unnecessary ProjectNodes when
aggregating", {
1
)
- # 0 Projections only if
+ # 0 Projections if
# (a) input only contains the col you're aggregating, and
# (b) the output col name is the same as the input name, and
# (c) no grouping
@@ -1128,6 +1128,14 @@ test_that("We don't add unnecessary ProjectNodes when
aggregating", {
0
)
+ # 0 Projections if
+ # (a) only nullary functions in summarize()
+ # (b) no grouping
+ expect_project_nodes(
+ tab[, "int"] %>% summarize(n()),
+ 0
+ )
+
# 2 projections: one before, and one after in order to put grouping cols
first
expect_project_nodes(
tab %>% group_by(lgl) %>% summarize(mean(int)),