This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new b965f64 fix(c/driver/postgresql): properly handle NULLs (#626)
b965f64 is described below
commit b965f64886026ed17f5e60f96b819af9472fe484
Author: David Li <[email protected]>
AuthorDate: Mon May 1 10:27:19 2023 +0900
fix(c/driver/postgresql): properly handle NULLs (#626)
The driver didn't handle NULL values at all! Fix this, and also instead
of a homegrown array appender, just use NanoArrow. Simpler and less
error prone!
Fixes #557.
---
c/driver/postgresql/postgresql_test.cc | 3 +-
c/driver/postgresql/statement.cc | 183 ++++++++++++++-------------------
c/driver/postgresql/statement.h | 2 +
c/validation/adbc_validation.cc | 42 +++++++-
c/validation/adbc_validation.h | 2 +
5 files changed, 127 insertions(+), 105 deletions(-)
diff --git a/c/driver/postgresql/postgresql_test.cc
b/c/driver/postgresql/postgresql_test.cc
index 26d6f6d..a395582 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -116,7 +116,6 @@ class PostgresStatementTest : public ::testing::Test,
void TestSqlIngestUInt64() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestFloat32() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestFloat64() { GTEST_SKIP() << "Not implemented"; }
- void TestSqlIngestString() { GTEST_SKIP() << "TODO(apache/arrow-adbc#557)"; }
void TestSqlIngestBinary() { GTEST_SKIP() << "Not implemented"; }
void TestSqlPrepareErrorParamCountMismatch() { GTEST_SKIP() << "Not yet
implemented"; }
@@ -148,6 +147,8 @@ struct TypeTestCase {
}
};
+void PrintTo(const TypeTestCase& value, std::ostream* os) { (*os) <<
value.name; }
+
class PostgresTypeTest : public ::testing::TestWithParam<TypeTestCase> {
public:
void SetUp() override {
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index c88be68..c0d3de3 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -41,6 +41,8 @@ constexpr std::array<char, 11> kPgCopyBinarySignature = {
'P', 'G', 'C', 'O', 'P', 'Y', '\n', '\377', '\r', '\n', '\0'};
/// The flag indicating to PostgreSQL that we want binary-format values.
constexpr int kPgBinaryFormat = 1;
+// A negative field length indicates a null.
+constexpr int32_t kNullFieldLength = -1;
/// One-value ArrowArrayStream used to unify the implementations of Bind
struct OneValueStream {
@@ -370,10 +372,6 @@ int TupleReader::GetNext(struct ArrowArray* out) {
if (out->release) out->release(out);
return na_res;
}
-
- struct ArrowBitmap validity_bitmap;
- ArrowBitmapInit(&validity_bitmap);
- ArrowArraySetValidityBitmap(out->children[col], &validity_bitmap);
}
// TODO: we need to always PQgetResult
@@ -430,11 +428,11 @@ int TupleReader::GetNext(struct ArrowArray* out) {
out->children[col]->length = num_rows;
}
out->length = num_rows;
- na_res = ArrowArrayFinishBuildingDefault(out, 0);
+ na_res = ArrowArrayFinishBuildingDefault(out, &error);
if (na_res != 0) {
result_code = na_res;
if (!last_error_.empty()) last_error_ += '\n';
- last_error_ += StringBuilder("[libpq] Failed to build result array");
+ last_error_ += StringBuilder("[libpq] Failed to build result array: ",
error.message);
}
// Check the server-side response
@@ -496,109 +494,88 @@ int TupleReader::AppendNext(struct ArrowSchemaView*
fields, const char* buf, int
int32_t field_length = LoadNetworkInt32(buf);
buf += sizeof(int32_t);
- struct ArrowBitmap* bitmap = ArrowArrayValidityBitmap(out->children[col]);
-
// TODO: set error message here
- CHECK_NA(ArrowBitmapAppend(bitmap, field_length >= 0, 1));
-
- switch (fields[col].type) {
- case NANOARROW_TYPE_BOOL: {
- // DCHECK_EQ(field_length, 1);
- struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
- uint8_t raw_value = buf[0];
- buf += 1;
-
- if (raw_value != 0 && raw_value != 1) {
- last_error_ = StringBuilder("[libpq] Column #", col + 1, " (\"",
- schema_.children[col]->name,
- "\"): invalid BOOL value ", raw_value);
- return EINVAL;
- }
+ if (field_length != kNullFieldLength) {
+ CHECK_NA(AppendValue(fields, buf, col, *row_count, field_length, out));
+ buf += field_length;
+ } else {
+ CHECK_NA(ArrowArrayAppendNull(out->children[col], 1));
+ }
+ }
+ (*row_count)++;
+ return 0;
+}
- int64_t bytes_required = _ArrowRoundUpToMultipleOf8(*row_count + 1) /
8;
- if (bytes_required > buffer->size_bytes) {
- CHECK_NA(ArrowBufferAppendFill(buffer, 0, bytes_required -
buffer->size_bytes));
- }
- ArrowBitsSetTo(buffer->data, *row_count, 1, raw_value);
- break;
- }
- case NANOARROW_TYPE_DOUBLE: {
- // DCHECK_EQ(field_length, 8);
- static_assert(sizeof(double) == sizeof(uint64_t),
- "Float is not same size as uint64_t");
- struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
- uint64_t raw_value = LoadNetworkUInt64(buf);
- buf += sizeof(uint64_t);
- double value = 0.0;
- std::memcpy(&value, &raw_value, sizeof(double));
- CHECK_NA(ArrowBufferAppendDouble(buffer, value));
- break;
- }
- case NANOARROW_TYPE_FLOAT: {
- // DCHECK_EQ(field_length, 4);
- static_assert(sizeof(float) == sizeof(uint32_t),
- "Float is not same size as uint32_t");
- struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
- uint32_t raw_value = LoadNetworkUInt32(buf);
- buf += sizeof(uint32_t);
- float value = 0.0;
- std::memcpy(&value, &raw_value, sizeof(float));
- CHECK_NA(ArrowBufferAppendFloat(buffer, value));
- break;
- }
- case NANOARROW_TYPE_INT16: {
- // DCHECK_EQ(field_length, 2);
- struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
- int32_t value = LoadNetworkInt16(buf);
- buf += sizeof(int32_t);
- CHECK_NA(ArrowBufferAppendInt16(buffer, value));
- break;
- }
- case NANOARROW_TYPE_INT32: {
- // DCHECK_EQ(field_length, 4);
- struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
- int32_t value = LoadNetworkInt32(buf);
- buf += sizeof(int32_t);
- CHECK_NA(ArrowBufferAppendInt32(buffer, value));
- break;
- }
- case NANOARROW_TYPE_INT64: {
- // DCHECK_EQ(field_length, 8);
- struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
- int64_t value = field_length < 0 ? 0 : LoadNetworkInt64(buf);
- buf += sizeof(int64_t);
- CHECK_NA(ArrowBufferAppendInt64(buffer, value));
- break;
- }
- case NANOARROW_TYPE_BINARY: {
- struct ArrowBuffer* offset = ArrowArrayBuffer(out->children[col], 1);
- struct ArrowBuffer* data = ArrowArrayBuffer(out->children[col], 2);
- const int32_t last_offset =
- reinterpret_cast<const int32_t*>(offset->data)[*row_count];
- CHECK_NA(ArrowBufferAppendInt32(offset, last_offset + field_length));
- CHECK_NA(ArrowBufferAppend(data, buf, field_length));
- buf += field_length;
- break;
- }
- case NANOARROW_TYPE_STRING: {
- // textsend() in varlena.c
- struct ArrowBuffer* offset = ArrowArrayBuffer(out->children[col], 1);
- struct ArrowBuffer* data = ArrowArrayBuffer(out->children[col], 2);
- const int32_t last_offset =
- reinterpret_cast<const int32_t*>(offset->data)[*row_count];
- CHECK_NA(ArrowBufferAppendInt32(offset, last_offset + field_length));
- CHECK_NA(ArrowBufferAppend(data, buf, field_length));
- buf += field_length;
- break;
- }
- default:
+int TupleReader::AppendValue(struct ArrowSchemaView* fields, const char* buf,
int col,
+ int64_t row_count, int32_t field_length,
+ struct ArrowArray* out) {
+ switch (fields[col].type) {
+ case NANOARROW_TYPE_BOOL: {
+ uint8_t raw_value = buf[0];
+ buf += 1;
+
+ if (raw_value != 0 && raw_value != 1) {
last_error_ = StringBuilder("[libpq] Column #", col + 1, " (\"",
schema_.children[col]->name,
- "\") has unsupported type ",
fields[col].type);
- return ENOTSUP;
+ "\"): invalid BOOL value ", raw_value);
+ return EINVAL;
+ }
+ CHECK_NA(ArrowArrayAppendInt(out->children[col], raw_value));
+ break;
+ }
+ case NANOARROW_TYPE_DOUBLE: {
+ static_assert(sizeof(double) == sizeof(uint64_t),
+ "Float is not same size as uint64_t");
+ uint64_t raw_value = LoadNetworkUInt64(buf);
+ buf += sizeof(uint64_t);
+ double value = 0.0;
+ std::memcpy(&value, &raw_value, sizeof(double));
+ CHECK_NA(ArrowArrayAppendDouble(out->children[col], value));
+ break;
+ }
+ case NANOARROW_TYPE_FLOAT: {
+ static_assert(sizeof(float) == sizeof(uint32_t),
+ "Float is not same size as uint32_t");
+ uint32_t raw_value = LoadNetworkUInt32(buf);
+ buf += sizeof(uint32_t);
+ float value = 0.0;
+ std::memcpy(&value, &raw_value, sizeof(float));
+ CHECK_NA(ArrowArrayAppendDouble(out->children[col], value));
+ break;
}
+ case NANOARROW_TYPE_INT16: {
+ int32_t value = LoadNetworkInt16(buf);
+ buf += sizeof(int32_t);
+ CHECK_NA(ArrowArrayAppendInt(out->children[col], value));
+ break;
+ }
+ case NANOARROW_TYPE_INT32: {
+ int32_t value = LoadNetworkInt32(buf);
+ buf += sizeof(int32_t);
+ CHECK_NA(ArrowArrayAppendInt(out->children[col], value));
+ break;
+ }
+ case NANOARROW_TYPE_INT64: {
+ int64_t value = LoadNetworkInt64(buf);
+ buf += sizeof(int64_t);
+ CHECK_NA(ArrowArrayAppendInt(out->children[col], value));
+ break;
+ }
+ case NANOARROW_TYPE_BINARY: {
+ CHECK_NA(ArrowArrayAppendBytes(out->children[col], {buf, field_length}));
+ break;
+ }
+ case NANOARROW_TYPE_STRING: {
+ // textsend() in varlena.c
+ CHECK_NA(ArrowArrayAppendString(out->children[col], {buf,
field_length}));
+ break;
+ }
+ default:
+ last_error_ =
+ StringBuilder("[libpq] Column #", col + 1, " (\"",
schema_.children[col]->name,
+ "\") has unsupported type ", fields[col].type);
+ return ENOTSUP;
}
- (*row_count)++;
return 0;
}
diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h
index 1ad4d82..5dfe04f 100644
--- a/c/driver/postgresql/statement.h
+++ b/c/driver/postgresql/statement.h
@@ -46,6 +46,8 @@ class TupleReader final {
int AppendNext(struct ArrowSchemaView* fields, const char* buf, int buf_size,
int64_t* row_count, struct ArrowArray* out);
+ int AppendValue(struct ArrowSchemaView* fields, const char* buf, int col,
+ int64_t row_count, int32_t field_length, struct ArrowArray*
out);
void ExportTo(struct ArrowArrayStream* stream);
private:
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index 0dc17f5..886ed48 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -1007,7 +1007,7 @@ void StatementTest::TestSqlIngestString() {
void StatementTest::TestSqlIngestBinary() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
- NANOARROW_TYPE_BINARY, {std::nullopt, "", "\x00\x01\x02\x04", "",
"\xFE\xFF"}));
+ NANOARROW_TYPE_BINARY, {std::nullopt, "", "\x00\x01\x02\x04",
"\xFE\xFF"}));
}
void StatementTest::TestSqlIngestAppend() {
@@ -1239,6 +1239,46 @@ void StatementTest::TestSqlIngestMultipleConnections() {
}
}
+void StatementTest::TestSqlIngestSample() {
+ if (!quirks()->supports_bulk_ingest()) {
+ GTEST_SKIP();
+ }
+
+ ASSERT_THAT(quirks()->EnsureSampleTable(&connection, "bulk_ingest", &error),
+ IsOkStatus(&error));
+
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));
+ ASSERT_THAT(
+ AdbcStatementSetSqlQuery(
+ &statement, "SELECT * FROM bulk_ingest ORDER BY \"int64s\" ASC NULLS
FIRST",
+ &error),
+ IsOkStatus(&error));
+ StreamReader reader;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+ &reader.rows_affected, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(reader.rows_affected,
+ ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1)));
+
+ ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+ ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value,
+ {{"int64s", NANOARROW_TYPE_INT64,
NULLABLE},
+ {"strings", NANOARROW_TYPE_STRING,
NULLABLE}}));
+
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_NE(nullptr, reader.array->release);
+ ASSERT_EQ(3, reader.array->length);
+ ASSERT_EQ(2, reader.array->n_children);
+
+ ASSERT_NO_FATAL_FAILURE(
+ CompareArray<int64_t>(reader.array_view->children[0], {std::nullopt,
-42, 42}));
+
ASSERT_NO_FATAL_FAILURE(CompareArray<std::string>(reader.array_view->children[1],
+ {"", std::nullopt,
"foo"}));
+
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_EQ(nullptr, reader.array->release);
+}
+
void StatementTest::TestSqlPartitionedInts() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 400e8f1..7f3f175 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -205,6 +205,7 @@ class StatementTest {
void TestSqlIngestAppend();
void TestSqlIngestErrors();
void TestSqlIngestMultipleConnections();
+ void TestSqlIngestSample();
void TestSqlPartitionedInts();
@@ -261,6 +262,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); }
\
TEST_F(FIXTURE, SqlIngestErrors) { TestSqlIngestErrors(); }
\
TEST_F(FIXTURE, SqlIngestMultipleConnections) {
TestSqlIngestMultipleConnections(); } \
+ TEST_F(FIXTURE, SqlIngestSample) { TestSqlIngestSample(); }
\
TEST_F(FIXTURE, SqlPartitionedInts) { TestSqlPartitionedInts(); }
\
TEST_F(FIXTURE, SqlPrepareGetParameterSchema) {
TestSqlPrepareGetParameterSchema(); } \
TEST_F(FIXTURE, SqlPrepareSelectNoParams) { TestSqlPrepareSelectNoParams();
} \