jonkeane commented on code in PR #41223:
URL: https://github.com/apache/arrow/pull/41223#discussion_r1566460890
##########
r/R/dplyr-eval.R:
##########
@@ -20,6 +20,8 @@ arrow_eval <- function(expr, mask) {
# with references to Arrays (if .data is Table/RecordBatch) or Fields (if
# .data is a Dataset).
+ add_user_functions_to_mask(expr, mask)
Review Comment:
It might be nice to have a comment here explaining what this is — it's right
below on line 54, so maybe that's enough, but something quick and descriptive
(mostly that this is _not_ UDFs and instead about R funcs in the parent/global
environment
##########
r/R/dplyr-summarize.R:
##########
@@ -221,25 +257,27 @@ do_arrow_summarize <- function(.data, ..., .groups =
NULL) {
# It's more complex than other places because a single summarize() expr
# may result in multiple query nodes (Aggregate, Project),
# and we have to walk through the expressions to disentangle them.
- ctx <- env(
- mask = arrow_mask(.data, aggregation = TRUE),
- aggregations = empty_named_list(),
- post_mutate = empty_named_list()
- )
+
+ # Agg functions pull out the aggregation info and append it here
+ ..aggregations <- empty_named_list()
+ # And if there are any transformations after the aggregation, they go here
+ ..post_mutate <- empty_named_list()
+ mask <- arrow_mask(.data, aggregation = TRUE)
+
for (i in seq_along(exprs)) {
# Iterate over the indices and not the names because names may be repeated
# (which overwrites the previous name)
summarize_eval(
names(exprs)[i],
exprs[[i]],
- ctx,
+ mask,
length(.data$group_by_vars) > 0
)
}
# Apply the results to the .data object.
# First, the aggregations
- .data$aggregations <- ctx$aggregations
+ .data$aggregations <- ..aggregations
Review Comment:
I'm curious to know more about going from the `ctx` object to storing these
as `..aggregations`. I'm not at all opposed, and think this looks more natural
given some of our other machinery — but can't tell directly here if/why that's
necessary
##########
r/tests/testthat/test-dplyr-summarize.R:
##########
@@ -1083,11 +1121,6 @@ test_that("summarise() can handle scalars and literal
values", {
tibble(y = 1L)
)
- expect_identical(
- record_batch(tbl) %>% summarise(y = Expression$scalar(1L)) %>% collect(),
- tibble(y = 1L)
- )
Review Comment:
This is the same as the one above it, yeah?
##########
r/R/dplyr-eval.R:
##########
@@ -48,6 +50,43 @@ arrow_eval <- function(expr, mask) {
})
}
+add_user_functions_to_mask <- function(expr, mask) {
+ # Look for user-defined R functions that are not in the mask,
+ # see if we can add them to the mask and set their parent env to the mask
+ # so that they can reference other functions in the mask
+ if (is_quosure(expr)) {
+ # case_when evaluates regular formulas not quosures, which don't have
+ # their own environment, so let's just skip them for now
+ function_env <- parent.env(parent.env(mask))
+ quo_expr <- quo_get_expr(expr)
+ funs_in_expr <- all_funs(quo_expr)
+ quo_env <- quo_get_env(expr)
+ # Enumerate the things we have bindings for, and add anything else that we
+ # explicitly want to block from trying to add to the function environment
+ known_funcs <- c(ls(function_env, all.names = TRUE), "~", "[", ":")
+ unknown <- setdiff(funs_in_expr, known_funcs)
+ for (i in unknown) {
+ if (exists(i, quo_env)) {
+ user_fun <- get(i, quo_env)
+ if (!is.null(environment(user_fun)) &&
!rlang::is_namespace(environment(user_fun))) {
+ # Primitives don't have an environment
+ if (getOption("arrow.debug", FALSE)) {
+ print(paste("Adding", i, "to the function environment"))
+ }
+ function_env[[i]] <- user_fun
+ # Also set the enclosing environment to be the function environment.
+ # This allows the function to reference other functions in the env.
+ # This may have other undesired side effects?
+ environment(function_env[[i]]) <- function_env
+ }
+ }
+ }
+ }
Review Comment:
`i` here is the func name as a string, yeah? Maybe it would be clearer to
call it `func_name` or simply `func`?
##########
r/R/dplyr-eval.R:
##########
@@ -88,10 +127,14 @@ arrow_mask <- function(.data, aggregation = FALSE) {
}
if (aggregation) {
+ pf <- parent.frame()
# This should probably be done with an environment inside an environment
# but a first attempt at that had scoping problems (ARROW-13499)
for (f in names(agg_funcs)) {
f_env[[f]] <- agg_funcs[[f]]
+ # Make sure that ..aggregations and ..post_mutate are in the search path
+ # This assumes being called from summarize
Review Comment:
Mostly for my own education "This assumes being called from summarize" is
because we're in `if (aggregation) {...}` yeah?
##########
r/R/dplyr-summarize.R:
##########
@@ -129,35 +129,71 @@ register_bindings_aggregate <- function() {
notes = "approximate median (t-digest) is computed"
)
register_binding_agg("dplyr::n_distinct", function(..., na.rm = FALSE) {
- list(
+ set_agg(
fun = "count_distinct",
data = ensure_one_arg(list2(...), "n_distinct"),
options = list(na.rm = na.rm)
)
})
register_binding_agg("dplyr::n", function() {
- list(
+ set_agg(
fun = "count_all",
data = list(),
options = list()
)
})
register_binding_agg("base::min", function(..., na.rm = FALSE) {
- list(
+ set_agg(
fun = "min",
data = ensure_one_arg(list2(...), "min"),
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
register_binding_agg("base::max", function(..., na.rm = FALSE) {
- list(
+ set_agg(
fun = "max",
data = ensure_one_arg(list2(...), "max"),
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
}
+set_agg <- function(...) {
+ agg_data <- list2(...)
+ # Find the environment where ..aggregations is stored
+ target <- find_aggregations_env()
+ aggs <- get("..aggregations", target)
+ lapply(agg_data[["data"]], function(expr) {
+ # If any of the fields referenced in the expression are in ..aggregations,
+ # then we can't aggregate over them.
+ # This is mainly for combinations of dataset columns and aggregations,
+ # like sum(x - mean(x)), i.e. window functions.
+ # This will reject (sum(sum(x)) as well, but that's not a useful operation.
+ if (any(expr$field_names_in_expression() %in% names(aggs))) {
+ # TODO: support in ARROW-13926
+ arrow_not_supported("aggregate within aggregate expression")
+ }
+ })
+
+ # Record the (fun, data, options) in ..aggregations
+ # and return a FieldRef pointing to it
+ tmpname <- paste0("..temp", length(aggs))
+ aggs[[tmpname]] <- agg_data
+ assign("..aggregations", aggs, envir = target)
+ Expression$field_ref(tmpname)
+}
+
+find_aggregations_env <- function() {
+ # Find the environment where ..aggregations is stored,
+ # it's in parent.env of something in the call stack
+ for (f in sys.frames()) {
+ if (exists("..aggregations", envir = f)) {
+ return(f)
+ }
+ }
+ stop("Could not find ..aggregations")
Review Comment:
This should never happen in the normal course of execution, yeah?
##########
r/R/dplyr-summarize.R:
##########
@@ -254,12 +292,35 @@ do_arrow_summarize <- function(.data, ..., .groups =
NULL) {
# select(-starts_with("..temp"))
# If this is the case, there will be expressions in post_mutate
# nolint end
- if (length(ctx$post_mutate)) {
+ if (length(..post_mutate)) {
+ # One last check: it's possible that an expression like y - mean(y) would
+ # successfully evaluate, but it's not supported. It gets transformed to:
+ # nolint start
+ # summarize(..temp0 = mean(y)) %>%
+ # mutate(y - ..temp0)
+ # nolint end
+ # but y is not in the schema of the data after summarize(). To catch this
+ # in the expression evaluation step, we'd have to remove all data variables
+ # from the mask, which would be a bit tortured (even for me).
+ # So we'll check here.
+ for (post in names(..post_mutate)) {
+ # We can tell the expression is invalid if it references fields not in
+ # the schema of the data after summarize(). Evaulating its type will
+ # throw an error if it's invalid.
+ tryCatch(..post_mutate[[post]]$type(out$.data$schema), error =
function(e) {
+ msg <- paste(
+ "Expression", as_label(exprs[[post]]),
+ "is not a valid aggregation expression or is"
+ )
+ arrow_not_supported(msg)
+ })
Review Comment:
Sneaky!
##########
r/R/dplyr-eval.R:
##########
@@ -20,6 +20,8 @@ arrow_eval <- function(expr, mask) {
# with references to Arrays (if .data is Table/RecordBatch) or Fields (if
# .data is a Dataset).
+ add_user_functions_to_mask(expr, mask)
Review Comment:
And here it might be nice to describe why: we do this because some functions
can be reduced to functions that do have arrow bindings without needing to
register one yourself or the like
##########
r/tests/testthat/test-dplyr-across.R:
##########
@@ -279,7 +277,17 @@ test_that("purrr-style lambda functions are supported", {
)
})
-test_that("ARROW-14071 - function(x)-style lambda functions are not
supported", {
+test_that("ARROW-14071 - user-defined R functions", {
Review Comment:
```suggestion
test_that("ARROW-14071 - R functions from a user's environment", {
```
Just to be super clear this _isn't_ about UDFs
##########
r/tests/testthat/test-dplyr-summarize.R:
##########
@@ -836,28 +835,68 @@ test_that("Expressions on aggregations", {
expect_warning(
record_batch(tbl) %>% summarise(any(any(lgl))),
paste(
- "Aggregate within aggregate expression",
- "any\\(any\\(lgl\\)\\) not supported in Arrow"
+ "In any\\(any\\(lgl\\)\\), aggregate within aggregate expression",
+ "not supported in Arrow"
)
)
# Check aggregates on aggregates with more complex calls
expect_warning(
record_batch(tbl) %>% summarise(any(any(!lgl))),
paste(
- "Aggregate within aggregate expression",
- "any\\(any\\(!lgl\\)\\) not supported in Arrow"
+ "In any\\(any\\(!lgl\\)\\), aggregate within aggregate expression",
+ "not supported in Arrow"
)
)
expect_warning(
record_batch(tbl) %>% summarise(!any(any(lgl))),
paste(
- "Aggregate within aggregate expression",
- "any\\(any\\(lgl\\)\\) not supported in Arrow"
+ "In \\!any\\(any\\(lgl\\)\\), aggregate within aggregate expression",
+ "not supported in Arrow"
)
)
})
+test_that("Weighted mean", {
+ compare_dplyr_binding(
+ .input %>%
+ group_by(some_grouping) %>%
+ summarize(
+ weighted_mean = sum(int * dbl) / sum(dbl)
+ ) %>%
+ collect(),
+ tbl
+ )
+
+ division <- function(x, y) x / y
+ compare_dplyr_binding(
+ .input %>%
+ group_by(some_grouping) %>%
+ summarize(
+ weighted_mean = division(sum(int * dbl), sum(dbl))
+ ) %>%
+ collect(),
+ tbl
+ )
+
+ # We can also define functions that call supported aggregation functions
+ # and it just works
+ wtd_mean <- function(x, w) sum(x * w) / sum(w)
+ withr::local_options(list(arrow.debug = TRUE))
Review Comment:
Nice, this is a helpful catch / test that honestly I could see being helpful
when debugging this too — but makes this super clear that it's hitting this
code and not some other path
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]