This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 17898709 feat(c/driver/sqlite): Support binding dictionary-encoded
string and binary types (#1224)
17898709 is described below
commit 17898709be7a8a207726beb083afcc9d6d2c0d3d
Author: Dewey Dunnington <[email protected]>
AuthorDate: Thu Oct 26 19:47:35 2023 +0000
feat(c/driver/sqlite): Support binding dictionary-encoded string and binary
types (#1224)
This PR adds the ability to ingest dictionary-encoded string and binary
columns.
Part of addressing #1008.
From the R bindings:
``` r
library(adbcdrivermanager)
db <- adbc_database_init(adbcsqlite::adbcsqlite(), uri = ":memory:")
con <- adbc_connection_init(db)
df <- data.frame(x = factor(letters[1:10]))
write_adbc(df, con, "tbl")
read_adbc(con, "SELECT * from tbl") |>
as.data.frame()
#> x
#> 1 a
#> 2 b
#> 3 c
#> 4 d
#> 5 e
#> 6 f
#> 7 g
#> 8 h
#> 9 i
#> 10 j
```
<sup>Created on 2023-10-25 with [reprex
v2.0.2](https://reprex.tidyverse.org)</sup>
---
c/driver/postgresql/postgresql_test.cc | 1 +
c/driver/sqlite/sqlite_test.cc | 2 +-
c/driver/sqlite/statement_reader.c | 41 +++++++++++++++++++++-
c/validation/adbc_validation.cc | 64 ++++++++++++++++++++++++++--------
c/validation/adbc_validation.h | 7 +++-
5 files changed, 98 insertions(+), 17 deletions(-)
diff --git a/c/driver/postgresql/postgresql_test.cc
b/c/driver/postgresql/postgresql_test.cc
index d762ef5b..f6df1809 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -812,6 +812,7 @@ class PostgresStatementTest : public ::testing::Test,
void TestSqlIngestUInt16() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestUInt32() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestUInt64() { GTEST_SKIP() << "Not implemented"; }
+ void TestSqlIngestStringDictionary() { GTEST_SKIP() << "Not implemented"; }
void TestSqlPrepareErrorParamCountMismatch() { GTEST_SKIP() << "Not yet
implemented"; }
void TestSqlPrepareGetParameterSchema() { GTEST_SKIP() << "Not yet
implemented"; }
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index f4455a57..13da21c1 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -246,7 +246,7 @@ class SqliteStatementTest : public ::testing::Test,
void TestSqlIngestUInt64() {
std::vector<std::optional<uint64_t>> values = {std::nullopt, 0, INT64_MAX};
- return TestSqlIngestType(NANOARROW_TYPE_UINT64, values);
+ return TestSqlIngestType(NANOARROW_TYPE_UINT64, values,
/*dictionary_encode*/ false);
}
void TestSqlIngestDuration() {
diff --git a/c/driver/sqlite/statement_reader.c
b/c/driver/sqlite/statement_reader.c
index e3b2525b..9e02ee3b 100644
--- a/c/driver/sqlite/statement_reader.c
+++ b/c/driver/sqlite/statement_reader.c
@@ -60,7 +60,7 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder*
binder,
struct ArrowSchemaView view = {0};
for (int i = 0; i < binder->schema.n_children; i++) {
status = ArrowSchemaViewInit(&view, binder->schema.children[i],
&arrow_error);
- if (status != 0) {
+ if (status != NANOARROW_OK) {
SetError(error, "Failed to parse schema for column %d: %s (%d): %s", i,
strerror(status), status, arrow_error.message);
return ADBC_STATUS_INVALID_ARGUMENT;
@@ -70,6 +70,31 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder*
binder,
SetError(error, "Column %d has UNINITIALIZED type", i);
return ADBC_STATUS_INTERNAL;
}
+
+ if (view.type == NANOARROW_TYPE_DICTIONARY) {
+ struct ArrowSchemaView value_view = {0};
+ status = ArrowSchemaViewInit(&value_view,
binder->schema.children[i]->dictionary,
+ &arrow_error);
+ if (status != NANOARROW_OK) {
+ SetError(error, "Failed to parse schema for column %d->dictionary: %s
(%d): %s",
+ i, strerror(status), status, arrow_error.message);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ // We only support string/binary dictionary-encoded values
+ switch (value_view.type) {
+ case NANOARROW_TYPE_STRING:
+ case NANOARROW_TYPE_LARGE_STRING:
+ case NANOARROW_TYPE_BINARY:
+ case NANOARROW_TYPE_LARGE_BINARY:
+ break;
+ default:
+ SetError(error, "Column %d dictionary has unsupported type %s", i,
+ ArrowTypeString(value_view.type));
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+ }
+
binder->types[i] = view.type;
}
@@ -353,6 +378,20 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct
AdbcSqliteBinder* binder, sqlite3
SQLITE_STATIC);
break;
}
+ case NANOARROW_TYPE_DICTIONARY: {
+ int64_t value_index =
+ ArrowArrayViewGetIntUnsafe(binder->batch.children[col],
binder->next_row);
+ if (ArrowArrayViewIsNull(binder->batch.children[col]->dictionary,
+ value_index)) {
+ status = sqlite3_bind_null(stmt, col + 1);
+ } else {
+ struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe(
+ binder->batch.children[col]->dictionary, value_index);
+ status = sqlite3_bind_text(stmt, col + 1, value.data.as_char,
+ value.size_bytes, SQLITE_STATIC);
+ }
+ break;
+ }
case NANOARROW_TYPE_DATE32: {
int64_t value =
ArrowArrayViewGetIntUnsafe(binder->batch.children[col],
binder->next_row);
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index 6dd3fb7b..f0f42937 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -1366,7 +1366,8 @@ void StatementTest::TestRelease() {
template <typename CType>
void StatementTest::TestSqlIngestType(ArrowType type,
- const std::vector<std::optional<CType>>&
values) {
+ const std::vector<std::optional<CType>>&
values,
+ bool dictionary_encode) {
if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) {
GTEST_SKIP();
}
@@ -1381,6 +1382,38 @@ void StatementTest::TestSqlIngestType(ArrowType type,
ASSERT_THAT(MakeBatch<CType>(&schema.value, &array.value, &na_error, values),
IsOkErrno());
+ if (dictionary_encode) {
+ // Create a dictionary-encoded version of the target schema
+ Handle<struct ArrowSchema> dict_schema;
+ ASSERT_THAT(ArrowSchemaInitFromType(&dict_schema.value,
NANOARROW_TYPE_INT32),
+ IsOkErrno());
+ ASSERT_THAT(ArrowSchemaSetName(&dict_schema.value,
schema.value.children[0]->name),
+ IsOkErrno());
+ ASSERT_THAT(ArrowSchemaSetName(schema.value.children[0], nullptr),
IsOkErrno());
+
+ // Swap it into the target schema
+ ASSERT_THAT(ArrowSchemaAllocateDictionary(&dict_schema.value),
IsOkErrno());
+ ArrowSchemaMove(schema.value.children[0], dict_schema.value.dictionary);
+ ArrowSchemaMove(&dict_schema.value, schema.value.children[0]);
+
+ // Create a dictionary-encoded array with easy 0...n indices so that the
+ // matched values will be the same.
+ Handle<struct ArrowArray> dict_array;
+ ASSERT_THAT(ArrowArrayInitFromType(&dict_array.value,
NANOARROW_TYPE_INT32),
+ IsOkErrno());
+ ASSERT_THAT(ArrowArrayStartAppending(&dict_array.value), IsOkErrno());
+ for (size_t i = 0; i < values.size(); i++) {
+ ASSERT_THAT(ArrowArrayAppendInt(&dict_array.value,
static_cast<int64_t>(i)),
+ IsOkErrno());
+ }
+ ASSERT_THAT(ArrowArrayFinishBuildingDefault(&dict_array.value, nullptr),
IsOkErrno());
+
+ // Swap it into the target batch
+ ASSERT_THAT(ArrowArrayAllocateDictionary(&dict_array.value), IsOkErrno());
+ ArrowArrayMove(array.value.children[0], dict_array.value.dictionary);
+ ArrowArrayMove(&dict_array.value, array.value.children[0]);
+ }
+
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement,
ADBC_INGEST_OPTION_TARGET_TABLE,
"bulk_ingest", &error),
@@ -1448,7 +1481,7 @@ void StatementTest::TestSqlIngestNumericType(ArrowType
type) {
values.push_back(std::numeric_limits<CType>::max());
}
- return TestSqlIngestType(type, values);
+ return TestSqlIngestType(type, values, false);
}
void StatementTest::TestSqlIngestBool() {
@@ -1497,25 +1530,23 @@ void StatementTest::TestSqlIngestFloat64() {
void StatementTest::TestSqlIngestString() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
- NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"}));
+ NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"}, false));
}
void StatementTest::TestSqlIngestLargeString() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
- NANOARROW_TYPE_LARGE_STRING, {std::nullopt, "", "", "1234", "例"}));
+ NANOARROW_TYPE_LARGE_STRING, {std::nullopt, "", "", "1234", "例"},
false));
}
void StatementTest::TestSqlIngestBinary() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::vector<std::byte>>(
NANOARROW_TYPE_BINARY,
- {
- std::nullopt, std::vector<std::byte>{},
- std::vector<std::byte>{std::byte{0x00}, std::byte{0x01}},
- std::vector<std::byte>{
- std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}
- },
- std::vector<std::byte>{std::byte{0xfe}, std::byte{0xff}}
- }));
+ {std::nullopt, std::vector<std::byte>{},
+ std::vector<std::byte>{std::byte{0x00}, std::byte{0x01}},
+ std::vector<std::byte>{std::byte{0x01}, std::byte{0x02},
std::byte{0x03},
+ std::byte{0x04}},
+ std::vector<std::byte>{std::byte{0xfe}, std::byte{0xff}}},
+ false));
}
void StatementTest::TestSqlIngestDate32() {
@@ -1737,6 +1768,12 @@ void StatementTest::TestSqlIngestInterval() {
ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}
+void StatementTest::TestSqlIngestStringDictionary() {
+ ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
+ NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"},
+ /*dictionary_encode*/ true));
+}
+
void StatementTest::TestSqlIngestTableEscaping() {
std::string name = "create_table_escaping";
@@ -2112,8 +2149,7 @@ void StatementTest::TestSqlIngestErrors() {
{"coltwo", NANOARROW_TYPE_INT64}}),
IsOkErrno());
ASSERT_THAT(
- (MakeBatch<int64_t, int64_t>(&schema.value, &array.value, &na_error,
- {-42}, {-42})),
+ (MakeBatch<int64_t, int64_t>(&schema.value, &array.value, &na_error,
{-42}, {-42})),
IsOkErrno(&na_error));
ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value,
&error),
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 2e4c894d..e2b5d434 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -327,6 +327,9 @@ class StatementTest {
void TestSqlIngestTimestampTz();
void TestSqlIngestInterval();
+ // Dictionary-encoded
+ void TestSqlIngestStringDictionary();
+
// ---- End Type-specific tests ----------------
void TestSqlIngestTableEscaping();
@@ -387,7 +390,8 @@ class StatementTest {
struct AdbcStatement statement;
template <typename CType>
- void TestSqlIngestType(ArrowType type, const
std::vector<std::optional<CType>>& values);
+ void TestSqlIngestType(ArrowType type, const
std::vector<std::optional<CType>>& values,
+ bool dictionary_encode);
template <typename CType>
void TestSqlIngestNumericType(ArrowType type);
@@ -424,6 +428,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); }
\
TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); }
\
TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); }
\
+ TEST_F(FIXTURE, SqlIngestStringDictionary) {
TestSqlIngestStringDictionary(); } \
TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); }
\
TEST_F(FIXTURE, SqlIngestColumnEscaping) { TestSqlIngestColumnEscaping(); }
\
TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); }
\