This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new b965f64  fix(c/driver/postgresql): properly handle NULLs (#626)
b965f64 is described below

commit b965f64886026ed17f5e60f96b819af9472fe484
Author: David Li <[email protected]>
AuthorDate: Mon May 1 10:27:19 2023 +0900

    fix(c/driver/postgresql): properly handle NULLs (#626)
    
    The driver didn't handle NULL values at all! Fix this, and also instead
    of a homegrown array appender, just use NanoArrow. Simpler and less
    error prone!
    
    Fixes #557.
---
 c/driver/postgresql/postgresql_test.cc |   3 +-
 c/driver/postgresql/statement.cc       | 183 ++++++++++++++-------------------
 c/driver/postgresql/statement.h        |   2 +
 c/validation/adbc_validation.cc        |  42 +++++++-
 c/validation/adbc_validation.h         |   2 +
 5 files changed, 127 insertions(+), 105 deletions(-)

diff --git a/c/driver/postgresql/postgresql_test.cc 
b/c/driver/postgresql/postgresql_test.cc
index 26d6f6d..a395582 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -116,7 +116,6 @@ class PostgresStatementTest : public ::testing::Test,
   void TestSqlIngestUInt64() { GTEST_SKIP() << "Not implemented"; }
   void TestSqlIngestFloat32() { GTEST_SKIP() << "Not implemented"; }
   void TestSqlIngestFloat64() { GTEST_SKIP() << "Not implemented"; }
-  void TestSqlIngestString() { GTEST_SKIP() << "TODO(apache/arrow-adbc#557)"; }
   void TestSqlIngestBinary() { GTEST_SKIP() << "Not implemented"; }
 
   void TestSqlPrepareErrorParamCountMismatch() { GTEST_SKIP() << "Not yet 
implemented"; }
@@ -148,6 +147,8 @@ struct TypeTestCase {
   }
 };
 
+void PrintTo(const TypeTestCase& value, std::ostream* os) { (*os) << 
value.name; }
+
 class PostgresTypeTest : public ::testing::TestWithParam<TypeTestCase> {
  public:
   void SetUp() override {
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index c88be68..c0d3de3 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -41,6 +41,8 @@ constexpr std::array<char, 11> kPgCopyBinarySignature = {
     'P', 'G', 'C', 'O', 'P', 'Y', '\n', '\377', '\r', '\n', '\0'};
 /// The flag indicating to PostgreSQL that we want binary-format values.
 constexpr int kPgBinaryFormat = 1;
+// A negative field length indicates a null.
+constexpr int32_t kNullFieldLength = -1;
 
 /// One-value ArrowArrayStream used to unify the implementations of Bind
 struct OneValueStream {
@@ -370,10 +372,6 @@ int TupleReader::GetNext(struct ArrowArray* out) {
       if (out->release) out->release(out);
       return na_res;
     }
-
-    struct ArrowBitmap validity_bitmap;
-    ArrowBitmapInit(&validity_bitmap);
-    ArrowArraySetValidityBitmap(out->children[col], &validity_bitmap);
   }
 
   // TODO: we need to always PQgetResult
@@ -430,11 +428,11 @@ int TupleReader::GetNext(struct ArrowArray* out) {
     out->children[col]->length = num_rows;
   }
   out->length = num_rows;
-  na_res = ArrowArrayFinishBuildingDefault(out, 0);
+  na_res = ArrowArrayFinishBuildingDefault(out, &error);
   if (na_res != 0) {
     result_code = na_res;
     if (!last_error_.empty()) last_error_ += '\n';
-    last_error_ += StringBuilder("[libpq] Failed to build result array");
+    last_error_ += StringBuilder("[libpq] Failed to build result array: ", 
error.message);
   }
 
   // Check the server-side response
@@ -496,109 +494,88 @@ int TupleReader::AppendNext(struct ArrowSchemaView* 
fields, const char* buf, int
     int32_t field_length = LoadNetworkInt32(buf);
     buf += sizeof(int32_t);
 
-    struct ArrowBitmap* bitmap = ArrowArrayValidityBitmap(out->children[col]);
-
     // TODO: set error message here
-    CHECK_NA(ArrowBitmapAppend(bitmap, field_length >= 0, 1));
-
-    switch (fields[col].type) {
-      case NANOARROW_TYPE_BOOL: {
-        // DCHECK_EQ(field_length, 1);
-        struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
-        uint8_t raw_value = buf[0];
-        buf += 1;
-
-        if (raw_value != 0 && raw_value != 1) {
-          last_error_ = StringBuilder("[libpq] Column #", col + 1, " (\"",
-                                      schema_.children[col]->name,
-                                      "\"): invalid BOOL value ", raw_value);
-          return EINVAL;
-        }
+    if (field_length != kNullFieldLength) {
+      CHECK_NA(AppendValue(fields, buf, col, *row_count, field_length, out));
+      buf += field_length;
+    } else {
+      CHECK_NA(ArrowArrayAppendNull(out->children[col], 1));
+    }
+  }
+  (*row_count)++;
+  return 0;
+}
 
-        int64_t bytes_required = _ArrowRoundUpToMultipleOf8(*row_count + 1) / 
8;
-        if (bytes_required > buffer->size_bytes) {
-          CHECK_NA(ArrowBufferAppendFill(buffer, 0, bytes_required - 
buffer->size_bytes));
-        }
-        ArrowBitsSetTo(buffer->data, *row_count, 1, raw_value);
-        break;
-      }
-      case NANOARROW_TYPE_DOUBLE: {
-        // DCHECK_EQ(field_length, 8);
-        static_assert(sizeof(double) == sizeof(uint64_t),
-                      "Float is not same size as uint64_t");
-        struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
-        uint64_t raw_value = LoadNetworkUInt64(buf);
-        buf += sizeof(uint64_t);
-        double value = 0.0;
-        std::memcpy(&value, &raw_value, sizeof(double));
-        CHECK_NA(ArrowBufferAppendDouble(buffer, value));
-        break;
-      }
-      case NANOARROW_TYPE_FLOAT: {
-        // DCHECK_EQ(field_length, 4);
-        static_assert(sizeof(float) == sizeof(uint32_t),
-                      "Float is not same size as uint32_t");
-        struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
-        uint32_t raw_value = LoadNetworkUInt32(buf);
-        buf += sizeof(uint32_t);
-        float value = 0.0;
-        std::memcpy(&value, &raw_value, sizeof(float));
-        CHECK_NA(ArrowBufferAppendFloat(buffer, value));
-        break;
-      }
-      case NANOARROW_TYPE_INT16: {
-        // DCHECK_EQ(field_length, 2);
-        struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
-        int32_t value = LoadNetworkInt16(buf);
-        buf += sizeof(int32_t);
-        CHECK_NA(ArrowBufferAppendInt16(buffer, value));
-        break;
-      }
-      case NANOARROW_TYPE_INT32: {
-        // DCHECK_EQ(field_length, 4);
-        struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
-        int32_t value = LoadNetworkInt32(buf);
-        buf += sizeof(int32_t);
-        CHECK_NA(ArrowBufferAppendInt32(buffer, value));
-        break;
-      }
-      case NANOARROW_TYPE_INT64: {
-        // DCHECK_EQ(field_length, 8);
-        struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
-        int64_t value = field_length < 0 ? 0 : LoadNetworkInt64(buf);
-        buf += sizeof(int64_t);
-        CHECK_NA(ArrowBufferAppendInt64(buffer, value));
-        break;
-      }
-      case NANOARROW_TYPE_BINARY: {
-        struct ArrowBuffer* offset = ArrowArrayBuffer(out->children[col], 1);
-        struct ArrowBuffer* data = ArrowArrayBuffer(out->children[col], 2);
-        const int32_t last_offset =
-            reinterpret_cast<const int32_t*>(offset->data)[*row_count];
-        CHECK_NA(ArrowBufferAppendInt32(offset, last_offset + field_length));
-        CHECK_NA(ArrowBufferAppend(data, buf, field_length));
-        buf += field_length;
-        break;
-      }
-      case NANOARROW_TYPE_STRING: {
-        // textsend() in varlena.c
-        struct ArrowBuffer* offset = ArrowArrayBuffer(out->children[col], 1);
-        struct ArrowBuffer* data = ArrowArrayBuffer(out->children[col], 2);
-        const int32_t last_offset =
-            reinterpret_cast<const int32_t*>(offset->data)[*row_count];
-        CHECK_NA(ArrowBufferAppendInt32(offset, last_offset + field_length));
-        CHECK_NA(ArrowBufferAppend(data, buf, field_length));
-        buf += field_length;
-        break;
-      }
-      default:
+int TupleReader::AppendValue(struct ArrowSchemaView* fields, const char* buf, 
int col,
+                             int64_t row_count, int32_t field_length,
+                             struct ArrowArray* out) {
+  switch (fields[col].type) {
+    case NANOARROW_TYPE_BOOL: {
+      uint8_t raw_value = buf[0];
+      buf += 1;
+
+      if (raw_value != 0 && raw_value != 1) {
         last_error_ = StringBuilder("[libpq] Column #", col + 1, " (\"",
                                     schema_.children[col]->name,
-                                    "\") has unsupported type ", 
fields[col].type);
-        return ENOTSUP;
+                                    "\"): invalid BOOL value ", raw_value);
+        return EINVAL;
+      }
+      CHECK_NA(ArrowArrayAppendInt(out->children[col], raw_value));
+      break;
+    }
+    case NANOARROW_TYPE_DOUBLE: {
+      static_assert(sizeof(double) == sizeof(uint64_t),
+                    "Float is not same size as uint64_t");
+      uint64_t raw_value = LoadNetworkUInt64(buf);
+      buf += sizeof(uint64_t);
+      double value = 0.0;
+      std::memcpy(&value, &raw_value, sizeof(double));
+      CHECK_NA(ArrowArrayAppendDouble(out->children[col], value));
+      break;
+    }
+    case NANOARROW_TYPE_FLOAT: {
+      static_assert(sizeof(float) == sizeof(uint32_t),
+                    "Float is not same size as uint32_t");
+      uint32_t raw_value = LoadNetworkUInt32(buf);
+      buf += sizeof(uint32_t);
+      float value = 0.0;
+      std::memcpy(&value, &raw_value, sizeof(float));
+      CHECK_NA(ArrowArrayAppendDouble(out->children[col], value));
+      break;
     }
+    case NANOARROW_TYPE_INT16: {
+      int32_t value = LoadNetworkInt16(buf);
+      buf += sizeof(int32_t);
+      CHECK_NA(ArrowArrayAppendInt(out->children[col], value));
+      break;
+    }
+    case NANOARROW_TYPE_INT32: {
+      int32_t value = LoadNetworkInt32(buf);
+      buf += sizeof(int32_t);
+      CHECK_NA(ArrowArrayAppendInt(out->children[col], value));
+      break;
+    }
+    case NANOARROW_TYPE_INT64: {
+      int64_t value = LoadNetworkInt64(buf);
+      buf += sizeof(int64_t);
+      CHECK_NA(ArrowArrayAppendInt(out->children[col], value));
+      break;
+    }
+    case NANOARROW_TYPE_BINARY: {
+      CHECK_NA(ArrowArrayAppendBytes(out->children[col], {buf, field_length}));
+      break;
+    }
+    case NANOARROW_TYPE_STRING: {
+      // textsend() in varlena.c
+      CHECK_NA(ArrowArrayAppendString(out->children[col], {buf, 
field_length}));
+      break;
+    }
+    default:
+      last_error_ =
+          StringBuilder("[libpq] Column #", col + 1, " (\"", 
schema_.children[col]->name,
+                        "\") has unsupported type ", fields[col].type);
+      return ENOTSUP;
   }
-  (*row_count)++;
   return 0;
 }
 
diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h
index 1ad4d82..5dfe04f 100644
--- a/c/driver/postgresql/statement.h
+++ b/c/driver/postgresql/statement.h
@@ -46,6 +46,8 @@ class TupleReader final {
 
   int AppendNext(struct ArrowSchemaView* fields, const char* buf, int buf_size,
                  int64_t* row_count, struct ArrowArray* out);
+  int AppendValue(struct ArrowSchemaView* fields, const char* buf, int col,
+                  int64_t row_count, int32_t field_length, struct ArrowArray* 
out);
   void ExportTo(struct ArrowArrayStream* stream);
 
  private:
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index 0dc17f5..886ed48 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -1007,7 +1007,7 @@ void StatementTest::TestSqlIngestString() {
 
 void StatementTest::TestSqlIngestBinary() {
   ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
-      NANOARROW_TYPE_BINARY, {std::nullopt, "", "\x00\x01\x02\x04", "", 
"\xFE\xFF"}));
+      NANOARROW_TYPE_BINARY, {std::nullopt, "", "\x00\x01\x02\x04", 
"\xFE\xFF"}));
 }
 
 void StatementTest::TestSqlIngestAppend() {
@@ -1239,6 +1239,46 @@ void StatementTest::TestSqlIngestMultipleConnections() {
   }
 }
 
+void StatementTest::TestSqlIngestSample() {
+  if (!quirks()->supports_bulk_ingest()) {
+    GTEST_SKIP();
+  }
+
+  ASSERT_THAT(quirks()->EnsureSampleTable(&connection, "bulk_ingest", &error),
+              IsOkStatus(&error));
+
+  ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
+  ASSERT_THAT(
+      AdbcStatementSetSqlQuery(
+          &statement, "SELECT * FROM bulk_ingest ORDER BY \"int64s\" ASC NULLS 
FIRST",
+          &error),
+      IsOkStatus(&error));
+  StreamReader reader;
+  ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+                                        &reader.rows_affected, &error),
+              IsOkStatus(&error));
+  ASSERT_THAT(reader.rows_affected,
+              ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1)));
+
+  ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+  ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value,
+                                        {{"int64s", NANOARROW_TYPE_INT64, 
NULLABLE},
+                                         {"strings", NANOARROW_TYPE_STRING, 
NULLABLE}}));
+
+  ASSERT_NO_FATAL_FAILURE(reader.Next());
+  ASSERT_NE(nullptr, reader.array->release);
+  ASSERT_EQ(3, reader.array->length);
+  ASSERT_EQ(2, reader.array->n_children);
+
+  ASSERT_NO_FATAL_FAILURE(
+      CompareArray<int64_t>(reader.array_view->children[0], {std::nullopt, 
-42, 42}));
+  
ASSERT_NO_FATAL_FAILURE(CompareArray<std::string>(reader.array_view->children[1],
+                                                    {"", std::nullopt, 
"foo"}));
+
+  ASSERT_NO_FATAL_FAILURE(reader.Next());
+  ASSERT_EQ(nullptr, reader.array->release);
+}
+
 void StatementTest::TestSqlPartitionedInts() {
   ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
   ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 400e8f1..7f3f175 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -205,6 +205,7 @@ class StatementTest {
   void TestSqlIngestAppend();
   void TestSqlIngestErrors();
   void TestSqlIngestMultipleConnections();
+  void TestSqlIngestSample();
 
   void TestSqlPartitionedInts();
 
@@ -261,6 +262,7 @@ class StatementTest {
   TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); }                  
         \
   TEST_F(FIXTURE, SqlIngestErrors) { TestSqlIngestErrors(); }                  
         \
   TEST_F(FIXTURE, SqlIngestMultipleConnections) { 
TestSqlIngestMultipleConnections(); } \
+  TEST_F(FIXTURE, SqlIngestSample) { TestSqlIngestSample(); }                  
         \
   TEST_F(FIXTURE, SqlPartitionedInts) { TestSqlPartitionedInts(); }            
         \
   TEST_F(FIXTURE, SqlPrepareGetParameterSchema) { 
TestSqlPrepareGetParameterSchema(); } \
   TEST_F(FIXTURE, SqlPrepareSelectNoParams) { TestSqlPrepareSelectNoParams(); 
}         \

Reply via email to