This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e0c2525 feat(c/driver/postgresql,c/driver/sqlite): implement BOOL 
support in drivers (#1091)
4e0c2525 is described below

commit 4e0c25252923f5f4d827dc1cbfd9f84f0107e63a
Author: William Ayd <[email protected]>
AuthorDate: Thu Sep 21 09:22:41 2023 -0400

    feat(c/driver/postgresql,c/driver/sqlite): implement BOOL support in 
drivers (#1091)
---
 c/driver/postgresql/statement.cc             | 14 ++++++++++++++
 c/driver/sqlite/sqlite.c                     |  1 +
 c/driver/sqlite/sqlite_test.cc               |  1 +
 c/driver/sqlite/statement_reader.c           |  1 +
 c/driver_manager/adbc_driver_manager_test.cc |  1 +
 c/validation/adbc_validation.cc              |  4 ++++
 c/validation/adbc_validation.h               |  3 +++
 c/validation/adbc_validation_util.h          |  8 ++++++--
 8 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 3910378a..1fa03116 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -194,6 +194,10 @@ struct BindStream {
     for (size_t i = 0; i < bind_schema_fields.size(); i++) {
       PostgresTypeId type_id;
       switch (bind_schema_fields[i].type) {
+        case ArrowType::NANOARROW_TYPE_BOOL:
+          type_id = PostgresTypeId::kBool;
+          param_lengths[i] = 1;
+          break;
         case ArrowType::NANOARROW_TYPE_INT8:
         case ArrowType::NANOARROW_TYPE_INT16:
           type_id = PostgresTypeId::kInt2;
@@ -358,6 +362,13 @@ struct BindStream {
             param_values[col] = param_values_buffer.data() + 
param_values_offsets[col];
           }
           switch (bind_schema_fields[col].type) {
+            case ArrowType::NANOARROW_TYPE_BOOL: {
+              const int8_t val = ArrowBitGet(
+                  array_view->children[col]->buffer_views[1].data.as_uint8, 
row);
+              std::memcpy(param_values[col], &val, sizeof(int8_t));
+              break;
+            }
+
             case ArrowType::NANOARROW_TYPE_INT8: {
               const int16_t val =
                   array_view->children[col]->buffer_views[1].data.as_int8[row];
@@ -934,6 +945,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
     PQfreemem(escaped);
 
     switch (source_schema_fields[i].type) {
+      case ArrowType::NANOARROW_TYPE_BOOL:
+        create += " BOOLEAN";
+        break;
       case ArrowType::NANOARROW_TYPE_INT8:
       case ArrowType::NANOARROW_TYPE_INT16:
         create += " SMALLINT";
diff --git a/c/driver/sqlite/sqlite.c b/c/driver/sqlite/sqlite.c
index 83cebec0..b47336fe 100644
--- a/c/driver/sqlite/sqlite.c
+++ b/c/driver/sqlite/sqlite.c
@@ -1170,6 +1170,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct 
SqliteStatement* stmt,
     }
 
     switch (view.type) {
+      case NANOARROW_TYPE_BOOL:
       case NANOARROW_TYPE_UINT8:
       case NANOARROW_TYPE_UINT16:
       case NANOARROW_TYPE_UINT32:
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index c95a3f1d..617fcb01 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -69,6 +69,7 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
 
   ArrowType IngestSelectRoundTripType(ArrowType ingest_type) const override {
     switch (ingest_type) {
+      case NANOARROW_TYPE_BOOL:
       case NANOARROW_TYPE_INT8:
       case NANOARROW_TYPE_INT16:
       case NANOARROW_TYPE_INT32:
diff --git a/c/driver/sqlite/statement_reader.c 
b/c/driver/sqlite/statement_reader.c
index 08bd27d4..c609e1e4 100644
--- a/c/driver/sqlite/statement_reader.c
+++ b/c/driver/sqlite/statement_reader.c
@@ -312,6 +312,7 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct 
AdbcSqliteBinder* binder, sqlite3
                                      SQLITE_STATIC);
           break;
         }
+        case NANOARROW_TYPE_BOOL:
         case NANOARROW_TYPE_UINT8:
         case NANOARROW_TYPE_UINT16:
         case NANOARROW_TYPE_UINT32:
diff --git a/c/driver_manager/adbc_driver_manager_test.cc 
b/c/driver_manager/adbc_driver_manager_test.cc
index 100feab7..18e0a87d 100644
--- a/c/driver_manager/adbc_driver_manager_test.cc
+++ b/c/driver_manager/adbc_driver_manager_test.cc
@@ -177,6 +177,7 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
 
   ArrowType IngestSelectRoundTripType(ArrowType ingest_type) const override {
     switch (ingest_type) {
+      case NANOARROW_TYPE_BOOL:
       case NANOARROW_TYPE_INT8:
       case NANOARROW_TYPE_INT16:
       case NANOARROW_TYPE_INT32:
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index d25f236b..cae3598d 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -1208,6 +1208,10 @@ void StatementTest::TestSqlIngestNumericType(ArrowType 
type) {
   return TestSqlIngestType(type, values);
 }
 
+void StatementTest::TestSqlIngestBool() {
+  ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<bool>(NANOARROW_TYPE_BOOL));
+}
+
 void StatementTest::TestSqlIngestUInt8() {
   
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<uint8_t>(NANOARROW_TYPE_UINT8));
 }
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 0d936de7..d1c23a03 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -271,6 +271,8 @@ class StatementTest {
 
   // ---- Type-specific tests --------------------
 
+  void TestSqlIngestBool();
+
   // Integers
   void TestSqlIngestInt8();
   void TestSqlIngestInt16();
@@ -370,6 +372,7 @@ class StatementTest {
                 ADBCV_STRINGIFY(FIXTURE) " must inherit from StatementTest");  
         \
   TEST_F(FIXTURE, NewInit) { TestNewInit(); }                                  
         \
   TEST_F(FIXTURE, Release) { TestRelease(); }                                  
         \
+  TEST_F(FIXTURE, SqlIngestBool) { TestSqlIngestBool(); }                      
         \
   TEST_F(FIXTURE, SqlIngestInt8) { TestSqlIngestInt8(); }                      
         \
   TEST_F(FIXTURE, SqlIngestInt16) { TestSqlIngestInt16(); }                    
         \
   TEST_F(FIXTURE, SqlIngestInt32) { TestSqlIngestInt32(); }                    
         \
diff --git a/c/validation/adbc_validation_util.h 
b/c/validation/adbc_validation_util.h
index b6376593..59cc3887 100644
--- a/c/validation/adbc_validation_util.h
+++ b/c/validation/adbc_validation_util.h
@@ -239,8 +239,9 @@ int MakeArray(struct ArrowArray* parent, struct ArrowArray* 
array,
               const std::vector<std::optional<T>>& values) {
   for (const auto& v : values) {
     if (v.has_value()) {
-      if constexpr (std::is_same<T, int8_t>::value || std::is_same<T, 
int16_t>::value ||
-                    std::is_same<T, int32_t>::value || std::is_same<T, 
int64_t>::value) {
+      if constexpr (std::is_same<T, bool>::value || std::is_same<T, 
int8_t>::value ||
+                    std::is_same<T, int16_t>::value || std::is_same<T, 
int32_t>::value ||
+                    std::is_same<T, int64_t>::value) {
         if (int errno_res = ArrowArrayAppendInt(array, *v); errno_res != 0) {
           return errno_res;
         }
@@ -352,6 +353,9 @@ void CompareArray(struct ArrowArrayView* array,
       } else if constexpr (std::is_same<T, float>::value) {
         ASSERT_NE(array->buffer_views[1].data.data, nullptr);
         ASSERT_EQ(*v, array->buffer_views[1].data.as_float[i]);
+      } else if constexpr (std::is_same<T, bool>::value) {
+        ASSERT_NE(array->buffer_views[1].data.data, nullptr);
+        ASSERT_EQ(*v, ArrowBitGet(array->buffer_views[1].data.as_uint8, i));
       } else if constexpr (std::is_same<T, int8_t>::value) {
         ASSERT_NE(array->buffer_views[1].data.data, nullptr);
         ASSERT_EQ(*v, array->buffer_views[1].data.as_int8[i]);

Reply via email to