This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new a6bd37a  feat(c/driver/postgresql): handle non-SELECT statements (#707)
a6bd37a is described below

commit a6bd37a9f2189b41d6f7a70f10ee72ef6d615aab
Author: David Li <[email protected]>
AuthorDate: Thu May 25 18:00:15 2023 -0400

    feat(c/driver/postgresql): handle non-SELECT statements (#707)
    
    Before we tried to infer the query schema before COPY by wrapping it in
    a "SELECT * FROM (...) LIMIT 0". This broke if the query was (for
    example) a CREATE or UPDATE. Instead, use a prepared statement to infer
    instead. If we find that there are no result columns, then execute
    without the COPY path.
    
    Also, test what happens with "INSERT INTO ... RETURNING" (this works
    with COPY).
    
    Fixes #701.
---
 c/driver/postgresql/postgresql_test.cc            | 58 +++++++++++++
 c/driver/postgresql/statement.cc                  | 99 ++++++++++++-----------
 c/vendor/nanoarrow/nanoarrow.hpp                  |  2 +
 python/adbc_driver_postgresql/tests/test_dbapi.py | 15 ++++
 4 files changed, 127 insertions(+), 47 deletions(-)

diff --git a/c/driver/postgresql/postgresql_test.cc 
b/c/driver/postgresql/postgresql_test.cc
index c3e8dc7..39859e6 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -329,6 +329,64 @@ class PostgresStatementTest : public ::testing::Test,
 };
 ADBCV_TEST_STATEMENT(PostgresStatementTest)
 
+TEST_F(PostgresStatementTest, UpdateInExecuteQuery) {
+  ASSERT_THAT(quirks()->DropTable(&connection, "adbc_test", &error), 
IsOkStatus(&error));
+
+  ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
+
+  {
+    ASSERT_THAT(AdbcStatementSetSqlQuery(
+                    &statement,
+                    "CREATE TABLE adbc_test (ints INT, id SERIAL PRIMARY 
KEY)", &error),
+                IsOkStatus(&error));
+    adbc_validation::StreamReader reader;
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+                                          &reader.rows_affected, &error),
+                IsOkStatus(&error));
+    ASSERT_EQ(reader.rows_affected, 0);
+    ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    ASSERT_EQ(reader.array->release, nullptr);
+  }
+
+  {
+    // Use INSERT INTO
+    ASSERT_THAT(AdbcStatementSetSqlQuery(
+                    &statement, "INSERT INTO adbc_test (ints) VALUES (1), 
(2)", &error),
+                IsOkStatus(&error));
+    adbc_validation::StreamReader reader;
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+                                          &reader.rows_affected, &error),
+                IsOkStatus(&error));
+    ASSERT_EQ(reader.rows_affected, 0);
+    ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    ASSERT_EQ(reader.array->release, nullptr);
+  }
+
+  {
+    // Use INSERT INTO ... RETURNING
+    ASSERT_THAT(AdbcStatementSetSqlQuery(
+                    &statement,
+                    "INSERT INTO adbc_test (ints) VALUES (3), (4) RETURNING 
id", &error),
+                IsOkStatus(&error));
+    adbc_validation::StreamReader reader;
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+                                          &reader.rows_affected, &error),
+                IsOkStatus(&error));
+    ASSERT_EQ(reader.rows_affected, -1);
+    ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    ASSERT_NE(reader.array->release, nullptr);
+    ASSERT_EQ(reader.array->n_children, 1);
+    ASSERT_EQ(reader.array->length, 2);
+    
ASSERT_EQ(reader.array_view->children[0]->buffer_views[1].data.as_int32[0], 3);
+    
ASSERT_EQ(reader.array_view->children[0]->buffer_views[1].data.as_int32[1], 4);
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    ASSERT_EQ(reader.array->release, nullptr);
+  }
+}
+
 struct TypeTestCase {
   std::string name;
   std::string sql_type;
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 494c736..b91d2fe 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -239,8 +239,8 @@ struct BindStream {
     PGresult* result = PQprepare(conn, /*stmtName=*/"", query.c_str(),
                                  /*nParams=*/bind_schema->n_children, 
param_types.data());
     if (PQresultStatus(result) != PGRES_COMMAND_OK) {
-      SetError(error, "%s%s", "[libpq] Failed to prepare query: ", 
PQerrorMessage(conn));
-      SetError(error, "%s%s", "[libpq] Query: ", query.c_str());
+      SetError(error, "[libpq] Failed to prepare query: %s\nQuery was:%s",
+               PQerrorMessage(conn), query.c_str());
       PQclear(result);
       return ADBC_STATUS_IO;
     }
@@ -256,10 +256,10 @@ struct BindStream {
       Handle<struct ArrowArray> array;
       int res = bind->get_next(&bind.value, &array.value);
       if (res != 0) {
-        // TODO: include errno
-        SetError(error, "%s%s",
-                 "[libpq] Failed to read next batch from stream of bind 
parameters: ",
-                 bind->get_last_error(&bind.value));
+        SetError(error,
+                 "[libpq] Failed to read next batch from stream of bind 
parameters: "
+                 "(%d) %s %s",
+                 res, std::strerror(res), bind->get_last_error(&bind.value));
         return ADBC_STATUS_IO;
       }
       if (!array->release) break;
@@ -584,9 +584,8 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
                                   /*paramLengths=*/nullptr, 
/*paramFormats=*/nullptr,
                                   /*resultFormat=*/1 /*(binary)*/);
   if (PQresultStatus(result) != PGRES_COMMAND_OK) {
-    SetError(error, "%s%s",
-             "[libpq] Failed to create table: ", 
PQerrorMessage(connection_->conn()));
-    SetError(error, "%s%s", "[libpq] Query: ", create.c_str());
+    SetError(error, "[libpq] Failed to create table: %s\nQuery was: %s",
+             PQerrorMessage(connection_->conn()), create.c_str());
     PQclear(result);
     return ADBC_STATUS_IO;
   }
@@ -637,11 +636,8 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct 
ArrowArrayStream* stream,
     // and https://stackoverflow.com/questions/69233792 suggests that
     // you can't PREPARE a query containing COPY.
   }
-  if (!stream) {
-    if (!ingest_.target.empty()) {
-      return ExecuteUpdateBulk(rows_affected, error);
-    }
-    return ExecuteUpdateQuery(rows_affected, error);
+  if (!stream && !ingest_.target.empty()) {
+    return ExecuteUpdateBulk(rows_affected, error);
   }
 
   if (query_.empty()) {
@@ -649,18 +645,26 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct 
ArrowArrayStream* stream,
     return ADBC_STATUS_INVALID_STATE;
   }
 
-  // 1. Execute the query with LIMIT 0 to get the schema
+  // 1. Prepare the query to get the schema
   {
     // TODO: we should pipeline here and assume this will succeed
-    std::string schema_query = "SELECT * FROM (" + query_ + ") AS ignored 
LIMIT 0";
-    PGresult* result =
-        PQexecParams(connection_->conn(), query_.c_str(), /*nParams=*/0,
-                     /*paramTypes=*/nullptr, /*paramValues=*/nullptr,
-                     /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, 
kPgBinaryFormat);
-    if (PQresultStatus(result) != PGRES_TUPLES_OK) {
-      SetError(error, "%s%s", "[libpq] Query was: ", schema_query.c_str());
-      SetError(error, "%s%s", "[libpq] Failed to execute query: could not 
infer schema: ",
-               PQerrorMessage(connection_->conn()));
+    PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", 
query_.c_str(),
+                                 /*nParams=*/0, nullptr);
+    if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+      SetError(error,
+               "[libpq] Failed to execute query: could not infer schema: 
failed to "
+               "prepare query: %s\nQuery was:%s",
+               PQerrorMessage(connection_->conn()), query_.c_str());
+      PQclear(result);
+      return ADBC_STATUS_IO;
+    }
+    PQclear(result);
+    result = PQdescribePrepared(connection_->conn(), /*stmtName=*/"");
+    if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+      SetError(error,
+               "[libpq] Failed to execute query: could not infer schema: 
failed to "
+               "describe prepared statement: %s\nQuery was:%s",
+               PQerrorMessage(connection_->conn()), query_.c_str());
       PQclear(result);
       return ADBC_STATUS_IO;
     }
@@ -683,6 +687,20 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct 
ArrowArrayStream* stream,
       return na_res;
     }
 
