This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new a6bd37a feat(c/driver/postgresql): handle non-SELECT statements (#707)
a6bd37a is described below
commit a6bd37a9f2189b41d6f7a70f10ee72ef6d615aab
Author: David Li <[email protected]>
AuthorDate: Thu May 25 18:00:15 2023 -0400
feat(c/driver/postgresql): handle non-SELECT statements (#707)
Before we tried to infer the query schema before COPY by wrapping it in
a "SELECT * FROM (...) LIMIT 0". This broke if the query was (for
example) a CREATE or UPDATE. Instead, use a prepared statement to infer
instead. If we find that there are no result columns, then execute
without the COPY path.
Also, test what happens with "INSERT INTO ... RETURNING" (this works
with COPY).
Fixes #701.
---
c/driver/postgresql/postgresql_test.cc | 58 +++++++++++++
c/driver/postgresql/statement.cc | 99 ++++++++++++-----------
c/vendor/nanoarrow/nanoarrow.hpp | 2 +
python/adbc_driver_postgresql/tests/test_dbapi.py | 15 ++++
4 files changed, 127 insertions(+), 47 deletions(-)
diff --git a/c/driver/postgresql/postgresql_test.cc
b/c/driver/postgresql/postgresql_test.cc
index c3e8dc7..39859e6 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -329,6 +329,64 @@ class PostgresStatementTest : public ::testing::Test,
};
ADBCV_TEST_STATEMENT(PostgresStatementTest)
+TEST_F(PostgresStatementTest, UpdateInExecuteQuery) {
+ ASSERT_THAT(quirks()->DropTable(&connection, "adbc_test", &error),
IsOkStatus(&error));
+
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));
+
+ {
+ ASSERT_THAT(AdbcStatementSetSqlQuery(
+ &statement,
+ "CREATE TABLE adbc_test (ints INT, id SERIAL PRIMARY
KEY)", &error),
+ IsOkStatus(&error));
+ adbc_validation::StreamReader reader;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+ &reader.rows_affected, &error),
+ IsOkStatus(&error));
+ ASSERT_EQ(reader.rows_affected, 0);
+ ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_EQ(reader.array->release, nullptr);
+ }
+
+ {
+ // Use INSERT INTO
+ ASSERT_THAT(AdbcStatementSetSqlQuery(
+ &statement, "INSERT INTO adbc_test (ints) VALUES (1),
(2)", &error),
+ IsOkStatus(&error));
+ adbc_validation::StreamReader reader;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+ &reader.rows_affected, &error),
+ IsOkStatus(&error));
+ ASSERT_EQ(reader.rows_affected, 0);
+ ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_EQ(reader.array->release, nullptr);
+ }
+
+ {
+ // Use INSERT INTO ... RETURNING
+ ASSERT_THAT(AdbcStatementSetSqlQuery(
+ &statement,
+ "INSERT INTO adbc_test (ints) VALUES (3), (4) RETURNING
id", &error),
+ IsOkStatus(&error));
+ adbc_validation::StreamReader reader;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+ &reader.rows_affected, &error),
+ IsOkStatus(&error));
+ ASSERT_EQ(reader.rows_affected, -1);
+ ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_NE(reader.array->release, nullptr);
+ ASSERT_EQ(reader.array->n_children, 1);
+ ASSERT_EQ(reader.array->length, 2);
+
ASSERT_EQ(reader.array_view->children[0]->buffer_views[1].data.as_int32[0], 3);
+
ASSERT_EQ(reader.array_view->children[0]->buffer_views[1].data.as_int32[1], 4);
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_EQ(reader.array->release, nullptr);
+ }
+}
+
struct TypeTestCase {
std::string name;
std::string sql_type;
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 494c736..b91d2fe 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -239,8 +239,8 @@ struct BindStream {
PGresult* result = PQprepare(conn, /*stmtName=*/"", query.c_str(),
/*nParams=*/bind_schema->n_children,
param_types.data());
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
- SetError(error, "%s%s", "[libpq] Failed to prepare query: ",
PQerrorMessage(conn));
- SetError(error, "%s%s", "[libpq] Query: ", query.c_str());
+ SetError(error, "[libpq] Failed to prepare query: %s\nQuery was:%s",
+ PQerrorMessage(conn), query.c_str());
PQclear(result);
return ADBC_STATUS_IO;
}
@@ -256,10 +256,10 @@ struct BindStream {
Handle<struct ArrowArray> array;
int res = bind->get_next(&bind.value, &array.value);
if (res != 0) {
- // TODO: include errno
- SetError(error, "%s%s",
- "[libpq] Failed to read next batch from stream of bind
parameters: ",
- bind->get_last_error(&bind.value));
+ SetError(error,
+ "[libpq] Failed to read next batch from stream of bind
parameters: "
+ "(%d) %s %s",
+ res, std::strerror(res), bind->get_last_error(&bind.value));
return ADBC_STATUS_IO;
}
if (!array->release) break;
@@ -584,9 +584,8 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
/*paramLengths=*/nullptr,
/*paramFormats=*/nullptr,
/*resultFormat=*/1 /*(binary)*/);
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
- SetError(error, "%s%s",
- "[libpq] Failed to create table: ",
PQerrorMessage(connection_->conn()));
- SetError(error, "%s%s", "[libpq] Query: ", create.c_str());
+ SetError(error, "[libpq] Failed to create table: %s\nQuery was: %s",
+ PQerrorMessage(connection_->conn()), create.c_str());
PQclear(result);
return ADBC_STATUS_IO;
}
@@ -637,11 +636,8 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct
ArrowArrayStream* stream,
// and https://stackoverflow.com/questions/69233792 suggests that
// you can't PREPARE a query containing COPY.
}
- if (!stream) {
- if (!ingest_.target.empty()) {
- return ExecuteUpdateBulk(rows_affected, error);
- }
- return ExecuteUpdateQuery(rows_affected, error);
+ if (!stream && !ingest_.target.empty()) {
+ return ExecuteUpdateBulk(rows_affected, error);
}
if (query_.empty()) {
@@ -649,18 +645,26 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct
ArrowArrayStream* stream,
return ADBC_STATUS_INVALID_STATE;
}
- // 1. Execute the query with LIMIT 0 to get the schema
+ // 1. Prepare the query to get the schema
{
// TODO: we should pipeline here and assume this will succeed
- std::string schema_query = "SELECT * FROM (" + query_ + ") AS ignored
LIMIT 0";
- PGresult* result =
- PQexecParams(connection_->conn(), query_.c_str(), /*nParams=*/0,
- /*paramTypes=*/nullptr, /*paramValues=*/nullptr,
- /*paramLengths=*/nullptr, /*paramFormats=*/nullptr,
kPgBinaryFormat);
- if (PQresultStatus(result) != PGRES_TUPLES_OK) {
- SetError(error, "%s%s", "[libpq] Query was: ", schema_query.c_str());
- SetError(error, "%s%s", "[libpq] Failed to execute query: could not
infer schema: ",
- PQerrorMessage(connection_->conn()));
+ PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"",
query_.c_str(),
+ /*nParams=*/0, nullptr);
+ if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+ SetError(error,
+ "[libpq] Failed to execute query: could not infer schema:
failed to "
+ "prepare query: %s\nQuery was:%s",
+ PQerrorMessage(connection_->conn()), query_.c_str());
+ PQclear(result);
+ return ADBC_STATUS_IO;
+ }
+ PQclear(result);
+ result = PQdescribePrepared(connection_->conn(), /*stmtName=*/"");
+ if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+ SetError(error,
+ "[libpq] Failed to execute query: could not infer schema:
failed to "
+ "describe prepared statement: %s\nQuery was:%s",
+ PQerrorMessage(connection_->conn()), query_.c_str());
PQclear(result);
return ADBC_STATUS_IO;
}
@@ -683,6 +687,20 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct
ArrowArrayStream* stream,
return na_res;
}
+ // If the caller did not request a result set or if there are no
+ // inferred output columns (e.g. a CREATE or UPDATE), then don't
+ // use COPY (which would fail anyways)
+ if (!stream || root_type.n_children() == 0) {
+ RAISE_ADBC(ExecuteUpdateQuery(rows_affected, error));
+ if (stream) {
+ struct ArrowSchema schema;
+ std::memset(&schema, 0, sizeof(schema));
+ RAISE_NA(reader_.copy_reader_->GetSchema(&schema));
+ nanoarrow::EmptyArrayStream::MakeUnique(&schema).move(stream);
+ }
+ return ADBC_STATUS_OK;
+ }
+
// This resolves the reader specific to each PostgresType -> ArrowSchema
// conversion. It is unlikely that this will fail given that we have just
// inferred these conversions ourselves.
@@ -701,9 +719,9 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct
ArrowArrayStream* stream,
/*paramTypes=*/nullptr, /*paramValues=*/nullptr,
/*paramLengths=*/nullptr, /*paramFormats=*/nullptr,
kPgBinaryFormat);
if (PQresultStatus(reader_.result_) != PGRES_COPY_OUT) {
- SetError(error, "%s%s", "[libpq] Query was: ", copy_query.c_str());
- SetError(error, "%s%s", "[libpq] Failed to execute query: could not
begin COPY: ",
- PQerrorMessage(connection_->conn()));
+ SetError(error,
+ "[libpq] Failed to execute query: could not begin COPY:
%s\nQuery was: %s",
+ PQerrorMessage(connection_->conn()), copy_query.c_str());
ClearResult();
return ADBC_STATUS_IO;
}
@@ -753,27 +771,14 @@ AdbcStatusCode
PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected,
AdbcStatusCode PostgresStatement::ExecuteUpdateQuery(int64_t* rows_affected,
struct AdbcError* error) {
- if (query_.empty()) {
- SetError(error, "%s", "[libpq] Must SetSqlQuery before ExecuteQuery");
- return ADBC_STATUS_INVALID_STATE;
- }
-
- PGresult* result = nullptr;
-
- if (prepared_) {
- result = PQexecPrepared(connection_->conn(), /*stmtName=*/"",
/*nParams=*/0,
- /*paramValues=*/nullptr, /*paramLengths=*/nullptr,
- /*paramFormats=*/nullptr,
/*resultFormat=*/kPgBinaryFormat);
- } else {
- result = PQexecParams(connection_->conn(), query_.c_str(), /*nParams=*/0,
- /*paramTypes=*/nullptr, /*paramValues=*/nullptr,
- /*paramLengths=*/nullptr, /*paramFormats=*/nullptr,
- /*resultFormat=*/kPgBinaryFormat);
- }
+ // NOTE: must prepare first (used in ExecuteQuery)
+ PGresult* result =
+ PQexecPrepared(connection_->conn(), /*stmtName=*/"", /*nParams=*/0,
+ /*paramValues=*/nullptr, /*paramLengths=*/nullptr,
+ /*paramFormats=*/nullptr,
/*resultFormat=*/kPgBinaryFormat);
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
- SetError(error, "%s%s", "[libpq] Query was: ", query_.c_str());
- SetError(error, "%s%s",
- "[libpq] Failed to execute query: ",
PQerrorMessage(connection_->conn()));
+ SetError(error, "[libpq] Failed to execute query: %s\nQuery was:%s",
+ PQerrorMessage(connection_->conn()), query_.c_str());
PQclear(result);
return ADBC_STATUS_IO;
}
diff --git a/c/vendor/nanoarrow/nanoarrow.hpp b/c/vendor/nanoarrow/nanoarrow.hpp
index b01d2a6..468e911 100644
--- a/c/vendor/nanoarrow/nanoarrow.hpp
+++ b/c/vendor/nanoarrow/nanoarrow.hpp
@@ -250,6 +250,8 @@ class EmptyArrayStream {
static void release_wrapper(struct ArrowArrayStream* stream) {
delete reinterpret_cast<EmptyArrayStream*>(stream->private_data);
+ stream->release = nullptr;
+ stream->private_data = nullptr;
}
};
diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py
b/python/adbc_driver_postgresql/tests/test_dbapi.py
index 7344f1a..1630108 100644
--- a/python/adbc_driver_postgresql/tests/test_dbapi.py
+++ b/python/adbc_driver_postgresql/tests/test_dbapi.py
@@ -30,3 +30,18 @@ def test_query_trivial(postgres: dbapi.Connection):
with postgres.cursor() as cur:
cur.execute("SELECT 1")
assert cur.fetchone() == (1,)
+
+
+def test_ddl(postgres: dbapi.Connection):
+ with postgres.cursor() as cur:
+ cur.execute("DROP TABLE IF EXISTS test_ddl")
+ assert cur.fetchone() is None
+
+ cur.execute("CREATE TABLE test_ddl (ints INT)")
+ assert cur.fetchone() is None
+
+ cur.execute("INSERT INTO test_ddl VALUES (1) RETURNING ints")
+ assert cur.fetchone() == (1,)
+
+ cur.execute("SELECT * FROM test_ddl")
+ assert cur.fetchone() == (1,)