This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 5ca9c29e fix(c/driver/postgresql): support catalog arg of 
GetTableSchema (#1387)
5ca9c29e is described below

commit 5ca9c29e684c364ab38184a9ae3ba4090cd8b9ed
Author: David Li <[email protected]>
AuthorDate: Wed Dec 20 12:49:56 2023 -0500

    fix(c/driver/postgresql): support catalog arg of GetTableSchema (#1387)
    
    Fixes #1339.
---
 c/driver/postgresql/connection.cc                  | 37 +++++-----------
 c/driver/postgresql/postgresql_test.cc             | 34 ++++++++++++---
 c/validation/adbc_validation.cc                    | 51 ++++++++++++++++++++--
 c/validation/adbc_validation.h                     | 33 +++++++++++++-
 ci/conda_env_cpp_lint.txt                          |  4 +-
 .../python/recipe/postgresql_get_table_schema.py   | 26 ++++++++++-
 6 files changed, 147 insertions(+), 38 deletions(-)

diff --git a/c/driver/postgresql/connection.cc 
b/c/driver/postgresql/connection.cc
index d389a66c..deae3171 100644
--- a/c/driver/postgresql/connection.cc
+++ b/c/driver/postgresql/connection.cc
@@ -1147,38 +1147,23 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const 
char* catalog,
                                                   struct ArrowSchema* schema,
                                                   struct AdbcError* error) {
   AdbcStatusCode final_status = ADBC_STATUS_OK;
-  struct StringBuilder query;
-  std::memset(&query, 0, sizeof(query));
-  std::vector<std::string> params;
-  if (StringBuilderInit(&query, /*initial_size=*/256) != 0) return 
ADBC_STATUS_INTERNAL;
 
-  if (StringBuilderAppend(
-          &query, "%s",
-          "SELECT attname, atttypid "
-          "FROM pg_catalog.pg_class AS cls "
-          "INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid = 
attr.attrelid "
-          "INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid "
-          "WHERE attr.attnum >= 0 AND cls.oid = ") != 0)
-    return ADBC_STATUS_INTERNAL;
+  std::string query =
+      "SELECT attname, atttypid "
+      "FROM pg_catalog.pg_class AS cls "
+      "INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid "
+      "INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid "
+      "WHERE attr.attnum >= 0 AND cls.oid = $1::regclass::oid";
 
+  std::vector<std::string> params;
   if (db_schema != nullptr) {
-    if (StringBuilderAppend(&query, "%s", "$1.")) {
-      StringBuilderReset(&query);
-      return ADBC_STATUS_INTERNAL;
-    }
-    params.push_back(db_schema);
-  }
-
-  if (StringBuilderAppend(&query, "%s%" PRIu64 "%s", "$",
-                          static_cast<uint64_t>(params.size() + 1), 
"::regclass::oid")) {
-    StringBuilderReset(&query);
-    return ADBC_STATUS_INTERNAL;
+    params.push_back(std::string(db_schema) + "." + table_name);
+  } else {
+    params.push_back(table_name);
   }
-  params.push_back(table_name);
 
   PqResultHelper result_helper =
-      PqResultHelper{conn_, std::string(query.buffer), params, error};
-  StringBuilderReset(&query);
+      PqResultHelper{conn_, std::string(query.c_str()), params, error};
 
   RAISE_ADBC(result_helper.Prepare());
   auto result = result_helper.Execute();
diff --git a/c/driver/postgresql/postgresql_test.cc 
b/c/driver/postgresql/postgresql_test.cc
index 5e04b455..2327767a 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -61,6 +61,18 @@ class PostgresQuirks : public adbc_validation::DriverQuirks {
     return AdbcStatementRelease(&statement.value, error);
   }
 
+  AdbcStatusCode DropTable(struct AdbcConnection* connection, const 
std::string& name,
+                           const std::string& db_schema,
+                           struct AdbcError* error) const override {
+    Handle<struct AdbcStatement> statement;
+    RAISE_ADBC(AdbcStatementNew(connection, &statement.value, error));
+
+    std::string query = "DROP TABLE IF EXISTS \"" + db_schema + "\".\"" + name 
+ "\"";
+    RAISE_ADBC(AdbcStatementSetSqlQuery(&statement.value, query.c_str(), 
error));
+    RAISE_ADBC(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, 
error));
+    return AdbcStatementRelease(&statement.value, error);
+  }
+
   AdbcStatusCode DropTempTable(struct AdbcConnection* connection, const 
std::string& name,
                                struct AdbcError* error) const override {
     Handle<struct AdbcStatement> statement;
@@ -83,6 +95,18 @@ class PostgresQuirks : public adbc_validation::DriverQuirks {
     return AdbcStatementRelease(&statement.value, error);
   }
 
+  AdbcStatusCode EnsureDbSchema(struct AdbcConnection* connection,
+                                const std::string& name,
+                                struct AdbcError* error) const override {
+    Handle<struct AdbcStatement> statement;
+    RAISE_ADBC(AdbcStatementNew(connection, &statement.value, error));
+
+    std::string query = "CREATE SCHEMA IF NOT EXISTS \"" + name + "\"";
+    RAISE_ADBC(AdbcStatementSetSqlQuery(&statement.value, query.c_str(), 
error));
+    RAISE_ADBC(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, 
error));
+    return AdbcStatementRelease(&statement.value, error);
+  }
+
   std::string BindParameter(int index) const override {
     return "$" + std::to_string(index + 1);
   }
@@ -343,7 +367,7 @@ TEST_F(PostgresConnectionTest, 
GetObjectsGetAllFindsPrimaryKey) {
     ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
                                           &reader.rows_affected, &error),
                 IsOkStatus(&error));
