This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new f0ae5194 fix(c/driver): be explicit about columns in ingestion (#1238)
f0ae5194 is described below

commit f0ae5194806b9c09a96b0de11856c36b16c998d7
Author: David Li <[email protected]>
AuthorDate: Wed Nov 1 15:25:50 2023 -0400

    fix(c/driver): be explicit about columns in ingestion (#1238)
    
    Fixes #1077.
---
 c/driver/postgresql/postgresql_test.cc |   8 +++
 c/driver/postgresql/statement.cc       |  26 ++++++--
 c/driver/postgresql/statement.h        |   3 +-
 c/driver/sqlite/sqlite.c               |  40 +++++++++---
 c/driver/sqlite/sqlite_test.cc         |   8 +++
 c/validation/adbc_validation.cc        | 109 +++++++++++++++++++++++++++++++++
 c/validation/adbc_validation.h         |  15 +++++
 7 files changed, 194 insertions(+), 15 deletions(-)

diff --git a/c/driver/postgresql/postgresql_test.cc 
b/c/driver/postgresql/postgresql_test.cc
index f6df1809..343c729d 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -107,6 +107,14 @@ class PostgresQuirks : public 
adbc_validation::DriverQuirks {
     return ddl;
   }
 
+  std::optional<std::string> PrimaryKeyIngestTableDdl(
+      std::string_view name) const override {
+    std::string ddl = "CREATE TABLE ";
+    ddl += name;
+    ddl += " (id BIGSERIAL PRIMARY KEY, value BIGINT)";
+    return ddl;
+  }
+
   std::optional<std::string> CompositePrimaryKeyTableDdl(
       std::string_view name) const override {
     std::string ddl = "CREATE TABLE ";
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index e5691b52..eac7eded 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -887,7 +887,8 @@ AdbcStatusCode PostgresStatement::Cancel(struct AdbcError* 
error) {
 AdbcStatusCode PostgresStatement::CreateBulkTable(
     const std::string& current_schema, const struct ArrowSchema& source_schema,
     const std::vector<struct ArrowSchemaView>& source_schema_fields,
-    std::string* escaped_table, struct AdbcError* error) {
+    std::string* escaped_table, std::string* escaped_field_list,
+    struct AdbcError* error) {
   PGconn* conn = connection_->conn();
 
   if (!ingest_.db_schema.empty() && ingest_.temporary) {
@@ -944,10 +945,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
 
   switch (ingest_.mode) {
     case IngestMode::kCreate:
+    case IngestMode::kAppend:
       // Nothing to do
       break;
-    case IngestMode::kAppend:
-      return ADBC_STATUS_OK;
     case IngestMode::kReplace: {
       std::string drop = "DROP TABLE IF EXISTS " + *escaped_table;
       PGresult* result = PQexecParams(conn, drop.c_str(), /*nParams=*/0,
@@ -972,7 +972,10 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
   create += " (";
 
   for (size_t i = 0; i < source_schema_fields.size(); i++) {
-    if (i > 0) create += ", ";
+    if (i > 0) {
+      create += ", ";
+      *escaped_field_list += ", ";
+    }
 
     const char* unescaped = source_schema.children[i]->name;
     char* escaped = PQescapeIdentifier(conn, unescaped, 
std::strlen(unescaped));
@@ -982,6 +985,7 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
       return ADBC_STATUS_INTERNAL;
     }
     create += escaped;
+    *escaped_field_list += escaped;
     PQfreemem(escaped);
 
     switch (source_schema_fields[i].type) {
@@ -1034,6 +1038,10 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
     }
   }
 
+  if (ingest_.mode == IngestMode::kAppend) {
+    return ADBC_STATUS_OK;
+  }
+
   create += ")";
   SetError(error, "%s%s", "[libpq] ", create.c_str());
   PGresult* result = PQexecParams(conn, create.c_str(), /*nParams=*/0,
@@ -1203,15 +1211,21 @@ AdbcStatusCode 
PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected,
   BindStream bind_stream(std::move(bind_));
   std::memset(&bind_, 0, sizeof(bind_));
   std::string escaped_table;
+  std::string escaped_field_list;
   RAISE_ADBC(bind_stream.Begin(
       [&]() -> AdbcStatusCode {
         return CreateBulkTable(current_schema, bind_stream.bind_schema.value,
-                               bind_stream.bind_schema_fields, &escaped_table, 
error);
+                               bind_stream.bind_schema_fields, &escaped_table,
+                               &escaped_field_list, error);
       },
       error));
   RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error));
 
-  std::string query = "COPY " + escaped_table + " FROM STDIN WITH (FORMAT 
binary)";
+  std::string query = "COPY ";
+  query += escaped_table;
+  query += " (";
+  query += escaped_field_list;
+  query += ") FROM STDIN WITH (FORMAT binary)";
   PGresult* result = PQexec(connection_->conn(), query.c_str());
   if (PQresultStatus(result) != PGRES_COPY_IN) {
     AdbcStatusCode code =
diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h
index 20bb3b7a..c822390d 100644
--- a/c/driver/postgresql/statement.h
+++ b/c/driver/postgresql/statement.h
@@ -128,7 +128,8 @@ class PostgresStatement {
   AdbcStatusCode CreateBulkTable(
       const std::string& current_schema, const struct ArrowSchema& 
source_schema,
       const std::vector<struct ArrowSchemaView>& source_schema_fields,
-      std::string* escaped_table, struct AdbcError* error);
+      std::string* escaped_table, std::string* escaped_field_list,
+      struct AdbcError* error);
   AdbcStatusCode ExecuteUpdateBulk(int64_t* rows_affected, struct AdbcError* 
error);
   AdbcStatusCode ExecuteUpdateQuery(int64_t* rows_affected, struct AdbcError* 
error);
   AdbcStatusCode ExecutePreparedStatement(struct ArrowArrayStream* stream,
diff --git a/c/driver/sqlite/sqlite.c b/c/driver/sqlite/sqlite.c
index e4928018..a94b83f7 100644
--- a/c/driver/sqlite/sqlite.c
+++ b/c/driver/sqlite/sqlite.c
@@ -1136,7 +1136,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct 
SqliteStatement* stmt,
     goto cleanup;
   }
 
-  sqlite3_str_appendf(insert_query, "INSERT INTO %s VALUES (", table);
+  sqlite3_str_appendf(insert_query, "INSERT INTO %s (", table);
   if (sqlite3_str_errcode(insert_query)) {
     SetError(error, "[SQLite] Failed to build INSERT: %s", 
sqlite3_errmsg(stmt->conn));
     code = ADBC_STATUS_INTERNAL;
@@ -1154,6 +1154,14 @@ AdbcStatusCode SqliteStatementInitIngest(struct 
SqliteStatement* stmt,
         code = ADBC_STATUS_INTERNAL;
         goto cleanup;
       }
+
+      sqlite3_str_appendf(insert_query, "%s", ", ");
+      if (sqlite3_str_errcode(insert_query)) {
+        SetError(error, "[SQLite] Failed to build INSERT: %s",
+                 sqlite3_errmsg(stmt->conn));
+        code = ADBC_STATUS_INTERNAL;
+        goto cleanup;
+      }
     }
 
     sqlite3_str_appendf(create_query, "\"%w\"", 
stmt->binder.schema.children[i]->name);
@@ -1163,6 +1171,13 @@ AdbcStatusCode SqliteStatementInitIngest(struct 
SqliteStatement* stmt,
       goto cleanup;
     }
 
+    sqlite3_str_appendf(insert_query, "\"%w\"", 
stmt->binder.schema.children[i]->name);
+    if (sqlite3_str_errcode(insert_query)) {
+      SetError(error, "[SQLite] Failed to build INSERT: %s", 
sqlite3_errmsg(stmt->conn));
+      code = ADBC_STATUS_INTERNAL;
+      goto cleanup;
+    }
+
     int status =
         ArrowSchemaViewInit(&view, stmt->binder.schema.children[i], 
&arrow_error);
     if (status != 0) {
@@ -1199,13 +1214,6 @@ AdbcStatusCode SqliteStatementInitIngest(struct 
SqliteStatement* stmt,
       default:
         break;
     }
-
-    sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : ""));
-    if (sqlite3_str_errcode(insert_query)) {
-      SetError(error, "[SQLite] Failed to build INSERT: %s", 
sqlite3_errmsg(stmt->conn));
-      code = ADBC_STATUS_INTERNAL;
-      goto cleanup;
-    }
   }
 
   sqlite3_str_appendchar(create_query, 1, ')');
@@ -1215,6 +1223,22 @@ AdbcStatusCode SqliteStatementInitIngest(struct 
SqliteStatement* stmt,
     goto cleanup;
   }
 
+  sqlite3_str_appendall(insert_query, ") VALUES (");
+  if (sqlite3_str_errcode(insert_query)) {
+    SetError(error, "[SQLite] Failed to build INSERT: %s", 
sqlite3_errmsg(stmt->conn));
+    code = ADBC_STATUS_INTERNAL;
+    goto cleanup;
+  }
+
+  for (int i = 0; i < stmt->binder.schema.n_children; i++) {
+    sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : ""));
+    if (sqlite3_str_errcode(insert_query)) {
+      SetError(error, "[SQLite] Failed to build INSERT: %s", 
sqlite3_errmsg(stmt->conn));
+      code = ADBC_STATUS_INTERNAL;
+      goto cleanup;
+    }
+  }
+
   sqlite3_str_appendchar(insert_query, 1, ')');
   if (sqlite3_str_errcode(insert_query)) {
     SetError(error, "[SQLite] Failed to build INSERT: %s", 
sqlite3_errmsg(stmt->conn));
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index 13da21c1..db318917 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -98,6 +98,14 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
     return ddl;
   }
 
+  std::optional<std::string> PrimaryKeyIngestTableDdl(
+      std::string_view name) const override {
+    std::string ddl = "CREATE TABLE ";
+    ddl += name;
+    ddl += " (id INTEGER PRIMARY KEY, value BIGINT)";
+    return ddl;
+  }
+
   std::optional<std::string> CompositePrimaryKeyTableDdl(
       std::string_view name) const override {
     std::string ddl = "CREATE TABLE ";
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index f0f42937..d30aa0a9 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -2803,6 +2803,115 @@ void StatementTest::TestSqlIngestTemporaryExclusive() {
   }
 }
 
+void StatementTest::TestSqlIngestPrimaryKey() {
+  std::string name = "pkeytest";
+  auto ddl = quirks()->PrimaryKeyIngestTableDdl(name);
+  if (!ddl) {
+    GTEST_SKIP();
+  }
+  ASSERT_THAT(quirks()->DropTable(&connection, name, &error), 
IsOkStatus(&error));
+
+  // Create table
+  {
+    Handle<struct AdbcStatement> statement;
+    StreamReader reader;
+    ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, ddl->c_str(), 
&error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, 
&error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), 
IsOkStatus(&error));
+  }
+
+  // Ingest without the primary key
+  {
+    Handle<struct ArrowSchema> schema;
+    Handle<struct ArrowArray> array;
+    struct ArrowError na_error;
+    ASSERT_THAT(MakeSchema(&schema.value, {{"value", NANOARROW_TYPE_INT64}}),
+                IsOkErrno());
+    ASSERT_THAT((MakeBatch<int64_t>(&schema.value, &array.value, &na_error,
+                                    {42, -42, std::nullopt})),
+                IsOkErrno());
+
+    Handle<struct AdbcStatement> statement;
+    ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementSetOption(&statement.value, 
ADBC_INGEST_OPTION_TARGET_TABLE,
+                                       name.c_str(), &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementSetOption(&statement.value, 
ADBC_INGEST_OPTION_MODE,
+                                       ADBC_INGEST_OPTION_MODE_APPEND, &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, 
&schema.value, &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, 
&error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), 
IsOkStatus(&error));
+  }
+
+  // Ingest with the primary key
+  {
+    Handle<struct ArrowSchema> schema;
+    Handle<struct ArrowArray> array;
+    struct ArrowError na_error;
+    ASSERT_THAT(MakeSchema(&schema.value,
+                           {
+                               {"id", NANOARROW_TYPE_INT64},
+                               {"value", NANOARROW_TYPE_INT64},
+                           }),
+                IsOkErrno());
+    ASSERT_THAT((MakeBatch<int64_t, int64_t>(&schema.value, &array.value, 
&na_error,
+                                             {4, 5, 6}, {1, 0, -1})),
+                IsOkErrno());
+
+    Handle<struct AdbcStatement> statement;
+    ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementSetOption(&statement.value, 
ADBC_INGEST_OPTION_TARGET_TABLE,
+                                       name.c_str(), &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementSetOption(&statement.value, 
ADBC_INGEST_OPTION_MODE,
+                                       ADBC_INGEST_OPTION_MODE_APPEND, &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, 
&schema.value, &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, 
&error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), 
IsOkStatus(&error));
+  }
+
+  // Get the data
+  {
+    Handle<struct AdbcStatement> statement;
+    StreamReader reader;
+    ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementSetSqlQuery(
+                    &statement.value, "SELECT * FROM pkeytest ORDER BY id 
ASC", &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, 
&reader.stream.value, nullptr,
+                                          &error),
+                IsOkStatus(&error));
+
+    ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+    ASSERT_EQ(2, reader.schema->n_children);
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    ASSERT_NE(nullptr, reader.array->release);
+    ASSERT_EQ(6, reader.array->length);
+    ASSERT_EQ(2, reader.array->n_children);
+
+    // Different databases start numbering at 0 or 1 for the primary key
+    // column, so can't compare it
+    // TODO(https://github.com/apache/arrow-adbc/issues/938): if the test
+    // helpers converted data to plain C++ values we could do a more
+    // sophisticated assertion
+    
ASSERT_NO_FATAL_FAILURE(CompareArray<int64_t>(reader.array_view->children[1],
+                                                  {42, -42, std::nullopt, 1, 
0, -1}));
+  }
+}
+
 void StatementTest::TestSqlPartitionedInts() {
   ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
   ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index e2b5d434..874d9a05 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -86,6 +86,19 @@ class DriverQuirks {
     return std::nullopt;
   }
 
+  /// \brief Get the statement to create a table with a primary key, or
+  ///   nullopt if not supported.  This is used to test ingestion into a table
+  ///   with an auto-incrementing primary key (which should not require the
+  ///   data to contain the primary key).
+  ///
+  /// The table should have two columns:
+  /// - "id" which should be an auto-incrementing primary key compatible with 
int64
+  /// - "value" with Arrow type int64
+  virtual std::optional<std::string> PrimaryKeyIngestTableDdl(
+      std::string_view name) const {
+    return std::nullopt;
+  }
+
   /// \brief Get the statement to create a table with a composite primary key,
   /// or nullopt if not supported.
   ///
@@ -347,6 +360,7 @@ class StatementTest {
   void TestSqlIngestTemporaryAppend();
   void TestSqlIngestTemporaryReplace();
   void TestSqlIngestTemporaryExclusive();
+  void TestSqlIngestPrimaryKey();
 
   void TestSqlPartitionedInts();
 
@@ -444,6 +458,7 @@ class StatementTest {
   TEST_F(FIXTURE, SqlIngestTemporaryAppend) { TestSqlIngestTemporaryAppend(); 
}         \
   TEST_F(FIXTURE, SqlIngestTemporaryReplace) { 
TestSqlIngestTemporaryReplace(); }       \
   TEST_F(FIXTURE, SqlIngestTemporaryExclusive) { 
TestSqlIngestTemporaryExclusive(); }   \
+  TEST_F(FIXTURE, SqlIngestPrimaryKey) { TestSqlIngestPrimaryKey(); }   \
   TEST_F(FIXTURE, SqlPartitionedInts) { TestSqlPartitionedInts(); }            
         \
   TEST_F(FIXTURE, SqlPrepareGetParameterSchema) { 
TestSqlPrepareGetParameterSchema(); } \
   TEST_F(FIXTURE, SqlPrepareSelectNoParams) { TestSqlPrepareSelectNoParams(); 
}         \

Reply via email to