This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 7226c0e6 feat(c/driver/postgresql): Inital COPY Writer design (#1110)
7226c0e6 is described below

commit 7226c0e6de53f6b2e571211521ed280d7bda7f11
Author: William Ayd <[email protected]>
AuthorDate: Thu Sep 28 12:56:10 2023 -0400

    feat(c/driver/postgresql): Inital COPY Writer design (#1110)
    
    This is a pre-cursor to #1093 ; figured it would be easier to work
    piece-wise rather than all at once.
    
    This does not try to actually connect the statement.cc code to use this,
    but just gets the test case / general structure set up
---
 c/driver/postgresql/postgres_copy_reader.h       | 183 +++++++++++++++++++++++
 c/driver/postgresql/postgres_copy_reader_test.cc |  57 +++++++
 2 files changed, 240 insertions(+)

diff --git a/c/driver/postgresql/postgres_copy_reader.h 
b/c/driver/postgresql/postgres_copy_reader.h
index 5a589700..0aab4690 100644
--- a/c/driver/postgresql/postgres_copy_reader.h
+++ b/c/driver/postgresql/postgres_copy_reader.h
@@ -117,6 +117,54 @@ ArrowErrorCode ReadChecked(ArrowBufferView* data, T* out, 
ArrowError* error) {
   return NANOARROW_OK;
 }
 
+// Write a value to a buffer without checking the buffer size. Advances
+// the cursor of buffer and reduces it by sizeof(T)
+template <typename T>
+inline void WriteUnsafe(ArrowBuffer* buffer, T in) {
+  const T value = SwapNetworkToHost(in);
+  memcpy(buffer->data, &value, sizeof(T));
+  buffer->data += sizeof(T);
+  buffer->size_bytes += sizeof(T);
+}
+
+template <>
+inline void WriteUnsafe(ArrowBuffer* buffer, int8_t in) {
+  buffer->data[0] = in;
+  buffer->data += sizeof(int8_t);
+  buffer->size_bytes += sizeof(int8_t);
+}
+
+template <>
+inline void WriteUnsafe(ArrowBuffer* buffer, int16_t in) {
+  WriteUnsafe<uint16_t>(buffer, in);
+}
+
+template <>
+inline void WriteUnsafe(ArrowBuffer* buffer, int32_t in) {
+  WriteUnsafe<uint32_t>(buffer, in);
+}
+
+template <>
+inline void WriteUnsafe(ArrowBuffer* buffer, int64_t in) {
+  WriteUnsafe<uint64_t>(buffer, in);
+}
+
+template <typename T>
+ArrowErrorCode WriteChecked(ArrowBuffer* buffer, T in, ArrowError* error) {
+  // TODO: beware of overflow here
+  if (buffer->capacity_bytes < buffer->size_bytes + 
static_cast<int64_t>(sizeof(T))) {
+    ArrowErrorSet(error,
+                  "Insufficient buffer capacity (expected " PRId64
+                  " bytes but found " PRId64 ")",
+                  buffer->size_bytes + sizeof(T), buffer->capacity_bytes);
+
+    return EINVAL;
+  }
+
+  WriteUnsafe<T>(buffer, in);
+  return NANOARROW_OK;
+}
+
 class PostgresCopyFieldReader {
  public:
   PostgresCopyFieldReader() : validity_(nullptr), offsets_(nullptr), 
data_(nullptr) {
@@ -1058,4 +1106,139 @@ class PostgresCopyStreamReader {
   int64_t array_size_approx_bytes_;
 };
 
+class PostgresCopyFieldWriter {
+ public:
+  virtual ~PostgresCopyFieldWriter() {}
+
+  void Init(struct ArrowArrayView* array_view) { array_view_ = array_view; };
+
+  virtual ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* 
error) {
+    return ENOTSUP;
+  }
+
+ protected:
+  struct ArrowArrayView* array_view_;
+  std::vector<std::unique_ptr<PostgresCopyFieldWriter>> children_;
+};
+
+class PostgresCopyFieldTupleWriter : public PostgresCopyFieldWriter {
+ public:
+  void AppendChild(std::unique_ptr<PostgresCopyFieldWriter> child) {
+    int64_t child_i = static_cast<int64_t>(children_.size());
+    children_.push_back(std::move(child));
+    children_[child_i]->Init(array_view_->children[child_i]);
+  }
+
+  ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) 
override {
+    if (index >= array_view_->length) {
+      return ENODATA;
+    }
+
+    const int16_t n_fields = children_.size();
+    NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, n_fields, error));
+
+    for (int16_t i = 0; i < n_fields; i++) {
+      children_[i]->Write(buffer, index, error);
+    }
+
+    return NANOARROW_OK;
+  }
+
+ private:
+  std::vector<std::unique_ptr<PostgresCopyFieldWriter>> children_;
+};
+
+class PostgresCopyBooleanFieldWriter : public PostgresCopyFieldWriter {
+ public:
+  ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) 
override {
+    const int8_t is_null = ArrowArrayViewIsNull(array_view_, index);
+    const int32_t field_size_bytes = is_null ? -1 : 1;
+    NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, 
error));
+    if (is_null) {
+      return ADBC_STATUS_OK;
+    }
+
+    const int8_t value =
+        static_cast<int8_t>(ArrowArrayViewGetIntUnsafe(array_view_, index));
+    NANOARROW_RETURN_NOT_OK(WriteChecked<int8_t>(buffer, value, error));
+
+    return ADBC_STATUS_OK;
+  }
+};
+
+static inline ArrowErrorCode MakeCopyFieldWriter(const enum ArrowType 
arrow_type,
+                                                 PostgresCopyFieldWriter** out,
+                                                 ArrowError* error) {
+  switch (arrow_type) {
+    case NANOARROW_TYPE_BOOL:
+      *out = new PostgresCopyBooleanFieldWriter();
+      return NANOARROW_OK;
+    default:
+      return EINVAL;
+  }
+  return NANOARROW_OK;
+}
+
+class PostgresCopyStreamWriter {
+ public:
+  ~PostgresCopyStreamWriter() { ArrowArrayViewReset(array_view_.get()); }
+
+  ArrowErrorCode Init(struct ArrowSchema* schema, struct ArrowArray* array) {
+    schema_ = schema;
+    NANOARROW_RETURN_NOT_OK(
+        ArrowArrayViewInitFromSchema(array_view_.get(), schema, nullptr));
+    NANOARROW_RETURN_NOT_OK(ArrowArrayViewSetArray(array_view_.get(), array, 
nullptr));
+    root_writer_.Init(array_view_.get());
+    return NANOARROW_OK;
+  }
+
+  ArrowErrorCode WriteHeader(ArrowBuffer* buffer, ArrowError* error) {
+    ArrowBufferAppend(buffer, kPgCopyBinarySignature, 
sizeof(kPgCopyBinarySignature));
+
+    const uint32_t flag_fields = 0;
+    ArrowBufferAppend(buffer, &flag_fields, sizeof(flag_fields));
+
+    const uint32_t extension_bytes = 0;
+    ArrowBufferAppend(buffer, &extension_bytes, sizeof(extension_bytes));
+
+    const int64_t header_bytes =
+        sizeof(kPgCopyBinarySignature) + sizeof(flag_fields) + 
sizeof(extension_bytes);
+    buffer->data += header_bytes;
+
+    return NANOARROW_OK;
+  }
+
+  ArrowErrorCode WriteRecord(ArrowBuffer* buffer, ArrowError* error) {
+    NANOARROW_RETURN_NOT_OK(root_writer_.Write(buffer, records_written_, 
error));
+    records_written_++;
+    return NANOARROW_OK;
+  }
+
+  ArrowErrorCode InitFieldWriters(ArrowError* error) {
+    if (schema_->release == nullptr) {
+      return EINVAL;
+    }
+
+    for (int64_t i = 0; i < schema_->n_children; i++) {
+      struct ArrowSchemaView schema_view;
+      if (ArrowSchemaViewInit(&schema_view, schema_->children[i], error) !=
+          NANOARROW_OK) {
+        return ADBC_STATUS_INTERNAL;
+      }
+      const ArrowType arrow_type = schema_view.type;
+      PostgresCopyFieldWriter* child_writer;
+      NANOARROW_RETURN_NOT_OK(MakeCopyFieldWriter(arrow_type, &child_writer, 
error));
+      
root_writer_.AppendChild(std::unique_ptr<PostgresCopyFieldWriter>(child_writer));
+    }
+
+    return NANOARROW_OK;
+  }
+
+ private:
+  PostgresCopyFieldTupleWriter root_writer_;
+  struct ArrowSchema* schema_;
+  std::unique_ptr<struct ArrowArrayView> array_view_{new struct 
ArrowArrayView};
+  int64_t records_written_ = 0;
+};
+
 }  // namespace adbcpq