-    ASSERT_EQ(reader.rows_affected, 0);
+    ASSERT_EQ(reader.rows_affected, -1);
     ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
     ASSERT_NO_FATAL_FAILURE(reader.Next());
     ASSERT_EQ(reader.array->release, nullptr);
@@ -416,7 +440,7 @@ TEST_F(PostgresConnectionTest, 
GetObjectsGetAllFindsForeignKey) {
     ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
                                           &reader.rows_affected, &error),
                 IsOkStatus(&error));
-    ASSERT_EQ(reader.rows_affected, 0);
+    ASSERT_EQ(reader.rows_affected, -1);
     ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
     ASSERT_NO_FATAL_FAILURE(reader.Next());
     ASSERT_EQ(reader.array->release, nullptr);
@@ -435,7 +459,7 @@ TEST_F(PostgresConnectionTest, 
GetObjectsGetAllFindsForeignKey) {
     ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
                                           &reader.rows_affected, &error),
                 IsOkStatus(&error));
-    ASSERT_EQ(reader.rows_affected, 0);
+    ASSERT_EQ(reader.rows_affected, -1);
     ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
     ASSERT_NO_FATAL_FAILURE(reader.Next());
     ASSERT_EQ(reader.array->release, nullptr);
@@ -1162,7 +1186,7 @@ TEST_F(PostgresStatementTest, UpdateInExecuteQuery) {
     ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
                                           &reader.rows_affected, &error),
                 IsOkStatus(&error));
-    ASSERT_EQ(reader.rows_affected, 0);
+    ASSERT_EQ(reader.rows_affected, -1);
     ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
     ASSERT_NO_FATAL_FAILURE(reader.Next());
     ASSERT_EQ(reader.array->release, nullptr);
@@ -1177,7 +1201,7 @@ TEST_F(PostgresStatementTest, UpdateInExecuteQuery) {
     ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
                                           &reader.rows_affected, &error),
                 IsOkStatus(&error));
