This is an automated email from the ASF dual-hosted git repository.

npr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new c87073737b MINOR: [R] refactor arrow_mask to include aggregations list 
(#41414)
c87073737b is described below

commit c87073737b6ffef9715549a199499b92630e8e5f
Author: Neal Richardson <[email protected]>
AuthorDate: Mon Apr 29 11:32:01 2024 -0400

    MINOR: [R] refactor arrow_mask to include aggregations list (#41414)
    
    ### Rationale for this change
    
    Keeping the `..aggregations` list in parent.frame felt a little wrong.
    As we're starting to use this in more places (like mutate in #41350, and
    potentially more places), I wanted to try to improve this. I tried a
    bunch of things before to put it somewhere better (like in the mask) but
    failed. Finally I found one that worked.
    
    ### What changes are included in this PR?
    
    Just a refactor
    
    ### Are these changes tested?
    
    Existing tests pass.
    
    ### Are there any user-facing changes?
    
    Nope.
---
 r/R/dplyr-eval.R      |  8 +++-----
 r/R/dplyr-funcs-agg.R | 23 ++++++++++++-----------
 r/R/dplyr-summarize.R | 41 ++++++++++++++++++-----------------------
 3 files changed, 33 insertions(+), 39 deletions(-)

diff --git a/r/R/dplyr-eval.R b/r/R/dplyr-eval.R
index 3aaa29696b..ff1619ce94 100644
--- a/r/R/dplyr-eval.R
+++ b/r/R/dplyr-eval.R
@@ -125,13 +125,9 @@ arrow_mask <- function(.data, aggregation = FALSE) {
   f_env <- new_environment(.cache$functions)
 
   if (aggregation) {
-    # Add the aggregation functions to the environment, and set the enclosing
-    # environment to the parent frame so that, when called from 
summarize_eval(),
-    # they can reference and assign into `..aggregations` defined there.
-    pf <- parent.frame()
+    # Add the aggregation functions to the environment.
     for (f in names(agg_funcs)) {
       f_env[[f]] <- agg_funcs[[f]]
-      environment(f_env[[f]]) <- pf
     }
   } else {
     # Add functions that need to error hard and clear.
@@ -156,6 +152,8 @@ arrow_mask <- function(.data, aggregation = FALSE) {
   # TODO: figure out what rlang::as_data_pronoun does/why we should use it
   # (because if we do we get `Error: Can't modify the data pronoun` in 
mutate())
   out$.data <- .data$selected_columns
+  # Add the aggregations list to collect any that get pulled out when 
evaluating
+  out$.aggregations <- empty_named_list()
   out
 }
 
diff --git a/r/R/dplyr-funcs-agg.R b/r/R/dplyr-funcs-agg.R
index ab1df1d2f1..d84f8f28f0 100644
--- a/r/R/dplyr-funcs-agg.R
+++ b/r/R/dplyr-funcs-agg.R
@@ -17,7 +17,7 @@
 
 # Aggregation functions
 #
-# These all insert into an ..aggregations list (in a parent frame) a list 
containing:
+# These all insert into an .aggregations list in the mask, a list containing:
 # @param fun string function name
 # @param data list of 0 or more Expressions
 # @param options list of function options, as passed to call_function
@@ -154,11 +154,11 @@ register_bindings_aggregate <- function() {
 
 set_agg <- function(...) {
   agg_data <- list2(...)
-  # Find the environment where ..aggregations is stored
+  # Find the environment where .aggregations is stored
   target <- find_aggregations_env()
-  aggs <- get("..aggregations", target)
+  aggs <- get(".aggregations", target)
   lapply(agg_data[["data"]], function(expr) {
-    # If any of the fields referenced in the expression are in ..aggregations,
+    # If any of the fields referenced in the expression are in .aggregations,
     # then we can't aggregate over them.
     # This is mainly for combinations of dataset columns and aggregations,
     # like sum(x - mean(x)), i.e. window functions.
@@ -169,23 +169,24 @@ set_agg <- function(...) {
     }
   })
 
-  # Record the (fun, data, options) in ..aggregations
+  # Record the (fun, data, options) in .aggregations
   # and return a FieldRef pointing to it
   tmpname <- paste0("..temp", length(aggs))
   aggs[[tmpname]] <- agg_data
-  assign("..aggregations", aggs, envir = target)
+  assign(".aggregations", aggs, envir = target)
   Expression$field_ref(tmpname)
 }
 
 find_aggregations_env <- function() {
-  # Find the environment where ..aggregations is stored,
+  # Find the environment where .aggregations is stored,
   # it's in parent.env of something in the call stack
-  for (f in sys.frames()) {
-    if (exists("..aggregations", envir = f)) {
-      return(f)
+  n <- 1
+  while (TRUE) {
+    if (exists(".aggregations", envir = caller_env(n))) {
+      return(caller_env(n))
     }
+    n <- n + 1
   }
-  stop("Could not find ..aggregations")
 }
 
 ensure_one_arg <- function(args, fun) {
diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R
index 5bb81dc2b3..56de14db6d 100644
--- a/r/R/dplyr-summarize.R
+++ b/r/R/dplyr-summarize.R
@@ -80,34 +80,32 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) {
   # ExecNode), and in the expressions, replace them with FieldRefs so that
   # further operations can happen (in what will become a ProjectNode that works
   # on the result of the Aggregate).
-  # To do this, we create a list in this function scope, and in arrow_mask(),
-  # and we make sure this environment here is the parent env of the binding
-  # functions, so that when they receive an expression, they can pull out
-  # aggregations and insert them into the list, which they can find because it
-  # is in the parent env.
+  # To do this, arrow_mask() includes a list called .aggregations,
+  # and the aggregation functions will pull out those terms and insert into
+  # that list.
   # nolint end
-  ..aggregations <- empty_named_list()
-
-  # We'll collect any transformations after the aggregation here
-  ..post_mutate <- empty_named_list()
   mask <- arrow_mask(.data, aggregation = TRUE)
 
+  # We'll collect any transformations after the aggregation here.
+  # summarize_eval() returns NULL when the outer expression is an aggregation,
+  # i.e. there is no projection to do after
+  post_mutate <- empty_named_list()
   for (i in seq_along(exprs)) {
     # Iterate over the indices and not the names because names may be repeated
     # (which overwrites the previous name)
     name <- names(exprs)[i]
-    ..post_mutate[[name]] <- summarize_eval(name, exprs[[i]], mask)
+    post_mutate[[name]] <- summarize_eval(name, exprs[[i]], mask)
   }
 
   # Apply the results to the .data object.
   # First, the aggregations
-  .data$aggregations <- ..aggregations
+  .data$aggregations <- mask$.aggregations
   # Then collapse the query so that the resulting query object can have
   # additional operations applied to it
   out <- collapse.arrow_dplyr_query(.data)
 
-  # Now, add the projections in ..post_mutate (if any)
-  for (post in names(..post_mutate)) {
+  # Now, add the projections in post_mutate (if any)
+  for (post in names(post_mutate)) {
     # One last check: it's possible that an expression like y - mean(y) would
     # successfully evaluate, but it's not supported. It gets transformed to:
     # nolint start
@@ -121,7 +119,7 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) {
     # We can tell the expression is invalid if it references fields not in
     # the schema of the data after summarize(). Evaulating its type will
     # throw an error if it's invalid.
-    tryCatch(..post_mutate[[post]]$type(out$.data$schema), error = function(e) 
{
+    tryCatch(post_mutate[[post]]$type(out$.data$schema), error = function(e) {
       msg <- paste(
         "Expression", as_label(exprs[[post]]),
         "is not a valid aggregation expression or is"
@@ -129,7 +127,7 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) {
       arrow_not_supported(msg)
     })
     # If it's valid, add it to the .data object
-    out$selected_columns[[post]] <- ..post_mutate[[post]]
+    out$selected_columns[[post]] <- post_mutate[[post]]
   }
 
   # Make sure column order is correct (and also drop ..temp columns)
@@ -266,10 +264,10 @@ format_aggregation <- function(x) {
 # This function evaluates an expression and returns the post-summarize
 # projection that results, or NULL if there is none because the top-level
 # expression was an aggregation. Any aggregations are pulled out and collected
-# in the ..aggregations list outside this function.
+# in the .aggregations list outside this function.
 summarize_eval <- function(name, quosure, mask) {
   # Add previous aggregations to the mask, so they can be referenced
-  for (n in names(get("..aggregations", parent.frame()))) {
+  for (n in names(mask$.aggregations)) {
     mask[[n]] <- mask$.data[[n]] <- Expression$field_ref(n)
   }
   # Evaluate:
@@ -286,14 +284,11 @@ summarize_eval <- function(name, quosure, mask) {
   # Handle case where outer expr is ..temp field ref. This came from an
   # aggregation at the top level. So the resulting name should be `name`.
   # not `..tempN`. Rename the corresponding aggregation.
-  post_aggs <- get("..aggregations", parent.frame())
   result_field_name <- value$field_name
-  if (result_field_name %in% names(post_aggs)) {
+  if (result_field_name %in% names(mask$.aggregations)) {
     # Do this by assigning over `name` in case something else was in `name`
-    post_aggs[[name]] <- post_aggs[[result_field_name]]
-    post_aggs[[result_field_name]] <- NULL
-    # Assign back into the parent environment
-    assign("..aggregations", post_aggs, parent.frame())
+    mask$.aggregations[[name]] <- mask$.aggregations[[result_field_name]]
+    mask$.aggregations[[result_field_name]] <- NULL
     # Return NULL because there is no post-mutate projection, it's just
     # the aggregation
     return(NULL)

Reply via email to