This is an automated email from the ASF dual-hosted git repository.
npr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new f1bc82f2b3 MINOR: [R] refactor: move aggregation function bindings to
their own file (#41355)
f1bc82f2b3 is described below
commit f1bc82f2b39a317970427052c360383f983ec3f8
Author: Neal Richardson <[email protected]>
AuthorDate: Tue Apr 23 13:31:26 2024 -0400
MINOR: [R] refactor: move aggregation function bindings to their own file
(#41355)
For consistency with other bindings, and to allow `dplyr-summarize.R` to
start with the summarize method, as do the other dplyr verb files.
---
r/DESCRIPTION | 1 +
r/R/dplyr-funcs-agg.R | 198 ++++++++++++++++++++++++++++++++++++++++++++++++++
r/R/dplyr-funcs.R | 16 +++-
r/R/dplyr-summarize.R | 195 -------------------------------------------------
4 files changed, 213 insertions(+), 197 deletions(-)
diff --git a/r/DESCRIPTION b/r/DESCRIPTION
index 2efaed4d6c..eeff8168b3 100644
--- a/r/DESCRIPTION
+++ b/r/DESCRIPTION
@@ -107,6 +107,7 @@ Collate:
'dplyr-distinct.R'
'dplyr-eval.R'
'dplyr-filter.R'
+ 'dplyr-funcs-agg.R'
'dplyr-funcs-augmented.R'
'dplyr-funcs-conditional.R'
'dplyr-funcs-datetime.R'
diff --git a/r/R/dplyr-funcs-agg.R b/r/R/dplyr-funcs-agg.R
new file mode 100644
index 0000000000..ab1df1d2f1
--- /dev/null
+++ b/r/R/dplyr-funcs-agg.R
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Aggregation functions
+#
+# These all insert into an ..aggregations list (in a parent frame) a list
containing:
+# @param fun string function name
+# @param data list of 0 or more Expressions
+# @param options list of function options, as passed to call_function
+# The functions return a FieldRef pointing to the result of the aggregation.
+#
+# For group-by aggregation, `hash_` gets prepended to the function name when
+# the query is executed.
+# So to see a list of available hash aggregation functions,
+# you can use list_compute_functions("^hash_")
+
+register_bindings_aggregate <- function() {
+ register_binding_agg("base::sum", function(..., na.rm = FALSE) {
+ set_agg(
+ fun = "sum",
+ data = ensure_one_arg(list2(...), "sum"),
+ options = list(skip_nulls = na.rm, min_count = 0L)
+ )
+ })
+ register_binding_agg("base::prod", function(..., na.rm = FALSE) {
+ set_agg(
+ fun = "product",
+ data = ensure_one_arg(list2(...), "prod"),
+ options = list(skip_nulls = na.rm, min_count = 0L)
+ )
+ })
+ register_binding_agg("base::any", function(..., na.rm = FALSE) {
+ set_agg(
+ fun = "any",
+ data = ensure_one_arg(list2(...), "any"),
+ options = list(skip_nulls = na.rm, min_count = 0L)
+ )
+ })
+ register_binding_agg("base::all", function(..., na.rm = FALSE) {
+ set_agg(
+ fun = "all",
+ data = ensure_one_arg(list2(...), "all"),
+ options = list(skip_nulls = na.rm, min_count = 0L)
+ )
+ })
+ register_binding_agg("base::mean", function(x, na.rm = FALSE) {
+ set_agg(
+ fun = "mean",
+ data = list(x),
+ options = list(skip_nulls = na.rm, min_count = 0L)
+ )
+ })
+ register_binding_agg("stats::sd", function(x, na.rm = FALSE, ddof = 1) {
+ set_agg(
+ fun = "stddev",
+ data = list(x),
+ options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof)
+ )
+ })
+ register_binding_agg("stats::var", function(x, na.rm = FALSE, ddof = 1) {
+ set_agg(
+ fun = "variance",
+ data = list(x),
+ options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof)
+ )
+ })
+ register_binding_agg(
+ "stats::quantile",
+ function(x, probs, na.rm = FALSE) {
+ if (length(probs) != 1) {
+ arrow_not_supported("quantile() with length(probs) != 1")
+ }
+ # TODO: Bind to the Arrow function that returns an exact quantile and
remove
+ # this warning (ARROW-14021)
+ warn(
+ "quantile() currently returns an approximate quantile in Arrow",
+ .frequency = "once",
+ .frequency_id = "arrow.quantile.approximate",
+ class = "arrow.quantile.approximate"
+ )
+ set_agg(
+ fun = "tdigest",
+ data = list(x),
+ options = list(skip_nulls = na.rm, q = probs)
+ )
+ },
+ notes = c(
+ "`probs` must be length 1;",
+ "approximate quantile (t-digest) is computed"
+ )
+ )
+ register_binding_agg(
+ "stats::median",
+ function(x, na.rm = FALSE) {
+ # TODO: Bind to the Arrow function that returns an exact median and
remove
+ # this warning (ARROW-14021)
+ warn(
+ "median() currently returns an approximate median in Arrow",
+ .frequency = "once",
+ .frequency_id = "arrow.median.approximate",
+ class = "arrow.median.approximate"
+ )
+ set_agg(
+ fun = "approximate_median",
+ data = list(x),
+ options = list(skip_nulls = na.rm)
+ )
+ },
+ notes = "approximate median (t-digest) is computed"
+ )
+ register_binding_agg("dplyr::n_distinct", function(..., na.rm = FALSE) {
+ set_agg(
+ fun = "count_distinct",
+ data = ensure_one_arg(list2(...), "n_distinct"),
+ options = list(na.rm = na.rm)
+ )
+ })
+ register_binding_agg("dplyr::n", function() {
+ set_agg(
+ fun = "count_all",
+ data = list(),
+ options = list()
+ )
+ })
+ register_binding_agg("base::min", function(..., na.rm = FALSE) {
+ set_agg(
+ fun = "min",
+ data = ensure_one_arg(list2(...), "min"),
+ options = list(skip_nulls = na.rm, min_count = 0L)
+ )
+ })
+ register_binding_agg("base::max", function(..., na.rm = FALSE) {
+ set_agg(
+ fun = "max",
+ data = ensure_one_arg(list2(...), "max"),
+ options = list(skip_nulls = na.rm, min_count = 0L)
+ )
+ })
+}
+
+set_agg <- function(...) {
+ agg_data <- list2(...)
+ # Find the environment where ..aggregations is stored
+ target <- find_aggregations_env()
+ aggs <- get("..aggregations", target)
+ lapply(agg_data[["data"]], function(expr) {
+ # If any of the fields referenced in the expression are in ..aggregations,
+ # then we can't aggregate over them.
+ # This is mainly for combinations of dataset columns and aggregations,
+ # like sum(x - mean(x)), i.e. window functions.
+ # This will reject (sum(sum(x)) as well, but that's not a useful operation.
+ if (any(expr$field_names_in_expression() %in% names(aggs))) {
+ # TODO: support in ARROW-13926
+ arrow_not_supported("aggregate within aggregate expression")
+ }
+ })
+
+ # Record the (fun, data, options) in ..aggregations
+ # and return a FieldRef pointing to it
+ tmpname <- paste0("..temp", length(aggs))
+ aggs[[tmpname]] <- agg_data
+ assign("..aggregations", aggs, envir = target)
+ Expression$field_ref(tmpname)
+}
+
+find_aggregations_env <- function() {
+ # Find the environment where ..aggregations is stored,
+ # it's in parent.env of something in the call stack
+ for (f in sys.frames()) {
+ if (exists("..aggregations", envir = f)) {
+ return(f)
+ }
+ }
+ stop("Could not find ..aggregations")
+}
+
+ensure_one_arg <- function(args, fun) {
+ if (length(args) == 0) {
+ arrow_not_supported(paste0(fun, "() with 0 arguments"))
+ } else if (length(args) > 1) {
+ arrow_not_supported(paste0("Multiple arguments to ", fun, "()"))
+ }
+ args
+}
diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R
index 956e31fe2b..abf2362d01 100644
--- a/r/R/dplyr-funcs.R
+++ b/r/R/dplyr-funcs.R
@@ -175,8 +175,7 @@ agg_funcs <- new.env(parent = emptyenv())
.cache <- new.env(parent = emptyenv())
# we register 2 versions of the "::" binding - one for use with nse_funcs
-# (registered below) and another one for use with agg_funcs (registered in
-# dplyr-summarize.R)
+# and another one for use with agg_funcs (registered in dplyr-funcs-agg.R)
nse_funcs[["::"]] <- function(lhs, rhs) {
lhs_name <- as.character(substitute(lhs))
rhs_name <- as.character(substitute(rhs))
@@ -187,3 +186,16 @@ nse_funcs[["::"]] <- function(lhs, rhs) {
# regular pkg::fun function
nse_funcs[[fun_name]] %||% asNamespace(lhs_name)[[rhs_name]]
}
+
+agg_funcs[["::"]] <- function(lhs, rhs) {
+ lhs_name <- as.character(substitute(lhs))
+ rhs_name <- as.character(substitute(rhs))
+
+ fun_name <- paste0(lhs_name, "::", rhs_name)
+
+ # if we do not have a binding for pkg::fun, then fall back on to the
+ # nse_funcs (useful when we have a regular function inside an aggregating
one)
+ # and then, if searching nse_funcs fails too, fall back to the
+ # regular `pkg::fun()` function
+ agg_funcs[[fun_name]] %||% nse_funcs[[fun_name]] %||%
asNamespace(lhs_name)[[rhs_name]]
+}
diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R
index 1b625e34ad..5bb81dc2b3 100644
--- a/r/R/dplyr-summarize.R
+++ b/r/R/dplyr-summarize.R
@@ -15,201 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-# Aggregation functions
-# These all return a list of:
-# @param fun string function name
-# @param data list of 0 or more Expressions
-# @param options list of function options, as passed to call_function
-# For group-by aggregation, `hash_` gets prepended to the function name.
-# So to see a list of available hash aggregation functions,
-# you can use list_compute_functions("^hash_")
-
-
-ensure_one_arg <- function(args, fun) {
- if (length(args) == 0) {
- arrow_not_supported(paste0(fun, "() with 0 arguments"))
- } else if (length(args) > 1) {
- arrow_not_supported(paste0("Multiple arguments to ", fun, "()"))
- }
- args
-}
-
-register_bindings_aggregate <- function() {
- register_binding_agg("base::sum", function(..., na.rm = FALSE) {
- set_agg(
- fun = "sum",
- data = ensure_one_arg(list2(...), "sum"),
- options = list(skip_nulls = na.rm, min_count = 0L)
- )
- })
- register_binding_agg("base::prod", function(..., na.rm = FALSE) {
- set_agg(
- fun = "product",
- data = ensure_one_arg(list2(...), "prod"),
- options = list(skip_nulls = na.rm, min_count = 0L)
- )
- })
- register_binding_agg("base::any", function(..., na.rm = FALSE) {
- set_agg(
- fun = "any",
- data = ensure_one_arg(list2(...), "any"),
- options = list(skip_nulls = na.rm, min_count = 0L)
- )
- })
- register_binding_agg("base::all", function(..., na.rm = FALSE) {
- set_agg(
- fun = "all",
- data = ensure_one_arg(list2(...), "all"),
- options = list(skip_nulls = na.rm, min_count = 0L)
- )
- })
- register_binding_agg("base::mean", function(x, na.rm = FALSE) {
- set_agg(
- fun = "mean",
- data = list(x),
- options = list(skip_nulls = na.rm, min_count = 0L)
- )
- })
- register_binding_agg("stats::sd", function(x, na.rm = FALSE, ddof = 1) {
- set_agg(
- fun = "stddev",
- data = list(x),
- options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof)
- )
- })
- register_binding_agg("stats::var", function(x, na.rm = FALSE, ddof = 1) {
- set_agg(
- fun = "variance",
- data = list(x),
- options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof)
- )
- })
- register_binding_agg(
- "stats::quantile",
- function(x, probs, na.rm = FALSE) {
- if (length(probs) != 1) {
- arrow_not_supported("quantile() with length(probs) != 1")
- }
- # TODO: Bind to the Arrow function that returns an exact quantile and
remove
- # this warning (ARROW-14021)
- warn(
- "quantile() currently returns an approximate quantile in Arrow",
- .frequency = "once",
- .frequency_id = "arrow.quantile.approximate",
- class = "arrow.quantile.approximate"
- )
- set_agg(
- fun = "tdigest",
- data = list(x),
- options = list(skip_nulls = na.rm, q = probs)
- )
- },
- notes = c(
- "`probs` must be length 1;",
- "approximate quantile (t-digest) is computed"
- )
- )
- register_binding_agg(
- "stats::median",
- function(x, na.rm = FALSE) {
- # TODO: Bind to the Arrow function that returns an exact median and
remove
- # this warning (ARROW-14021)
- warn(
- "median() currently returns an approximate median in Arrow",
- .frequency = "once",
- .frequency_id = "arrow.median.approximate",
- class = "arrow.median.approximate"
- )
- set_agg(
- fun = "approximate_median",
- data = list(x),
- options = list(skip_nulls = na.rm)
- )
- },
- notes = "approximate median (t-digest) is computed"
- )
- register_binding_agg("dplyr::n_distinct", function(..., na.rm = FALSE) {
- set_agg(
- fun = "count_distinct",
- data = ensure_one_arg(list2(...), "n_distinct"),
- options = list(na.rm = na.rm)
- )
- })
- register_binding_agg("dplyr::n", function() {
- set_agg(
- fun = "count_all",
- data = list(),
- options = list()
- )
- })
- register_binding_agg("base::min", function(..., na.rm = FALSE) {
- set_agg(
- fun = "min",
- data = ensure_one_arg(list2(...), "min"),
- options = list(skip_nulls = na.rm, min_count = 0L)
- )
- })
- register_binding_agg("base::max", function(..., na.rm = FALSE) {
- set_agg(
- fun = "max",
- data = ensure_one_arg(list2(...), "max"),
- options = list(skip_nulls = na.rm, min_count = 0L)
- )
- })
-}
-
-set_agg <- function(...) {
- agg_data <- list2(...)
- # Find the environment where ..aggregations is stored
- target <- find_aggregations_env()
- aggs <- get("..aggregations", target)
- lapply(agg_data[["data"]], function(expr) {
- # If any of the fields referenced in the expression are in ..aggregations,
- # then we can't aggregate over them.
- # This is mainly for combinations of dataset columns and aggregations,
- # like sum(x - mean(x)), i.e. window functions.
- # This will reject (sum(sum(x)) as well, but that's not a useful operation.
- if (any(expr$field_names_in_expression() %in% names(aggs))) {
- # TODO: support in ARROW-13926
- arrow_not_supported("aggregate within aggregate expression")
- }
- })
-
- # Record the (fun, data, options) in ..aggregations
- # and return a FieldRef pointing to it
- tmpname <- paste0("..temp", length(aggs))
- aggs[[tmpname]] <- agg_data
- assign("..aggregations", aggs, envir = target)
- Expression$field_ref(tmpname)
-}
-
-find_aggregations_env <- function() {
- # Find the environment where ..aggregations is stored,
- # it's in parent.env of something in the call stack
- for (f in sys.frames()) {
- if (exists("..aggregations", envir = f)) {
- return(f)
- }
- }
- stop("Could not find ..aggregations")
-}
-
-# we register 2 versions of the "::" binding - one for use with agg_funcs
-# (registered below) and another one for use with nse_funcs
-# (registered in dplyr-funcs.R)
-agg_funcs[["::"]] <- function(lhs, rhs) {
- lhs_name <- as.character(substitute(lhs))
- rhs_name <- as.character(substitute(rhs))
-
- fun_name <- paste0(lhs_name, "::", rhs_name)
-
- # if we do not have a binding for pkg::fun, then fall back on to the
- # nse_funcs (useful when we have a regular function inside an aggregating
one)
- # and then, if searching nse_funcs fails too, fall back to the
- # regular `pkg::fun()` function
- agg_funcs[[fun_name]] %||% nse_funcs[[fun_name]] %||%
asNamespace(lhs_name)[[rhs_name]]
-}
-
# The following S3 methods are registered on load if dplyr is present
summarise.arrow_dplyr_query <- function(.data, ..., .by = NULL, .groups =
NULL) {