This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 9d4dcb2f refactor(c/driver/postgresql): Have Copy Writer manage its
own buffer (#1148)
9d4dcb2f is described below
commit 9d4dcb2fbc43ee72a12829e77607c1c0b9a1484f
Author: William Ayd <[email protected]>
AuthorDate: Thu Oct 5 16:20:12 2023 -0400
refactor(c/driver/postgresql): Have Copy Writer manage its own buffer
(#1148)
---
c/driver/postgresql/postgres_copy_reader.h | 42 +++++------
c/driver/postgresql/postgres_copy_reader_test.cc | 89 +++++++++---------------
c/driver/postgresql/postgres_util.h | 5 ++
c/validation/adbc_validation_util.h | 5 ++
4 files changed, 60 insertions(+), 81 deletions(-)
diff --git a/c/driver/postgresql/postgres_copy_reader.h
b/c/driver/postgresql/postgres_copy_reader.h
index 5066c49a..fce5e104 100644
--- a/c/driver/postgresql/postgres_copy_reader.h
+++ b/c/driver/postgresql/postgres_copy_reader.h
@@ -122,16 +122,12 @@ ArrowErrorCode ReadChecked(ArrowBufferView* data, T* out,
ArrowError* error) {
template <typename T>
inline void WriteUnsafe(ArrowBuffer* buffer, T in) {
const T value = SwapNetworkToHost(in);
- memcpy(buffer->data, &value, sizeof(T));
- buffer->data += sizeof(T);
- buffer->size_bytes += sizeof(T);
+ ArrowBufferAppendUnsafe(buffer, &value, sizeof(T));
}
template <>
inline void WriteUnsafe(ArrowBuffer* buffer, int8_t in) {
- buffer->data[0] = in;
- buffer->data += sizeof(int8_t);
- buffer->size_bytes += sizeof(int8_t);
+ ArrowBufferAppendUnsafe(buffer, &in, sizeof(int8_t));
}
template <>
@@ -151,16 +147,7 @@ inline void WriteUnsafe(ArrowBuffer* buffer, int64_t in) {
template <typename T>
ArrowErrorCode WriteChecked(ArrowBuffer* buffer, T in, ArrowError* error) {
- // TODO: beware of overflow here
- if (buffer->capacity_bytes < buffer->size_bytes +
static_cast<int64_t>(sizeof(T))) {
- ArrowErrorSet(error,
- "Insufficient buffer capacity (expected " PRId64
- " bytes but found " PRId64 ")",
- buffer->size_bytes + sizeof(T), buffer->capacity_bytes);
-
- return EINVAL;
- }
-
+ NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(buffer, sizeof(T)));
WriteUnsafe<T>(buffer, in);
return NANOARROW_OK;
}
@@ -1215,27 +1202,27 @@ class PostgresCopyStreamWriter {
ArrowArrayViewInitFromSchema(&array_view_.value, schema, nullptr));
NANOARROW_RETURN_NOT_OK(ArrowArrayViewSetArray(&array_view_.value, array,
nullptr));
root_writer_.Init(&array_view_.value);
+ ArrowBufferInit(&buffer_.value);
return NANOARROW_OK;
}
- ArrowErrorCode WriteHeader(ArrowBuffer* buffer, ArrowError* error) {
- ArrowBufferAppend(buffer, kPgCopyBinarySignature,
sizeof(kPgCopyBinarySignature));
+ ArrowErrorCode WriteHeader(ArrowError* error) {
+ NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(&buffer_.value,
kPgCopyBinarySignature,
+ sizeof(kPgCopyBinarySignature)));
const uint32_t flag_fields = 0;
- ArrowBufferAppend(buffer, &flag_fields, sizeof(flag_fields));
+ NANOARROW_RETURN_NOT_OK(
+ ArrowBufferAppend(&buffer_.value, &flag_fields, sizeof(flag_fields)));
const uint32_t extension_bytes = 0;
- ArrowBufferAppend(buffer, &extension_bytes, sizeof(extension_bytes));
-
- const int64_t header_bytes =
- sizeof(kPgCopyBinarySignature) + sizeof(flag_fields) +
sizeof(extension_bytes);
- buffer->data += header_bytes;
+ NANOARROW_RETURN_NOT_OK(
+ ArrowBufferAppend(&buffer_.value, &extension_bytes,
sizeof(extension_bytes)));
return NANOARROW_OK;
}
- ArrowErrorCode WriteRecord(ArrowBuffer* buffer, ArrowError* error) {
- NANOARROW_RETURN_NOT_OK(root_writer_.Write(buffer, records_written_,
error));
+ ArrowErrorCode WriteRecord(ArrowError* error) {
+ NANOARROW_RETURN_NOT_OK(root_writer_.Write(&buffer_.value,
records_written_, error));
records_written_++;
return NANOARROW_OK;
}
@@ -1260,10 +1247,13 @@ class PostgresCopyStreamWriter {
return NANOARROW_OK;
}
+ const struct ArrowBuffer& WriteBuffer() const { return buffer_.value; }
+
private:
PostgresCopyFieldTupleWriter root_writer_;
struct ArrowSchema* schema_;
Handle<struct ArrowArrayView> array_view_;
+ Handle<struct ArrowBuffer> buffer_;
int64_t records_written_ = 0;
};
diff --git a/c/driver/postgresql/postgres_copy_reader_test.cc
b/c/driver/postgresql/postgres_copy_reader_test.cc
index d520271a..9edb3977 100644
--- a/c/driver/postgresql/postgres_copy_reader_test.cc
+++ b/c/driver/postgresql/postgres_copy_reader_test.cc
@@ -64,17 +64,19 @@ class PostgresCopyStreamWriteTester {
return NANOARROW_OK;
}
- ArrowErrorCode WriteAll(struct ArrowBuffer* buffer, ArrowError* error =
nullptr) {
- NANOARROW_RETURN_NOT_OK(writer_.WriteHeader(buffer, error));
+ ArrowErrorCode WriteAll(ArrowError* error = nullptr) {
+ NANOARROW_RETURN_NOT_OK(writer_.WriteHeader(error));
int result;
do {
- result = writer_.WriteRecord(buffer, error);
+ result = writer_.WriteRecord(error);
} while (result == NANOARROW_OK);
return result;
}
+ const struct ArrowBuffer& WriteBuffer() const { return
writer_.WriteBuffer(); }
+
private:
PostgresCopyStreamWriter writer_;
};
@@ -126,6 +128,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadBoolean) {
TEST(PostgresCopyUtilsTest, PostgresCopyWriteBoolean) {
adbc_validation::Handle<struct ArrowSchema> schema;
adbc_validation::Handle<struct ArrowArray> array;
+ adbc_validation::Handle<struct ArrowBuffer> buffer;
struct ArrowError na_error;
ASSERT_EQ(adbc_validation::MakeSchema(&schema.value, {{"col",
NANOARROW_TYPE_BOOL}}),
ADBC_STATUS_OK);
@@ -135,22 +138,16 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteBoolean) {
PostgresCopyStreamWriteTester tester;
ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK);
+ ASSERT_EQ(tester.WriteAll(nullptr), ENODATA);
- struct ArrowBuffer buffer;
- ArrowBufferInit(&buffer);
- ArrowBufferReserve(&buffer, sizeof(kTestPgCopyBoolean));
- uint8_t* cursor = buffer.data;
-
- ASSERT_EQ(tester.WriteAll(&buffer, nullptr), ENODATA);
-
- // The last 4 bytes of a message can be transmitted via PQputCopyData
+ const struct ArrowBuffer buf = tester.WriteBuffer();
+ // The last 2 bytes of a message can be transmitted via PQputCopyData
// so no need to test those bytes from the Writer
- for (size_t i = 0; i < sizeof(kTestPgCopyBoolean) - 4; i++) {
- EXPECT_EQ(cursor[i], kTestPgCopyBoolean[i]);
+ constexpr size_t buf_size = sizeof(kTestPgCopyBoolean) - 2;
+ ASSERT_EQ(buf.size_bytes, buf_size);
+ for (size_t i = 0; i < buf_size; i++) {
+ ASSERT_EQ(buf.data[i], kTestPgCopyBoolean[i]);
}
-
- buffer.data = cursor;
- ArrowBufferReset(&buffer);
}
// COPY (SELECT CAST("col" AS SMALLINT) AS "col" FROM ( VALUES (-123), (-1),
(1), (123),
@@ -212,22 +209,16 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt16) {
PostgresCopyStreamWriteTester tester;
ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK);
+ ASSERT_EQ(tester.WriteAll(nullptr), ENODATA);
- struct ArrowBuffer buffer;
- ArrowBufferInit(&buffer);
- ArrowBufferReserve(&buffer, sizeof(kTestPgCopySmallInt));
- uint8_t* cursor = buffer.data;
-
- ASSERT_EQ(tester.WriteAll(&buffer, nullptr), ENODATA);
-
- // The last 4 bytes of a message can be transmitted via PQputCopyData
+ const struct ArrowBuffer buf = tester.WriteBuffer();
+ // The last 2 bytes of a message can be transmitted via PQputCopyData
// so no need to test those bytes from the Writer
- for (size_t i = 0; i < sizeof(kTestPgCopySmallInt) - 4; i++) {
- EXPECT_EQ(cursor[i], kTestPgCopySmallInt[i]);
+ constexpr size_t buf_size = sizeof(kTestPgCopySmallInt) - 2;
+ ASSERT_EQ(buf.size_bytes, buf_size);
+ for (size_t i = 0; i < buf_size; i++) {
+ ASSERT_EQ(buf.data[i], kTestPgCopySmallInt[i]);
}
-
- buffer.data = cursor;
- ArrowBufferReset(&buffer);
}
// COPY (SELECT CAST("col" AS INTEGER) AS "col" FROM ( VALUES (-123), (-1),
(1), (123),
@@ -289,22 +280,16 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt32) {
PostgresCopyStreamWriteTester tester;
ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK);
+ ASSERT_EQ(tester.WriteAll(nullptr), ENODATA);
- struct ArrowBuffer buffer;
- ArrowBufferInit(&buffer);
- ArrowBufferReserve(&buffer, sizeof(kTestPgCopyInteger));
- uint8_t* cursor = buffer.data;
-
- ASSERT_EQ(tester.WriteAll(&buffer, nullptr), ENODATA);
-
- // The last 4 bytes of a message can be transmitted via PQputCopyData
+ const struct ArrowBuffer buf = tester.WriteBuffer();
+ // The last 2 bytes of a message can be transmitted via PQputCopyData
// so no need to test those bytes from the Writer
- for (size_t i = 0; i < sizeof(kTestPgCopyInteger) - 4; i++) {
- EXPECT_EQ(cursor[i], kTestPgCopyInteger[i]);
+ constexpr size_t buf_size = sizeof(kTestPgCopyInteger) - 2;
+ ASSERT_EQ(buf.size_bytes, buf_size);
+ for (size_t i = 0; i < buf_size; i++) {
+ ASSERT_EQ(buf.data[i], kTestPgCopyInteger[i]);
}
-
- buffer.data = cursor;
- ArrowBufferReset(&buffer);
}
// COPY (SELECT CAST("col" AS BIGINT) AS "col" FROM ( VALUES (-123), (-1),
(1), (123),
@@ -367,22 +352,16 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt64) {
PostgresCopyStreamWriteTester tester;
ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK);
+ ASSERT_EQ(tester.WriteAll(nullptr), ENODATA);
- struct ArrowBuffer buffer;
- ArrowBufferInit(&buffer);
- ArrowBufferReserve(&buffer, sizeof(kTestPgCopyBigInt));
- uint8_t* cursor = buffer.data;
-
- ASSERT_EQ(tester.WriteAll(&buffer, nullptr), ENODATA);
-
- // The last 4 bytes of a message can be transmitted via PQputCopyData
+ const struct ArrowBuffer buf = tester.WriteBuffer();
+ // The last 2 bytes of a message can be transmitted via PQputCopyData
// so no need to test those bytes from the Writer
- for (size_t i = 0; i < sizeof(kTestPgCopyBigInt) - 4; i++) {
- EXPECT_EQ(cursor[i], kTestPgCopyBigInt[i]);
+ constexpr size_t buf_size = sizeof(kTestPgCopyBigInt) - 2;
+ ASSERT_EQ(buf.size_bytes, buf_size);
+ for (size_t i = 0; i < buf_size; i++) {
+ ASSERT_EQ(buf.data[i], kTestPgCopyBigInt[i]);
}
-
- buffer.data = cursor;
- ArrowBufferReset(&buffer);
}
// COPY (SELECT CAST("col" AS REAL) AS "col" FROM ( VALUES (-123.456), (-1),
(1),
diff --git a/c/driver/postgresql/postgres_util.h
b/c/driver/postgresql/postgres_util.h
index 8d1af084..1009d70b 100644
--- a/c/driver/postgresql/postgres_util.h
+++ b/c/driver/postgresql/postgres_util.h
@@ -146,6 +146,11 @@ struct Releaser {
}
};
+template <>
+struct Releaser<struct ArrowBuffer> {
+ static void Release(struct ArrowBuffer* buffer) { ArrowBufferReset(buffer); }
+};
+
template <>
struct Releaser<struct ArrowArrayView> {
static void Release(struct ArrowArrayView* value) {
diff --git a/c/validation/adbc_validation_util.h
b/c/validation/adbc_validation_util.h
index 59cc3887..0db6e491 100644
--- a/c/validation/adbc_validation_util.h
+++ b/c/validation/adbc_validation_util.h
@@ -56,6 +56,11 @@ struct Releaser {
}
};
+template <>
+struct Releaser<struct ArrowBuffer> {
+ static void Release(struct ArrowBuffer* buffer) { ArrowBufferReset(buffer); }
+};
+
template <>
struct Releaser<struct ArrowArrayView> {
static void Release(struct ArrowArrayView* value) {