This is an automated email from the ASF dual-hosted git repository.

paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 9cce5991 fix(r/adbcdrivermanager): Improve handling of integer and 
character list inputs (#1205)
9cce5991 is described below

commit 9cce5991a233141f6012b35c8623f59cbfbcacc2
Author: Dewey Dunnington <[email protected]>
AuthorDate: Tue Oct 17 09:23:07 2023 -0300

    fix(r/adbcdrivermanager): Improve handling of integer and character list 
inputs (#1205)
    
    Closes #1127. This also introduces better error messages and fixes some
    inputs that might have segfaulted (`NA_character_` input as a table
    type):
    
    ``` r
    library(adbcdrivermanager)
    con <- adbc_driver_void() |>
      adbc_database_init() |>
      adbc_connection_init()
    
    adbc_connection_get_objects(con, table_name = 5L)
    #> Error in adbc_connection_get_objects(con, table_name = 5L): Expected 
character(1) for conversion to const char*
    adbc_connection_get_objects(con, table_name = NA_character_)
    #> Error in adbc_connection_get_objects(con, table_name = NA_character_): 
Can't convert NA_character_ to const char*
    adbc_connection_get_objects(con, NA_integer_)
    #> Error in adbc_connection_get_objects(con, NA_integer_): Can't convert 
NA_integer_ to int
    adbc_connection_get_objects(con, NA_real_)
    #> Error in adbc_connection_get_objects(con, NA_real_): Can't convert 
NA_real_ to int
    ```
    
    <sup>Created on 2023-10-16 with [reprex
    v2.0.2](https://reprex.tidyverse.org)</sup>
---
 r/adbcdrivermanager/R/adbc.R                       |  4 +-
 .../man/adbc_connection_get_info.Rd                |  2 +-
 r/adbcdrivermanager/src/radbc.cc                   | 31 +++-----
 r/adbcdrivermanager/src/radbc.h                    | 89 +++++++++++++++++++++-
 r/adbcdrivermanager/tests/testthat/test-radbc.R    | 61 ++++++++++++++-
 5 files changed, 157 insertions(+), 30 deletions(-)

diff --git a/r/adbcdrivermanager/R/adbc.R b/r/adbcdrivermanager/R/adbc.R
index 53756f41..85f9676b 100644
--- a/r/adbcdrivermanager/R/adbc.R
+++ b/r/adbcdrivermanager/R/adbc.R
@@ -210,13 +210,13 @@ adbc_connection_release <- function(connection) {
 #' # (not implemented by the void driver)
 #' try(adbc_connection_get_info(con, 0))
 #'
-adbc_connection_get_info <- function(connection, info_codes) {
+adbc_connection_get_info <- function(connection, info_codes = NULL) {
   error <- adbc_allocate_error()
   out_stream <- nanoarrow::nanoarrow_allocate_array_stream()
   status <- .Call(
     RAdbcConnectionGetInfo,
     connection,
-    as.integer(info_codes),
+    info_codes,
     out_stream,
     error
   )
diff --git a/r/adbcdrivermanager/man/adbc_connection_get_info.Rd 
b/r/adbcdrivermanager/man/adbc_connection_get_info.Rd
index 12f5afa5..92acb78f 100644
--- a/r/adbcdrivermanager/man/adbc_connection_get_info.Rd
+++ b/r/adbcdrivermanager/man/adbc_connection_get_info.Rd
@@ -12,7 +12,7 @@
 \alias{adbc_connection_rollback}
 \title{Connection methods}
 \usage{
-adbc_connection_get_info(connection, info_codes)
+adbc_connection_get_info(connection, info_codes = NULL)
 
 adbc_connection_get_objects(
   connection,
diff --git a/r/adbcdrivermanager/src/radbc.cc b/r/adbcdrivermanager/src/radbc.cc
index f5afdd9d..c10f05c5 100644
--- a/r/adbcdrivermanager/src/radbc.cc
+++ b/r/adbcdrivermanager/src/radbc.cc
@@ -20,6 +20,7 @@
 #include <Rinternals.h>
 
 #include <string.h>
+#include <utility>
 
 #include <adbc.h>
 #include "adbc_driver_manager.h"
@@ -280,10 +281,13 @@ extern "C" SEXP RAdbcConnectionGetInfo(SEXP 
connection_xptr, SEXP info_codes_sex
   auto connection = adbc_from_xptr<AdbcConnection>(connection_xptr);
   auto error = adbc_from_xptr<AdbcError>(error_xptr);
   auto out_stream = adbc_from_xptr<ArrowArrayStream>(out_stream_xptr);
-  auto info_codes = reinterpret_cast<uint32_t*>(INTEGER(info_codes_sexp));
+  std::pair<SEXP, int*> info_codes = adbc_as_int_list(info_codes_sexp);
+  PROTECT(info_codes.first);
   size_t info_codes_length = Rf_xlength(info_codes_sexp);
   int status =
-      AdbcConnectionGetInfo(connection, info_codes, info_codes_length, 
out_stream, error);
+      AdbcConnectionGetInfo(connection, 
reinterpret_cast<uint32_t*>(info_codes.second),
+                            info_codes_length, out_stream, error);
+  UNPROTECT(1);
   return adbc_wrap_status(status);
 }
 
@@ -297,25 +301,8 @@ extern "C" SEXP RAdbcConnectionGetObjects(SEXP 
connection_xptr, SEXP depth_sexp,
   const char* catalog = adbc_as_const_char(catalog_sexp, true);
   const char* db_schema = adbc_as_const_char(db_schema_sexp, true);
   const char* table_name = adbc_as_const_char(table_name_sexp, true);
-
-  // Build the null-terminated const char** used to filter by table type
-  int table_type_length = Rf_length(table_type_sexp);
-  SEXP table_type_shelter =
-      PROTECT(Rf_allocVector(RAWSXP, (table_type_length + 1) * sizeof(const 
char*)));
-  auto table_type = reinterpret_cast<const char**>(RAW(table_type_shelter));
-  for (int i = 0; i < table_type_length; i++) {
-    table_type[i] = Rf_translateCharUTF8(STRING_ELT(table_type_sexp, i));
-  }
-  table_type[table_type_length] = nullptr;
-
-  // Ensure that R_NilValue maps to null and not a null-termianted const char**
-  // of length 0.
-  const char** table_type_maybe_null;
-  if (table_type_sexp == R_NilValue) {
-    table_type_maybe_null = nullptr;
-  } else {
-    table_type_maybe_null = table_type;
-  }
+  std::pair<SEXP, const char**> table_type = 
adbc_as_const_char_list(table_type_sexp);
+  PROTECT(table_type.first);
 
   const char* column_name = adbc_as_const_char(column_name_sexp, true);
   auto out_stream = adbc_from_xptr<ArrowArrayStream>(out_stream_xptr);
@@ -323,7 +310,7 @@ extern "C" SEXP RAdbcConnectionGetObjects(SEXP 
connection_xptr, SEXP depth_sexp,
 
   int status =
       AdbcConnectionGetObjects(connection, depth, catalog, db_schema, 
table_name,
-                               table_type_maybe_null, column_name, out_stream, 
error);
+                               table_type.second, column_name, out_stream, 
error);
   UNPROTECT(1);
   return adbc_wrap_status(status);
 }
diff --git a/r/adbcdrivermanager/src/radbc.h b/r/adbcdrivermanager/src/radbc.h
index 9c20686d..fa9fb5ff 100644
--- a/r/adbcdrivermanager/src/radbc.h
+++ b/r/adbcdrivermanager/src/radbc.h
@@ -20,6 +20,8 @@
 #include <R.h>
 #include <Rinternals.h>
 
+#include <utility>
+
 template <typename T>
 static inline const char* adbc_xptr_class();
 
@@ -151,16 +153,95 @@ static inline const char* adbc_as_const_char(SEXP sexp, 
bool nullable = false) {
 static inline int adbc_as_int(SEXP sexp) {
   if (Rf_length(sexp) == 1) {
     switch (TYPEOF(sexp)) {
-      case REALSXP:
-        return REAL(sexp)[0];
-      case INTSXP:
-        return INTEGER(sexp)[0];
+      case REALSXP: {
+        double value = REAL(sexp)[0];
+        if (ISNA(value) || ISNAN(value)) {
+          Rf_error("Can't convert NA_real_ to int");
+        }
+
+        return value;
+      }
+
+      case INTSXP: {
+        int value = INTEGER(sexp)[0];
+        if (value == NA_INTEGER) {
+          Rf_error("Can't convert NA_integer_ to int");
+        }
+
+        return value;
+      }
     }
   }
 
   Rf_error("Expected integer(1) or double(1) for conversion to int");
 }
 
+static inline std::pair<SEXP, const char**> adbc_as_const_char_list(SEXP sexp) 
{
+  switch (TYPEOF(sexp)) {
+    case NILSXP:
+      return {R_NilValue, nullptr};
+    case STRSXP:
+      break;
+    default:
+      Rf_error("Expected character() for conversion to const char**");
+  }
+
+  int sexp_length = Rf_length(sexp);
+  SEXP result_shelter =
+      PROTECT(Rf_allocVector(RAWSXP, (sexp_length + 1) * sizeof(const char*)));
+  auto result = reinterpret_cast<const char**>(RAW(result_shelter));
+  for (int i = 0; i < sexp_length; i++) {
+    SEXP item = STRING_ELT(sexp, i);
+    if (item == NA_STRING) {
+      Rf_error("Can't convert NA_character_ element to const char*");
+    }
+
+    result[i] = Rf_translateCharUTF8(STRING_ELT(sexp, i));
+  }
+  result[sexp_length] = nullptr;
+  UNPROTECT(1);
+  return {result_shelter, result};
+}
+
+static inline std::pair<SEXP, int*> adbc_as_int_list(SEXP sexp) {
+  int result_length = Rf_length(sexp);
+
+  switch (TYPEOF(sexp)) {
+    case NILSXP:
+      return {R_NilValue, nullptr};
+
+    case INTSXP: {
+      int* result = INTEGER(sexp);
+      for (int i = 0; i < result_length; i++) {
+        if (result[i] == NA_INTEGER) {
+          Rf_error("Can't convert NA_integer_ element to int");
+        }
+      }
+
+      return {sexp, result};
+    }
+
+    case REALSXP: {
+      SEXP result_shelter = PROTECT(Rf_allocVector(INTSXP, result_length));
+      int* result = INTEGER(result_shelter);
+      for (int i = 0; i < result_length; i++) {
+        double item = REAL(sexp)[i];
+        if (ISNA(item) || ISNAN(item)) {
+          Rf_error("Can't convert NA_real_ or NaN element to int");
+        }
+
+        result[i] = item;
+      }
+
+      UNPROTECT(1);
+      return {result_shelter, result};
+    }
+
+    default:
+      Rf_error("Expected character for conversion to const char**");
+  }
+}
+
 static inline SEXP adbc_wrap_status(AdbcStatusCode code) {
   return Rf_ScalarInteger(code);
 }
diff --git a/r/adbcdrivermanager/tests/testthat/test-radbc.R 
b/r/adbcdrivermanager/tests/testthat/test-radbc.R
index 1ee83660..3d6e6522 100644
--- a/r/adbcdrivermanager/tests/testthat/test-radbc.R
+++ b/r/adbcdrivermanager/tests/testthat/test-radbc.R
@@ -39,6 +39,23 @@ test_that("connection methods work for the void driver", {
     "NOT_IMPLEMENTED"
   )
 
+  expect_error(
+    adbc_connection_get_info(con, double()),
+    "NOT_IMPLEMENTED"
+  )
+
+  expect_error(
+    adbc_connection_get_info(con, NULL),
+    "NOT_IMPLEMENTED"
+  )
+
+  # With defaults of NULL/OL
+  expect_error(
+    adbc_connection_get_objects(con),
+    "NOT_IMPLEMENTED"
+  )
+
+  # With explicit args
   expect_error(
     adbc_connection_get_objects(
       con, 0,
@@ -155,7 +172,7 @@ test_that("invalid parameter types generate errors", {
 
   expect_error(
     adbc_connection_get_objects(
-      con, NULL,
+      con, character(),
       "catalog", "db_schema",
       "table_name", "table_type", "column_name"
     ),
@@ -163,6 +180,33 @@ test_that("invalid parameter types generate errors", {
     fixed = TRUE
   )
 
+  expect_error(
+    adbc_connection_get_objects(
+      con, NA_integer_,
+      "catalog", "db_schema",
+      "table_name", "table_type", "column_name"
+    ),
+    "Can't convert NA_integer_"
+  )
+
+  expect_error(
+    adbc_connection_get_objects(
+      con, NA_real_,
+      "catalog", "db_schema",
+      "table_name", "table_type", "column_name"
+    ),
+    "Can't convert NA_real_"
+  )
+
+  expect_error(
+    adbc_connection_get_objects(
+      con, 0L,
+      "catalog", "db_schema",
+      "table_name", c("table_type1", NA_character_), "column_name"
+    ),
+    "Can't convert NA_character_ element"
+  )
+
   expect_error(
     adbc_statement_set_sql_query(stmt, NULL),
     "Expected character(1)",
@@ -174,6 +218,21 @@ test_that("invalid parameter types generate errors", {
     "Can't convert NA_character_"
   )
 
+  expect_error(
+    adbc_connection_get_info(con, NA_integer_),
+    "Can't convert NA_integer_ element"
+  )
+
+  expect_error(
+    adbc_connection_get_info(con, NA_real_),
+    "Can't convert NA_real_ or NaN element"
+  )
+
+  expect_error(
+    adbc_connection_get_info(con, NaN),
+    "Can't convert NA_real_ or NaN element"
+  )
+
   # (makes a NULL xptr)
   stmt2 <- unserialize(serialize(stmt, NULL))
   expect_error(

Reply via email to