+    // If the caller did not request a result set or if there are no
+    // inferred output columns (e.g. a CREATE or UPDATE), then don't
+    // use COPY (which would fail anyways)
+    if (!stream || root_type.n_children() == 0) {
+      RAISE_ADBC(ExecuteUpdateQuery(rows_affected, error));
+      if (stream) {
+        struct ArrowSchema schema;
+        std::memset(&schema, 0, sizeof(schema));
+        RAISE_NA(reader_.copy_reader_->GetSchema(&schema));
+        nanoarrow::EmptyArrayStream::MakeUnique(&schema).move(stream);
+      }
+      return ADBC_STATUS_OK;
+    }
+
     // This resolves the reader specific to each PostgresType -> ArrowSchema
     // conversion. It is unlikely that this will fail given that we have just
     // inferred these conversions ourselves.
@@ -701,9 +719,9 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct 
ArrowArrayStream* stream,
                      /*paramTypes=*/nullptr, /*paramValues=*/nullptr,
                      /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, 
kPgBinaryFormat);
     if (PQresultStatus(reader_.result_) != PGRES_COPY_OUT) {
-      SetError(error, "%s%s", "[libpq] Query was: ", copy_query.c_str());
-      SetError(error, "%s%s", "[libpq] Failed to execute query: could not 
begin COPY: ",
-               PQerrorMessage(connection_->conn()));
+      SetError(error,
+               "[libpq] Failed to execute query: could not begin COPY: 
%s\nQuery was: %s",
+               PQerrorMessage(connection_->conn()), copy_query.c_str());
       ClearResult();
       return ADBC_STATUS_IO;
     }
