This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new da33e6c7 fix(c/driver/postgresql): Fix segfault associated with 
uninitialized copy_reader_ (#964)
da33e6c7 is described below

commit da33e6c7f0fb3086b0b8d864f4ad276df43a13d4
Author: Solomon Choe <[email protected]>
AuthorDate: Thu Aug 10 11:56:30 2023 -0700

    fix(c/driver/postgresql): Fix segfault associated with uninitialized 
copy_reader_ (#964)
    
    Fixes #958.
---
 c/driver/postgresql/statement.cc                  | 39 ++++++++---------------
 c/driver/postgresql/statement.h                   |  5 +--
 python/adbc_driver_postgresql/tests/test_dbapi.py | 28 ++++++++++++++++
 3 files changed, 44 insertions(+), 28 deletions(-)

diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index d8c474bd..0b8f1fc0 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -18,6 +18,7 @@
 #include "statement.h"
 
 #include <array>
+#include <cassert>
 #include <cerrno>
 #include <cinttypes>
 #include <cstring>
@@ -511,6 +512,8 @@ struct BindStream {
 }  // namespace
 
 int TupleReader::GetSchema(struct ArrowSchema* out) {
+  assert(copy_reader_ != nullptr);
+
   int na_res = copy_reader_->GetSchema(out);
   if (out->release == nullptr) {
     StringBuilderAppend(&error_builder_,
@@ -525,8 +528,6 @@ int TupleReader::GetSchema(struct ArrowSchema* out) {
 }
 
 int TupleReader::InitQueryAndFetchFirst(struct ArrowError* error) {
-  ResetQuery();
-
   // Fetch + parse the header
   int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0);
   data_.size_bytes = get_copy_res;
@@ -601,27 +602,8 @@ int TupleReader::BuildOutput(struct ArrowArray* out, 
struct ArrowError* error) {
   return NANOARROW_OK;
 }
 
-void TupleReader::ResetQuery() {
-  // Clear result
-  if (result_) {
-    PQclear(result_);
-    result_ = nullptr;
-  }
-
-  // Reset result buffer
-  if (pgbuf_ != nullptr) {
-    PQfreemem(pgbuf_);
-    pgbuf_ = nullptr;
-  }
-
-  // Clear the error builder
-  error_builder_.size = 0;
-
-  row_id_ = -1;
-}
-
 int TupleReader::GetNext(struct ArrowArray* out) {
-  if (!copy_reader_) {
+  if (is_finished_) {
     out->release = nullptr;
     return 0;
   }
@@ -649,15 +631,14 @@ int TupleReader::GetNext(struct ArrowArray* out) {
     return na_res;
   }
 
+  is_finished_ = true;
+
   // Finish the result properly and return the last result. Note that 
BuildOutput() may
   // set tmp.release = nullptr if there were zero rows in the copy reader (can
   // occur in an overflow scenario).
   struct ArrowArray tmp;
   NANOARROW_RETURN_NOT_OK(BuildOutput(&tmp, &error));
 
-  // Clear the copy reader to mark this reader as finished
-  copy_reader_.reset();
-
   // Check the server-side response
   result_ = PQgetResult(conn_);
   const int pq_status = PQresultStatus(result_);
@@ -672,7 +653,6 @@ int TupleReader::GetNext(struct ArrowArray* out) {
     return EIO;
   }
 
-  ResetQuery();
   ArrowArrayMove(&tmp, out);
   return NANOARROW_OK;
 }
@@ -689,6 +669,13 @@ void TupleReader::Release() {
     PQfreemem(pgbuf_);
     pgbuf_ = nullptr;
   }
+
+  if (copy_reader_) {
+    copy_reader_.reset();
+  }
+
+  is_finished_ = false;
+  row_id_ = -1;
 }
 
 void TupleReader::ExportTo(struct ArrowArrayStream* stream) {
diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h
index 62af2457..0326e80e 100644
--- a/c/driver/postgresql/statement.h
+++ b/c/driver/postgresql/statement.h
@@ -46,7 +46,8 @@ class TupleReader final {
         pgbuf_(nullptr),
         copy_reader_(nullptr),
         row_id_(-1),
-        batch_size_hint_bytes_(16777216) {
+        batch_size_hint_bytes_(16777216),
+        is_finished_(false) {
     StringBuilderInit(&error_builder_, 0);
     data_.data.as_char = nullptr;
     data_.size_bytes = 0;
@@ -70,7 +71,6 @@ class TupleReader final {
   int InitQueryAndFetchFirst(struct ArrowError* error);
   int AppendRowAndFetchNext(struct ArrowError* error);
   int BuildOutput(struct ArrowArray* out, struct ArrowError* error);
-  void ResetQuery();
 
   static int GetSchemaTrampoline(struct ArrowArrayStream* self, struct 
ArrowSchema* out);
   static int GetNextTrampoline(struct ArrowArrayStream* self, struct 
ArrowArray* out);
@@ -85,6 +85,7 @@ class TupleReader final {
   std::unique_ptr<PostgresCopyStreamReader> copy_reader_;
   int64_t row_id_;
   int64_t batch_size_hint_bytes_;
+  bool is_finished_;
 };
 
 class PostgresStatement {
diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py 
b/python/adbc_driver_postgresql/tests/test_dbapi.py
index c50cad1e..80339d60 100644
--- a/python/adbc_driver_postgresql/tests/test_dbapi.py
+++ b/python/adbc_driver_postgresql/tests/test_dbapi.py
@@ -78,3 +78,31 @@ def test_ddl(postgres: dbapi.Connection):
 
         cur.execute("SELECT * FROM test_ddl")
         assert cur.fetchone() == (1,)
+
+
+def test_crash(postgres: dbapi.Connection) -> None:
+    with postgres.cursor() as cur:
+        cur.execute("SELECT 1")
+        assert cur.fetchone() == (1,)
+
+
+def test_reuse(postgres: dbapi.Connection) -> None:
+    with postgres.cursor() as cur:
+        cur.execute("DROP TABLE IF EXISTS test_batch_size")
+        cur.execute("CREATE TABLE test_batch_size (ints INT)")
+        cur.execute(
+            """
+            INSERT INTO test_batch_size (ints)
+            SELECT generated :: INT
+            FROM GENERATE_SERIES(1, 65536) temp(generated)
+        """
+        )
+
+        cur.execute("SELECT * FROM test_batch_size ORDER BY ints ASC")
+        assert cur.fetchone() == (1,)
+
+        cur.execute("SELECT 1")
+        assert cur.fetchone() == (1,)
+
+        cur.execute("SELECT 2")
+        assert cur.fetchone() == (2,)

Reply via email to