This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 61a1ce00 refactor(c/driver/postgresql): Use Prepared Statement in 
Result Helper (#714)
61a1ce00 is described below

commit 61a1ce00cad58392dd794924ca3e4747ae8d667f
Author: William Ayd <[email protected]>
AuthorDate: Tue May 30 10:15:28 2023 -0700

    refactor(c/driver/postgresql): Use Prepared Statement in Result Helper 
(#714)
---
 c/driver/postgresql/connection.cc | 242 ++++++++++++++++++++------------------
 1 file changed, 128 insertions(+), 114 deletions(-)

diff --git a/c/driver/postgresql/connection.cc 
b/c/driver/postgresql/connection.cc
index 4730721b..e7fa5911 100644
--- a/c/driver/postgresql/connection.cc
+++ b/c/driver/postgresql/connection.cc
@@ -22,6 +22,7 @@
 #include <cstring>
 #include <memory>
 #include <string>
+#include <utility>
 #include <vector>
 
 #include <adbc.h>
@@ -68,16 +69,52 @@ class PqResultRow {
 
 // Helper to manager the lifecycle of a PQResult. The query argument
 // will be evaluated as part of the constructor, with the desctructor handling 
cleanup
-// Caller is responsible for calling the `Status()` method to ensure results 
are
-// as expected prior to iterating
+// Caller must call Prepare then Execute, checking both for an OK 
AdbcStatusCode
+// prior to iterating
 class PqResultHelper {
  public:
-  PqResultHelper(PGconn* conn, const char* query) : conn_(conn) {
-    query_ = std::string(query);
-    result_ = PQexec(conn_, query_.c_str());
+  explicit PqResultHelper(PGconn* conn, std::string query, struct AdbcError* 
error)
+      : conn_(conn), query_(std::move(query)), error_(error) {}
+
+  explicit PqResultHelper(PGconn* conn, std::string query,
+                          std::vector<std::string> param_values, struct 
AdbcError* error)
+      : conn_(conn),
+        query_(std::move(query)),
+        param_values_(param_values),
+        error_(error) {}
+
+  AdbcStatusCode Prepare() {
+    // TODO: make stmtName a unique identifier?
+    PGresult* result =
+        PQprepare(conn_, /*stmtName=*/"", query_.c_str(), 
param_values_.size(), NULL);
+    if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+      SetError(error_, "[libpq] Failed to prepare query: %s\nQuery was:%s",
+               PQerrorMessage(conn_), query_.c_str());
+      PQclear(result);
+      return ADBC_STATUS_IO;
+    }
+
+    PQclear(result);
+    return ADBC_STATUS_OK;
   }
 
-  ExecStatusType Status() { return PQresultStatus(result_); }
+  AdbcStatusCode Execute() {
+    std::vector<const char*> param_c_strs;
+
+    for (auto index = 0; index < param_values_.size(); index++) {
+      param_c_strs.push_back(param_values_[index].c_str());
+    }
+
+    result_ = PQexecPrepared(conn_, "", param_values_.size(), 
param_c_strs.data(), NULL,
+                             NULL, 0);
+
+    if (PQresultStatus(result_) != PGRES_TUPLES_OK) {
+      SetError(error_, "[libpq] Failed to execute query: %s", 
PQerrorMessage(conn_));
+      return ADBC_STATUS_IO;
+    }
+
+    return ADBC_STATUS_OK;
+  }
 
   ~PqResultHelper() {
     if (result_ != nullptr) {
@@ -124,6 +161,8 @@ class PqResultHelper {
   pg_result* result_ = nullptr;
   PGconn* conn_;
   std::string query_;
+  std::vector<std::string> param_values_;
+  struct AdbcError* error_;
 };
 
 class PqGetObjectsHelper {
@@ -146,15 +185,16 @@ class PqGetObjectsHelper {
   }
 
   AdbcStatusCode GetObjects() {
-    PqResultHelper curr_db_helper = PqResultHelper{conn_, "SELECT 
current_database()"};
-    if (curr_db_helper.Status() == PGRES_TUPLES_OK) {
-      assert(curr_db_helper.NumRows() == 1);
-      auto curr_iter = curr_db_helper.begin();
-      PqResultRow db_row = *curr_iter;
-      current_db_ = std::string(db_row[0].data);
-    } else {
-      return ADBC_STATUS_INTERNAL;
-    }
+    PqResultHelper curr_db_helper =
+        PqResultHelper{conn_, std::string("SELECT current_database()"), 
error_};
+
+    RAISE_ADBC(curr_db_helper.Prepare());
+    RAISE_ADBC(curr_db_helper.Execute());
+
+    assert(curr_db_helper.NumRows() == 1);
+    auto curr_iter = curr_db_helper.begin();
+    PqResultRow db_row = *curr_iter;
+    current_db_ = std::string(db_row[0].data);
 
     RAISE_ADBC(InitArrowArray());
 
@@ -197,41 +237,33 @@ class PqGetObjectsHelper {
         return ADBC_STATUS_INTERNAL;
       }
 
+      std::vector<std::string> params;
       if (db_schema_ != NULL) {
-        char* schema_name = PQescapeIdentifier(conn_, db_schema_, 
strlen(db_schema_));
-        if (schema_name == NULL) {
-          SetError(error_, "%s%s", "Failed to escape schema: ", 
PQerrorMessage(conn_));
+        if (StringBuilderAppend(&query, "%s", " AND nspname = $1")) {
           StringBuilderReset(&query);
-          return ADBC_STATUS_INVALID_ARGUMENT;
-        }
-
-        int res =
-            StringBuilderAppend(&query, "%s%s%s", " AND nspname ='", 
schema_name, "'");
-        PQfreemem(schema_name);
-        if (res) {
           return ADBC_STATUS_INTERNAL;
         }
+        params.push_back(db_schema_);
       }
 
-      auto result_helper = PqResultHelper{conn_, query.buffer};
+      auto result_helper =
+          PqResultHelper{conn_, std::string(query.buffer), params, error_};
       StringBuilderReset(&query);
 
-      if (result_helper.Status() == PGRES_TUPLES_OK) {
-        for (PqResultRow row : result_helper) {
-          const char* schema_name = row[0].data;
-          CHECK_NA(
-              INTERNAL,
-              ArrowArrayAppendString(db_schema_name_col_, 
ArrowCharView(schema_name)),
-              error_);
-          if (depth_ >= ADBC_OBJECT_DEPTH_TABLES) {
-            return ADBC_STATUS_NOT_IMPLEMENTED;
-          } else {
-            CHECK_NA(INTERNAL, ArrowArrayAppendNull(db_schema_tables_col_, 1), 
error_);
-          }
-          CHECK_NA(INTERNAL, 
ArrowArrayFinishElement(catalog_db_schemas_items_), error_);
+      RAISE_ADBC(result_helper.Prepare());
+      RAISE_ADBC(result_helper.Execute());
+
+      for (PqResultRow row : result_helper) {
+        const char* schema_name = row[0].data;
+        CHECK_NA(INTERNAL,
+                 ArrowArrayAppendString(db_schema_name_col_, 
ArrowCharView(schema_name)),
+                 error_);
+        if (depth_ >= ADBC_OBJECT_DEPTH_TABLES) {
+          return ADBC_STATUS_NOT_IMPLEMENTED;
+        } else {
+          CHECK_NA(INTERNAL, ArrowArrayAppendNull(db_schema_tables_col_, 1), 
error_);
         }
-      } else {
-        return ADBC_STATUS_NOT_IMPLEMENTED;
+        CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_items_), 
error_);
       }
     }
 
@@ -247,40 +279,32 @@ class PqGetObjectsHelper {
       return ADBC_STATUS_INTERNAL;
     }
 
+    std::vector<std::string> params;
     if (catalog_ != NULL) {
-      char* catalog_name = PQescapeIdentifier(conn_, catalog_, 
strlen(catalog_));
-      if (catalog_name == NULL) {
-        SetError(error_, "%s%s", "Failed to escape catalog: ", 
PQerrorMessage(conn_));
+      if (StringBuilderAppend(&query, "%s", " WHERE datname = $1")) {
         StringBuilderReset(&query);
-        return ADBC_STATUS_INVALID_ARGUMENT;
-      }
-
-      int res =
-          StringBuilderAppend(&query, "%s%s%s", " WHERE datname = '", 
catalog_name, "'");
-      PQfreemem(catalog_name);
-      if (res) {
         return ADBC_STATUS_INTERNAL;
       }
+      params.push_back(catalog_);
     }
 
-    PqResultHelper result_helper = PqResultHelper{conn_, query.buffer};
+    PqResultHelper result_helper =
+        PqResultHelper{conn_, std::string(query.buffer), params, error_};
     StringBuilderReset(&query);
 
-    if (result_helper.Status() == PGRES_TUPLES_OK) {
-      for (PqResultRow row : result_helper) {
-        const char* db_name = row[0].data;
-        CHECK_NA(INTERNAL,
-                 ArrowArrayAppendString(catalog_name_col_, 
ArrowCharView(db_name)),
-                 error_);
-        if (depth_ == ADBC_OBJECT_DEPTH_CATALOGS) {
-          CHECK_NA(INTERNAL, ArrowArrayAppendNull(catalog_db_schemas_col_, 1), 
error_);
-        } else {
-          RAISE_ADBC(AppendSchemas(std::string(db_name)));
-        }
-        CHECK_NA(INTERNAL, ArrowArrayFinishElement(array_), error_);
+    RAISE_ADBC(result_helper.Prepare());
+    RAISE_ADBC(result_helper.Execute());
+
+    for (PqResultRow row : result_helper) {
+      const char* db_name = row[0].data;
+      CHECK_NA(INTERNAL,
+               ArrowArrayAppendString(catalog_name_col_, 
ArrowCharView(db_name)), error_);
+      if (depth_ == ADBC_OBJECT_DEPTH_CATALOGS) {
+        CHECK_NA(INTERNAL, ArrowArrayAppendNull(catalog_db_schemas_col_, 1), 
error_);
+      } else {
+        RAISE_ADBC(AppendSchemas(std::string(db_name)));
       }
-    } else {
-      return ADBC_STATUS_INTERNAL;
+      CHECK_NA(INTERNAL, ArrowArrayFinishElement(array_), error_);
     }
 
     return ADBC_STATUS_OK;
@@ -430,6 +454,7 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const 
char* catalog,
                                                   struct AdbcError* error) {
   AdbcStatusCode final_status = ADBC_STATUS_OK;
   struct StringBuilder query = {0};
+  std::vector<std::string> params;
   if (StringBuilderInit(&query, /*initial_size=*/256) != 0) return 
ADBC_STATUS_INTERNAL;
 
   if (StringBuilderAppend(
@@ -438,67 +463,56 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const 
char* catalog,
           "FROM pg_catalog.pg_class AS cls "
           "INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid = 
attr.attrelid "
           "INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid "
-          "WHERE attr.attnum >= 0 AND cls.oid = '") != 0)
+          "WHERE attr.attnum >= 0 AND cls.oid = ") != 0)
     return ADBC_STATUS_INTERNAL;
 
   if (db_schema != nullptr) {
-    char* schema = PQescapeIdentifier(conn_, db_schema, strlen(db_schema));
-    if (schema == NULL) {
-      SetError(error, "%s%s", "Faled to escape schema: ", 
PQerrorMessage(conn_));
-      return ADBC_STATUS_INVALID_ARGUMENT;
+    if (StringBuilderAppend(&query, "%s", "$1.")) {
+      StringBuilderReset(&query);
+      return ADBC_STATUS_INTERNAL;
     }
-
-    int ret = StringBuilderAppend(&query, "%s%s", schema, ".");
-    PQfreemem(schema);
-
-    if (ret != 0) return ADBC_STATUS_INTERNAL;
+    params.push_back(db_schema);
   }
 
-  char* table = PQescapeIdentifier(conn_, table_name, strlen(table_name));
-  if (table == NULL) {
-    SetError(error, "%s%s", "Failed to escape table: ", PQerrorMessage(conn_));
-    return ADBC_STATUS_INVALID_ARGUMENT;
+  if (StringBuilderAppend(&query, "%s%" PRIu64 "%s", "$",
+                          static_cast<uint64_t>(params.size() + 1), 
"::regclass::oid")) {
+    StringBuilderReset(&query);
+    return ADBC_STATUS_INTERNAL;
   }
+  params.push_back(table_name);
 
-  int ret = StringBuilderAppend(&query, "%s%s", table, "'::regclass::oid");
-  PQfreemem(table);
-
-  if (ret != 0) return ADBC_STATUS_INTERNAL;
-
-  PqResultHelper result_helper = PqResultHelper{conn_, query.buffer};
+  PqResultHelper result_helper =
+      PqResultHelper{conn_, std::string(query.buffer), params, error};
   StringBuilderReset(&query);
 
-  if (result_helper.Status() != PGRES_TUPLES_OK) {
-    SetError(error, "%s%s", "Failed to get table schema: ", 
PQerrorMessage(conn_));
-    final_status = ADBC_STATUS_IO;
-  } else {
-    auto uschema = nanoarrow::UniqueSchema();
-    ArrowSchemaInit(uschema.get());
-    CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), 
result_helper.NumRows()),
-             error);
+  RAISE_ADBC(result_helper.Prepare());
+  RAISE_ADBC(result_helper.Execute());
 
-    ArrowError na_error;
-    int row_counter = 0;
-    for (auto row : result_helper) {
-      const char* colname = row[0].data;
-      const Oid pg_oid = static_cast<uint32_t>(
-          std::strtol(row[1].data, /*str_end=*/nullptr, /*base=*/10));
-
-      PostgresType pg_type;
-      if (type_resolver_->Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) {
-        SetError(error, "%s%d%s%s%s%" PRIu32, "Column #", row_counter + 1, " 
(\"",
-                 colname, "\") has unknown type code ", pg_oid);
-        final_status = ADBC_STATUS_NOT_IMPLEMENTED;
-        goto loopExit;
-      }
-      CHECK_NA(INTERNAL,
-               
pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter]),
-               error);
-      row_counter++;
+  auto uschema = nanoarrow::UniqueSchema();
+  ArrowSchemaInit(uschema.get());
+  CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), 
result_helper.NumRows()),
+           error);
+
+  ArrowError na_error;
+  int row_counter = 0;
+  for (auto row : result_helper) {
+    const char* colname = row[0].data;
+    const Oid pg_oid =
+        static_cast<uint32_t>(std::strtol(row[1].data, /*str_end=*/nullptr, 
/*base=*/10));
+
+    PostgresType pg_type;
+    if (type_resolver_->Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) {
+      SetError(error, "%s%d%s%s%s%" PRIu32, "Column #", row_counter + 1, " 
(\"", colname,
+               "\") has unknown type code ", pg_oid);
+      final_status = ADBC_STATUS_NOT_IMPLEMENTED;
+      break;
     }
-    uschema.move(schema);
+    CHECK_NA(INTERNAL,
+             
pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter]),
+             error);
+    row_counter++;
   }
-loopExit:
+  uschema.move(schema);
 
   return final_status;
 }

Reply via email to