This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new f0ae5194 fix(c/driver): be explicit about columns in ingestion (#1238)
f0ae5194 is described below
commit f0ae5194806b9c09a96b0de11856c36b16c998d7
Author: David Li <[email protected]>
AuthorDate: Wed Nov 1 15:25:50 2023 -0400
fix(c/driver): be explicit about columns in ingestion (#1238)
Fixes #1077.
---
c/driver/postgresql/postgresql_test.cc | 8 +++
c/driver/postgresql/statement.cc | 26 ++++++--
c/driver/postgresql/statement.h | 3 +-
c/driver/sqlite/sqlite.c | 40 +++++++++---
c/driver/sqlite/sqlite_test.cc | 8 +++
c/validation/adbc_validation.cc | 109 +++++++++++++++++++++++++++++++++
c/validation/adbc_validation.h | 15 +++++
7 files changed, 194 insertions(+), 15 deletions(-)
diff --git a/c/driver/postgresql/postgresql_test.cc
b/c/driver/postgresql/postgresql_test.cc
index f6df1809..343c729d 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -107,6 +107,14 @@ class PostgresQuirks : public
adbc_validation::DriverQuirks {
return ddl;
}
+ std::optional<std::string> PrimaryKeyIngestTableDdl(
+ std::string_view name) const override {
+ std::string ddl = "CREATE TABLE ";
+ ddl += name;
+ ddl += " (id BIGSERIAL PRIMARY KEY, value BIGINT)";
+ return ddl;
+ }
+
std::optional<std::string> CompositePrimaryKeyTableDdl(
std::string_view name) const override {
std::string ddl = "CREATE TABLE ";
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index e5691b52..eac7eded 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -887,7 +887,8 @@ AdbcStatusCode PostgresStatement::Cancel(struct AdbcError*
error) {
AdbcStatusCode PostgresStatement::CreateBulkTable(
const std::string& current_schema, const struct ArrowSchema& source_schema,
const std::vector<struct ArrowSchemaView>& source_schema_fields,
- std::string* escaped_table, struct AdbcError* error) {
+ std::string* escaped_table, std::string* escaped_field_list,
+ struct AdbcError* error) {
PGconn* conn = connection_->conn();
if (!ingest_.db_schema.empty() && ingest_.temporary) {
@@ -944,10 +945,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
switch (ingest_.mode) {
case IngestMode::kCreate:
+ case IngestMode::kAppend:
// Nothing to do
break;
- case IngestMode::kAppend:
- return ADBC_STATUS_OK;
case IngestMode::kReplace: {
std::string drop = "DROP TABLE IF EXISTS " + *escaped_table;
PGresult* result = PQexecParams(conn, drop.c_str(), /*nParams=*/0,
@@ -972,7 +972,10 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
create += " (";
for (size_t i = 0; i < source_schema_fields.size(); i++) {
- if (i > 0) create += ", ";
+ if (i > 0) {
+ create += ", ";
+ *escaped_field_list += ", ";
+ }
const char* unescaped = source_schema.children[i]->name;
char* escaped = PQescapeIdentifier(conn, unescaped,
std::strlen(unescaped));
@@ -982,6 +985,7 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
return ADBC_STATUS_INTERNAL;
}
create += escaped;
+ *escaped_field_list += escaped;
PQfreemem(escaped);
switch (source_schema_fields[i].type) {
@@ -1034,6 +1038,10 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
}
}
+ if (ingest_.mode == IngestMode::kAppend) {
+ return ADBC_STATUS_OK;
+ }
+
create += ")";
SetError(error, "%s%s", "[libpq] ", create.c_str());
PGresult* result = PQexecParams(conn, create.c_str(), /*nParams=*/0,
@@ -1203,15 +1211,21 @@ AdbcStatusCode
PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected,
BindStream bind_stream(std::move(bind_));
std::memset(&bind_, 0, sizeof(bind_));
std::string escaped_table;
+ std::string escaped_field_list;
RAISE_ADBC(bind_stream.Begin(
[&]() -> AdbcStatusCode {
return CreateBulkTable(current_schema, bind_stream.bind_schema.value,
- bind_stream.bind_schema_fields, &escaped_table,
error);
+ bind_stream.bind_schema_fields, &escaped_table,
+ &escaped_field_list, error);
},
error));
RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error));
- std::string query = "COPY " + escaped_table + " FROM STDIN WITH (FORMAT
binary)";
+ std::string query = "COPY ";
+ query += escaped_table;
+ query += " (";
+ query += escaped_field_list;
+ query += ") FROM STDIN WITH (FORMAT binary)";
PGresult* result = PQexec(connection_->conn(), query.c_str());
if (PQresultStatus(result) != PGRES_COPY_IN) {
AdbcStatusCode code =
diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h
index 20bb3b7a..c822390d 100644
--- a/c/driver/postgresql/statement.h
+++ b/c/driver/postgresql/statement.h
@@ -128,7 +128,8 @@ class PostgresStatement {
AdbcStatusCode CreateBulkTable(
const std::string& current_schema, const struct ArrowSchema&
source_schema,
const std::vector<struct ArrowSchemaView>& source_schema_fields,
- std::string* escaped_table, struct AdbcError* error);
+ std::string* escaped_table, std::string* escaped_field_list,
+ struct AdbcError* error);
AdbcStatusCode ExecuteUpdateBulk(int64_t* rows_affected, struct AdbcError*
error);
AdbcStatusCode ExecuteUpdateQuery(int64_t* rows_affected, struct AdbcError*
error);
AdbcStatusCode ExecutePreparedStatement(struct ArrowArrayStream* stream,
diff --git a/c/driver/sqlite/sqlite.c b/c/driver/sqlite/sqlite.c
index e4928018..a94b83f7 100644
--- a/c/driver/sqlite/sqlite.c
+++ b/c/driver/sqlite/sqlite.c
@@ -1136,7 +1136,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct
SqliteStatement* stmt,
goto cleanup;
}
- sqlite3_str_appendf(insert_query, "INSERT INTO %s VALUES (", table);
+ sqlite3_str_appendf(insert_query, "INSERT INTO %s (", table);
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s",
sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
@@ -1154,6 +1154,14 @@ AdbcStatusCode SqliteStatementInitIngest(struct
SqliteStatement* stmt,
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}
+
+ sqlite3_str_appendf(insert_query, "%s", ", ");
+ if (sqlite3_str_errcode(insert_query)) {
+ SetError(error, "[SQLite] Failed to build INSERT: %s",
+ sqlite3_errmsg(stmt->conn));
+ code = ADBC_STATUS_INTERNAL;
+ goto cleanup;
+ }
}
sqlite3_str_appendf(create_query, "\"%w\"",
stmt->binder.schema.children[i]->name);
@@ -1163,6 +1171,13 @@ AdbcStatusCode SqliteStatementInitIngest(struct
SqliteStatement* stmt,
goto cleanup;
}
+ sqlite3_str_appendf(insert_query, "\"%w\"",
stmt->binder.schema.children[i]->name);
+ if (sqlite3_str_errcode(insert_query)) {
+ SetError(error, "[SQLite] Failed to build INSERT: %s",
sqlite3_errmsg(stmt->conn));
+ code = ADBC_STATUS_INTERNAL;
+ goto cleanup;
+ }
+
int status =
ArrowSchemaViewInit(&view, stmt->binder.schema.children[i],
&arrow_error);
if (status != 0) {
@@ -1199,13 +1214,6 @@ AdbcStatusCode SqliteStatementInitIngest(struct
SqliteStatement* stmt,
default:
break;
}
-
- sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : ""));
- if (sqlite3_str_errcode(insert_query)) {
- SetError(error, "[SQLite] Failed to build INSERT: %s",
sqlite3_errmsg(stmt->conn));
- code = ADBC_STATUS_INTERNAL;
- goto cleanup;
- }
}
sqlite3_str_appendchar(create_query, 1, ')');
@@ -1215,6 +1223,22 @@ AdbcStatusCode SqliteStatementInitIngest(struct
SqliteStatement* stmt,
goto cleanup;
}
+ sqlite3_str_appendall(insert_query, ") VALUES (");
+ if (sqlite3_str_errcode(insert_query)) {
+ SetError(error, "[SQLite] Failed to build INSERT: %s",
sqlite3_errmsg(stmt->conn));
+ code = ADBC_STATUS_INTERNAL;
+ goto cleanup;
+ }
+
+ for (int i = 0; i < stmt->binder.schema.n_children; i++) {
+ sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : ""));
+ if (sqlite3_str_errcode(insert_query)) {
+ SetError(error, "[SQLite] Failed to build INSERT: %s",
sqlite3_errmsg(stmt->conn));
+ code = ADBC_STATUS_INTERNAL;
+ goto cleanup;
+ }
+ }
+
sqlite3_str_appendchar(insert_query, 1, ')');
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s",
sqlite3_errmsg(stmt->conn));
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index 13da21c1..db318917 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -98,6 +98,14 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
return ddl;
}
+ std::optional<std::string> PrimaryKeyIngestTableDdl(
+ std::string_view name) const override {
+ std::string ddl = "CREATE TABLE ";
+ ddl += name;
+ ddl += " (id INTEGER PRIMARY KEY, value BIGINT)";
+ return ddl;
+ }
+
std::optional<std::string> CompositePrimaryKeyTableDdl(
std::string_view name) const override {
std::string ddl = "CREATE TABLE ";
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index f0f42937..d30aa0a9 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -2803,6 +2803,115 @@ void StatementTest::TestSqlIngestTemporaryExclusive() {
}
}
+void StatementTest::TestSqlIngestPrimaryKey() {
+ std::string name = "pkeytest";
+ auto ddl = quirks()->PrimaryKeyIngestTableDdl(name);
+ if (!ddl) {
+ GTEST_SKIP();
+ }
+ ASSERT_THAT(quirks()->DropTable(&connection, name, &error),
IsOkStatus(&error));
+
+ // Create table
+ {
+ Handle<struct AdbcStatement> statement;
+ StreamReader reader;
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, ddl->c_str(),
&error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr,
&error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementRelease(&statement.value, &error),
IsOkStatus(&error));
+ }
+
+ // Ingest without the primary key
+ {
+ Handle<struct ArrowSchema> schema;
+ Handle<struct ArrowArray> array;
+ struct ArrowError na_error;
+ ASSERT_THAT(MakeSchema(&schema.value, {{"value", NANOARROW_TYPE_INT64}}),
+ IsOkErrno());
+ ASSERT_THAT((MakeBatch<int64_t>(&schema.value, &array.value, &na_error,
+ {42, -42, std::nullopt})),
+ IsOkErrno());
+
+ Handle<struct AdbcStatement> statement;
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetOption(&statement.value,
ADBC_INGEST_OPTION_TARGET_TABLE,
+ name.c_str(), &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetOption(&statement.value,
ADBC_INGEST_OPTION_MODE,
+ ADBC_INGEST_OPTION_MODE_APPEND, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value,
&schema.value, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr,
&error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementRelease(&statement.value, &error),
IsOkStatus(&error));
+ }
+
+ // Ingest with the primary key
+ {
+ Handle<struct ArrowSchema> schema;
+ Handle<struct ArrowArray> array;
+ struct ArrowError na_error;
+ ASSERT_THAT(MakeSchema(&schema.value,
+ {
+ {"id", NANOARROW_TYPE_INT64},
+ {"value", NANOARROW_TYPE_INT64},
+ }),
+ IsOkErrno());
+ ASSERT_THAT((MakeBatch<int64_t, int64_t>(&schema.value, &array.value,
&na_error,
+ {4, 5, 6}, {1, 0, -1})),
+ IsOkErrno());
+
+ Handle<struct AdbcStatement> statement;
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetOption(&statement.value,
ADBC_INGEST_OPTION_TARGET_TABLE,
+ name.c_str(), &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetOption(&statement.value,
ADBC_INGEST_OPTION_MODE,
+ ADBC_INGEST_OPTION_MODE_APPEND, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value,
&schema.value, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr,
&error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementRelease(&statement.value, &error),
IsOkStatus(&error));
+ }
+
+ // Get the data
+ {
+ Handle<struct AdbcStatement> statement;
+ StreamReader reader;
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetSqlQuery(
+ &statement.value, "SELECT * FROM pkeytest ORDER BY id
ASC", &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value,
&reader.stream.value, nullptr,
+ &error),
+ IsOkStatus(&error));
+
+ ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+ ASSERT_EQ(2, reader.schema->n_children);
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_NE(nullptr, reader.array->release);
+ ASSERT_EQ(6, reader.array->length);
+ ASSERT_EQ(2, reader.array->n_children);
+
+ // Different databases start numbering at 0 or 1 for the primary key
+ // column, so can't compare it
+ // TODO(https://github.com/apache/arrow-adbc/issues/938): if the test
+ // helpers converted data to plain C++ values we could do a more
+ // sophisticated assertion
+
ASSERT_NO_FATAL_FAILURE(CompareArray<int64_t>(reader.array_view->children[1],
+ {42, -42, std::nullopt, 1,
0, -1}));
+ }
+}
+
void StatementTest::TestSqlPartitionedInts() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index e2b5d434..874d9a05 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -86,6 +86,19 @@ class DriverQuirks {
return std::nullopt;
}
+ /// \brief Get the statement to create a table with a primary key, or
+ /// nullopt if not supported. This is used to test ingestion into a table
+ /// with an auto-incrementing primary key (which should not require the
+ /// data to contain the primary key).
+ ///
+ /// The table should have two columns:
+ /// - "id" which should be an auto-incrementing primary key compatible with
int64
+ /// - "value" with Arrow type int64
+ virtual std::optional<std::string> PrimaryKeyIngestTableDdl(
+ std::string_view name) const {
+ return std::nullopt;
+ }
+
/// \brief Get the statement to create a table with a composite primary key,
/// or nullopt if not supported.
///
@@ -347,6 +360,7 @@ class StatementTest {
void TestSqlIngestTemporaryAppend();
void TestSqlIngestTemporaryReplace();
void TestSqlIngestTemporaryExclusive();
+ void TestSqlIngestPrimaryKey();
void TestSqlPartitionedInts();
@@ -444,6 +458,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestTemporaryAppend) { TestSqlIngestTemporaryAppend();
} \
TEST_F(FIXTURE, SqlIngestTemporaryReplace) {
TestSqlIngestTemporaryReplace(); } \
TEST_F(FIXTURE, SqlIngestTemporaryExclusive) {
TestSqlIngestTemporaryExclusive(); } \
+ TEST_F(FIXTURE, SqlIngestPrimaryKey) { TestSqlIngestPrimaryKey(); } \
TEST_F(FIXTURE, SqlPartitionedInts) { TestSqlPartitionedInts(); }
\
TEST_F(FIXTURE, SqlPrepareGetParameterSchema) {
TestSqlPrepareGetParameterSchema(); } \
TEST_F(FIXTURE, SqlPrepareSelectNoParams) { TestSqlPrepareSelectNoParams();
} \