This is an automated email from the ASF dual-hosted git repository.
npr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 838687178f ARROW-15260: [R] open_dataset - add file_name as column
(#12826)
838687178f is described below
commit 838687178fda7f82e31668f502e2f94071ce8077
Author: Nic Crane <[email protected]>
AuthorDate: Wed Aug 10 01:19:40 2022 +0100
ARROW-15260: [R] open_dataset - add file_name as column (#12826)
Authored-by: Nic Crane <[email protected]>
Signed-off-by: Neal Richardson <[email protected]>
---
r/DESCRIPTION | 1 +
r/R/dataset.R | 1 +
r/R/dplyr-collect.R | 11 +++++
r/R/dplyr-funcs-augmented.R | 22 ++++++++++
r/R/dplyr-funcs.R | 1 +
r/R/dplyr.R | 3 ++
r/R/util.R | 31 +++++++++++++-
r/src/compute-exec.cpp | 8 ++--
r/tests/testthat/test-dataset.R | 94 ++++++++++++++++++++++++++++++++++++++++-
9 files changed, 164 insertions(+), 8 deletions(-)
diff --git a/r/DESCRIPTION b/r/DESCRIPTION
index 308a7ec3fa..95c1405869 100644
--- a/r/DESCRIPTION
+++ b/r/DESCRIPTION
@@ -98,6 +98,7 @@ Collate:
'dplyr-distinct.R'
'dplyr-eval.R'
'dplyr-filter.R'
+ 'dplyr-funcs-augmented.R'
'dplyr-funcs-conditional.R'
'dplyr-funcs-datetime.R'
'dplyr-funcs-math.R'
diff --git a/r/R/dataset.R b/r/R/dataset.R
index 12765fbfc0..d86962cc1d 100644
--- a/r/R/dataset.R
+++ b/r/R/dataset.R
@@ -224,6 +224,7 @@ open_dataset <- function(sources,
# and not handle_parquet_io_error()
error = function(e, call = caller_env(n = 4)) {
handle_parquet_io_error(e, format, call)
+ abort(conditionMessage(e), call = call)
}
)
}
diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R
index 3e83475a8c..8049e46eb5 100644
--- a/r/R/dplyr-collect.R
+++ b/r/R/dplyr-collect.R
@@ -25,6 +25,8 @@ collect.arrow_dplyr_query <- function(x, as_data_frame =
TRUE, ...) {
# and not handle_csv_read_error()
error = function(e, call = caller_env(n = 4)) {
handle_csv_read_error(e, x$.data$schema, call)
+ handle_augmented_field_misuse(e, call)
+ abort(conditionMessage(e), call = call)
}
)
@@ -104,10 +106,18 @@ add_suffix <- function(fields, common_cols, suffix) {
}
implicit_schema <- function(.data) {
+ # Get the source data schema so that we can evaluate expressions to determine
+ # the output schema. Note that we don't use source_data() because we only
+ # want to go one level up (where we may have called implicit_schema() before)
.data <- ensure_group_vars(.data)
old_schm <- .data$.data$schema
+ # Add in any augmented fields that may exist in the query but not in the
+ # real data, in case we have FieldRefs to them
+ old_schm[["__filename"]] <- string()
if (is.null(.data$aggregations)) {
+ # .data$selected_columns is a named list of Expressions (FieldRefs or
+ # something more complex). Bind them in order to determine their output
type
new_fields <- map(.data$selected_columns, ~ .$type(old_schm))
if (!is.null(.data$join) && !(.data$join$type %in% JoinType[1:4])) {
# Add cols from right side, except for semi/anti joins
@@ -128,6 +138,7 @@ implicit_schema <- function(.data) {
new_fields <- c(left_fields, right_fields)
}
} else {
+ # The output schema is based on the aggregations and any group_by vars
new_fields <- map(summarize_projection(.data), ~ .$type(old_schm))
# * Put group_by_vars first (this can't be done by summarize,
# they have to be last per the aggregate node signature,
diff --git a/r/R/dplyr-funcs-augmented.R b/r/R/dplyr-funcs-augmented.R
new file mode 100644
index 0000000000..6e751d49f6
--- /dev/null
+++ b/r/R/dplyr-funcs-augmented.R
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+register_bindings_augmented <- function() {
+ register_binding("add_filename", function() {
+ Expression$field_ref("__filename")
+ })
+}
diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R
index c1dcdd1774..4dadff54b4 100644
--- a/r/R/dplyr-funcs.R
+++ b/r/R/dplyr-funcs.R
@@ -151,6 +151,7 @@ create_binding_cache <- function() {
register_bindings_math()
register_bindings_string()
register_bindings_type()
+ register_bindings_augmented()
# We only create the cache for nse_funcs and not agg_funcs
.cache$functions <- c(as.list(nse_funcs), arrow_funcs)
diff --git a/r/R/dplyr.R b/r/R/dplyr.R
index dd6340c4f5..dffe269199 100644
--- a/r/R/dplyr.R
+++ b/r/R/dplyr.R
@@ -110,6 +110,9 @@ make_field_refs <- function(field_names) {
#' @export
print.arrow_dplyr_query <- function(x, ...) {
schm <- x$.data$schema
+ # If we are using this augmented field, it won't be in the schema
+ schm[["__filename"]] <- string()
+
types <- map_chr(x$selected_columns, function(expr) {
name <- expr$field_name
if (nzchar(name)) {
diff --git a/r/R/util.R b/r/R/util.R
index 55ff29db73..eef69d0244 100644
--- a/r/R/util.R
+++ b/r/R/util.R
@@ -134,6 +134,10 @@ read_compressed_error <- function(e) {
stop(e)
}
+# This function was refactored in ARROW-15260 to only raise an error if
+# the appropriate string was found and so errors must be raised manually after
+# calling this if matching error not found
+# TODO: Refactor as part of ARROW-17355 to prevent potential missed errors
handle_parquet_io_error <- function(e, format, call) {
msg <- conditionMessage(e)
if (grepl("Parquet magic bytes not found in footer", msg) && length(format)
> 1 && is_character(format)) {
@@ -143,8 +147,8 @@ handle_parquet_io_error <- function(e, format, call) {
msg,
i = "Did you mean to specify a 'format' other than the default
(parquet)?"
)
+ abort(msg, call = call)
}
- abort(msg, call = call)
}
as_writable_table <- function(x) {
@@ -205,6 +209,10 @@ repeat_value_as_array <- function(object, n) {
return(Scalar$create(object)$as_array(n))
}
+# This function was refactored in ARROW-15260 to only raise an error if
+# the appropriate string was found and so errors must be raised manually after
+# calling this if matching error not found
+# TODO: Refactor as part of ARROW-17355 to prevent potential missed errors
handle_csv_read_error <- function(e, schema, call) {
msg <- conditionMessage(e)
@@ -217,8 +225,27 @@ handle_csv_read_error <- function(e, schema, call) {
"header being read in as data."
)
)
+ abort(msg, call = call)
+ }
+}
+
+# This function only raises an error if
+# the appropriate string was found and so errors must be raised manually after
+# calling this if matching error not found
+# TODO: Refactor as part of ARROW-17355 to prevent potential missed errors
+handle_augmented_field_misuse <- function(e, call) {
+ msg <- conditionMessage(e)
+ if (grepl("No match for FieldRef.Name(__filename)", msg, fixed = TRUE)) {
+ msg <- c(
+ msg,
+ i = paste(
+ "`add_filename()` or use of the `__filename` augmented field can only",
+ "be used with with Dataset objects, and can only be added before
doing",
+ "an aggregation or a join."
+ )
+ )
+ abort(msg, call = call)
}
- abort(msg, call = call)
}
is_compressed <- function(compression) {
diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp
index 91d646f0a3..f9183a3a10 100644
--- a/r/src/compute-exec.cpp
+++ b/r/src/compute-exec.cpp
@@ -222,8 +222,7 @@ std::shared_ptr<compute::ExecNode> ExecNode_Scan(
options->dataset_schema = dataset->schema();
- // ScanNode needs the filter to do predicate pushdown and skip partitions
- options->filter = ValueOrStop(filter->Bind(*dataset->schema()));
+ options->filter = *filter;
// ScanNode needs to know which fields to materialize (and which are
unnecessary)
std::vector<compute::Expression> exprs;
@@ -232,9 +231,8 @@ std::shared_ptr<compute::ExecNode> ExecNode_Scan(
}
options->projection =
- ValueOrStop(call("make_struct", std::move(exprs),
-
compute::MakeStructOptions{std::move(materialized_field_names)})
- .Bind(*dataset->schema()));
+ call("make_struct", std::move(exprs),
+ compute::MakeStructOptions{std::move(materialized_field_names)});
return MakeExecNodeOrStop("scan", plan.get(), {},
ds::ScanNodeOptions{dataset, options});
diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R
index d43bb492d0..d9512ef94f 100644
--- a/r/tests/testthat/test-dataset.R
+++ b/r/tests/testthat/test-dataset.R
@@ -1131,7 +1131,6 @@ test_that("dataset to C-interface to arrow_dplyr_query
with proj/filter", {
delete_arrow_array_stream(stream_ptr)
})
-
test_that("Filter parquet dataset with is.na ARROW-15312", {
ds_path <- make_temp_dir()
@@ -1349,3 +1348,96 @@ test_that("FileSystemFactoryOptions input validation", {
fixed = TRUE
)
})
+
+test_that("can add in augmented fields", {
+ ds <- open_dataset(hive_dir)
+
+ observed <- ds %>%
+ mutate(file_name = add_filename()) %>%
+ collect()
+
+ expect_named(
+ observed,
+ c("int", "dbl", "lgl", "chr", "fct", "ts", "group", "other", "file_name")
+ )
+
+ expect_equal(
+ sort(unique(observed$file_name)),
+ list.files(hive_dir, full.names = TRUE, recursive = TRUE)
+ )
+
+ error_regex <- paste(
+ "`add_filename()` or use of the `__filename` augmented field can only",
+ "be used with with Dataset objects, and can only be added before doing",
+ "an aggregation or a join."
+ )
+
+ # errors appropriately with ArrowTabular objects
+ expect_error(
+ arrow_table(mtcars) %>%
+ mutate(file = add_filename()) %>%
+ collect(),
+ regexp = error_regex,
+ fixed = TRUE
+ )
+
+ # errors appropriately with aggregation
+ expect_error(
+ ds %>%
+ summarise(max_int = max(int)) %>%
+ mutate(file_name = add_filename()) %>%
+ collect(),
+ regexp = error_regex,
+ fixed = TRUE
+ )
+
+ # joins to tables
+ another_table <- select(example_data, int, dbl2)
+ expect_error(
+ ds %>%
+ left_join(another_table, by = "int") %>%
+ mutate(file = add_filename()) %>%
+ collect(),
+ regexp = error_regex,
+ fixed = TRUE
+ )
+
+ # and on joins to datasets
+ another_dataset <- write_dataset(another_table, "another_dataset")
+ expect_error(
+ ds %>%
+ left_join(open_dataset("another_dataset"), by = "int") %>%
+ mutate(file = add_filename()) %>%
+ collect(),
+ regexp = error_regex,
+ fixed = TRUE
+ )
+
+ # this hits the implicit_schema path by joining afterwards
+ join_after <- ds %>%
+ mutate(file = add_filename()) %>%
+ left_join(open_dataset("another_dataset"), by = "int") %>%
+ collect()
+
+ expect_named(
+ join_after,
+ c("int", "dbl", "lgl", "chr", "fct", "ts", "group", "other", "file",
"dbl2")
+ )
+
+ expect_equal(
+ sort(unique(join_after$file)),
+ list.files(hive_dir, full.names = TRUE, recursive = TRUE)
+ )
+
+ # another test on the explicit_schema path
+ summarise_after <- ds %>%
+ mutate(file = add_filename()) %>%
+ group_by(file) %>%
+ summarise(max_int = max(int)) %>%
+ collect()
+
+ expect_equal(
+ sort(summarise_after$file),
+ list.files(hive_dir, full.names = TRUE, recursive = TRUE)
+ )
+})