This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 17898709 feat(c/driver/sqlite): Support binding dictionary-encoded 
string and binary types (#1224)
17898709 is described below

commit 17898709be7a8a207726beb083afcc9d6d2c0d3d
Author: Dewey Dunnington <[email protected]>
AuthorDate: Thu Oct 26 19:47:35 2023 +0000

    feat(c/driver/sqlite): Support binding dictionary-encoded string and binary 
types (#1224)
    
    This PR adds the ability to ingest dictionary-encoded string and binary
    columns.
    
    Part of addressing #1008.
    
    From the R bindings:
    
    ``` r
    library(adbcdrivermanager)
    
    db <- adbc_database_init(adbcsqlite::adbcsqlite(), uri = ":memory:")
    con <- adbc_connection_init(db)
    
    df <- data.frame(x = factor(letters[1:10]))
    write_adbc(df, con, "tbl")
    
    read_adbc(con, "SELECT * from tbl") |>
      as.data.frame()
    #>    x
    #> 1  a
    #> 2  b
    #> 3  c
    #> 4  d
    #> 5  e
    #> 6  f
    #> 7  g
    #> 8  h
    #> 9  i
    #> 10 j
    ```
    
    <sup>Created on 2023-10-25 with [reprex
    v2.0.2](https://reprex.tidyverse.org)</sup>
---
 c/driver/postgresql/postgresql_test.cc |  1 +
 c/driver/sqlite/sqlite_test.cc         |  2 +-
 c/driver/sqlite/statement_reader.c     | 41 +++++++++++++++++++++-
 c/validation/adbc_validation.cc        | 64 ++++++++++++++++++++++++++--------
 c/validation/adbc_validation.h         |  7 +++-
 5 files changed, 98 insertions(+), 17 deletions(-)

diff --git a/c/driver/postgresql/postgresql_test.cc 
b/c/driver/postgresql/postgresql_test.cc
index d762ef5b..f6df1809 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -812,6 +812,7 @@ class PostgresStatementTest : public ::testing::Test,
   void TestSqlIngestUInt16() { GTEST_SKIP() << "Not implemented"; }
   void TestSqlIngestUInt32() { GTEST_SKIP() << "Not implemented"; }
   void TestSqlIngestUInt64() { GTEST_SKIP() << "Not implemented"; }
+  void TestSqlIngestStringDictionary() { GTEST_SKIP() << "Not implemented"; }
 
   void TestSqlPrepareErrorParamCountMismatch() { GTEST_SKIP() << "Not yet 
implemented"; }
   void TestSqlPrepareGetParameterSchema() { GTEST_SKIP() << "Not yet 
implemented"; }
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index f4455a57..13da21c1 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -246,7 +246,7 @@ class SqliteStatementTest : public ::testing::Test,
 
   void TestSqlIngestUInt64() {
     std::vector<std::optional<uint64_t>> values = {std::nullopt, 0, INT64_MAX};
-    return TestSqlIngestType(NANOARROW_TYPE_UINT64, values);
+    return TestSqlIngestType(NANOARROW_TYPE_UINT64, values, 
/*dictionary_encode*/ false);
   }
 
   void TestSqlIngestDuration() {
diff --git a/c/driver/sqlite/statement_reader.c 
b/c/driver/sqlite/statement_reader.c
index e3b2525b..9e02ee3b 100644
--- a/c/driver/sqlite/statement_reader.c
+++ b/c/driver/sqlite/statement_reader.c
@@ -60,7 +60,7 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder* 
binder,
   struct ArrowSchemaView view = {0};
   for (int i = 0; i < binder->schema.n_children; i++) {
     status = ArrowSchemaViewInit(&view, binder->schema.children[i], 
&arrow_error);
-    if (status != 0) {
+    if (status != NANOARROW_OK) {
       SetError(error, "Failed to parse schema for column %d: %s (%d): %s", i,
                strerror(status), status, arrow_error.message);
       return ADBC_STATUS_INVALID_ARGUMENT;
@@ -70,6 +70,31 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder* 
binder,
       SetError(error, "Column %d has UNINITIALIZED type", i);
       return ADBC_STATUS_INTERNAL;
     }
+
+    if (view.type == NANOARROW_TYPE_DICTIONARY) {
+      struct ArrowSchemaView value_view = {0};
+      status = ArrowSchemaViewInit(&value_view, 
binder->schema.children[i]->dictionary,
+                                   &arrow_error);
+      if (status != NANOARROW_OK) {
+        SetError(error, "Failed to parse schema for column %d->dictionary: %s 
(%d): %s",
+                 i, strerror(status), status, arrow_error.message);
+        return ADBC_STATUS_INVALID_ARGUMENT;
+      }
+
+      // We only support string/binary dictionary-encoded values
+      switch (value_view.type) {
+        case NANOARROW_TYPE_STRING:
+        case NANOARROW_TYPE_LARGE_STRING:
+        case NANOARROW_TYPE_BINARY:
+        case NANOARROW_TYPE_LARGE_BINARY:
+          break;
+        default:
+          SetError(error, "Column %d dictionary has unsupported type %s", i,
+                   ArrowTypeString(value_view.type));
+          return ADBC_STATUS_NOT_IMPLEMENTED;
+      }
+    }
+
     binder->types[i] = view.type;
   }
 
@@ -353,6 +378,20 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct 
AdbcSqliteBinder* binder, sqlite3
                                      SQLITE_STATIC);
           break;
         }
+        case NANOARROW_TYPE_DICTIONARY: {
+          int64_t value_index =
+              ArrowArrayViewGetIntUnsafe(binder->batch.children[col], 
binder->next_row);
+          if (ArrowArrayViewIsNull(binder->batch.children[col]->dictionary,
+                                   value_index)) {
+            status = sqlite3_bind_null(stmt, col + 1);
+          } else {
+            struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe(
+                binder->batch.children[col]->dictionary, value_index);
+            status = sqlite3_bind_text(stmt, col + 1, value.data.as_char,
+                                       value.size_bytes, SQLITE_STATIC);
+          }
+          break;
+        }
         case NANOARROW_TYPE_DATE32: {
           int64_t value =
               ArrowArrayViewGetIntUnsafe(binder->batch.children[col], 
binder->next_row);
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index 6dd3fb7b..f0f42937 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -1366,7 +1366,8 @@ void StatementTest::TestRelease() {
 
 template <typename CType>
 void StatementTest::TestSqlIngestType(ArrowType type,
-                                      const std::vector<std::optional<CType>>& 
values) {
+                                      const std::vector<std::optional<CType>>& 
values,
+                                      bool dictionary_encode) {
   if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) {
     GTEST_SKIP();
   }
@@ -1381,6 +1382,38 @@ void StatementTest::TestSqlIngestType(ArrowType type,
   ASSERT_THAT(MakeBatch<CType>(&schema.value, &array.value, &na_error, values),
               IsOkErrno());
 
+  if (dictionary_encode) {
+    // Create a dictionary-encoded version of the target schema
+    Handle<struct ArrowSchema> dict_schema;
+    ASSERT_THAT(ArrowSchemaInitFromType(&dict_schema.value, 
NANOARROW_TYPE_INT32),
+                IsOkErrno());
+    ASSERT_THAT(ArrowSchemaSetName(&dict_schema.value, 
schema.value.children[0]->name),
+                IsOkErrno());
+    ASSERT_THAT(ArrowSchemaSetName(schema.value.children[0], nullptr), 
IsOkErrno());
+
+    // Swap it into the target schema
+    ASSERT_THAT(ArrowSchemaAllocateDictionary(&dict_schema.value), 
IsOkErrno());
+    ArrowSchemaMove(schema.value.children[0], dict_schema.value.dictionary);
+    ArrowSchemaMove(&dict_schema.value, schema.value.children[0]);
+
+    // Create a dictionary-encoded array with easy 0...n indices so that the
+    // matched values will be the same.
+    Handle<struct ArrowArray> dict_array;
+    ASSERT_THAT(ArrowArrayInitFromType(&dict_array.value, 
NANOARROW_TYPE_INT32),
+                IsOkErrno());
+    ASSERT_THAT(ArrowArrayStartAppending(&dict_array.value), IsOkErrno());
+    for (size_t i = 0; i < values.size(); i++) {
+      ASSERT_THAT(ArrowArrayAppendInt(&dict_array.value, 
static_cast<int64_t>(i)),
+                  IsOkErrno());
+    }
+    ASSERT_THAT(ArrowArrayFinishBuildingDefault(&dict_array.value, nullptr), 
IsOkErrno());
+
+    // Swap it into the target batch
+    ASSERT_THAT(ArrowArrayAllocateDictionary(&dict_array.value), IsOkErrno());
+    ArrowArrayMove(array.value.children[0], dict_array.value.dictionary);
+    ArrowArrayMove(&dict_array.value, array.value.children[0]);
+  }
+
   ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
   ASSERT_THAT(AdbcStatementSetOption(&statement, 
ADBC_INGEST_OPTION_TARGET_TABLE,
                                      "bulk_ingest", &error),
@@ -1448,7 +1481,7 @@ void StatementTest::TestSqlIngestNumericType(ArrowType 
type) {
     values.push_back(std::numeric_limits<CType>::max());
   }
 
-  return TestSqlIngestType(type, values);
+  return TestSqlIngestType(type, values, false);
 }
 
 void StatementTest::TestSqlIngestBool() {
@@ -1497,25 +1530,23 @@ void StatementTest::TestSqlIngestFloat64() {
 
 void StatementTest::TestSqlIngestString() {
   ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
-      NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"}));
+      NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"}, false));
 }
 
 void StatementTest::TestSqlIngestLargeString() {
   ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
-      NANOARROW_TYPE_LARGE_STRING, {std::nullopt, "", "", "1234", "例"}));
+      NANOARROW_TYPE_LARGE_STRING, {std::nullopt, "", "", "1234", "例"}, 
false));
 }
 
 void StatementTest::TestSqlIngestBinary() {
   ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::vector<std::byte>>(
       NANOARROW_TYPE_BINARY,
-      {
-        std::nullopt, std::vector<std::byte>{},
-        std::vector<std::byte>{std::byte{0x00}, std::byte{0x01}},
-        std::vector<std::byte>{
-          std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}
-        },
-        std::vector<std::byte>{std::byte{0xfe}, std::byte{0xff}}
-      }));
+      {std::nullopt, std::vector<std::byte>{},
+       std::vector<std::byte>{std::byte{0x00}, std::byte{0x01}},
+       std::vector<std::byte>{std::byte{0x01}, std::byte{0x02}, 
std::byte{0x03},
+                              std::byte{0x04}},
+       std::vector<std::byte>{std::byte{0xfe}, std::byte{0xff}}},
+      false));
 }
 
 void StatementTest::TestSqlIngestDate32() {
@@ -1737,6 +1768,12 @@ void StatementTest::TestSqlIngestInterval() {
   ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
 }
 
+void StatementTest::TestSqlIngestStringDictionary() {
+  ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
+      NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"},
+      /*dictionary_encode*/ true));
+}
+
 void StatementTest::TestSqlIngestTableEscaping() {
   std::string name = "create_table_escaping";
 
@@ -2112,8 +2149,7 @@ void StatementTest::TestSqlIngestErrors() {
                                          {"coltwo", NANOARROW_TYPE_INT64}}),
               IsOkErrno());
   ASSERT_THAT(
-      (MakeBatch<int64_t, int64_t>(&schema.value, &array.value, &na_error,
-                                   {-42}, {-42})),
+      (MakeBatch<int64_t, int64_t>(&schema.value, &array.value, &na_error, 
{-42}, {-42})),
       IsOkErrno(&na_error));
 
   ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, 
