This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 7226c0e6 feat(c/driver/postgresql): Inital COPY Writer design (#1110)
7226c0e6 is described below
commit 7226c0e6de53f6b2e571211521ed280d7bda7f11
Author: William Ayd <[email protected]>
AuthorDate: Thu Sep 28 12:56:10 2023 -0400
feat(c/driver/postgresql): Inital COPY Writer design (#1110)
This is a pre-cursor to #1093 ; figured it would be easier to work
piece-wise rather than all at once.
This does not try to actually connect the statement.cc code to use this,
but just gets the test case / general structure set up
---
c/driver/postgresql/postgres_copy_reader.h | 183 +++++++++++++++++++++++
c/driver/postgresql/postgres_copy_reader_test.cc | 57 +++++++
2 files changed, 240 insertions(+)
diff --git a/c/driver/postgresql/postgres_copy_reader.h
b/c/driver/postgresql/postgres_copy_reader.h
index 5a589700..0aab4690 100644
--- a/c/driver/postgresql/postgres_copy_reader.h
+++ b/c/driver/postgresql/postgres_copy_reader.h
@@ -117,6 +117,54 @@ ArrowErrorCode ReadChecked(ArrowBufferView* data, T* out,
ArrowError* error) {
return NANOARROW_OK;
}
+// Write a value to a buffer without checking the buffer size. Advances
+// the cursor of buffer and reduces it by sizeof(T)
+template <typename T>
+inline void WriteUnsafe(ArrowBuffer* buffer, T in) {
+ const T value = SwapNetworkToHost(in);
+ memcpy(buffer->data, &value, sizeof(T));
+ buffer->data += sizeof(T);
+ buffer->size_bytes += sizeof(T);
+}
+
+template <>
+inline void WriteUnsafe(ArrowBuffer* buffer, int8_t in) {
+ buffer->data[0] = in;
+ buffer->data += sizeof(int8_t);
+ buffer->size_bytes += sizeof(int8_t);
+}
+
+template <>
+inline void WriteUnsafe(ArrowBuffer* buffer, int16_t in) {
+ WriteUnsafe<uint16_t>(buffer, in);
+}
+
+template <>
+inline void WriteUnsafe(ArrowBuffer* buffer, int32_t in) {
+ WriteUnsafe<uint32_t>(buffer, in);
+}
+
+template <>
+inline void WriteUnsafe(ArrowBuffer* buffer, int64_t in) {
+ WriteUnsafe<uint64_t>(buffer, in);
+}
+
+template <typename T>
+ArrowErrorCode WriteChecked(ArrowBuffer* buffer, T in, ArrowError* error) {
+ // TODO: beware of overflow here
+ if (buffer->capacity_bytes < buffer->size_bytes +
static_cast<int64_t>(sizeof(T))) {
+ ArrowErrorSet(error,
+ "Insufficient buffer capacity (expected " PRId64
+ " bytes but found " PRId64 ")",
+ buffer->size_bytes + sizeof(T), buffer->capacity_bytes);
+
+ return EINVAL;
+ }
+
+ WriteUnsafe<T>(buffer, in);
+ return NANOARROW_OK;
+}
+
class PostgresCopyFieldReader {
public:
PostgresCopyFieldReader() : validity_(nullptr), offsets_(nullptr),
data_(nullptr) {
@@ -1058,4 +1106,139 @@ class PostgresCopyStreamReader {
int64_t array_size_approx_bytes_;
};
+class PostgresCopyFieldWriter {
+ public:
+ virtual ~PostgresCopyFieldWriter() {}
+
+ void Init(struct ArrowArrayView* array_view) { array_view_ = array_view; };
+
+ virtual ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError*
error) {
+ return ENOTSUP;
+ }
+
+ protected:
+ struct ArrowArrayView* array_view_;
+ std::vector<std::unique_ptr<PostgresCopyFieldWriter>> children_;
+};
+
+class PostgresCopyFieldTupleWriter : public PostgresCopyFieldWriter {
+ public:
+ void AppendChild(std::unique_ptr<PostgresCopyFieldWriter> child) {
+ int64_t child_i = static_cast<int64_t>(children_.size());
+ children_.push_back(std::move(child));
+ children_[child_i]->Init(array_view_->children[child_i]);
+ }
+
+ ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error)
override {
+ if (index >= array_view_->length) {
+ return ENODATA;
+ }
+
+ const int16_t n_fields = children_.size();
+ NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, n_fields, error));
+
+ for (int16_t i = 0; i < n_fields; i++) {
+ children_[i]->Write(buffer, index, error);
+ }
+
+ return NANOARROW_OK;
+ }
+
+ private:
+ std::vector<std::unique_ptr<PostgresCopyFieldWriter>> children_;
+};
+
+class PostgresCopyBooleanFieldWriter : public PostgresCopyFieldWriter {
+ public:
+ ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error)
override {
+ const int8_t is_null = ArrowArrayViewIsNull(array_view_, index);
+ const int32_t field_size_bytes = is_null ? -1 : 1;
+ NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes,
error));
+ if (is_null) {
+ return ADBC_STATUS_OK;
+ }
+
+ const int8_t value =
+ static_cast<int8_t>(ArrowArrayViewGetIntUnsafe(array_view_, index));
+ NANOARROW_RETURN_NOT_OK(WriteChecked<int8_t>(buffer, value, error));
+
+ return ADBC_STATUS_OK;
+ }
+};
+
+static inline ArrowErrorCode MakeCopyFieldWriter(const enum ArrowType
arrow_type,
+ PostgresCopyFieldWriter** out,
+ ArrowError* error) {
+ switch (arrow_type) {
+ case NANOARROW_TYPE_BOOL:
+ *out = new PostgresCopyBooleanFieldWriter();
+ return NANOARROW_OK;
+ default:
+ return EINVAL;
+ }
+ return NANOARROW_OK;
+}
+
+class PostgresCopyStreamWriter {
+ public:
+ ~PostgresCopyStreamWriter() { ArrowArrayViewReset(array_view_.get()); }
+
+ ArrowErrorCode Init(struct ArrowSchema* schema, struct ArrowArray* array) {
+ schema_ = schema;
+ NANOARROW_RETURN_NOT_OK(
+ ArrowArrayViewInitFromSchema(array_view_.get(), schema, nullptr));
+ NANOARROW_RETURN_NOT_OK(ArrowArrayViewSetArray(array_view_.get(), array,
nullptr));
+ root_writer_.Init(array_view_.get());
+ return NANOARROW_OK;
+ }
+
+ ArrowErrorCode WriteHeader(ArrowBuffer* buffer, ArrowError* error) {
+ ArrowBufferAppend(buffer, kPgCopyBinarySignature,
sizeof(kPgCopyBinarySignature));
+
+ const uint32_t flag_fields = 0;
+ ArrowBufferAppend(buffer, &flag_fields, sizeof(flag_fields));
+
+ const uint32_t extension_bytes = 0;
+ ArrowBufferAppend(buffer, &extension_bytes, sizeof(extension_bytes));
+
+ const int64_t header_bytes =
+ sizeof(kPgCopyBinarySignature) + sizeof(flag_fields) +
sizeof(extension_bytes);
+ buffer->data += header_bytes;
+
+ return NANOARROW_OK;
+ }
+
+ ArrowErrorCode WriteRecord(ArrowBuffer* buffer, ArrowError* error) {
+ NANOARROW_RETURN_NOT_OK(root_writer_.Write(buffer, records_written_,
error));
+ records_written_++;
+ return NANOARROW_OK;
+ }
+
+ ArrowErrorCode InitFieldWriters(ArrowError* error) {
+ if (schema_->release == nullptr) {
+ return EINVAL;
+ }
+
+ for (int64_t i = 0; i < schema_->n_children; i++) {
+ struct ArrowSchemaView schema_view;
+ if (ArrowSchemaViewInit(&schema_view, schema_->children[i], error) !=
+ NANOARROW_OK) {
+ return ADBC_STATUS_INTERNAL;
+ }
+ const ArrowType arrow_type = schema_view.type;
+ PostgresCopyFieldWriter* child_writer;
+ NANOARROW_RETURN_NOT_OK(MakeCopyFieldWriter(arrow_type, &child_writer,
error));
+
root_writer_.AppendChild(std::unique_ptr<PostgresCopyFieldWriter>(child_writer));
+ }
+
+ return NANOARROW_OK;
+ }
+
+ private:
+ PostgresCopyFieldTupleWriter root_writer_;
+ struct ArrowSchema* schema_;
+ std::unique_ptr<struct ArrowArrayView> array_view_{new struct
ArrowArrayView};
+ int64_t records_written_ = 0;
+};
+
} // namespace adbcpq
diff --git a/c/driver/postgresql/postgres_copy_reader_test.cc
b/c/driver/postgresql/postgres_copy_reader_test.cc
index 44ad0601..f389d2a6 100644
--- a/c/driver/postgresql/postgres_copy_reader_test.cc
+++ b/c/driver/postgresql/postgres_copy_reader_test.cc
@@ -15,10 +15,13 @@
// specific language governing permissions and limitations
// under the License.
+#include <optional>
+
#include <gtest/gtest.h>
#include <nanoarrow/nanoarrow.hpp>
#include "postgres_copy_reader.h"
+#include "validation/adbc_validation_util.h"
namespace adbcpq {
@@ -52,6 +55,30 @@ class PostgresCopyStreamTester {
PostgresCopyStreamReader reader_;
};
+class PostgresCopyStreamWriteTester {
+ public:
+ ArrowErrorCode Init(struct ArrowSchema* schema, struct ArrowArray* array,
+ ArrowError* error = nullptr) {
+ NANOARROW_RETURN_NOT_OK(writer_.Init(schema, array));
+ NANOARROW_RETURN_NOT_OK(writer_.InitFieldWriters(error));
+ return NANOARROW_OK;
+ }
+
+ ArrowErrorCode WriteAll(struct ArrowBuffer* buffer, ArrowError* error =
nullptr) {
+ NANOARROW_RETURN_NOT_OK(writer_.WriteHeader(buffer, error));
+
+ int result;
+ do {
+ result = writer_.WriteRecord(buffer, error);
+ } while (result == NANOARROW_OK);
+
+ return result;
+ }
+
+ private:
+ PostgresCopyStreamWriter writer_;
+};
+
// COPY (SELECT CAST("col" AS BOOLEAN) AS "col" FROM ( VALUES (TRUE),
(FALSE), (NULL)) AS
// drvd("col")) TO STDOUT;
static uint8_t kTestPgCopyBoolean[] = {
@@ -96,6 +123,36 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadBoolean) {
ASSERT_FALSE(ArrowBitGet(data_buffer, 2));
}
+TEST(PostgresCopyUtilsTest, PostgresCopyWriteBoolean) {
+ adbc_validation::Handle<struct ArrowSchema> schema;
+ adbc_validation::Handle<struct ArrowArray> array;
+ struct ArrowError na_error;
+ ASSERT_EQ(adbc_validation::MakeSchema(&schema.value, {{"col",
NANOARROW_TYPE_BOOL}}),
+ ADBC_STATUS_OK);
+ ASSERT_EQ(adbc_validation::MakeBatch<bool>(&schema.value, &array.value,
&na_error,
+ {true, false, std::nullopt}),
+ ADBC_STATUS_OK);
+
+ PostgresCopyStreamWriteTester tester;
+ ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK);
+
+ struct ArrowBuffer buffer;
+ ArrowBufferInit(&buffer);
+ ArrowBufferReserve(&buffer, sizeof(kTestPgCopyBoolean));
+ uint8_t* cursor = buffer.data;
+
+ ASSERT_EQ(tester.WriteAll(&buffer, nullptr), ENODATA);
+
+ // The last 4 bytes of a message can be transmitted via PQputCopyData
+ // so no need to test those bytes from the Writer
+ for (size_t i = 0; i < sizeof(kTestPgCopyBoolean) - 4; i++) {
+ EXPECT_EQ(cursor[i], kTestPgCopyBoolean[i]);
+ }
+
+ buffer.data = cursor;
+ ArrowBufferReset(&buffer);
+}
+
// COPY (SELECT CAST("col" AS SMALLINT) AS "col" FROM ( VALUES (-123), (-1),
(1), (123),
// (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT binary);
static uint8_t kTestPgCopySmallInt[] = {