This is an automated email from the ASF dual-hosted git repository. npr pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push: new 1d9366f19e GH-18818: [R] Create a field ref to a field in a struct (#19706) 1d9366f19e is described below commit 1d9366f19e4b9846b33cc0c7bd7941cb5f482d74 Author: Neal Richardson <neal.p.richard...@gmail.com> AuthorDate: Wed Jan 18 12:38:06 2023 -0500 GH-18818: [R] Create a field ref to a field in a struct (#19706) This PR implements `$.Expression` and `[[.Expression` methods, such that if the Expression is a FieldRef, it returns a nested FieldRef. This required revising some assumptions in a few places, particularly that if an Expression is a FieldRef, it has a `name`, and that all FieldRefs correspond to a Field in a Schema. In the case where the Expression is not a FieldRef, it will create an Expression call to `struct_field` to extract the field, iff the Expression has a knowable `type`, the [...] Things not done because they weren't needed to get this working: * `Expression$field_ref()` take a vector to construct a nested ref * Method to return vector of nested components of a field ref in R Next steps for future PRs: * Wrap this in [tidyr::unpack()](https://tidyr.tidyverse.org/reference/pack.html) method (but unfortunately, unpack() is not a generic) * https://github.com/apache/arrow/issues/33756 * https://github.com/apache/arrow/issues/33757 * https://github.com/apache/arrow/issues/33760 * Closes: #18818 Authored-by: Neal Richardson <neal.p.richard...@gmail.com> Signed-off-by: Neal Richardson <neal.p.richard...@gmail.com> --- r/NAMESPACE | 3 ++ r/R/arrow-object.R | 2 +- r/R/arrowExports.R | 9 ++++- r/R/expression.R | 55 +++++++++++++++++++++++++++++ r/R/type.R | 3 ++ r/src/arrowExports.cpp | 19 ++++++++++ r/src/compute.cpp | 14 ++++++++ r/src/expression.cpp | 40 +++++++++++++++++++-- r/tests/testthat/test-dplyr-query.R | 70 +++++++++++++++++++++++++++++++++++++ r/tests/testthat/test-expression.R | 26 ++++++++++++++ 10 files changed, 237 insertions(+), 4 deletions(-) diff --git a/r/NAMESPACE b/r/NAMESPACE index 3df107a2d8..3ab828a958 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -2,6 +2,7 @@ S3method("!=",ArrowObject) S3method("$",ArrowTabular) +S3method("$",Expression) S3method("$",Schema) S3method("$",StructArray) S3method("$",SubTreeFileSystem) @@ -14,6 +15,7 @@ S3method("[",Dataset) S3method("[",Schema) S3method("[",arrow_dplyr_query) S3method("[[",ArrowTabular) +S3method("[[",Expression) S3method("[[",Schema) S3method("[[",StructArray) S3method("[[<-",ArrowTabular) @@ -137,6 +139,7 @@ S3method(names,Scanner) S3method(names,ScannerBuilder) S3method(names,Schema) S3method(names,StructArray) +S3method(names,StructType) S3method(names,Table) S3method(names,arrow_dplyr_query) S3method(print,"arrow-enum") diff --git a/r/R/arrow-object.R b/r/R/arrow-object.R index 516f407aaf..5c2cf4691f 100644 --- a/r/R/arrow-object.R +++ b/r/R/arrow-object.R @@ -32,7 +32,7 @@ ArrowObject <- R6Class("ArrowObject", assign(".:xp:.", xp, envir = self) }, class_title = function() { - if (!is.null(self$.class_title)) { + if (".class_title" %in% ls(self, all.names = TRUE)) { # Allow subclasses to override just printing the class name first class_title <- self$.class_title() } else { diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 38f1ecfb97..2eeca24dbd 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -1084,6 +1084,10 @@ compute___expr__call <- function(func_name, argument_list, options) { .Call(`_arrow_compute___expr__call`, func_name, argument_list, options) } +compute___expr__is_field_ref <- function(x) { + .Call(`_arrow_compute___expr__is_field_ref`, x) +} + field_names_in_expression <- function(x) { .Call(`_arrow_field_names_in_expression`, x) } @@ -1096,6 +1100,10 @@ compute___expr__field_ref <- function(name) { .Call(`_arrow_compute___expr__field_ref`, name) } +compute___expr__nested_field_ref <- function(x, name) { + .Call(`_arrow_compute___expr__nested_field_ref`, x, name) +} + compute___expr__scalar <- function(x) { .Call(`_arrow_compute___expr__scalar`, x) } @@ -2087,4 +2095,3 @@ SetIOThreadPoolCapacity <- function(threads) { Array__infer_type <- function(x) { .Call(`_arrow_Array__infer_type`, x) } - diff --git a/r/R/expression.R b/r/R/expression.R index a1163c12a8..8f84b4b31e 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -57,6 +57,9 @@ Expression <- R6Class("Expression", assert_that(!is.null(schema)) compute___expr__type_id(self, schema) }, + is_field_ref = function() { + compute___expr__is_field_ref(self) + }, cast = function(to_type, safe = TRUE, ...) { opts <- cast_options(safe, ...) opts$to_type <- as_type(to_type) @@ -89,7 +92,59 @@ Expression$create <- function(function_name, expr } + +#' @export +`[[.Expression` <- function(x, i, ...) get_nested_field(x, i) + +#' @export +`$.Expression` <- function(x, name, ...) { + assert_that(is.string(name)) + if (name %in% ls(x)) { + get(name, x) + } else { + get_nested_field(x, name) + } +} + +get_nested_field <- function(expr, name) { + if (expr$is_field_ref()) { + # Make a nested field ref + # TODO(#33756): integer (positional) field refs are supported in C++ + assert_that(is.string(name)) + out <- compute___expr__nested_field_ref(expr, name) + } else { + # Use the struct_field kernel if expr is a struct: + expr_type <- expr$type() # errors if no schema set + if (inherits(expr_type, "StructType")) { + # Because we have the type, we can validate that the field exists + if (!(name %in% names(expr_type))) { + stop( + "field '", name, "' not found in ", + expr_type$ToString(), + call. = FALSE + ) + } + out <- Expression$create( + "struct_field", + expr, + options = list(field_ref = Expression$field_ref(name)) + ) + } else { + # TODO(#33757): if expr is list type and name is integer or Expression, + # call list_element + stop( + "Cannot extract a field from an Expression of type ", expr_type$ToString(), + call. = FALSE + ) + } + } + # Schema bookkeeping + out$schema <- expr$schema + out +} + Expression$field_ref <- function(name) { + # TODO(#33756): allow construction of field ref from integer assert_that(is.string(name)) compute___expr__field_ref(name) } diff --git a/r/R/type.R b/r/R/type.R index d1578dd822..bd69311b25 100644 --- a/r/R/type.R +++ b/r/R/type.R @@ -641,6 +641,9 @@ StructType$create <- function(...) struct__(.fields(list(...))) #' @export struct <- StructType$create +#' @export +names.StructType <- function(x) StructType__field_names(x) + ListType <- R6Class("ListType", inherit = NestedType, public = list( diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index b7bda1870f..e918390e26 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -2732,6 +2732,14 @@ BEGIN_CPP11 END_CPP11 } // expression.cpp +bool compute___expr__is_field_ref(const std::shared_ptr<compute::Expression>& x); +extern "C" SEXP _arrow_compute___expr__is_field_ref(SEXP x_sexp){ +BEGIN_CPP11 + arrow::r::Input<const std::shared_ptr<compute::Expression>&>::type x(x_sexp); + return cpp11::as_sexp(compute___expr__is_field_ref(x)); +END_CPP11 +} +// expression.cpp std::vector<std::string> field_names_in_expression(const std::shared_ptr<compute::Expression>& x); extern "C" SEXP _arrow_field_names_in_expression(SEXP x_sexp){ BEGIN_CPP11 @@ -2756,6 +2764,15 @@ BEGIN_CPP11 END_CPP11 } // expression.cpp +std::shared_ptr<compute::Expression> compute___expr__nested_field_ref(const std::shared_ptr<compute::Expression>& x, std::string name); +extern "C" SEXP _arrow_compute___expr__nested_field_ref(SEXP x_sexp, SEXP name_sexp){ +BEGIN_CPP11 + arrow::r::Input<const std::shared_ptr<compute::Expression>&>::type x(x_sexp); + arrow::r::Input<std::string>::type name(name_sexp); + return cpp11::as_sexp(compute___expr__nested_field_ref(x, name)); +END_CPP11 +} +// expression.cpp std::shared_ptr<compute::Expression> compute___expr__scalar(const std::shared_ptr<arrow::Scalar>& x); extern "C" SEXP _arrow_compute___expr__scalar(SEXP x_sexp){ BEGIN_CPP11 @@ -5569,9 +5586,11 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_MapType__keys_sorted", (DL_FUNC) &_arrow_MapType__keys_sorted, 1}, { "_arrow_compute___expr__equals", (DL_FUNC) &_arrow_compute___expr__equals, 2}, { "_arrow_compute___expr__call", (DL_FUNC) &_arrow_compute___expr__call, 3}, + { "_arrow_compute___expr__is_field_ref", (DL_FUNC) &_arrow_compute___expr__is_field_ref, 1}, { "_arrow_field_names_in_expression", (DL_FUNC) &_arrow_field_names_in_expression, 1}, { "_arrow_compute___expr__get_field_ref_name", (DL_FUNC) &_arrow_compute___expr__get_field_ref_name, 1}, { "_arrow_compute___expr__field_ref", (DL_FUNC) &_arrow_compute___expr__field_ref, 1}, + { "_arrow_compute___expr__nested_field_ref", (DL_FUNC) &_arrow_compute___expr__nested_field_ref, 2}, { "_arrow_compute___expr__scalar", (DL_FUNC) &_arrow_compute___expr__scalar, 1}, { "_arrow_compute___expr__ToString", (DL_FUNC) &_arrow_compute___expr__ToString, 1}, { "_arrow_compute___expr__type", (DL_FUNC) &_arrow_compute___expr__type, 2}, diff --git a/r/src/compute.cpp b/r/src/compute.cpp index b4b4c5fdc8..578ce74d05 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -564,6 +564,20 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options( return out; } + if (func_name == "struct_field") { + using Options = arrow::compute::StructFieldOptions; + if (!Rf_isNull(options["indices"])) { + return std::make_shared<Options>( + cpp11::as_cpp<std::vector<int>>(options["indices"])); + } else { + // field_ref + return std::make_shared<Options>( + *cpp11::as_cpp<std::shared_ptr<arrow::compute::Expression>>( + options["field_ref"]) + ->field_ref()); + } + } + return nullptr; } diff --git a/r/src/expression.cpp b/r/src/expression.cpp index a845137e09..d7a511e760 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -46,13 +46,26 @@ std::shared_ptr<compute::Expression> compute___expr__call(std::string func_name, compute::call(std::move(func_name), std::move(arguments), std::move(options_ptr))); } +// [[arrow::export]] +bool compute___expr__is_field_ref(const std::shared_ptr<compute::Expression>& x) { + return x->field_ref() != nullptr; +} + // [[arrow::export]] std::vector<std::string> field_names_in_expression( const std::shared_ptr<compute::Expression>& x) { std::vector<std::string> out; + std::vector<arrow::FieldRef> nested; + auto field_refs = FieldsInExpression(*x); for (auto f : field_refs) { - out.push_back(*f.name()); + if (f.IsNested()) { + // We keep the top-level field name. + nested = *f.nested_refs(); + out.push_back(*nested[0].name()); + } else { + out.push_back(*f.name()); + } } return out; } @@ -61,7 +74,11 @@ std::vector<std::string> field_names_in_expression( std::string compute___expr__get_field_ref_name( const std::shared_ptr<compute::Expression>& x) { if (auto field_ref = x->field_ref()) { - return *field_ref->name(); + // Exclude nested field refs because we only use this to determine if we have simple + // field refs + if (!field_ref->IsNested()) { + return *field_ref->name(); + } } return ""; } @@ -71,6 +88,25 @@ std::shared_ptr<compute::Expression> compute___expr__field_ref(std::string name) return std::make_shared<compute::Expression>(compute::field_ref(std::move(name))); } +// [[arrow::export]] +std::shared_ptr<compute::Expression> compute___expr__nested_field_ref( + const std::shared_ptr<compute::Expression>& x, std::string name) { + if (auto field_ref = x->field_ref()) { + std::vector<arrow::FieldRef> ref_vec; + if (field_ref->IsNested()) { + ref_vec = *field_ref->nested_refs(); + } else { + // There's just one + ref_vec.push_back(*field_ref); + } + // Add the new ref + ref_vec.push_back(arrow::FieldRef(std::move(name))); + return std::make_shared<compute::Expression>(compute::field_ref(std::move(ref_vec))); + } else { + cpp11::stop("'x' must be a FieldRef Expression"); + } +} + // [[arrow::export]] std::shared_ptr<compute::Expression> compute___expr__scalar( const std::shared_ptr<arrow::Scalar>& x) { diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index ee11cd6678..a91c0b6ccb 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -714,3 +714,73 @@ test_that("Scalars in expressions match the type of the field, if possible", { collect() expect_equal(result$tpc_h_1, result$as_dbl) }) + +test_that("Can use nested field refs", { + nested_data <- tibble(int = 1:5, df_col = tibble(a = 6:10, b = 11:15)) + + compare_dplyr_binding( + .input %>% + mutate( + nested = df_col$a, + times2 = df_col$a * 2 + ) %>% + filter(nested > 7) %>% + collect(), + nested_data + ) + + compare_dplyr_binding( + .input %>% + mutate( + nested = df_col$a, + times2 = df_col$a * 2 + ) %>% + filter(nested > 7) %>% + summarize(sum(times2)) %>% + collect(), + nested_data + ) + + # Now with Dataset: make sure column pushdown in ScanNode works + expect_equal( + nested_data %>% + InMemoryDataset$create() %>% + mutate( + nested = df_col$a, + times2 = df_col$a * 2 + ) %>% + filter(nested > 7) %>% + collect(), + nested_data %>% + mutate( + nested = df_col$a, + times2 = df_col$a * 2 + ) %>% + filter(nested > 7) + ) +}) + +test_that("Use struct_field for $ on non-field-ref", { + compare_dplyr_binding( + .input %>% + mutate( + df_col = tibble(i = int, d = dbl) + ) %>% + transmute( + int2 = df_col$i, + dbl2 = df_col$d + ) %>% + collect(), + example_data + ) +}) + +test_that("nested field ref error handling", { + expect_error( + example_data %>% + arrow_table() %>% + mutate(x = int$nested) %>% + compute(), + "No match" + ) +}) diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index 2b6039b04c..ccb09b9eb0 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -76,6 +76,15 @@ test_that("Field reference expression schemas and types", { expect_equal(x$type(), int32()) }) +test_that("Nested field refs", { + x <- Expression$field_ref("x") + nested <- x$y + expect_r6_class(nested, "Expression") + expect_r6_class(x[["y"]], "Expression") + expect_r6_class(nested$z, "Expression") + expect_error(Expression$scalar(42L)$y, "Cannot extract a field from an Expression of type int32") +}) + test_that("Scalar expression schemas and types", { # type() works on scalars without setting the schema expect_equal( @@ -127,3 +136,20 @@ test_that("Expression schemas and types", { int32() ) }) + +test_that("Nested field ref types", { + nested <- Expression$field_ref("x")$y + schm <- schema(x = struct(y = int32(), z = double())) + expect_equal(nested$type(schm), int32()) + # implicit casting and schema propagation + x <- Expression$field_ref("x") + x$schema <- schm + expect_equal((x$y * 2)$type(), int32()) +}) + +test_that("Nested field from a non-field-ref (struct_field kernel)", { + x <- Expression$scalar(data.frame(a = 1, b = "two")) + expect_true(inherits(x$a, "Expression")) + expect_equal(x$a$type(), float64()) + expect_error(x$c, "field 'c' not found in struct<a: double, b: string>") +})