This is an automated email from the ASF dual-hosted git repository.
paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-nanoarrow.git
The following commit(s) were added to refs/heads/main by this push:
new 5771e5e feat(r): Implement infer schema methods (#104)
5771e5e is described below
commit 5771e5ec291da7d4e7abed1cdc671822c3c2e4f4
Author: Dewey Dunnington <[email protected]>
AuthorDate: Wed Feb 8 14:10:27 2023 -0400
feat(r): Implement infer schema methods (#104)
This PR implements `infer_nanoarrow_schema()`, which gives a
hypothetical type of what schema an object *would* have if it were
converted to Arrow format. Arrow's `arrow::infer_type()` does the same
thing, and before this PR, nanoarrow just fell back on Arrow to infer
the type conversion. Now that types can be created natively, we don't
need Arrow for this.
This PR implements schema inference using S3 dispatch. It would be
faster to implement this in C, and this probably will be done in the
future; however for now it seems like the overhead is acceptable for
most uses (basically, you need more than 10,000 columns in a data frame
to start noticing).
``` r
library(arrow, warn.conflicts = FALSE)
#> Some features are not enabled in this build of Arrow. Run `arrow_info()`
for more information.
library(nanoarrow)
make_big_df <- function(n) {
cols <- lapply(seq_len(n), function(...) integer())
names(cols) <- paste0("col", seq_len(n))
tibble::new_tibble(cols, nrow = 0)
}
df <- make_big_df(1e5)
bench::mark(
infer_nanoarrow_schema(df),
arrow::infer_type(df),
check = F
)
#> Warning: Some expressions had a GC in every iteration; so filtering is
#> disabled.
#> # A tibble: 2 × 6
#> expression min median `itr/sec` mem_alloc
`gc/sec`
#> <bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt>
<dbl>
#> 1 infer_nanoarrow_schema(df) 582.4ms 582.4ms 1.72 2.43MB
13.7
#> 2 arrow::infer_type(df) 85.1ms 86.6ms 11.5 1.02MB 0
```
<sup>Created on 2023-02-08 with [reprex
v2.0.2](https://reprex.tidyverse.org)</sup>
---
r/NAMESPACE | 22 +++++++++
r/R/pkg-arrow.R | 40 +++++++++++++++++
r/R/schema.R | 94 ++++++++++++++++++++++++++++++++++++++-
r/tests/testthat/test-pkg-arrow.R | 26 +++++++++++
r/tests/testthat/test-schema.R | 58 ++++++++++++++++++++++--
5 files changed, 236 insertions(+), 4 deletions(-)
diff --git a/r/NAMESPACE b/r/NAMESPACE
index c9b29b9..bd517da 100644
--- a/r/NAMESPACE
+++ b/r/NAMESPACE
@@ -31,9 +31,31 @@ S3method(format,nanoarrow_array)
S3method(format,nanoarrow_array_stream)
S3method(format,nanoarrow_buffer)
S3method(format,nanoarrow_schema)
+S3method(infer_nanoarrow_schema,Array)
+S3method(infer_nanoarrow_schema,ArrowTabular)
+S3method(infer_nanoarrow_schema,ChunkedArray)
+S3method(infer_nanoarrow_schema,Dataset)
+S3method(infer_nanoarrow_schema,Date)
+S3method(infer_nanoarrow_schema,Expression)
+S3method(infer_nanoarrow_schema,POSIXct)
+S3method(infer_nanoarrow_schema,RecordBatchReader)
+S3method(infer_nanoarrow_schema,Scalar)
+S3method(infer_nanoarrow_schema,arrow_dplyr_query)
+S3method(infer_nanoarrow_schema,blob)
+S3method(infer_nanoarrow_schema,character)
+S3method(infer_nanoarrow_schema,data.frame)
S3method(infer_nanoarrow_schema,default)
+S3method(infer_nanoarrow_schema,difftime)
+S3method(infer_nanoarrow_schema,double)
+S3method(infer_nanoarrow_schema,factor)
+S3method(infer_nanoarrow_schema,hms)
+S3method(infer_nanoarrow_schema,integer)
+S3method(infer_nanoarrow_schema,logical)
S3method(infer_nanoarrow_schema,nanoarrow_array)
S3method(infer_nanoarrow_schema,nanoarrow_array_stream)
+S3method(infer_nanoarrow_schema,raw)
+S3method(infer_nanoarrow_schema,vctrs_list_of)
+S3method(infer_nanoarrow_schema,vctrs_unspecified)
S3method(length,nanoarrow_array)
S3method(length,nanoarrow_array_stream)
S3method(length,nanoarrow_schema)
diff --git a/r/R/pkg-arrow.R b/r/R/pkg-arrow.R
index 6125ce4..30ce5e7 100644
--- a/r/R/pkg-arrow.R
+++ b/r/R/pkg-arrow.R
@@ -108,6 +108,46 @@ as_nanoarrow_schema.Schema <- function(x, ...) {
schema
}
+#' @export
+infer_nanoarrow_schema.Array <- function(x, ...) {
+ as_nanoarrow_schema.DataType(x$type)
+}
+
+#' @export
+infer_nanoarrow_schema.Scalar <- function(x, ...) {
+ as_nanoarrow_schema.DataType(x$type)
+}
+
+#' @export
+infer_nanoarrow_schema.Expression <- function(x, ...) {
+ as_nanoarrow_schema.DataType(x$type())
+}
+
+#' @export
+infer_nanoarrow_schema.ChunkedArray <- function(x, ...) {
+ as_nanoarrow_schema.DataType(x$type)
+}
+
+#' @export
+infer_nanoarrow_schema.ArrowTabular <- function(x, ...) {
+ as_nanoarrow_schema.Schema(x$schema)
+}
+
+#' @export
+infer_nanoarrow_schema.RecordBatchReader <- function(x, ...) {
+ as_nanoarrow_schema.Schema(x$schema)
+}
+
+#' @export
+infer_nanoarrow_schema.Dataset <- function(x, ...) {
+ as_nanoarrow_schema.Schema(x$schema)
+}
+
+#' @export
+infer_nanoarrow_schema.arrow_dplyr_query <- function(x, ...) {
+ infer_nanoarrow_schema.RecordBatchReader(arrow::as_record_batch_reader(x))
+}
+
#' @export
as_nanoarrow_array.Array <- function(x, ..., schema = NULL) {
imported_schema <- nanoarrow_allocate_schema()
diff --git a/r/R/schema.R b/r/R/schema.R
index 4758519..564dc38 100644
--- a/r/R/schema.R
+++ b/r/R/schema.R
@@ -54,7 +54,99 @@ infer_nanoarrow_schema <- function(x, ...) {
#' @export
infer_nanoarrow_schema.default <- function(x, ...) {
- as_nanoarrow_schema(arrow::infer_type(x, ...))
+ cls <- paste(class(x), collapse = "/")
+ stop(sprintf("Can't infer Arrow type for object of class %s", cls))
+}
+
+#' @export
+infer_nanoarrow_schema.raw <- function(x, ...) {
+ na_uint8()
+}
+
+#' @export
+infer_nanoarrow_schema.logical <- function(x, ...) {
+ na_bool()
+}
+
+#' @export
+infer_nanoarrow_schema.integer <- function(x, ...) {
+ na_int32()
+}
+
+#' @export
+infer_nanoarrow_schema.double <- function(x, ...) {
+ na_double()
+}
+
+#' @export
+infer_nanoarrow_schema.character <- function(x, ...) {
+ if (length(x) > 0 && sum(nchar(x, type = "bytes")) > .Machine$integer.max) {
+ na_large_string()
+ } else {
+ na_string()
+ }
+}
+
+#' @export
+infer_nanoarrow_schema.factor <- function(x, ...) {
+ na_dictionary(infer_nanoarrow_schema.character(levels(x)), na_int32())
+}
+
+#' @export
+infer_nanoarrow_schema.POSIXct <- function(x, ...) {
+ tz <- attr(x, "tzone")
+ if (is.null(tz) || identical(tz, "")) {
+ tz <- Sys.timezone()
+ }
+
+ na_timestamp(timezone = tz)
+}
+
+#' @export
+infer_nanoarrow_schema.Date <- function(x, ...) {
+ na_date32()
+}
+
+#' @export
+infer_nanoarrow_schema.difftime <- function(x, ...) {
+ # A balance between safety for large time ranges (not overflowing)
+ # and safety for small time ranges (not truncating)
+ na_duration(unit = "us")
+}
+
+#' @export
+infer_nanoarrow_schema.data.frame <- function(x, ...) {
+ na_struct(lapply(x, infer_nanoarrow_schema), nullable = FALSE)
+}
+
+#' @export
+infer_nanoarrow_schema.hms <- function(x, ...) {
+ # As a default, ms is safer than s and less likely to truncate
+ na_time32(unit = "ms")
+}
+
+#' @export
+infer_nanoarrow_schema.blob <- function(x, ...) {
+ if (length(x) > 0 && sum(lengths(x)) > .Machine$integer.max) {
+ na_large_binary()
+ } else {
+ na_binary()
+ }
+}
+
+#' @export
+infer_nanoarrow_schema.vctrs_unspecified <- function(x, ...) {
+ na_na()
+}
+
+#' @export
+infer_nanoarrow_schema.vctrs_list_of <- function(x, ...) {
+ child_type <- infer_nanoarrow_schema(attr(x, "ptype"))
+ if (length(x) > 0 && sum(lengths(x)) > .Machine$integer.max) {
+ na_large_list(child_type)
+ } else {
+ na_list(child_type)
+ }
}
#' @rdname as_nanoarrow_schema
diff --git a/r/tests/testthat/test-pkg-arrow.R
b/r/tests/testthat/test-pkg-arrow.R
index da7f9fd..a6f86ab 100644
--- a/r/tests/testthat/test-pkg-arrow.R
+++ b/r/tests/testthat/test-pkg-arrow.R
@@ -29,6 +29,32 @@ test_that("infer_nanoarrow_schema() works for arrow
objects", {
int_schema <- infer_nanoarrow_schema(arrow::Array$create(1:10))
expect_true(arrow::as_data_type(int_schema)$Equals(arrow::int32()))
+
+ int_schema <- infer_nanoarrow_schema(arrow::Scalar$create(1L))
+ expect_true(arrow::as_data_type(int_schema)$Equals(arrow::int32()))
+
+ int_schema <- infer_nanoarrow_schema(arrow::ChunkedArray$create(1:10))
+ expect_true(arrow::as_data_type(int_schema)$Equals(arrow::int32()))
+
+ int_schema <- infer_nanoarrow_schema(arrow::Expression$scalar(1L))
+ expect_true(arrow::as_data_type(int_schema)$Equals(arrow::int32()))
+
+ tbl_schema_expected <- arrow::schema(x = arrow::int32())
+ tbl_schema <- infer_nanoarrow_schema(arrow::record_batch(x = 1L))
+ expect_true(arrow::as_schema(tbl_schema)$Equals(tbl_schema_expected))
+
+ tbl_schema <- infer_nanoarrow_schema(arrow::arrow_table(x = 1L))
+ expect_true(arrow::as_schema(tbl_schema)$Equals(tbl_schema_expected))
+
+ tbl_schema <- infer_nanoarrow_schema(
+ arrow::RecordBatchReader$create(arrow::record_batch(x = 1L))
+ )
+ expect_true(arrow::as_schema(tbl_schema)$Equals(tbl_schema_expected))
+
+ tbl_schema <- infer_nanoarrow_schema(
+ arrow::InMemoryDataset$create(arrow::record_batch(x = 1L))
+ )
+ expect_true(arrow::as_schema(tbl_schema)$Equals(tbl_schema_expected))
})
test_that("nanoarrow_array to Array works", {
diff --git a/r/tests/testthat/test-schema.R b/r/tests/testthat/test-schema.R
index 19d58fa..f625c72 100644
--- a/r/tests/testthat/test-schema.R
+++ b/r/tests/testthat/test-schema.R
@@ -34,9 +34,61 @@ test_that("as_nanoarrow_schema() works for
nanoarrow_schema", {
expect_identical(as_nanoarrow_schema(schema), schema)
})
-test_that("infer_nanoarrow_schema() default method works", {
- schema <- na_int32()
- expect_true(arrow::as_data_type(schema)$Equals(arrow::int32()))
+test_that("infer_nanoarrow_schema() errors for unsupported types", {
+ expect_error(
+ infer_nanoarrow_schema(environment()),
+ "Can't infer Arrow type"
+ )
+})
+
+test_that("infer_nanoarrow_schema() methods work for built-in types", {
+ expect_identical(infer_nanoarrow_schema(raw())$format, "C")
+ expect_identical(infer_nanoarrow_schema(logical())$format, "b")
+ expect_identical(infer_nanoarrow_schema(integer())$format, "i")
+ expect_identical(infer_nanoarrow_schema(double())$format, "g")
+ expect_identical(infer_nanoarrow_schema(character())$format, "u")
+ expect_identical(infer_nanoarrow_schema(Sys.Date())$format, "tdD")
+
+ expect_identical(infer_nanoarrow_schema(factor())$format, "i")
+ expect_identical(infer_nanoarrow_schema(factor())$dictionary$format, "u")
+
+ time <- as.POSIXct("2000-01-01", tz = "UTC")
+ expect_identical(infer_nanoarrow_schema(time)$format, "tsm:UTC")
+
+ time <- as.POSIXct("2000-01-01", tz = "")
+ expect_identical(
+ infer_nanoarrow_schema(time)$format,
+ paste0("tsm:", Sys.timezone())
+ )
+
+ difftime <- as.difftime(double(), unit = "secs")
+ expect_identical(infer_nanoarrow_schema(difftime)$format, "tDu")
+
+ df_schema <- infer_nanoarrow_schema(data.frame(x = 1L))
+ expect_identical(df_schema$format, "+s")
+ expect_identical(df_schema$children$x$format, "i")
+})
+
+test_that("infer_nanoarrow_schema() methods work for blob type", {
+ skip_if_not_installed("blob")
+
+ expect_identical(infer_nanoarrow_schema(blob::blob())$format, "z")
+})
+
+test_that("infer_nanoarrow_schema() methods work for hms type", {
+ skip_if_not_installed("hms")
+
+ expect_identical(infer_nanoarrow_schema(hms::hms())$format, "ttm")
+})
+
+test_that("infer_nanoarrow_schema() methods work for vctrs types", {
+ skip_if_not_installed("vctrs")
+
+ expect_identical(infer_nanoarrow_schema(vctrs::unspecified())$format, "n")
+
+ list_schema <- infer_nanoarrow_schema(vctrs::list_of(.ptype = integer()))
+ expect_identical(list_schema$format, "+l")
+ expect_identical(list_schema$children[[1]]$format, "i")
})
test_that("nanoarrow_schema_parse() works", {