jonkeane commented on a change in pull request #11108:
URL: https://github.com/apache/arrow/pull/11108#discussion_r705692706



##########
File path: r/R/dplyr-summarize.R
##########
@@ -42,33 +49,63 @@ summarise.arrow_dplyr_query <- function(.data, ..., .engine 
= c("arrow", "duckdb
 }
 summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query
 
+# This is the Arrow summarize implementation

Review comment:
       This is great — finding these can be tricky in this code sometimes

##########
File path: r/R/dplyr-summarize.R
##########
@@ -81,3 +118,109 @@ summarize_projection <- function(.data) {
 format_aggregation <- function(x) {
   paste0(x$fun, "(", x$data$ToString(), ")")
 }
+
+# This function handles each summarize expression and turns it into the
+# appropriate combination of (1) aggregations (possibly temporary) and
+# (2) post-aggregation transformations (mutate)
+# The function returns nothing: it assigns into the `ctx` environment
+summarize_eval <- function(name, quosure, ctx, recurse = FALSE) {
+  expr <- quo_get_expr(quosure)
+  ctx$quo_env <- quo_get_env(quosure)
+
+  funs_in_expr <- all_funs(expr)
+  if (length(funs_in_expr) == 0) {
+    # If it is a scalar or field ref, no special handling required
+    ctx$aggregations[[name]] <- arrow_eval_or_stop(quosure, ctx$mask)
+    return()
+  }
+
+  # Start inspecting the expr to see what aggregations it involves
+  agg_funs <- names(agg_funcs)
+  outer_agg <- funs_in_expr[1] %in% agg_funs
+  inner_agg <- funs_in_expr[-1] %in% agg_funs
+
+  # First, pull out any aggregations wrapped in other function calls
+  if (any(inner_agg)) {
+    expr <- extract_aggregations(expr, ctx)
+  }
+
+  # By this point, there are no more aggregation functions in expr
+  # except for possibly the outer function call:
+  # they've all been pulled out to ctx$aggregations, and in their place in expr
+  # there are variable names, which will correspond to field refs in the
+  # query object after aggregation and collapse().
+  # So if we want to know if there are any aggregations inside expr,
+  # we have to look for them by their new var names
+  inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations)
+
+  if (outer_agg) {
+    # This is something like agg(fun(x, y)
+    # It just works by normal arrow_eval, unless there's a mix of aggs and
+    # columns in the original data like agg(fun(x, agg(x)))
+    # (but that will have been caught in extract_aggregations())
+    ctx$aggregations[[name]] <- arrow_eval_or_stop(
+      as_quosure(expr, ctx$quo_env),
+      ctx$mask
+    )
+    return()
+  } else if (all(inner_agg_exprs)) {
+    # fun(agg(x), agg(y))
+    # So based on the aggregations that have been extracted, mutate after
+    mutate_mask <- arrow_mask(
+      list(selected_columns = make_field_refs(names(ctx$aggregations)))
+    )
+    ctx$post_mutate[[name]] <- arrow_eval_or_stop(
+      as_quosure(expr, ctx$quo_env),
+      mutate_mask
+    )
+    return()
+  }
+
+  # Backstop for any other odd cases, like fun(x, y) (i.e. no aggregation),
+  # or aggregation functions that aren't supported in Arrow (not in agg_funcs)
+  stop(
+    handle_arrow_not_supported(
+      quo_get_expr(quosure),
+      as_label(quo_get_expr(quosure))
+    ),
+    call. = FALSE
+  )
+}
+
+# This function recurses through expr, pulls out any aggregation expressions,
+# and inserts a variable name (field ref) in place of the aggregation
+extract_aggregations <- function(expr, ctx) {
+  # Keep the input in case we need to raise an error message with it
+  original_expr <- expr
+  funs <- all_funs(expr)
+  if (length(funs) == 0) {
+    return(expr)
+  } else if (length(funs) > 1) {
+    # Recurse more
+    expr[-1] <- lapply(expr[-1], extract_aggregations, ctx)
+  }
+  if (funs[1] %in% names(agg_funcs)) {
+    inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations)
+    if (any(inner_agg_exprs) & !all(inner_agg_exprs)) {
+      # We can't aggregate over a combination of dataset columns and other
+      # aggregations (e.g. sum(x - mean(x)))
+      # TODO: support in ARROW-13926
+      # TODO: Add "because" arg to explain _why_ it's not supported?
+      # TODO: this message could also say "not supported in summarize()"
+      #       since some of these expressions may be legal elsewhere

Review comment:
       This last TODO I think is important — I think anyone who gets a not 
supported message will assume that expression is not supported in Arrow at all 
anywhere.

##########
File path: r/R/dplyr-summarize.R
##########
@@ -42,33 +49,63 @@ summarise.arrow_dplyr_query <- function(.data, ..., .engine 
= c("arrow", "duckdb
 }
 summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query
 
+# This is the Arrow summarize implementation
 do_arrow_summarize <- function(.data, ..., .groups = NULL) {
   if (!is.null(.groups)) {
     # ARROW-13550
     abort("`summarize()` with `.groups` argument not supported in Arrow")
   }
   exprs <- ensure_named_exprs(quos(...))
 
-  mask <- arrow_mask(.data, aggregation = TRUE)
-
-  results <- empty_named_list()
+  # Create a stateful environment for recording our evaluated expressions
+  # It's more complex than other places because a single summarize() expr
+  # may result in multiple query nodes (Aggregate, Project),
+  # and we have to walk through the expressions to disentangle them.
+  ctx <- env(
+    mask = arrow_mask(.data, aggregation = TRUE),
+    aggregations = empty_named_list(),
+    post_mutate = empty_named_list()
+  )
   for (i in seq_along(exprs)) {
     # Iterate over the indices and not the names because names may be repeated
     # (which overwrites the previous name)
-    new_var <- names(exprs)[i]
-    results[[new_var]] <- arrow_eval(exprs[[i]], mask)
-    if (inherits(results[[new_var]], "try-error")) {
-      msg <- handle_arrow_not_supported(
-        results[[new_var]],
-        as_label(exprs[[i]])
-      )
-      stop(msg, call. = FALSE)
-    }
+    summarize_eval(names(exprs)[i], exprs[[i]], ctx)
   }
 
-  .data$aggregations <- results
-  # TODO: should in-memory query evaluate eagerly?
-  collapse.arrow_dplyr_query(.data)
+  # Apply the results to the .data object.
+  # First, the aggregations
+  .data$aggregations <- ctx$aggregations
+  # Then collapse the query so that the resulting query object can have
+  # additional operations applied to it
+  out <- collapse.arrow_dplyr_query(.data)
+  # The expressions may have been translated into
+  # "first, aggregate, then transform the result further"
+  # For example,
+  #   summarize(mean = sum(x) / n())
+  # is effectively implemented as
+  #   summarize(..temp0 = sum(x), ..temp1 = n()) %>%
+  #   mutate(mean = ..temp0 / ..temp1) %>%
+  #   select(-starts_with("..temp"))
+  # If this is the case, there will be expressions in post_mutate

Review comment:
       ```suggestion
     # If this is the case, there will be expressions in post_mutate
     # nolint end
   ```

##########
File path: r/R/dplyr-summarize.R
##########
@@ -42,33 +49,63 @@ summarise.arrow_dplyr_query <- function(.data, ..., .engine 
= c("arrow", "duckdb
 }
 summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query
 
+# This is the Arrow summarize implementation
 do_arrow_summarize <- function(.data, ..., .groups = NULL) {
   if (!is.null(.groups)) {
     # ARROW-13550
     abort("`summarize()` with `.groups` argument not supported in Arrow")
   }
   exprs <- ensure_named_exprs(quos(...))
 
-  mask <- arrow_mask(.data, aggregation = TRUE)
-
-  results <- empty_named_list()
+  # Create a stateful environment for recording our evaluated expressions
+  # It's more complex than other places because a single summarize() expr
+  # may result in multiple query nodes (Aggregate, Project),
+  # and we have to walk through the expressions to disentangle them.
+  ctx <- env(
+    mask = arrow_mask(.data, aggregation = TRUE),
+    aggregations = empty_named_list(),
+    post_mutate = empty_named_list()
+  )
   for (i in seq_along(exprs)) {
     # Iterate over the indices and not the names because names may be repeated
     # (which overwrites the previous name)
-    new_var <- names(exprs)[i]
-    results[[new_var]] <- arrow_eval(exprs[[i]], mask)
-    if (inherits(results[[new_var]], "try-error")) {
-      msg <- handle_arrow_not_supported(
-        results[[new_var]],
-        as_label(exprs[[i]])
-      )
-      stop(msg, call. = FALSE)
-    }
+    summarize_eval(names(exprs)[i], exprs[[i]], ctx)
   }
 
-  .data$aggregations <- results
-  # TODO: should in-memory query evaluate eagerly?
-  collapse.arrow_dplyr_query(.data)
+  # Apply the results to the .data object.
+  # First, the aggregations
+  .data$aggregations <- ctx$aggregations
+  # Then collapse the query so that the resulting query object can have
+  # additional operations applied to it
+  out <- collapse.arrow_dplyr_query(.data)
+  # The expressions may have been translated into
+  # "first, aggregate, then transform the result further"
+  # For example,

Review comment:
       ```suggestion
     # nolint start
     # For example,
   ```

##########
File path: r/R/dplyr-summarize.R
##########
@@ -81,3 +118,109 @@ summarize_projection <- function(.data) {
 format_aggregation <- function(x) {
   paste0(x$fun, "(", x$data$ToString(), ")")
 }
+
+# This function handles each summarize expression and turns it into the
+# appropriate combination of (1) aggregations (possibly temporary) and
+# (2) post-aggregation transformations (mutate)
+# The function returns nothing: it assigns into the `ctx` environment
+summarize_eval <- function(name, quosure, ctx, recurse = FALSE) {
+  expr <- quo_get_expr(quosure)
+  ctx$quo_env <- quo_get_env(quosure)
+
+  funs_in_expr <- all_funs(expr)
+  if (length(funs_in_expr) == 0) {
+    # If it is a scalar or field ref, no special handling required
+    ctx$aggregations[[name]] <- arrow_eval_or_stop(quosure, ctx$mask)
+    return()
+  }
+
+  # Start inspecting the expr to see what aggregations it involves
+  agg_funs <- names(agg_funcs)
+  outer_agg <- funs_in_expr[1] %in% agg_funs
+  inner_agg <- funs_in_expr[-1] %in% agg_funs
+
+  # First, pull out any aggregations wrapped in other function calls
+  if (any(inner_agg)) {
+    expr <- extract_aggregations(expr, ctx)
+  }
+
+  # By this point, there are no more aggregation functions in expr
+  # except for possibly the outer function call:
+  # they've all been pulled out to ctx$aggregations, and in their place in expr
+  # there are variable names, which will correspond to field refs in the
+  # query object after aggregation and collapse().
+  # So if we want to know if there are any aggregations inside expr,
+  # we have to look for them by their new var names
+  inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations)
+
+  if (outer_agg) {
+    # This is something like agg(fun(x, y)
+    # It just works by normal arrow_eval, unless there's a mix of aggs and
+    # columns in the original data like agg(fun(x, agg(x)))
+    # (but that will have been caught in extract_aggregations())
+    ctx$aggregations[[name]] <- arrow_eval_or_stop(
+      as_quosure(expr, ctx$quo_env),
+      ctx$mask
+    )
+    return()
+  } else if (all(inner_agg_exprs)) {
+    # fun(agg(x), agg(y))

Review comment:
       ```suggestion
       # Something like: fun(agg(x), agg(y))
   ```
   
   We can turn off the commented code lintr if this gets too annoying (though 
when it catches legit comment code, it's great)

##########
File path: r/R/dplyr-summarize.R
##########
@@ -81,3 +118,109 @@ summarize_projection <- function(.data) {
 format_aggregation <- function(x) {
   paste0(x$fun, "(", x$data$ToString(), ")")
 }
+
+# This function handles each summarize expression and turns it into the
+# appropriate combination of (1) aggregations (possibly temporary) and
+# (2) post-aggregation transformations (mutate)
+# The function returns nothing: it assigns into the `ctx` environment
+summarize_eval <- function(name, quosure, ctx, recurse = FALSE) {
+  expr <- quo_get_expr(quosure)
+  ctx$quo_env <- quo_get_env(quosure)
+
+  funs_in_expr <- all_funs(expr)
+  if (length(funs_in_expr) == 0) {
+    # If it is a scalar or field ref, no special handling required
+    ctx$aggregations[[name]] <- arrow_eval_or_stop(quosure, ctx$mask)
+    return()
+  }
+
+  # Start inspecting the expr to see what aggregations it involves
+  agg_funs <- names(agg_funcs)
+  outer_agg <- funs_in_expr[1] %in% agg_funs
+  inner_agg <- funs_in_expr[-1] %in% agg_funs
+
+  # First, pull out any aggregations wrapped in other function calls
+  if (any(inner_agg)) {
+    expr <- extract_aggregations(expr, ctx)
+  }
+
+  # By this point, there are no more aggregation functions in expr
+  # except for possibly the outer function call:
+  # they've all been pulled out to ctx$aggregations, and in their place in expr
+  # there are variable names, which will correspond to field refs in the
+  # query object after aggregation and collapse().
+  # So if we want to know if there are any aggregations inside expr,
+  # we have to look for them by their new var names
+  inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations)
+
+  if (outer_agg) {
+    # This is something like agg(fun(x, y)
+    # It just works by normal arrow_eval, unless there's a mix of aggs and
+    # columns in the original data like agg(fun(x, agg(x)))
+    # (but that will have been caught in extract_aggregations())
+    ctx$aggregations[[name]] <- arrow_eval_or_stop(
+      as_quosure(expr, ctx$quo_env),
+      ctx$mask
+    )
+    return()
+  } else if (all(inner_agg_exprs)) {
+    # fun(agg(x), agg(y))
+    # So based on the aggregations that have been extracted, mutate after
+    mutate_mask <- arrow_mask(
+      list(selected_columns = make_field_refs(names(ctx$aggregations)))
+    )
+    ctx$post_mutate[[name]] <- arrow_eval_or_stop(
+      as_quosure(expr, ctx$quo_env),
+      mutate_mask
+    )
+    return()
+  }
+
+  # Backstop for any other odd cases, like fun(x, y) (i.e. no aggregation),
+  # or aggregation functions that aren't supported in Arrow (not in agg_funcs)
+  stop(
+    handle_arrow_not_supported(
+      quo_get_expr(quosure),
+      as_label(quo_get_expr(quosure))
+    ),
+    call. = FALSE
+  )
+}
+
+# This function recurses through expr, pulls out any aggregation expressions,
+# and inserts a variable name (field ref) in place of the aggregation
+extract_aggregations <- function(expr, ctx) {
+  # Keep the input in case we need to raise an error message with it
+  original_expr <- expr
+  funs <- all_funs(expr)
+  if (length(funs) == 0) {
+    return(expr)
+  } else if (length(funs) > 1) {
+    # Recurse more
+    expr[-1] <- lapply(expr[-1], extract_aggregations, ctx)
+  }
+  if (funs[1] %in% names(agg_funcs)) {
+    inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations)
+    if (any(inner_agg_exprs) & !all(inner_agg_exprs)) {
+      # We can't aggregate over a combination of dataset columns and other
+      # aggregations (e.g. sum(x - mean(x)))
+      # TODO: support in ARROW-13926
+      # TODO: Add "because" arg to explain _why_ it's not supported?
+      # TODO: this message could also say "not supported in summarize()"
+      #       since some of these expressions may be legal elsewhere
+      stop(
+        handle_arrow_not_supported(original_expr, as_label(original_expr)),
+        call. = FALSE
+      )
+    }
+
+    # We have an aggregation expression with no other aggregations inside it,
+    # so arrow_eval the expression on the data and give it a ..temp name 
prefix,
+    # then insert that name (symbol) back into the expression so that we can
+    # mutate() on the result of the aggregation and reference this field.
+    tmpname <- paste0("..temp", length(ctx$aggregations))

Review comment:
       This is vanishingly rare, but could we at this point check for any 
fields named in `tmpname` here? I can't imagine anyone would have one, but it 
would be better to error. Or would this error later without proactive checking? 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to