This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 792f3d2d feat(c/driver/postgresql): Use COPY for writes (#1093)
792f3d2d is described below

commit 792f3d2dd66883c97377b7660bf5fcaf86e4152a
Author: William Ayd <[email protected]>
AuthorDate: Fri Oct 13 16:07:58 2023 -0400

    feat(c/driver/postgresql): Use COPY for writes (#1093)
    
    closes https://github.com/apache/arrow-adbc/issues/1037
---
 c/driver/postgresql/postgres_copy_reader.h  |  4 +-
 c/driver/postgresql/postgresql_benchmark.cc | 84 ++++++++++++++++++--------
 c/driver/postgresql/statement.cc            | 92 +++++++++++++++++++++++++----
 c/validation/adbc_validation.cc             |  3 +-
 4 files changed, 146 insertions(+), 37 deletions(-)

diff --git a/c/driver/postgresql/postgres_copy_reader.h 
b/c/driver/postgresql/postgres_copy_reader.h
index d9a7cdfb..09436351 100644
--- a/c/driver/postgresql/postgres_copy_reader.h
+++ b/c/driver/postgresql/postgres_copy_reader.h
@@ -1406,9 +1406,9 @@ static inline ArrowErrorCode MakeCopyFieldWriter(
       return NANOARROW_OK;
     }
     default:
+      ArrowErrorSet(error, "COPY Writer not implemented for type %d", 
schema_view.type);
       return EINVAL;
   }
-  return NANOARROW_OK;
 }
 
 class PostgresCopyStreamWriter {
@@ -1455,7 +1455,7 @@ class PostgresCopyStreamWriter {
           NANOARROW_OK) {
         return ADBC_STATUS_INTERNAL;
       }
-      PostgresCopyFieldWriter* child_writer;
+      PostgresCopyFieldWriter* child_writer = nullptr;
       NANOARROW_RETURN_NOT_OK(MakeCopyFieldWriter(schema_view, &child_writer, 
error));
       
root_writer_.AppendChild(std::unique_ptr<PostgresCopyFieldWriter>(child_writer));
     }
diff --git a/c/driver/postgresql/postgresql_benchmark.cc 
b/c/driver/postgresql/postgresql_benchmark.cc
index 85a4ae7c..2f5a050f 100644
--- a/c/driver/postgresql/postgresql_benchmark.cc
+++ b/c/driver/postgresql/postgresql_benchmark.cc
@@ -22,49 +22,67 @@
 #include "adbc.h"
 #include "validation/adbc_validation_util.h"
 
