This is an automated email from the ASF dual-hosted git repository.
npr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new ea314a3f8d GH-41358: [R] Support join "na_matches" argument (#41372)
ea314a3f8d is described below
commit ea314a3f8d9d4446836aa999b66659c07421f7a4
Author: Neal Richardson <[email protected]>
AuthorDate: Fri Apr 26 18:32:32 2024 -0400
GH-41358: [R] Support join "na_matches" argument (#41372)
### Rationale for this change
Noticed in #41350, I made #41358 to implement this in C++, but it turns
out the option was there, just buried a bit.
### What changes are included in this PR?
`na_matches` is mapped through to the `key_cmp` field in
`HashJoinNodeOptions`. Acero supports having a different value for this
for each of the join keys, but dplyr does not, so I kept it constant for
all key columns to match the dplyr behavior.
### Are these changes tested?
Yes
### Are there any user-facing changes?
Yes
* GitHub Issue: #41358
---
r/NEWS.md | 1 +
r/R/arrow-package.R | 12 ++++++------
r/R/arrowExports.R | 4 ++--
r/R/dplyr-funcs-doc.R | 12 ++++++------
r/R/dplyr-join.R | 8 +++++---
r/R/query-engine.R | 8 +++++---
r/man/acero.Rd | 12 ++++++------
r/src/arrowExports.cpp | 11 ++++++-----
r/src/compute-exec.cpp | 18 +++++++++++++-----
r/tests/testthat/test-dplyr-join.R | 32 ++++++++++++++++++++++++++++++++
10 files changed, 82 insertions(+), 36 deletions(-)
diff --git a/r/NEWS.md b/r/NEWS.md
index 4ed9f28a28..05f934dac6 100644
--- a/r/NEWS.md
+++ b/r/NEWS.md
@@ -21,6 +21,7 @@
* R functions that users write that use functions that Arrow supports in
dataset queries now can be used in queries too. Previously, only functions that
used arithmetic operators worked. For example, `time_hours <- function(mins)
mins / 60` worked, but `time_hours_rounded <- function(mins) round(mins / 60)`
did not; now both work. These are automatic translations rather than true
user-defined functions (UDFs); for UDFs, see `register_scalar_function()`.
(#41223)
* `summarize()` supports more complex expressions, and correctly handles cases
where column names are reused in expressions.
+* The `na_matches` argument to the `dplyr::*_join()` functions is now
supported. This argument controls whether `NA` values are considered equal when
joining. (#41358)
# arrow 16.0.0
diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R
index f6977e6262..7087a40c49 100644
--- a/r/R/arrow-package.R
+++ b/r/R/arrow-package.R
@@ -66,12 +66,12 @@ supported_dplyr_methods <- list(
compute = NULL,
collapse = NULL,
distinct = "`.keep_all = TRUE` not supported",
- left_join = "the `copy` and `na_matches` arguments are ignored",
- right_join = "the `copy` and `na_matches` arguments are ignored",
- inner_join = "the `copy` and `na_matches` arguments are ignored",
- full_join = "the `copy` and `na_matches` arguments are ignored",
- semi_join = "the `copy` and `na_matches` arguments are ignored",
- anti_join = "the `copy` and `na_matches` arguments are ignored",
+ left_join = "the `copy` argument is ignored",
+ right_join = "the `copy` argument is ignored",
+ inner_join = "the `copy` argument is ignored",
+ full_join = "the `copy` argument is ignored",
+ semi_join = "the `copy` argument is ignored",
+ anti_join = "the `copy` argument is ignored",
count = NULL,
tally = NULL,
rename_with = NULL,
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index 752d3a266b..62e2182ffc 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -484,8 +484,8 @@ ExecNode_Aggregate <- function(input, options, key_names) {
.Call(`_arrow_ExecNode_Aggregate`, input, options, key_names)
}
-ExecNode_Join <- function(input, join_type, right_data, left_keys, right_keys,
left_output, right_output, output_suffix_for_left, output_suffix_for_right) {
- .Call(`_arrow_ExecNode_Join`, input, join_type, right_data, left_keys,
right_keys, left_output, right_output, output_suffix_for_left,
output_suffix_for_right)
+ExecNode_Join <- function(input, join_type, right_data, left_keys, right_keys,
left_output, right_output, output_suffix_for_left, output_suffix_for_right,
na_matches) {
+ .Call(`_arrow_ExecNode_Join`, input, join_type, right_data, left_keys,
right_keys, left_output, right_output, output_suffix_for_left,
output_suffix_for_right, na_matches)
}
ExecNode_Union <- function(input, right_data) {
diff --git a/r/R/dplyr-funcs-doc.R b/r/R/dplyr-funcs-doc.R
index 2042f80014..fda77bca83 100644
--- a/r/R/dplyr-funcs-doc.R
+++ b/r/R/dplyr-funcs-doc.R
@@ -36,7 +36,7 @@
#' which returns an `arrow` [Table], or `collect()`, which pulls the resulting
#' Table into an R `tibble`.
#'
-#' * [`anti_join()`][dplyr::anti_join()]: the `copy` and `na_matches`
arguments are ignored
+#' * [`anti_join()`][dplyr::anti_join()]: the `copy` argument is ignored
#' * [`arrange()`][dplyr::arrange()]
#' * [`collapse()`][dplyr::collapse()]
#' * [`collect()`][dplyr::collect()]
@@ -45,22 +45,22 @@
#' * [`distinct()`][dplyr::distinct()]: `.keep_all = TRUE` not supported
#' * [`explain()`][dplyr::explain()]
#' * [`filter()`][dplyr::filter()]
-#' * [`full_join()`][dplyr::full_join()]: the `copy` and `na_matches`
arguments are ignored
+#' * [`full_join()`][dplyr::full_join()]: the `copy` argument is ignored
#' * [`glimpse()`][dplyr::glimpse()]
#' * [`group_by()`][dplyr::group_by()]
#' * [`group_by_drop_default()`][dplyr::group_by_drop_default()]
#' * [`group_vars()`][dplyr::group_vars()]
#' * [`groups()`][dplyr::groups()]
-#' * [`inner_join()`][dplyr::inner_join()]: the `copy` and `na_matches`
arguments are ignored
-#' * [`left_join()`][dplyr::left_join()]: the `copy` and `na_matches`
arguments are ignored
+#' * [`inner_join()`][dplyr::inner_join()]: the `copy` argument is ignored
+#' * [`left_join()`][dplyr::left_join()]: the `copy` argument is ignored
#' * [`mutate()`][dplyr::mutate()]: window functions (e.g. things that require
aggregation within groups) not currently supported
#' * [`pull()`][dplyr::pull()]: the `name` argument is not supported; returns
an R vector by default but this behavior is deprecated and will return an Arrow
[ChunkedArray] in a future release. Provide `as_vector = TRUE/FALSE` to control
this behavior, or set `options(arrow.pull_as_vector)` globally.
#' * [`relocate()`][dplyr::relocate()]
#' * [`rename()`][dplyr::rename()]
#' * [`rename_with()`][dplyr::rename_with()]
-#' * [`right_join()`][dplyr::right_join()]: the `copy` and `na_matches`
arguments are ignored
+#' * [`right_join()`][dplyr::right_join()]: the `copy` argument is ignored
#' * [`select()`][dplyr::select()]
-#' * [`semi_join()`][dplyr::semi_join()]: the `copy` and `na_matches`
arguments are ignored
+#' * [`semi_join()`][dplyr::semi_join()]: the `copy` argument is ignored
#' * [`show_query()`][dplyr::show_query()]
#' * [`slice_head()`][dplyr::slice_head()]: slicing within groups not
supported; Arrow datasets do not have row order, so head is non-deterministic;
`prop` only supported on queries where `nrow()` is knowable without evaluating
#' * [`slice_max()`][dplyr::slice_max()]: slicing within groups not supported;
`with_ties = TRUE` (dplyr default) is not supported; `prop` only supported on
queries where `nrow()` is knowable without evaluating
diff --git a/r/R/dplyr-join.R b/r/R/dplyr-join.R
index 39237f574b..e76e041a54 100644
--- a/r/R/dplyr-join.R
+++ b/r/R/dplyr-join.R
@@ -25,14 +25,15 @@ do_join <- function(x,
suffix = c(".x", ".y"),
...,
keep = FALSE,
- na_matches,
+ na_matches = c("na", "never"),
join_type) {
# TODO: handle `copy` arg: ignore?
- # TODO: handle `na_matches` arg
x <- as_adq(x)
y <- as_adq(y)
by <- handle_join_by(by, x, y)
+ na_matches <- match.arg(na_matches)
+
# For outer joins, we need to output the join keys on both sides so we
# can coalesce them afterwards.
left_output <- if (!keep && join_type == "RIGHT_OUTER") {
@@ -54,7 +55,8 @@ do_join <- function(x,
left_output = left_output,
right_output = right_output,
suffix = suffix,
- keep = keep
+ keep = keep,
+ na_matches = na_matches == "na"
)
collapse.arrow_dplyr_query(x)
}
diff --git a/r/R/query-engine.R b/r/R/query-engine.R
index 0f8a84f9b8..fb48d790fd 100644
--- a/r/R/query-engine.R
+++ b/r/R/query-engine.R
@@ -148,7 +148,8 @@ ExecPlan <- R6Class("ExecPlan",
left_output = .data$join$left_output,
right_output = .data$join$right_output,
left_suffix = .data$join$suffix[[1]],
- right_suffix = .data$join$suffix[[2]]
+ right_suffix = .data$join$suffix[[2]],
+ na_matches = .data$join$na_matches
)
}
@@ -307,7 +308,7 @@ ExecNode <- R6Class("ExecNode",
out$extras$source_schema$metadata[["r"]]$attributes <- NULL
out
},
- Join = function(type, right_node, by, left_output, right_output,
left_suffix, right_suffix) {
+ Join = function(type, right_node, by, left_output, right_output,
left_suffix, right_suffix, na_matches = TRUE) {
self$preserve_extras(
ExecNode_Join(
self,
@@ -318,7 +319,8 @@ ExecNode <- R6Class("ExecNode",
left_output = left_output,
right_output = right_output,
output_suffix_for_left = left_suffix,
- output_suffix_for_right = right_suffix
+ output_suffix_for_right = right_suffix,
+ na_matches = na_matches
)
)
},
diff --git a/r/man/acero.Rd b/r/man/acero.Rd
index 365795d9fc..ca51ef5633 100644
--- a/r/man/acero.Rd
+++ b/r/man/acero.Rd
@@ -23,7 +23,7 @@ the query on the data. To run the query, call either
\code{compute()},
which returns an \code{arrow} \link{Table}, or \code{collect()}, which pulls
the resulting
Table into an R \code{tibble}.
\itemize{
-\item \code{\link[dplyr:filter-joins]{anti_join()}}: the \code{copy} and
\code{na_matches} arguments are ignored
+\item \code{\link[dplyr:filter-joins]{anti_join()}}: the \code{copy} argument
is ignored
\item \code{\link[dplyr:arrange]{arrange()}}
\item \code{\link[dplyr:compute]{collapse()}}
\item \code{\link[dplyr:compute]{collect()}}
@@ -32,22 +32,22 @@ Table into an R \code{tibble}.
\item \code{\link[dplyr:distinct]{distinct()}}: \code{.keep_all = TRUE} not
supported
\item \code{\link[dplyr:explain]{explain()}}
\item \code{\link[dplyr:filter]{filter()}}
-\item \code{\link[dplyr:mutate-joins]{full_join()}}: the \code{copy} and
\code{na_matches} arguments are ignored
+\item \code{\link[dplyr:mutate-joins]{full_join()}}: the \code{copy} argument
is ignored
\item \code{\link[dplyr:glimpse]{glimpse()}}
\item \code{\link[dplyr:group_by]{group_by()}}
\item \code{\link[dplyr:group_by_drop_default]{group_by_drop_default()}}
\item \code{\link[dplyr:group_data]{group_vars()}}
\item \code{\link[dplyr:group_data]{groups()}}
-\item \code{\link[dplyr:mutate-joins]{inner_join()}}: the \code{copy} and
\code{na_matches} arguments are ignored
-\item \code{\link[dplyr:mutate-joins]{left_join()}}: the \code{copy} and
\code{na_matches} arguments are ignored
+\item \code{\link[dplyr:mutate-joins]{inner_join()}}: the \code{copy} argument
is ignored
+\item \code{\link[dplyr:mutate-joins]{left_join()}}: the \code{copy} argument
is ignored
\item \code{\link[dplyr:mutate]{mutate()}}: window functions (e.g. things that
require aggregation within groups) not currently supported
\item \code{\link[dplyr:pull]{pull()}}: the \code{name} argument is not
supported; returns an R vector by default but this behavior is deprecated and
will return an Arrow \link{ChunkedArray} in a future release. Provide
\code{as_vector = TRUE/FALSE} to control this behavior, or set
\code{options(arrow.pull_as_vector)} globally.
\item \code{\link[dplyr:relocate]{relocate()}}
\item \code{\link[dplyr:rename]{rename()}}
\item \code{\link[dplyr:rename]{rename_with()}}
-\item \code{\link[dplyr:mutate-joins]{right_join()}}: the \code{copy} and
\code{na_matches} arguments are ignored
+\item \code{\link[dplyr:mutate-joins]{right_join()}}: the \code{copy} argument
is ignored
\item \code{\link[dplyr:select]{select()}}
-\item \code{\link[dplyr:filter-joins]{semi_join()}}: the \code{copy} and
\code{na_matches} arguments are ignored
+\item \code{\link[dplyr:filter-joins]{semi_join()}}: the \code{copy} argument
is ignored
\item \code{\link[dplyr:explain]{show_query()}}
\item \code{\link[dplyr:slice]{slice_head()}}: slicing within groups not
supported; Arrow datasets do not have row order, so head is non-deterministic;
\code{prop} only supported on queries where \code{nrow()} is knowable without
evaluating
\item \code{\link[dplyr:slice]{slice_max()}}: slicing within groups not
supported; \code{with_ties = TRUE} (dplyr default) is not supported;
\code{prop} only supported on queries where \code{nrow()} is knowable without
evaluating
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index a4c4b614d6..d5aec50219 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -1163,8 +1163,8 @@ extern "C" SEXP _arrow_ExecNode_Aggregate(SEXP
input_sexp, SEXP options_sexp, SE
// compute-exec.cpp
#if defined(ARROW_R_WITH_ACERO)
-std::shared_ptr<acero::ExecNode> ExecNode_Join(const
std::shared_ptr<acero::ExecNode>& input, acero::JoinType join_type, const
std::shared_ptr<acero::ExecNode>& right_data, std::vector<std::string>
left_keys, std::vector<std::string> right_keys, std::vector<std::string>
left_output, std::vector<std::string> right_output, std::string
output_suffix_for_left, std::string output_suffix_for_right);
-extern "C" SEXP _arrow_ExecNode_Join(SEXP input_sexp, SEXP join_type_sexp,
SEXP right_data_sexp, SEXP left_keys_sexp, SEXP right_keys_sexp, SEXP
left_output_sexp, SEXP right_output_sexp, SEXP output_suffix_for_left_sexp,
SEXP output_suffix_for_right_sexp){
+std::shared_ptr<acero::ExecNode> ExecNode_Join(const
std::shared_ptr<acero::ExecNode>& input, acero::JoinType join_type, const
std::shared_ptr<acero::ExecNode>& right_data, std::vector<std::string>
left_keys, std::vector<std::string> right_keys, std::vector<std::string>
left_output, std::vector<std::string> right_output, std::string
output_suffix_for_left, std::string output_suffix_for_right, bool na_matches);
+extern "C" SEXP _arrow_ExecNode_Join(SEXP input_sexp, SEXP join_type_sexp,
SEXP right_data_sexp, SEXP left_keys_sexp, SEXP right_keys_sexp, SEXP
left_output_sexp, SEXP right_output_sexp, SEXP output_suffix_for_left_sexp,
SEXP output_suffix_for_right_sexp, SEXP na_matches_sexp){
BEGIN_CPP11
arrow::r::Input<const std::shared_ptr<acero::ExecNode>&>::type
input(input_sexp);
arrow::r::Input<acero::JoinType>::type join_type(join_type_sexp);
@@ -1175,11 +1175,12 @@ BEGIN_CPP11
arrow::r::Input<std::vector<std::string>>::type
right_output(right_output_sexp);
arrow::r::Input<std::string>::type
output_suffix_for_left(output_suffix_for_left_sexp);
arrow::r::Input<std::string>::type
output_suffix_for_right(output_suffix_for_right_sexp);
- return cpp11::as_sexp(ExecNode_Join(input, join_type, right_data,
left_keys, right_keys, left_output, right_output, output_suffix_for_left,
output_suffix_for_right));
+ arrow::r::Input<bool>::type na_matches(na_matches_sexp);
+ return cpp11::as_sexp(ExecNode_Join(input, join_type, right_data,
left_keys, right_keys, left_output, right_output, output_suffix_for_left,
output_suffix_for_right, na_matches));
END_CPP11
}
#else
-extern "C" SEXP _arrow_ExecNode_Join(SEXP input_sexp, SEXP join_type_sexp,
SEXP right_data_sexp, SEXP left_keys_sexp, SEXP right_keys_sexp, SEXP
left_output_sexp, SEXP right_output_sexp, SEXP output_suffix_for_left_sexp,
SEXP output_suffix_for_right_sexp){
+extern "C" SEXP _arrow_ExecNode_Join(SEXP input_sexp, SEXP join_type_sexp,
SEXP right_data_sexp, SEXP left_keys_sexp, SEXP right_keys_sexp, SEXP
left_output_sexp, SEXP right_output_sexp, SEXP output_suffix_for_left_sexp,
SEXP output_suffix_for_right_sexp, SEXP na_matches_sexp){
Rf_error("Cannot call ExecNode_Join(). See
https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow
C++ libraries. ");
}
#endif
@@ -5790,7 +5791,7 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_ExecNode_Filter", (DL_FUNC) &_arrow_ExecNode_Filter,
2},
{ "_arrow_ExecNode_Project", (DL_FUNC)
&_arrow_ExecNode_Project, 3},
{ "_arrow_ExecNode_Aggregate", (DL_FUNC)
&_arrow_ExecNode_Aggregate, 3},
- { "_arrow_ExecNode_Join", (DL_FUNC) &_arrow_ExecNode_Join, 9},
+ { "_arrow_ExecNode_Join", (DL_FUNC) &_arrow_ExecNode_Join, 10},
{ "_arrow_ExecNode_Union", (DL_FUNC) &_arrow_ExecNode_Union,
2},
{ "_arrow_ExecNode_Fetch", (DL_FUNC) &_arrow_ExecNode_Fetch,
3},
{ "_arrow_ExecNode_OrderBy", (DL_FUNC)
&_arrow_ExecNode_OrderBy, 2},
diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp
index e0b3c62c47..d0c50315c2 100644
--- a/r/src/compute-exec.cpp
+++ b/r/src/compute-exec.cpp
@@ -411,10 +411,17 @@ std::shared_ptr<acero::ExecNode> ExecNode_Join(
const std::shared_ptr<acero::ExecNode>& right_data,
std::vector<std::string> left_keys, std::vector<std::string> right_keys,
std::vector<std::string> left_output, std::vector<std::string>
right_output,
- std::string output_suffix_for_left, std::string output_suffix_for_right) {
+ std::string output_suffix_for_left, std::string output_suffix_for_right,
+ bool na_matches) {
std::vector<arrow::FieldRef> left_refs, right_refs, left_out_refs,
right_out_refs;
+ std::vector<acero::JoinKeyCmp> key_cmps;
for (auto&& name : left_keys) {
left_refs.emplace_back(std::move(name));
+ // Populate key_cmps in this loop, one for each key
+ // Note that Acero supports having different values for each key, but dplyr
+ // only supports one value for all keys, so we're only going to support
that
+ // for now.
+ key_cmps.emplace_back(na_matches ? acero::JoinKeyCmp::IS :
acero::JoinKeyCmp::EQ);
}
for (auto&& name : right_keys) {
right_refs.emplace_back(std::move(name));
@@ -434,10 +441,11 @@ std::shared_ptr<acero::ExecNode> ExecNode_Join(
return MakeExecNodeOrStop(
"hashjoin", input->plan(), {input.get(), right_data.get()},
- acero::HashJoinNodeOptions{
- join_type, std::move(left_refs), std::move(right_refs),
- std::move(left_out_refs), std::move(right_out_refs),
compute::literal(true),
- std::move(output_suffix_for_left),
std::move(output_suffix_for_right)});
+ acero::HashJoinNodeOptions{join_type, std::move(left_refs),
std::move(right_refs),
+ std::move(left_out_refs),
std::move(right_out_refs),
+ std::move(key_cmps), compute::literal(true),
+ std::move(output_suffix_for_left),
+ std::move(output_suffix_for_right)});
}
// [[acero::export]]
diff --git a/r/tests/testthat/test-dplyr-join.R
b/r/tests/testthat/test-dplyr-join.R
index e3e1e98cfc..9a1c8b7b80 100644
--- a/r/tests/testthat/test-dplyr-join.R
+++ b/r/tests/testthat/test-dplyr-join.R
@@ -441,3 +441,35 @@ test_that("full joins handle keep", {
small_dataset_df
)
})
+
+left <- tibble::tibble(
+ x = c(1, NA, 3),
+)
+right <- tibble::tibble(
+ x = c(1, NA, 3),
+ y = c("a", "b", "c")
+)
+na_matches_na <- right
+na_matches_never <- tibble::tibble(
+ x = c(1, NA, 3),
+ y = c("a", NA, "c")
+)
+test_that("na_matches argument to join: na (default)", {
+ expect_equal(
+ arrow_table(left) %>%
+ left_join(right, by = "x", na_matches = "na") %>%
+ arrange(x) %>%
+ collect(),
+ na_matches_na %>% arrange(x)
+ )
+})
+
+test_that("na_matches argument to join: never", {
+ expect_equal(
+ arrow_table(left) %>%
+ left_join(right, by = "x", na_matches = "never") %>%
+ arrange(x) %>%
+ collect(),
+ na_matches_never %>% arrange(x)
+ )
+})