This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new da33e6c7 fix(c/driver/postgresql): Fix segfault associated with
uninitialized copy_reader_ (#964)
da33e6c7 is described below
commit da33e6c7f0fb3086b0b8d864f4ad276df43a13d4
Author: Solomon Choe <[email protected]>
AuthorDate: Thu Aug 10 11:56:30 2023 -0700
fix(c/driver/postgresql): Fix segfault associated with uninitialized
copy_reader_ (#964)
Fixes #958.
---
c/driver/postgresql/statement.cc | 39 ++++++++---------------
c/driver/postgresql/statement.h | 5 +--
python/adbc_driver_postgresql/tests/test_dbapi.py | 28 ++++++++++++++++
3 files changed, 44 insertions(+), 28 deletions(-)
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index d8c474bd..0b8f1fc0 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -18,6 +18,7 @@
#include "statement.h"
#include <array>
+#include <cassert>
#include <cerrno>
#include <cinttypes>
#include <cstring>
@@ -511,6 +512,8 @@ struct BindStream {
} // namespace
int TupleReader::GetSchema(struct ArrowSchema* out) {
+ assert(copy_reader_ != nullptr);
+
int na_res = copy_reader_->GetSchema(out);
if (out->release == nullptr) {
StringBuilderAppend(&error_builder_,
@@ -525,8 +528,6 @@ int TupleReader::GetSchema(struct ArrowSchema* out) {
}
int TupleReader::InitQueryAndFetchFirst(struct ArrowError* error) {
- ResetQuery();
-
// Fetch + parse the header
int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0);
data_.size_bytes = get_copy_res;
@@ -601,27 +602,8 @@ int TupleReader::BuildOutput(struct ArrowArray* out,
struct ArrowError* error) {
return NANOARROW_OK;
}
-void TupleReader::ResetQuery() {
- // Clear result
- if (result_) {
- PQclear(result_);
- result_ = nullptr;
- }
-
- // Reset result buffer
- if (pgbuf_ != nullptr) {
- PQfreemem(pgbuf_);
- pgbuf_ = nullptr;
- }
-
- // Clear the error builder
- error_builder_.size = 0;
-
- row_id_ = -1;
-}
-
int TupleReader::GetNext(struct ArrowArray* out) {
- if (!copy_reader_) {
+ if (is_finished_) {
out->release = nullptr;
return 0;
}
@@ -649,15 +631,14 @@ int TupleReader::GetNext(struct ArrowArray* out) {
return na_res;
}
+ is_finished_ = true;
+
// Finish the result properly and return the last result. Note that
BuildOutput() may
// set tmp.release = nullptr if there were zero rows in the copy reader (can
// occur in an overflow scenario).
struct ArrowArray tmp;
NANOARROW_RETURN_NOT_OK(BuildOutput(&tmp, &error));
- // Clear the copy reader to mark this reader as finished
- copy_reader_.reset();
-
// Check the server-side response
result_ = PQgetResult(conn_);
const int pq_status = PQresultStatus(result_);
@@ -672,7 +653,6 @@ int TupleReader::GetNext(struct ArrowArray* out) {
return EIO;
}
- ResetQuery();
ArrowArrayMove(&tmp, out);
return NANOARROW_OK;
}
@@ -689,6 +669,13 @@ void TupleReader::Release() {
PQfreemem(pgbuf_);
pgbuf_ = nullptr;
}
+
+ if (copy_reader_) {
+ copy_reader_.reset();
+ }
+
+ is_finished_ = false;
+ row_id_ = -1;
}
void TupleReader::ExportTo(struct ArrowArrayStream* stream) {
diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h
index 62af2457..0326e80e 100644
--- a/c/driver/postgresql/statement.h
+++ b/c/driver/postgresql/statement.h
@@ -46,7 +46,8 @@ class TupleReader final {
pgbuf_(nullptr),
copy_reader_(nullptr),
row_id_(-1),
- batch_size_hint_bytes_(16777216) {
+ batch_size_hint_bytes_(16777216),
+ is_finished_(false) {
StringBuilderInit(&error_builder_, 0);
data_.data.as_char = nullptr;
data_.size_bytes = 0;
@@ -70,7 +71,6 @@ class TupleReader final {
int InitQueryAndFetchFirst(struct ArrowError* error);
int AppendRowAndFetchNext(struct ArrowError* error);
int BuildOutput(struct ArrowArray* out, struct ArrowError* error);
- void ResetQuery();
static int GetSchemaTrampoline(struct ArrowArrayStream* self, struct
ArrowSchema* out);
static int GetNextTrampoline(struct ArrowArrayStream* self, struct
ArrowArray* out);
@@ -85,6 +85,7 @@ class TupleReader final {
std::unique_ptr<PostgresCopyStreamReader> copy_reader_;
int64_t row_id_;
int64_t batch_size_hint_bytes_;
+ bool is_finished_;
};
class PostgresStatement {
diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py
b/python/adbc_driver_postgresql/tests/test_dbapi.py
index c50cad1e..80339d60 100644
--- a/python/adbc_driver_postgresql/tests/test_dbapi.py
+++ b/python/adbc_driver_postgresql/tests/test_dbapi.py
@@ -78,3 +78,31 @@ def test_ddl(postgres: dbapi.Connection):
cur.execute("SELECT * FROM test_ddl")
assert cur.fetchone() == (1,)
+
+
+def test_crash(postgres: dbapi.Connection) -> None:
+ with postgres.cursor() as cur:
+ cur.execute("SELECT 1")
+ assert cur.fetchone() == (1,)
+
+
+def test_reuse(postgres: dbapi.Connection) -> None:
+ with postgres.cursor() as cur:
+ cur.execute("DROP TABLE IF EXISTS test_batch_size")
+ cur.execute("CREATE TABLE test_batch_size (ints INT)")
+ cur.execute(
+ """
+ INSERT INTO test_batch_size (ints)
+ SELECT generated :: INT
+ FROM GENERATE_SERIES(1, 65536) temp(generated)
+ """
+ )
+
+ cur.execute("SELECT * FROM test_batch_size ORDER BY ints ASC")
+ assert cur.fetchone() == (1,)
+
+ cur.execute("SELECT 1")
+ assert cur.fetchone() == (1,)
+
+ cur.execute("SELECT 2")
+ assert cur.fetchone() == (2,)