@@ -753,27 +771,14 @@ AdbcStatusCode 
PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected,
 
 AdbcStatusCode PostgresStatement::ExecuteUpdateQuery(int64_t* rows_affected,
                                                      struct AdbcError* error) {
-  if (query_.empty()) {
-    SetError(error, "%s", "[libpq] Must SetSqlQuery before ExecuteQuery");
-    return ADBC_STATUS_INVALID_STATE;
-  }
-
-  PGresult* result = nullptr;
-
-  if (prepared_) {
-    result = PQexecPrepared(connection_->conn(), /*stmtName=*/"", 
/*nParams=*/0,
-                            /*paramValues=*/nullptr, /*paramLengths=*/nullptr,
-                            /*paramFormats=*/nullptr, 
/*resultFormat=*/kPgBinaryFormat);
-  } else {
-    result = PQexecParams(connection_->conn(), query_.c_str(), /*nParams=*/0,
-                          /*paramTypes=*/nullptr, /*paramValues=*/nullptr,
-                          /*paramLengths=*/nullptr, /*paramFormats=*/nullptr,
-                          /*resultFormat=*/kPgBinaryFormat);
-  }
+  // NOTE: must prepare first (used in ExecuteQuery)
+  PGresult* result =
+      PQexecPrepared(connection_->conn(), /*stmtName=*/"", /*nParams=*/0,
+                     /*paramValues=*/nullptr, /*paramLengths=*/nullptr,
+                     /*paramFormats=*/nullptr, 
/*resultFormat=*/kPgBinaryFormat);
   if (PQresultStatus(result) != PGRES_COMMAND_OK) {
-    SetError(error, "%s%s", "[libpq] Query was: ", query_.c_str());
-    SetError(error, "%s%s",
-             "[libpq] Failed to execute query: ", 
PQerrorMessage(connection_->conn()));
+    SetError(error, "[libpq] Failed to execute query: %s\nQuery was:%s",
+             PQerrorMessage(connection_->conn()), query_.c_str());
     PQclear(result);
     return ADBC_STATUS_IO;
   }
diff --git a/c/vendor/nanoarrow/nanoarrow.hpp b/c/vendor/nanoarrow/nanoarrow.hpp
index b01d2a6..468e911 100644
--- a/c/vendor/nanoarrow/nanoarrow.hpp
+++ b/c/vendor/nanoarrow/nanoarrow.hpp
@@ -250,6 +250,8 @@ class EmptyArrayStream {
 
   static void release_wrapper(struct ArrowArrayStream* stream) {
     delete reinterpret_cast<EmptyArrayStream*>(stream->private_data);
+    stream->release = nullptr;
+    stream->private_data = nullptr;
   }
 };
 
diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py 
b/python/adbc_driver_postgresql/tests/test_dbapi.py
index 7344f1a..1630108 100644
--- a/python/adbc_driver_postgresql/tests/test_dbapi.py
+++ b/python/adbc_driver_postgresql/tests/test_dbapi.py
@@ -30,3 +30,18 @@ def test_query_trivial(postgres: dbapi.Connection):
     with postgres.cursor() as cur:
         cur.execute("SELECT 1")
         assert cur.fetchone() == (1,)
+
+
+def test_ddl(postgres: dbapi.Connection):
+    with postgres.cursor() as cur:
+        cur.execute("DROP TABLE IF EXISTS test_ddl")
+        assert cur.fetchone() is None
+
+        cur.execute("CREATE TABLE test_ddl (ints INT)")
+        assert cur.fetchone() is None
+
+        cur.execute("INSERT INTO test_ddl VALUES (1) RETURNING ints")
+        assert cur.fetchone() == (1,)
+
+        cur.execute("SELECT * FROM test_ddl")
+        assert cur.fetchone() == (1,)

Reply via email to