diff --git a/c/driver/postgresql/postgres_copy_reader_test.cc 
b/c/driver/postgresql/postgres_copy_reader_test.cc
index 44ad0601..f389d2a6 100644
--- a/c/driver/postgresql/postgres_copy_reader_test.cc
+++ b/c/driver/postgresql/postgres_copy_reader_test.cc
@@ -15,10 +15,13 @@
 // specific language governing permissions and limitations
 // under the License.
 
+#include <optional>
+
 #include <gtest/gtest.h>
 #include <nanoarrow/nanoarrow.hpp>
 
 #include "postgres_copy_reader.h"
+#include "validation/adbc_validation_util.h"
 
 namespace adbcpq {
 
@@ -52,6 +55,30 @@ class PostgresCopyStreamTester {
   PostgresCopyStreamReader reader_;
 };
 
+class PostgresCopyStreamWriteTester {
+ public:
+  ArrowErrorCode Init(struct ArrowSchema* schema, struct ArrowArray* array,
+                      ArrowError* error = nullptr) {
+    NANOARROW_RETURN_NOT_OK(writer_.Init(schema, array));
+    NANOARROW_RETURN_NOT_OK(writer_.InitFieldWriters(error));
+    return NANOARROW_OK;
+  }
+
+  ArrowErrorCode WriteAll(struct ArrowBuffer* buffer, ArrowError* error = 
nullptr) {
+    NANOARROW_RETURN_NOT_OK(writer_.WriteHeader(buffer, error));
+
+    int result;
+    do {
+      result = writer_.WriteRecord(buffer, error);
+    } while (result == NANOARROW_OK);
+
+    return result;
+  }
+
+ private:
+  PostgresCopyStreamWriter writer_;
+};
+
 // COPY (SELECT CAST("col" AS BOOLEAN) AS "col" FROM (  VALUES (TRUE), 
(FALSE), (NULL)) AS
 // drvd("col")) TO STDOUT;
 static uint8_t kTestPgCopyBoolean[] = {
@@ -96,6 +123,36 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadBoolean) {
   ASSERT_FALSE(ArrowBitGet(data_buffer, 2));
 }
 