&error),
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 2e4c894d..e2b5d434 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -327,6 +327,9 @@ class StatementTest {
   void TestSqlIngestTimestampTz();
   void TestSqlIngestInterval();
 
+  // Dictionary-encoded
+  void TestSqlIngestStringDictionary();
+
   // ---- End Type-specific tests ----------------
 
   void TestSqlIngestTableEscaping();
@@ -387,7 +390,8 @@ class StatementTest {
   struct AdbcStatement statement;
 
   template <typename CType>
-  void TestSqlIngestType(ArrowType type, const 
std::vector<std::optional<CType>>& values);
+  void TestSqlIngestType(ArrowType type, const 
std::vector<std::optional<CType>>& values,
+                         bool dictionary_encode);
 
   template <typename CType>
   void TestSqlIngestNumericType(ArrowType type);
@@ -424,6 +428,7 @@ class StatementTest {
   TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); }            
         \
   TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); }        
         \
   TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); }              
         \
+  TEST_F(FIXTURE, SqlIngestStringDictionary) { 
TestSqlIngestStringDictionary(); }       \
   TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); }    
         \
   TEST_F(FIXTURE, SqlIngestColumnEscaping) { TestSqlIngestColumnEscaping(); }  
         \
   TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); }                  
         \

Reply via email to