This is an automated email from the ASF dual-hosted git repository.
thisisnic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 1a038adad8 GH-32282: [R] Update case_when() binding to match changes
in dplyr (#35502)
1a038adad8 is described below
commit 1a038adad85be6f7e6949cc700dcb9b211feb44d
Author: Nic Crane <[email protected]>
AuthorDate: Thu May 11 10:02:46 2023 +0200
GH-32282: [R] Update case_when() binding to match changes in dplyr (#35502)
This PR implements the `.default` argument for `case_when` and raises
errors if non-NULL values are supplied for `.ptype` and `.size`.
* Closes: #32282
Authored-by: Nic Crane <[email protected]>
Signed-off-by: Nic Crane <[email protected]>
---
r/R/dplyr-funcs-conditional.R | 21 +++++++++++++++--
r/R/dplyr-funcs-doc.R | 26 ++++++++++-----------
r/man/acero.Rd | 26 ++++++++++-----------
r/tests/testthat/test-dplyr-funcs-conditional.R | 31 +++++++++++++++++++++++++
4 files changed, 76 insertions(+), 28 deletions(-)
diff --git a/r/R/dplyr-funcs-conditional.R b/r/R/dplyr-funcs-conditional.R
index 37411ed261..cd0245eeee 100644
--- a/r/R/dplyr-funcs-conditional.R
+++ b/r/R/dplyr-funcs-conditional.R
@@ -90,7 +90,15 @@ register_bindings_conditional <- function() {
out
})
- register_binding("dplyr::case_when", function(...) {
+ register_binding("dplyr::case_when", function(..., .default = NULL, .ptype =
NULL, .size = NULL) {
+ if (!is.null(.ptype)) {
+ arrow_not_supported("`case_when()` with `.ptype` specified")
+ }
+
+ if (!is.null(.size)) {
+ arrow_not_supported("`case_when()` with `.size` specified")
+ }
+
formulas <- list2(...)
n <- length(formulas)
if (n == 0) {
@@ -113,6 +121,14 @@ register_bindings_conditional <- function() {
abort(handle_arrow_not_supported(value[[i]], format_expr(f[[3]])))
}
}
+ if (!is.null(.default)) {
+ if (length(.default) != 1) {
+ abort(paste0("`.default` must have size 1, not size ",
length(.default), "."))
+ }
+
+ query[n + 1] <- TRUE
+ value[n + 1] <- .default
+ }
Expression$create(
"case_when",
args = c(
@@ -124,5 +140,6 @@ register_bindings_conditional <- function() {
value
)
)
- })
+ }, notes = "`.ptype` and `.size` arguments not supported"
+ )
}
diff --git a/r/R/dplyr-funcs-doc.R b/r/R/dplyr-funcs-doc.R
index b619cfe509..a62c4f8335 100644
--- a/r/R/dplyr-funcs-doc.R
+++ b/r/R/dplyr-funcs-doc.R
@@ -83,7 +83,7 @@
#' Functions can be called either as `pkg::fun()` or just `fun()`, i.e. both
#' `str_sub()` and `stringr::str_sub()` work.
#'
-#' In addition to these functions, you can call any of Arrow's 246 compute
+#' In addition to these functions, you can call any of Arrow's 254 compute
#' functions directly. Arrow has many functions that don't map to an existing R
#' function. In other cases where there is an R function mapping, you can still
#' call the Arrow function directly if you don't want the adaptations that the
R
@@ -99,30 +99,31 @@
#'
#' ## base
#'
-#' * [`-`][-()]
#' * [`!`][!()]
#' * [`!=`][!=()]
-#' * [`*`][*()]
-#' * [`/`][/()]
-#' * [`&`][&()]
-#' * [`%/%`][%/%()]
#' * [`%%`][%%()]
+#' * [`%/%`][%/%()]
#' * [`%in%`][%in%()]
-#' * [`^`][^()]
+#' * [`&`][&()]
+#' * [`*`][*()]
#' * [`+`][+()]
+#' * [`-`][-()]
+#' * [`/`][/()]
#' * [`<`][<()]
#' * [`<=`][<=()]
#' * [`==`][==()]
#' * [`>`][>()]
#' * [`>=`][>=()]
-#' * [`|`][|()]
+#' * [`ISOdate()`][base::ISOdate()]
+#' * [`ISOdatetime()`][base::ISOdatetime()]
+#' * [`^`][^()]
#' * [`abs()`][base::abs()]
#' * [`acos()`][base::acos()]
#' * [`all()`][base::all()]
#' * [`any()`][base::any()]
-#' * [`as.character()`][base::as.character()]
#' * [`as.Date()`][base::as.Date()]: Multiple `tryFormats` not supported in
Arrow.
#' Consider using the lubridate specialised parsing functions `ymd()`,
`ymd()`, etc.
+#' * [`as.character()`][base::as.character()]
#' * [`as.difftime()`][base::as.difftime()]: only supports `units = "secs"`
(the default)
#' * [`as.double()`][base::as.double()]
#' * [`as.integer()`][base::as.integer()]
@@ -153,8 +154,6 @@
#' * [`is.na()`][base::is.na()]
#' * [`is.nan()`][base::is.nan()]
#' * [`is.numeric()`][base::is.numeric()]
-#' * [`ISOdate()`][base::ISOdate()]
-#' * [`ISOdatetime()`][base::ISOdatetime()]
#' * [`log()`][base::log()]
#' * [`log10()`][base::log10()]
#' * [`log1p()`][base::log1p()]
@@ -186,6 +185,7 @@
#' * [`tolower()`][base::tolower()]
#' * [`toupper()`][base::toupper()]
#' * [`trunc()`][base::trunc()]
+#' * [`|`][|()]
#'
#' ## bit64
#'
@@ -196,7 +196,7 @@
#'
#' * [`across()`][dplyr::across()]
#' * [`between()`][dplyr::between()]
-#' * [`case_when()`][dplyr::case_when()]
+#' * [`case_when()`][dplyr::case_when()]: `.ptype` and `.size` arguments not
supported
#' * [`coalesce()`][dplyr::coalesce()]
#' * [`desc()`][dplyr::desc()]
#' * [`if_all()`][dplyr::if_all()]
@@ -242,8 +242,8 @@
#' * [`format_ISO8601()`][lubridate::format_ISO8601()]
#' * [`hour()`][lubridate::hour()]
#' * [`is.Date()`][lubridate::is.Date()]
-#' * [`is.instant()`][lubridate::is.instant()]
#' * [`is.POSIXct()`][lubridate::is.POSIXct()]
+#' * [`is.instant()`][lubridate::is.instant()]
#' * [`is.timepoint()`][lubridate::is.timepoint()]
#' * [`isoweek()`][lubridate::isoweek()]
#' * [`isoyear()`][lubridate::isoyear()]
diff --git a/r/man/acero.Rd b/r/man/acero.Rd
index 6d4476c44c..d41029c70b 100644
--- a/r/man/acero.Rd
+++ b/r/man/acero.Rd
@@ -68,7 +68,7 @@ can assume that the function works in Acero just as it does
in R.
Functions can be called either as \code{pkg::fun()} or just \code{fun()}, i.e.
both
\code{str_sub()} and \code{stringr::str_sub()} work.
-In addition to these functions, you can call any of Arrow's 246 compute
+In addition to these functions, you can call any of Arrow's 254 compute
functions directly. Arrow has many functions that don't map to an existing R
function. In other cases where there is an R function mapping, you can still
call the Arrow function directly if you don't want the adaptations that the R
@@ -85,30 +85,31 @@ as \code{arrow_ascii_is_decimal}.
\subsection{base}{
\itemize{
-\item \code{\link[=-]{-}}
\item \code{\link[=!]{!}}
\item \code{\link[=!=]{!=}}
-\item \code{\link[=*]{*}}
-\item \code{\link[=/]{/}}
-\item \code{\link[=&]{&}}
-\item \code{\link[=\%/\%]{\%/\%}}
\item \code{\link[=\%\%]{\%\%}}
+\item \code{\link[=\%/\%]{\%/\%}}
\item \code{\link[=\%in\%]{\%in\%}}
-\item \code{\link[=^]{^}}
+\item \code{\link[=&]{&}}
+\item \code{\link[=*]{*}}
\item \code{\link[=+]{+}}
+\item \code{\link[=-]{-}}
+\item \code{\link[=/]{/}}
\item \code{\link[=<]{<}}
\item \code{\link[=<=]{<=}}
\item \code{\link[===]{==}}
\item \code{\link[=>]{>}}
\item \code{\link[=>=]{>=}}
-\item \code{\link[=|]{|}}
+\item \code{\link[base:ISOdatetime]{ISOdate()}}
+\item \code{\link[base:ISOdatetime]{ISOdatetime()}}
+\item \code{\link[=^]{^}}
\item \code{\link[base:MathFun]{abs()}}
\item \code{\link[base:Trig]{acos()}}
\item \code{\link[base:all]{all()}}
\item \code{\link[base:any]{any()}}
-\item \code{\link[base:character]{as.character()}}
\item \code{\link[base:as.Date]{as.Date()}}: Multiple \code{tryFormats} not
supported in Arrow.
Consider using the lubridate specialised parsing functions \code{ymd()},
\code{ymd()}, etc.
+\item \code{\link[base:character]{as.character()}}
\item \code{\link[base:difftime]{as.difftime()}}: only supports \code{units =
"secs"} (the default)
\item \code{\link[base:double]{as.double()}}
\item \code{\link[base:integer]{as.integer()}}
@@ -139,8 +140,6 @@ Consider using the lubridate specialised parsing functions
\code{ymd()}, \code{y
\item \code{\link[base:NA]{is.na()}}
\item \code{\link[base:is.finite]{is.nan()}}
\item \code{\link[base:numeric]{is.numeric()}}
-\item \code{\link[base:ISOdatetime]{ISOdate()}}
-\item \code{\link[base:ISOdatetime]{ISOdatetime()}}
\item \code{\link[base:Log]{log()}}
\item \code{\link[base:Log]{log10()}}
\item \code{\link[base:Log]{log1p()}}
@@ -172,6 +171,7 @@ Valid values are "s", "ms" (default), "us", "ns".
\item \code{\link[base:chartr]{tolower()}}
\item \code{\link[base:chartr]{toupper()}}
\item \code{\link[base:Round]{trunc()}}
+\item \code{\link[=|]{|}}
}
}
@@ -186,7 +186,7 @@ Valid values are "s", "ms" (default), "us", "ns".
\itemize{
\item \code{\link[dplyr:across]{across()}}
\item \code{\link[dplyr:between]{between()}}
-\item \code{\link[dplyr:case_when]{case_when()}}
+\item \code{\link[dplyr:case_when]{case_when()}}: \code{.ptype} and
\code{.size} arguments not supported
\item \code{\link[dplyr:coalesce]{coalesce()}}
\item \code{\link[dplyr:desc]{desc()}}
\item \code{\link[dplyr:across]{if_all()}}
@@ -234,8 +234,8 @@ Valid values are "s", "ms" (default), "us", "ns".
\item \code{\link[lubridate:format_ISO8601]{format_ISO8601()}}
\item \code{\link[lubridate:hour]{hour()}}
\item \code{\link[lubridate:date_utils]{is.Date()}}
-\item \code{\link[lubridate:is.instant]{is.instant()}}
\item \code{\link[lubridate:posix_utils]{is.POSIXct()}}
+\item \code{\link[lubridate:is.instant]{is.instant()}}
\item \code{\link[lubridate:is.instant]{is.timepoint()}}
\item \code{\link[lubridate:week]{isoweek()}}
\item \code{\link[lubridate:year]{isoyear()}}
diff --git a/r/tests/testthat/test-dplyr-funcs-conditional.R
b/r/tests/testthat/test-dplyr-funcs-conditional.R
index b3d86da8b4..e60712e9e6 100644
--- a/r/tests/testthat/test-dplyr-funcs-conditional.R
+++ b/r/tests/testthat/test-dplyr-funcs-conditional.R
@@ -176,6 +176,14 @@ test_that("case_when()", {
collect(),
tbl
)
+
+ compare_dplyr_binding(
+ .input %>%
+ mutate(cw = case_when(int > 5 ~ 1, .default = 0)) %>%
+ collect(),
+ tbl
+ )
+
compare_dplyr_binding(
.input %>%
transmute(cw = case_when(chr %in% letters[1:3] ~ 1L) + 41L) %>%
@@ -271,6 +279,29 @@ test_that("case_when()", {
)
)
+ expect_error(
+ expect_warning(
+ tbl %>%
+ arrow_table() %>%
+ mutate(cw = case_when(int > 5 ~ 1, .default = c(0, 1)))
+ ),
+ "`.default` must have size"
+ )
+
+ expect_warning(
+ tbl %>%
+ arrow_table() %>%
+ mutate(cw = case_when(int > 5 ~ 1, .ptype = integer())),
+ "not supported in Arrow"
+ )
+
+ expect_warning(
+ tbl %>%
+ arrow_table() %>%
+ mutate(cw = case_when(int > 5 ~ 1, .size = 10)),
+ "not supported in Arrow"
+ )
+
compare_dplyr_binding(
.input %>%
transmute(cw = case_when(lgl ~ "abc")) %>%