This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 08901e8  refactor(c/driver/postgresql): implement InputIterator for 
ResultHelper (#683)
08901e8 is described below

commit 08901e8df4ec233e34be6cb09f51d32ceafbb9e2
Author: William Ayd <[email protected]>
AuthorDate: Wed May 17 09:18:34 2023 -0700

    refactor(c/driver/postgresql): implement InputIterator for ResultHelper 
(#683)
---
 c/driver/postgresql/connection.cc | 126 +++++++++++++++++++++++++++++---------
 1 file changed, 97 insertions(+), 29 deletions(-)

diff --git a/c/driver/postgresql/connection.cc 
b/c/driver/postgresql/connection.cc
index 684dee5..d9a5d13 100644
--- a/c/driver/postgresql/connection.cc
+++ b/c/driver/postgresql/connection.cc
@@ -17,10 +17,12 @@
 
 #include "connection.h"
 
+#include <cassert>
 #include <cinttypes>
 #include <cstring>
 #include <memory>
 #include <string>
+#include <vector>
 
 #include <adbc.h>
 #include <libpq-fe.h>
@@ -35,20 +37,89 @@ static const uint32_t kSupportedInfoCodes[] = {
     ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ARROW_VERSION,
 };
 
+struct PqRecord {
+  const char* data;
+  const int len;
+  const bool is_null;
+};
+
+// Used by PqResultHelper to provide index-based access to the records within 
each
+// row of a pg_result
+class PqResultRow {
+ public:
+  PqResultRow(pg_result* result, int row_num) : result_(result), 
row_num_(row_num) {
+    ncols_ = PQnfields(result);
+  }
+
+  PqRecord operator[](const int& col_num) {
+    assert(col_num < ncols_);
+    const char* data = PQgetvalue(result_, row_num_, col_num);
+    const int len = PQgetlength(result_, row_num_, col_num);
+    const bool is_null = PQgetisnull(result_, row_num_, col_num);
+
+    return PqRecord{data, len, is_null};
+  }
+
+ private:
+  pg_result* result_ = nullptr;
+  int row_num_;
+  int ncols_;
+};
+
+// Helper to manager the lifecycle of a PQResult. The query argument
+// will be evaluated as part of the constructor, with the desctructor handling 
cleanup
+// Caller is responsible for calling the `Status()` method to ensure results 
are
+// as expected prior to iterating
 class PqResultHelper {
  public:
   PqResultHelper(PGconn* conn, const char* query) : conn_(conn) {
     query_ = std::string(query);
-  }
-  pg_result* Execute() {
     result_ = PQexec(conn_, query_.c_str());
-    return result_;
   }
 
+  ExecStatusType Status() { return PQresultStatus(result_); }
+
   ~PqResultHelper() {
-    if (result_ != nullptr) PQclear(result_);
+    if (result_ != nullptr) {
+      PQclear(result_);
+    }
   }
 
+  int NumRows() { return PQntuples(result_); }
+
+  int NumColumns() { return PQnfields(result_); }
+
+  class iterator {
+    const PqResultHelper& outer_;
+    int curr_row_ = 0;
+
+   public:
+    explicit iterator(const PqResultHelper& outer, int curr_row = 0)
+        : outer_(outer), curr_row_(curr_row) {}
+    iterator& operator++() {
+      curr_row_++;
+      return *this;
+    }
+    iterator operator++(int) {
+      iterator retval = *this;
+      ++(*this);
+      return retval;
+    }
+    bool operator==(iterator other) const {
+      return outer_.result_ == other.outer_.result_ && curr_row_ == 
other.curr_row_;
+    }
+    bool operator!=(iterator other) const { return !(*this == other); }
+    PqResultRow operator*() { return PqResultRow(outer_.result_, curr_row_); }
+    using iterator_category = std::forward_iterator_tag;
+    using difference_type = std::ptrdiff_t;
+    using value_type = std::vector<PqResultRow>;
+    using pointer = const std::vector<PqResultRow>*;
+    using reference = const std::vector<PqResultRow>&;
+  };
+
+  iterator begin() { return iterator(*this); }
+  iterator end() { return iterator(*this, NumRows()); }
+
  private:
   pg_result* result_ = nullptr;
   PGconn* conn_;
@@ -170,13 +241,10 @@ AdbcStatusCode PostgresConnectionGetObjectsImpl(
 
     PqResultHelper result_helper = PqResultHelper{conn, query.buffer};
     StringBuilderReset(&query);
-    pg_result* result = result_helper.Execute();
 
-    ExecStatusType pq_status = PQresultStatus(result);
-    if (pq_status == PGRES_TUPLES_OK) {
-      int num_rows = PQntuples(result);
-      for (int row = 0; row < num_rows; row++) {
-        const char* db_name = PQgetvalue(result, row, 0);
+    if (result_helper.Status() == PGRES_TUPLES_OK) {
+      for (PqResultRow row : result_helper) {
+        const char* db_name = row[0].data;
         CHECK_NA(INTERNAL,
                  ArrowArrayAppendString(catalog_name_col, 
ArrowCharView(db_name)), error);
         if (depth == ADBC_OBJECT_DEPTH_CATALOGS) {
@@ -260,39 +328,39 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const 
char* catalog,
 
   PqResultHelper result_helper = PqResultHelper{conn_, query.buffer};
   StringBuilderReset(&query);
-  pg_result* result = result_helper.Execute();
-
-  ExecStatusType pq_status = PQresultStatus(result);
-  auto uschema = nanoarrow::UniqueSchema();
 
-  if (pq_status == PGRES_TUPLES_OK) {
-    int num_rows = PQntuples(result);
+  if (result_helper.Status() != PGRES_TUPLES_OK) {
+    SetError(error, "%s%s", "Failed to get table schema: ", 
PQerrorMessage(conn_));
+    final_status = ADBC_STATUS_IO;
+  } else {
+    auto uschema = nanoarrow::UniqueSchema();
     ArrowSchemaInit(uschema.get());
-    CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), num_rows), 
error);
+    CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), 
result_helper.NumRows()),
+             error);
 
     ArrowError na_error;
-    for (int row = 0; row < num_rows; row++) {
-      const char* colname = PQgetvalue(result, row, 0);
+    int row_counter = 0;
+    for (auto row : result_helper) {
+      const char* colname = row[0].data;
       const Oid pg_oid = static_cast<uint32_t>(
-          std::strtol(PQgetvalue(result, row, 1), /*str_end=*/nullptr, 
/*base=*/10));
+          std::strtol(row[1].data, /*str_end=*/nullptr, /*base=*/10));
 
       PostgresType pg_type;
       if (type_resolver_->Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) {
-        SetError(error, "%s%d%s%s%s%" PRIu32, "Column #", row + 1, " (\"", 
colname,
-                 "\") has unknown type code ", pg_oid);
+        SetError(error, "%s%d%s%s%s%" PRIu32, "Column #", row_counter + 1, " 
(\"",
+                 colname, "\") has unknown type code ", pg_oid);
         final_status = ADBC_STATUS_NOT_IMPLEMENTED;
-        break;
+        goto loopExit;
       }
-
-      CHECK_NA(INTERNAL, 
pg_type.WithFieldName(colname).SetSchema(uschema->children[row]),
+      CHECK_NA(INTERNAL,
+               
pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter]),
                error);
+      row_counter++;
     }
-  } else {
-    SetError(error, "%s%s", "Failed to get table schema: ", 
PQerrorMessage(conn_));
-    final_status = ADBC_STATUS_IO;
+    uschema.move(schema);
   }
+loopExit:
 
-  uschema.move(schema);
   return final_status;
 }
 

Reply via email to