-    ASSERT_EQ(reader.rows_affected, 0);
+    ASSERT_EQ(reader.rows_affected, -1);
     ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
     ASSERT_NO_FATAL_FAILURE(reader.Next());
     ASSERT_EQ(reader.array->release, nullptr);
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index d30aa0a9..97d12be1 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -70,7 +70,9 @@ bool iequals(std::string_view s1, std::string_view s2) {
 // DriverQuirks
 
 AdbcStatusCode DoIngestSampleTable(struct AdbcConnection* connection,
-                                   const std::string& name, struct AdbcError* 
error) {
+                                   const std::string& name,
+                                   std::optional<std::string> db_schema,
+                                   struct AdbcError* error) {
   Handle<struct ArrowSchema> schema;
   Handle<struct ArrowArray> array;
   struct ArrowError na_error;
@@ -84,6 +86,10 @@ AdbcStatusCode DoIngestSampleTable(struct AdbcConnection* 
connection,
   CHECK_OK(AdbcStatementNew(connection, &statement.value, error));
   CHECK_OK(AdbcStatementSetOption(&statement.value, 
ADBC_INGEST_OPTION_TARGET_TABLE,
                                   name.c_str(), error));
+  if (db_schema.has_value()) {
+    CHECK_OK(AdbcStatementSetOption(&statement.value, 
ADBC_INGEST_OPTION_TARGET_DB_SCHEMA,
+                                    db_schema->c_str(), error));
+  }
   CHECK_OK(AdbcStatementBind(&statement.value, &array.value, &schema.value, 
error));
   CHECK_OK(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, 
error));
   CHECK_OK(AdbcStatementRelease(&statement.value, error));
@@ -91,7 +97,8 @@ AdbcStatusCode DoIngestSampleTable(struct AdbcConnection* 
connection,
 }
 
 void IngestSampleTable(struct AdbcConnection* connection, struct AdbcError* 
error) {
-  ASSERT_THAT(DoIngestSampleTable(connection, "bulk_ingest", error), 
IsOkStatus(error));
+  ASSERT_THAT(DoIngestSampleTable(connection, "bulk_ingest", std::nullopt, 
error),
+              IsOkStatus(error));
 }
 
 AdbcStatusCode DriverQuirks::EnsureSampleTable(struct AdbcConnection* 
connection,
@@ -107,7 +114,17 @@ AdbcStatusCode DriverQuirks::CreateSampleTable(struct 
AdbcConnection* connection
   if (!supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) {
     return ADBC_STATUS_NOT_IMPLEMENTED;
   }
-  return DoIngestSampleTable(connection, name, error);
+  return DoIngestSampleTable(connection, name, std::nullopt, error);
+}
+
+AdbcStatusCode DriverQuirks::CreateSampleTable(struct AdbcConnection* 
connection,
+                                               const std::string& name,
+                                               const std::string& schema,
+                                               struct AdbcError* error) const {
+  if (!supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) {
+    return ADBC_STATUS_NOT_IMPLEMENTED;
+  }
+  return DoIngestSampleTable(connection, name, schema, error);
 }
 
 //------------------------------------------------------------
@@ -431,6 +448,34 @@ void ConnectionTest::TestMetadataGetTableSchema() {
                                     {"strings", NANOARROW_TYPE_STRING, 
NULLABLE}}));
 }
 
