This is an automated email from the ASF dual-hosted git repository.
npr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new c87073737b MINOR: [R] refactor arrow_mask to include aggregations list
(#41414)
c87073737b is described below
commit c87073737b6ffef9715549a199499b92630e8e5f
Author: Neal Richardson <[email protected]>
AuthorDate: Mon Apr 29 11:32:01 2024 -0400
MINOR: [R] refactor arrow_mask to include aggregations list (#41414)
### Rationale for this change
Keeping the `..aggregations` list in parent.frame felt a little wrong.
As we're starting to use this in more places (like mutate in #41350, and
potentially more places), I wanted to try to improve this. I tried a
bunch of things before to put it somewhere better (like in the mask) but
failed. Finally I found one that worked.
### What changes are included in this PR?
Just a refactor
### Are these changes tested?
Existing tests pass.
### Are there any user-facing changes?
Nope.
---
r/R/dplyr-eval.R | 8 +++-----
r/R/dplyr-funcs-agg.R | 23 ++++++++++++-----------
r/R/dplyr-summarize.R | 41 ++++++++++++++++++-----------------------
3 files changed, 33 insertions(+), 39 deletions(-)
diff --git a/r/R/dplyr-eval.R b/r/R/dplyr-eval.R
index 3aaa29696b..ff1619ce94 100644
--- a/r/R/dplyr-eval.R
+++ b/r/R/dplyr-eval.R
@@ -125,13 +125,9 @@ arrow_mask <- function(.data, aggregation = FALSE) {
f_env <- new_environment(.cache$functions)
if (aggregation) {
- # Add the aggregation functions to the environment, and set the enclosing
- # environment to the parent frame so that, when called from
summarize_eval(),
- # they can reference and assign into `..aggregations` defined there.
- pf <- parent.frame()
+ # Add the aggregation functions to the environment.
for (f in names(agg_funcs)) {
f_env[[f]] <- agg_funcs[[f]]
- environment(f_env[[f]]) <- pf
}
} else {
# Add functions that need to error hard and clear.
@@ -156,6 +152,8 @@ arrow_mask <- function(.data, aggregation = FALSE) {
# TODO: figure out what rlang::as_data_pronoun does/why we should use it
# (because if we do we get `Error: Can't modify the data pronoun` in
mutate())
out$.data <- .data$selected_columns
+ # Add the aggregations list to collect any that get pulled out when
evaluating
+ out$.aggregations <- empty_named_list()
out
}
diff --git a/r/R/dplyr-funcs-agg.R b/r/R/dplyr-funcs-agg.R
index ab1df1d2f1..d84f8f28f0 100644
--- a/r/R/dplyr-funcs-agg.R
+++ b/r/R/dplyr-funcs-agg.R
@@ -17,7 +17,7 @@
# Aggregation functions
#
-# These all insert into an ..aggregations list (in a parent frame) a list
containing:
+# These all insert into an .aggregations list in the mask, a list containing:
# @param fun string function name
# @param data list of 0 or more Expressions
# @param options list of function options, as passed to call_function
@@ -154,11 +154,11 @@ register_bindings_aggregate <- function() {
set_agg <- function(...) {
agg_data <- list2(...)
- # Find the environment where ..aggregations is stored
+ # Find the environment where .aggregations is stored
target <- find_aggregations_env()
- aggs <- get("..aggregations", target)
+ aggs <- get(".aggregations", target)
lapply(agg_data[["data"]], function(expr) {
- # If any of the fields referenced in the expression are in ..aggregations,
+ # If any of the fields referenced in the expression are in .aggregations,
# then we can't aggregate over them.
# This is mainly for combinations of dataset columns and aggregations,
# like sum(x - mean(x)), i.e. window functions.
@@ -169,23 +169,24 @@ set_agg <- function(...) {
}
})
- # Record the (fun, data, options) in ..aggregations
+ # Record the (fun, data, options) in .aggregations
# and return a FieldRef pointing to it
tmpname <- paste0("..temp", length(aggs))
aggs[[tmpname]] <- agg_data
- assign("..aggregations", aggs, envir = target)
+ assign(".aggregations", aggs, envir = target)
Expression$field_ref(tmpname)
}
find_aggregations_env <- function() {
- # Find the environment where ..aggregations is stored,
+ # Find the environment where .aggregations is stored,
# it's in parent.env of something in the call stack
- for (f in sys.frames()) {
- if (exists("..aggregations", envir = f)) {
- return(f)
+ n <- 1
+ while (TRUE) {
+ if (exists(".aggregations", envir = caller_env(n))) {
+ return(caller_env(n))
}
+ n <- n + 1
}
- stop("Could not find ..aggregations")
}
ensure_one_arg <- function(args, fun) {
diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R
index 5bb81dc2b3..56de14db6d 100644
--- a/r/R/dplyr-summarize.R
+++ b/r/R/dplyr-summarize.R
@@ -80,34 +80,32 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) {
# ExecNode), and in the expressions, replace them with FieldRefs so that
# further operations can happen (in what will become a ProjectNode that works
# on the result of the Aggregate).
- # To do this, we create a list in this function scope, and in arrow_mask(),
- # and we make sure this environment here is the parent env of the binding
- # functions, so that when they receive an expression, they can pull out
- # aggregations and insert them into the list, which they can find because it
- # is in the parent env.
+ # To do this, arrow_mask() includes a list called .aggregations,
+ # and the aggregation functions will pull out those terms and insert into
+ # that list.
# nolint end
- ..aggregations <- empty_named_list()
-
- # We'll collect any transformations after the aggregation here
- ..post_mutate <- empty_named_list()
mask <- arrow_mask(.data, aggregation = TRUE)
+ # We'll collect any transformations after the aggregation here.
+ # summarize_eval() returns NULL when the outer expression is an aggregation,
+ # i.e. there is no projection to do after
+ post_mutate <- empty_named_list()
for (i in seq_along(exprs)) {
# Iterate over the indices and not the names because names may be repeated
# (which overwrites the previous name)
name <- names(exprs)[i]
- ..post_mutate[[name]] <- summarize_eval(name, exprs[[i]], mask)
+ post_mutate[[name]] <- summarize_eval(name, exprs[[i]], mask)
}
# Apply the results to the .data object.
# First, the aggregations
- .data$aggregations <- ..aggregations
+ .data$aggregations <- mask$.aggregations
# Then collapse the query so that the resulting query object can have
# additional operations applied to it
out <- collapse.arrow_dplyr_query(.data)
- # Now, add the projections in ..post_mutate (if any)
- for (post in names(..post_mutate)) {
+ # Now, add the projections in post_mutate (if any)
+ for (post in names(post_mutate)) {
# One last check: it's possible that an expression like y - mean(y) would
# successfully evaluate, but it's not supported. It gets transformed to:
# nolint start
@@ -121,7 +119,7 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) {
# We can tell the expression is invalid if it references fields not in
# the schema of the data after summarize(). Evaulating its type will
# throw an error if it's invalid.
- tryCatch(..post_mutate[[post]]$type(out$.data$schema), error = function(e)
{
+ tryCatch(post_mutate[[post]]$type(out$.data$schema), error = function(e) {
msg <- paste(
"Expression", as_label(exprs[[post]]),
"is not a valid aggregation expression or is"
@@ -129,7 +127,7 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) {
arrow_not_supported(msg)
})
# If it's valid, add it to the .data object
- out$selected_columns[[post]] <- ..post_mutate[[post]]
+ out$selected_columns[[post]] <- post_mutate[[post]]
}
# Make sure column order is correct (and also drop ..temp columns)
@@ -266,10 +264,10 @@ format_aggregation <- function(x) {
# This function evaluates an expression and returns the post-summarize
# projection that results, or NULL if there is none because the top-level
# expression was an aggregation. Any aggregations are pulled out and collected
-# in the ..aggregations list outside this function.
+# in the .aggregations list outside this function.
summarize_eval <- function(name, quosure, mask) {
# Add previous aggregations to the mask, so they can be referenced
- for (n in names(get("..aggregations", parent.frame()))) {
+ for (n in names(mask$.aggregations)) {
mask[[n]] <- mask$.data[[n]] <- Expression$field_ref(n)
}
# Evaluate:
@@ -286,14 +284,11 @@ summarize_eval <- function(name, quosure, mask) {
# Handle case where outer expr is ..temp field ref. This came from an
# aggregation at the top level. So the resulting name should be `name`.
# not `..tempN`. Rename the corresponding aggregation.
- post_aggs <- get("..aggregations", parent.frame())
result_field_name <- value$field_name
- if (result_field_name %in% names(post_aggs)) {
+ if (result_field_name %in% names(mask$.aggregations)) {
# Do this by assigning over `name` in case something else was in `name`
- post_aggs[[name]] <- post_aggs[[result_field_name]]
- post_aggs[[result_field_name]] <- NULL
- # Assign back into the parent environment
- assign("..aggregations", post_aggs, parent.frame())
+ mask$.aggregations[[name]] <- mask$.aggregations[[result_field_name]]
+ mask$.aggregations[[result_field_name]] <- NULL
# Return NULL because there is no post-mutate projection, it's just
# the aggregation
return(NULL)