+
 static void BM_PostgresqlExecute(benchmark::State& state) {
   const char* uri = std::getenv("ADBC_POSTGRESQL_TEST_URI");
-  if (!uri) {
+  if (!uri || !strcmp(uri, "")) {
     state.SkipWithError("ADBC_POSTGRESQL_TEST_URI not set!");
+    return;
   }
   adbc_validation::Handle<struct AdbcDatabase> database;
   struct AdbcError error;
 
   if (AdbcDatabaseNew(&database.value, &error) != ADBC_STATUS_OK) {
-    state.SkipWithError("AdbcDatabaseNew call failed");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   if (AdbcDatabaseSetOption(&database.value, "uri", uri, &error) != 
ADBC_STATUS_OK) {
-    state.SkipWithError("Could not set database uri option");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   if (AdbcDatabaseInit(&database.value, &error) != ADBC_STATUS_OK) {
-    state.SkipWithError("AdbcDatabaseInit failed");
+state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   adbc_validation::Handle<struct AdbcConnection> connection;
   if (AdbcConnectionNew(&connection.value, &error) != ADBC_STATUS_OK) {
-    state.SkipWithError("Could not create connection object");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   if (AdbcConnectionInit(&connection.value, &database.value, &error) != 
ADBC_STATUS_OK) {
-    state.SkipWithError("Could not connect to database");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   adbc_validation::Handle<struct AdbcStatement> statement;
   if (AdbcStatementNew(&connection.value, &statement.value, &error) != 
ADBC_STATUS_OK) {
-    state.SkipWithError("Could not create statement object");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   const char* drop_query = "DROP TABLE IF EXISTS 
adbc_postgresql_ingest_benchmark";
   if (AdbcStatementSetSqlQuery(&statement.value, drop_query, &error)
       != ADBC_STATUS_OK) {
-    state.SkipWithError("Could not set DROP TABLE SQL query");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   if (AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error)
       != ADBC_STATUS_OK) {
-    state.SkipWithError("Could not execute DROP TABLE SQL query");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   adbc_validation::Handle<struct ArrowSchema> schema;
@@ -79,15 +97,21 @@ static void BM_PostgresqlExecute(benchmark::State& state) {
         {"floats", NANOARROW_TYPE_FLOAT},
         {"doubles", NANOARROW_TYPE_DOUBLE},
       }) != ADBC_STATUS_OK) {
-    state.SkipWithError("Could not create benchmark schema");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   if (ArrowArrayInitFromSchema(&array.value, &schema.value, &na_error) != 
NANOARROW_OK) {
-    state.SkipWithError("Could not init array from schema");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   if (ArrowArrayStartAppending(&array.value) != NANOARROW_OK) {
-    state.SkipWithError("Could not start appending to array");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   const size_t n_zeros = 1000;
@@ -118,7 +142,9 @@ static void BM_PostgresqlExecute(benchmark::State& state) {
   array.value.length = n_zeros + n_ones;
 
   if (ArrowArrayFinishBuildingDefault(&array.value, &na_error) != 
NANOARROW_OK) {
-    state.SkipWithError("Could not finish array");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   const char* create_query =
@@ -127,50 +153,62 @@ static void BM_PostgresqlExecute(benchmark::State& state) 
{
 
   if (AdbcStatementSetSqlQuery(&statement.value, create_query, &error)
       != ADBC_STATUS_OK) {
-    state.SkipWithError("Could not set CREATE TABLE SQL query");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   if (AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error)
       != ADBC_STATUS_OK) {
-    state.SkipWithError("Could not execute CREATE TABLE SQL query");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   adbc_validation::Handle<struct AdbcStatement> insert_stmt;
   if (AdbcStatementNew(&connection.value, &insert_stmt.value, &error) != 
ADBC_STATUS_OK) {
-    state.SkipWithError("Could not create INSERT statement object");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   if (AdbcStatementSetOption(&insert_stmt.value,
                              ADBC_INGEST_OPTION_TARGET_TABLE,
                              "adbc_postgresql_ingest_benchmark",
                              &error) != ADBC_STATUS_OK) {
-    state.SkipWithError("Could not set bulk_ingest statement option");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   if (AdbcStatementSetOption(&insert_stmt.value,
                              ADBC_INGEST_OPTION_MODE,
                              ADBC_INGEST_OPTION_MODE_APPEND,
                              &error) != ADBC_STATUS_OK) {
-    state.SkipWithError("Could not set bulk_ingest append option");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   for (auto _ : state) {
-    // Bind release the array, so if this actually loops you will get errors
-    // memory leaks
     AdbcStatementBind(&insert_stmt.value, &array.value, &schema.value, &error);
     AdbcStatementExecuteQuery(&insert_stmt.value, nullptr, nullptr, &error);
   }
 
   if (AdbcStatementSetSqlQuery(&statement.value, drop_query, &error)
       != ADBC_STATUS_OK) {
-    state.SkipWithError("Could not set DROP TABLE SQL query");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 
   if (AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error)
       != ADBC_STATUS_OK) {
-    state.SkipWithError("Could not execute DROP TABLE SQL query");
+    state.SkipWithError(error.message);
+    error.release(&error);
+    return;
   }
 }
 
-BENCHMARK(BM_PostgresqlExecute);
+BENCHMARK(BM_PostgresqlExecute)->Iterations(1);
 BENCHMARK_MAIN();
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 1f08fce1..e5691b52 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -523,6 +523,77 @@ struct BindStream {
     }
     return ADBC_STATUS_OK;
   }
+
+  AdbcStatusCode ExecuteCopy(PGconn* conn, int64_t* rows_affected,
+                             struct AdbcError* error) {
+    if (rows_affected) *rows_affected = 0;
+    PGresult* result = nullptr;
+
+    while (true) {
+      Handle<struct ArrowArray> array;
+      int res = bind->get_next(&bind.value, &array.value);
+      if (res != 0) {
+        SetError(error,
+                 "[libpq] Failed to read next batch from stream of bind 
parameters: "
+                 "(%d) %s %s",
+                 res, std::strerror(res), bind->get_last_error(&bind.value));
+        return ADBC_STATUS_IO;
+      }
+      if (!array->release) break;
+
+      Handle<struct ArrowArrayView> array_view;
+      CHECK_NA(
+          INTERNAL,
+          ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, 
nullptr),
+          error);
+      CHECK_NA(INTERNAL, ArrowArrayViewSetArray(&array_view.value, 
&array.value, nullptr),
+               error);
+
+      PostgresCopyStreamWriter writer;
+      CHECK_NA(INTERNAL, writer.Init(&bind_schema.value, &array.value), error);
+      CHECK_NA(INTERNAL, writer.InitFieldWriters(nullptr), error);
+
+      // build writer buffer
+      CHECK_NA(INTERNAL, writer.WriteHeader(nullptr), error);
+      int write_result;
+      do {
+        write_result = writer.WriteRecord(nullptr);
+      } while (write_result == NANOARROW_OK);
+
+      // check if not ENODATA at exit
+      if (write_result != ENODATA) {
+        SetError(error, "Error occurred writing COPY data: %s", 
PQerrorMessage(conn));
+        return ADBC_STATUS_IO;
+      }
+
+      ArrowBuffer buffer = writer.WriteBuffer();
+      if (PQputCopyData(conn, reinterpret_cast<char*>(buffer.data),
+                        buffer.size_bytes) <= 0) {
+        SetError(error, "Error writing tuple field data: %s", 
PQerrorMessage(conn));
+        return ADBC_STATUS_IO;
+      }
+
+      if (PQputCopyEnd(conn, NULL) <= 0) {
+        SetError(error, "Error message returned by PQputCopyEnd: %s",
+                 PQerrorMessage(conn));
+        return ADBC_STATUS_IO;
+      }
+
+      result = PQgetResult(conn);
+      ExecStatusType pg_status = PQresultStatus(result);
+      if (pg_status != PGRES_COMMAND_OK) {
+        AdbcStatusCode code =
+            SetError(error, result, "[libpq] Failed to execute COPY statement: 
%s %s",
+                     PQresStatus(pg_status), PQerrorMessage(conn));
+        PQclear(result);
+        return code;
+      }
+
+      PQclear(result);
+      if (rows_affected) *rows_affected += array->length;
+    }
+    return ADBC_STATUS_OK;
+  }
 };
 }  // namespace
 
@@ -1140,19 +1211,18 @@ AdbcStatusCode 
PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected,
       error));
   RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error));
 
-  std::string insert = "INSERT INTO ";
-  insert += escaped_table;
-  insert += " VALUES (";
-  for (size_t i = 0; i < bind_stream.bind_schema_fields.size(); i++) {
-    if (i > 0) insert += ", ";
-    insert += "$";
-    insert += std::to_string(i + 1);
+  std::string query = "COPY " + escaped_table + " FROM STDIN WITH (FORMAT 
binary)";
+  PGresult* result = PQexec(connection_->conn(), query.c_str());
+  if (PQresultStatus(result) != PGRES_COPY_IN) {
+    AdbcStatusCode code =
+        SetError(error, result, "[libpq] COPY query failed: %s\nQuery was:%s",
+                 PQerrorMessage(connection_->conn()), query.c_str());
+    PQclear(result);
+    return code;
   }
-  insert += ")";
 
-  RAISE_ADBC(
-      bind_stream.Prepare(connection_->conn(), insert, error, 
connection_->autocommit()));
-  RAISE_ADBC(bind_stream.Execute(connection_->conn(), rows_affected, error));
+  PQclear(result);
+  RAISE_ADBC(bind_stream.ExecuteCopy(connection_->conn(), rows_affected, 
error));
   return ADBC_STATUS_OK;
 }
 
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index 2afa0caf..c2f32ff5 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -2112,7 +2112,8 @@ void StatementTest::TestSqlIngestErrors() {
                                          {"coltwo", NANOARROW_TYPE_INT64}}),
               IsOkErrno());
   ASSERT_THAT(
-      (MakeBatch<int64_t, int64_t>(&schema.value, &array.value, &na_error, {}, 
{})),
+      (MakeBatch<int64_t, int64_t>(&schema.value, &array.value, &na_error,
+                                   {-42}, {-42})),
       IsOkErrno(&na_error));
 
   ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, 
&error),

Reply via email to