This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 4e0c2525 feat(c/driver/postgresql,c/driver/sqlite): implement BOOL
support in drivers (#1091)
4e0c2525 is described below
commit 4e0c25252923f5f4d827dc1cbfd9f84f0107e63a
Author: William Ayd <[email protected]>
AuthorDate: Thu Sep 21 09:22:41 2023 -0400
feat(c/driver/postgresql,c/driver/sqlite): implement BOOL support in
drivers (#1091)
---
c/driver/postgresql/statement.cc | 14 ++++++++++++++
c/driver/sqlite/sqlite.c | 1 +
c/driver/sqlite/sqlite_test.cc | 1 +
c/driver/sqlite/statement_reader.c | 1 +
c/driver_manager/adbc_driver_manager_test.cc | 1 +
c/validation/adbc_validation.cc | 4 ++++
c/validation/adbc_validation.h | 3 +++
c/validation/adbc_validation_util.h | 8 ++++++--
8 files changed, 31 insertions(+), 2 deletions(-)
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 3910378a..1fa03116 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -194,6 +194,10 @@ struct BindStream {
for (size_t i = 0; i < bind_schema_fields.size(); i++) {
PostgresTypeId type_id;
switch (bind_schema_fields[i].type) {
+ case ArrowType::NANOARROW_TYPE_BOOL:
+ type_id = PostgresTypeId::kBool;
+ param_lengths[i] = 1;
+ break;
case ArrowType::NANOARROW_TYPE_INT8:
case ArrowType::NANOARROW_TYPE_INT16:
type_id = PostgresTypeId::kInt2;
@@ -358,6 +362,13 @@ struct BindStream {
param_values[col] = param_values_buffer.data() +
param_values_offsets[col];
}
switch (bind_schema_fields[col].type) {
+ case ArrowType::NANOARROW_TYPE_BOOL: {
+ const int8_t val = ArrowBitGet(
+ array_view->children[col]->buffer_views[1].data.as_uint8,
row);
+ std::memcpy(param_values[col], &val, sizeof(int8_t));
+ break;
+ }
+
case ArrowType::NANOARROW_TYPE_INT8: {
const int16_t val =
array_view->children[col]->buffer_views[1].data.as_int8[row];
@@ -934,6 +945,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
PQfreemem(escaped);
switch (source_schema_fields[i].type) {
+ case ArrowType::NANOARROW_TYPE_BOOL:
+ create += " BOOLEAN";
+ break;
case ArrowType::NANOARROW_TYPE_INT8:
case ArrowType::NANOARROW_TYPE_INT16:
create += " SMALLINT";
diff --git a/c/driver/sqlite/sqlite.c b/c/driver/sqlite/sqlite.c
index 83cebec0..b47336fe 100644
--- a/c/driver/sqlite/sqlite.c
+++ b/c/driver/sqlite/sqlite.c
@@ -1170,6 +1170,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct
SqliteStatement* stmt,
}
switch (view.type) {
+ case NANOARROW_TYPE_BOOL:
case NANOARROW_TYPE_UINT8:
case NANOARROW_TYPE_UINT16:
case NANOARROW_TYPE_UINT32:
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index c95a3f1d..617fcb01 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -69,6 +69,7 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
ArrowType IngestSelectRoundTripType(ArrowType ingest_type) const override {
switch (ingest_type) {
+ case NANOARROW_TYPE_BOOL:
case NANOARROW_TYPE_INT8:
case NANOARROW_TYPE_INT16:
case NANOARROW_TYPE_INT32:
diff --git a/c/driver/sqlite/statement_reader.c
b/c/driver/sqlite/statement_reader.c
index 08bd27d4..c609e1e4 100644
--- a/c/driver/sqlite/statement_reader.c
+++ b/c/driver/sqlite/statement_reader.c
@@ -312,6 +312,7 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct
AdbcSqliteBinder* binder, sqlite3
SQLITE_STATIC);
break;
}
+ case NANOARROW_TYPE_BOOL:
case NANOARROW_TYPE_UINT8:
case NANOARROW_TYPE_UINT16:
case NANOARROW_TYPE_UINT32:
diff --git a/c/driver_manager/adbc_driver_manager_test.cc
b/c/driver_manager/adbc_driver_manager_test.cc
index 100feab7..18e0a87d 100644
--- a/c/driver_manager/adbc_driver_manager_test.cc
+++ b/c/driver_manager/adbc_driver_manager_test.cc
@@ -177,6 +177,7 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
ArrowType IngestSelectRoundTripType(ArrowType ingest_type) const override {
switch (ingest_type) {
+ case NANOARROW_TYPE_BOOL:
case NANOARROW_TYPE_INT8:
case NANOARROW_TYPE_INT16:
case NANOARROW_TYPE_INT32:
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index d25f236b..cae3598d 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -1208,6 +1208,10 @@ void StatementTest::TestSqlIngestNumericType(ArrowType
type) {
return TestSqlIngestType(type, values);
}
+void StatementTest::TestSqlIngestBool() {
+ ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<bool>(NANOARROW_TYPE_BOOL));
+}
+
void StatementTest::TestSqlIngestUInt8() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<uint8_t>(NANOARROW_TYPE_UINT8));
}
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 0d936de7..d1c23a03 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -271,6 +271,8 @@ class StatementTest {
// ---- Type-specific tests --------------------
+ void TestSqlIngestBool();
+
// Integers
void TestSqlIngestInt8();
void TestSqlIngestInt16();
@@ -370,6 +372,7 @@ class StatementTest {
ADBCV_STRINGIFY(FIXTURE) " must inherit from StatementTest");
\
TEST_F(FIXTURE, NewInit) { TestNewInit(); }
\
TEST_F(FIXTURE, Release) { TestRelease(); }
\
+ TEST_F(FIXTURE, SqlIngestBool) { TestSqlIngestBool(); }
\
TEST_F(FIXTURE, SqlIngestInt8) { TestSqlIngestInt8(); }
\
TEST_F(FIXTURE, SqlIngestInt16) { TestSqlIngestInt16(); }
\
TEST_F(FIXTURE, SqlIngestInt32) { TestSqlIngestInt32(); }
\
diff --git a/c/validation/adbc_validation_util.h
b/c/validation/adbc_validation_util.h
index b6376593..59cc3887 100644
--- a/c/validation/adbc_validation_util.h
+++ b/c/validation/adbc_validation_util.h
@@ -239,8 +239,9 @@ int MakeArray(struct ArrowArray* parent, struct ArrowArray*
array,
const std::vector<std::optional<T>>& values) {
for (const auto& v : values) {
if (v.has_value()) {
- if constexpr (std::is_same<T, int8_t>::value || std::is_same<T,
int16_t>::value ||
- std::is_same<T, int32_t>::value || std::is_same<T,
int64_t>::value) {
+ if constexpr (std::is_same<T, bool>::value || std::is_same<T,
int8_t>::value ||
+ std::is_same<T, int16_t>::value || std::is_same<T,
int32_t>::value ||
+ std::is_same<T, int64_t>::value) {
if (int errno_res = ArrowArrayAppendInt(array, *v); errno_res != 0) {
return errno_res;
}
@@ -352,6 +353,9 @@ void CompareArray(struct ArrowArrayView* array,
} else if constexpr (std::is_same<T, float>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_float[i]);
+ } else if constexpr (std::is_same<T, bool>::value) {
+ ASSERT_NE(array->buffer_views[1].data.data, nullptr);
+ ASSERT_EQ(*v, ArrowBitGet(array->buffer_views[1].data.as_uint8, i));
} else if constexpr (std::is_same<T, int8_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(*v, array->buffer_views[1].data.as_int8[i]);