+void ConnectionTest::TestMetadataGetTableSchemaDbSchema() {
+  ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
+  ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), 
IsOkStatus(&error));
+
+  auto status = quirks()->EnsureDbSchema(&connection, "otherschema", &error);
+  if (status == ADBC_STATUS_NOT_IMPLEMENTED) {
+    GTEST_SKIP() << "Schema not supported";
+    return;
+  }
+  ASSERT_THAT(status, IsOkStatus(&error));
+
+  ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", "otherschema", 
&error),
+              IsOkStatus(&error));
+  ASSERT_THAT(
+      quirks()->CreateSampleTable(&connection, "bulk_ingest", "otherschema", 
&error),
+      IsOkStatus(&error));
+
+  Handle<ArrowSchema> schema;
+  ASSERT_THAT(AdbcConnectionGetTableSchema(&connection, /*catalog=*/nullptr,
+                                           /*db_schema=*/"otherschema", 
"bulk_ingest",
+                                           &schema.value, &error),
+              IsOkStatus(&error));
+
+  ASSERT_NO_FATAL_FAILURE(
+      CompareSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64, NULLABLE},
+                                    {"strings", NANOARROW_TYPE_STRING, 
NULLABLE}}));
+}
+
 void ConnectionTest::TestMetadataGetTableSchemaEscaping() {
   if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) {
     GTEST_SKIP();
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 874d9a05..30a20491 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -50,6 +50,13 @@ class DriverQuirks {
     return ADBC_STATUS_OK;
   }
 
+  virtual AdbcStatusCode DropTable(struct AdbcConnection* connection,
+                                   const std::string& name,
+                                   const std::string& db_schema,
+                                   struct AdbcError* error) const {
+    return ADBC_STATUS_NOT_IMPLEMENTED;
+  }
+
   /// \brief Drop the given temporary table. Used by tests to reset state.
   virtual AdbcStatusCode DropTempTable(struct AdbcConnection* connection,
                                        const std::string& name,
@@ -68,13 +75,33 @@ class DriverQuirks {
                                            const std::string& name,
                                            struct AdbcError* error) const;
 
+  /// \brief Create a schema for testing.
+  virtual AdbcStatusCode EnsureDbSchema(struct AdbcConnection* connection,
+                                      const std::string& name,
+                                      struct AdbcError* error) const {
+    return ADBC_STATUS_NOT_IMPLEMENTED;
+  }
+
+  /// \brief Create a table of sample data with a fixed schema for testing.
+  ///
+  /// The table should have two columns:
+  /// - "int64s" with Arrow type int64.
+  /// - "strings" with Arrow type utf8.
+  virtual AdbcStatusCode CreateSampleTable(struct AdbcConnection* connection,
+                                           const std::string& name,
+                                           struct AdbcError* error) const;
+
   /// \brief Create a table of sample data with a fixed schema for testing.
   ///
+  /// Create it in the given schema.  Specify "" for the default schema.
+  /// Return NOT_IMPLEMENTED if not supported by this backend.
+  ///
   /// The table should have two columns:
   /// - "int64s" with Arrow type int64.
   /// - "strings" with Arrow type utf8.
   virtual AdbcStatusCode CreateSampleTable(struct AdbcConnection* connection,
                                            const std::string& name,
+                                           const std::string& schema,
                                            struct AdbcError* error) const;
 
   /// \brief Get the statement to create a table with a primary key, or 
nullopt if not
@@ -197,7 +224,7 @@ class DriverQuirks {
   /// \brief Default catalog to use for tests
   virtual std::string catalog() const { return ""; }
 
-  /// \brief Default Schema to use for tests
+  /// \brief Default database schema to use for tests
   virtual std::string db_schema() const { return ""; }
 };
 
@@ -243,6 +270,7 @@ class ConnectionTest {
 
   void TestMetadataGetInfo();
   void TestMetadataGetTableSchema();
+  void TestMetadataGetTableSchemaDbSchema();
   void TestMetadataGetTableSchemaEscaping();
   void TestMetadataGetTableSchemaNotFound();
   void TestMetadataGetTableTypes();
@@ -277,6 +305,9 @@ class ConnectionTest {
   TEST_F(FIXTURE, MetadataCurrentDbSchema) { TestMetadataCurrentDbSchema(); }  
         \
   TEST_F(FIXTURE, MetadataGetInfo) { TestMetadataGetInfo(); }                  
         \
   TEST_F(FIXTURE, MetadataGetTableSchema) { TestMetadataGetTableSchema(); }    
         \
+  TEST_F(FIXTURE, MetadataGetTableSchemaDbSchema) {                            
         \
+    TestMetadataGetTableSchemaDbSchema();                                      
         \
+  }                                                                            
         \
   TEST_F(FIXTURE, MetadataGetTableSchemaEscaping) {                            
         \
     TestMetadataGetTableSchemaEscaping();                                      
         \
   }                                                                            
         \
diff --git a/ci/conda_env_cpp_lint.txt b/ci/conda_env_cpp_lint.txt
index 471ef0de..7cc81c1e 100644
--- a/ci/conda_env_cpp_lint.txt
+++ b/ci/conda_env_cpp_lint.txt
@@ -15,5 +15,5 @@
 # specific language governing permissions and limitations
 # under the License.
 
-clang=14
-clang-tools=14
+clang=14.*
+clang-tools=14.*
diff --git a/docs/source/python/recipe/postgresql_get_table_schema.py 
b/docs/source/python/recipe/postgresql_get_table_schema.py
index 3f1bae72..aacbc1c2 100644
--- a/docs/source/python/recipe/postgresql_get_table_schema.py
+++ b/docs/source/python/recipe/postgresql_get_table_schema.py
@@ -28,13 +28,18 @@ import adbc_driver_postgresql.dbapi
 uri = os.environ["ADBC_POSTGRESQL_TEST_URI"]
 conn = adbc_driver_postgresql.dbapi.connect(uri)
 
-#: We'll create an example table to test.
+#: We'll create some example tables to test.
 with conn.cursor() as cur:
     cur.execute("DROP TABLE IF EXISTS example")
     cur.execute("CREATE TABLE example (ints INT, bigints BIGINT)")
 
+    cur.execute("CREATE SCHEMA IF NOT EXISTS other_schema")
+    cur.execute("DROP TABLE IF EXISTS other_schema.example")
+    cur.execute("CREATE TABLE other_schema.example (strings TEXT, values 
NUMERIC)")
+
 conn.commit()
 
+#: By default the "active" catalog/schema are assumed.
 assert conn.adbc_get_table_schema("example") == pyarrow.schema(
     [
         ("ints", "int32"),
@@ -42,4 +47,23 @@ assert conn.adbc_get_table_schema("example") == 
pyarrow.schema(
     ]
 )
 
+#: We can explicitly specify the PostgreSQL schema to get the Arrow schema of
+#: a table in a different namespace.
+#:
+#: .. note:: In PostgreSQL, you can only query the database (catalog) that you
+#:           are connected to.  So we cannot specify the catalog here (or
+#:           rather, there is no point in doing so).
+#:
+#: Note that the NUMERIC column is read as a string, because PostgreSQL
+#: decimals do not map onto Arrow decimals.
+assert conn.adbc_get_table_schema(
+    "example",
+    db_schema_filter="other_schema",
+) == pyarrow.schema(
+    [
+        ("strings", "string"),
+        ("values", "string"),
+    ]
+)
+
 conn.close()

Reply via email to