+TEST(PostgresCopyUtilsTest, PostgresCopyWriteBoolean) {
+  adbc_validation::Handle<struct ArrowSchema> schema;
+  adbc_validation::Handle<struct ArrowArray> array;
+  struct ArrowError na_error;
+  ASSERT_EQ(adbc_validation::MakeSchema(&schema.value, {{"col", 
NANOARROW_TYPE_BOOL}}),
+            ADBC_STATUS_OK);
+  ASSERT_EQ(adbc_validation::MakeBatch<bool>(&schema.value, &array.value, 
&na_error,
+                                             {true, false, std::nullopt}),
+            ADBC_STATUS_OK);
+
+  PostgresCopyStreamWriteTester tester;
+  ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK);
+
+  struct ArrowBuffer buffer;
+  ArrowBufferInit(&buffer);
+  ArrowBufferReserve(&buffer, sizeof(kTestPgCopyBoolean));
+  uint8_t* cursor = buffer.data;
+
+  ASSERT_EQ(tester.WriteAll(&buffer, nullptr), ENODATA);
+
+  // The last 4 bytes of a message can be transmitted via PQputCopyData
+  // so no need to test those bytes from the Writer
+  for (size_t i = 0; i < sizeof(kTestPgCopyBoolean) - 4; i++) {
+    EXPECT_EQ(cursor[i], kTestPgCopyBoolean[i]);
+  }
+
+  buffer.data = cursor;
+  ArrowBufferReset(&buffer);
+}
+
 // COPY (SELECT CAST("col" AS SMALLINT) AS "col" FROM (  VALUES (-123), (-1), 
(1), (123),
 // (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT binary);
 static uint8_t kTestPgCopySmallInt[] = {

Reply via email to