This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 995a02d5 feat(c/driver): Date32 support (#948)
995a02d5 is described below

commit 995a02d54555f71a4ea6c442426201f5358d8846
Author: William Ayd <[email protected]>
AuthorDate: Tue Aug 1 09:18:05 2023 -0400

    feat(c/driver): Date32 support (#948)
---
 c/driver/postgresql/postgresql_test.cc       |  6 +--
 c/driver/postgresql/statement.cc             | 23 +++++++++
 c/driver/snowflake/snowflake_test.cc         |  6 +--
 c/driver/sqlite/sqlite_test.cc               |  7 +--
 c/driver/sqlite/statement_reader.c           | 75 ++++++++++++++++++++++++++++
 c/driver_manager/adbc_driver_manager_test.cc |  1 +
 c/validation/adbc_validation.cc              | 45 ++++++++++-------
 c/validation/adbc_validation.h               | 10 ++--
 8 files changed, 142 insertions(+), 31 deletions(-)

diff --git a/c/driver/postgresql/postgresql_test.cc 
b/c/driver/postgresql/postgresql_test.cc
index 33115bcf..a826e172 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -576,9 +576,9 @@ class PostgresStatementTest : public ::testing::Test,
   }
 
  protected:
-  void ValidateIngestedTemporalData(struct ArrowArrayView* values,
-                                    enum ArrowTimeUnit unit,
-                                    const char* timezone) override {
+  void ValidateIngestedTimestampData(struct ArrowArrayView* values,
+                                     enum ArrowTimeUnit unit,
+                                     const char* timezone) override {
     std::vector<std::optional<int64_t>> expected;
     switch (unit) {
       case (NANOARROW_TIME_UNIT_SECOND):
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index dd3fd82c..3452411d 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -219,6 +219,10 @@ struct BindStream {
           type_id = PostgresTypeId::kBytea;
           param_lengths[i] = 0;
           break;
+        case ArrowType::NANOARROW_TYPE_DATE32:
+          type_id = PostgresTypeId::kDate;
+          param_lengths[i] = 4;
+          break;
         case ArrowType::NANOARROW_TYPE_TIMESTAMP:
           type_id = PostgresTypeId::kTimestamp;
           param_lengths[i] = 8;
@@ -389,6 +393,22 @@ struct BindStream {
               param_values[col] = const_cast<char*>(view.data.as_char);
               break;
             }
+            case ArrowType::NANOARROW_TYPE_DATE32: {
+              // 2000-01-01
+              constexpr int32_t kPostgresDateEpoch = 10957;
+              const int32_t raw_value =
+                  
array_view->children[col]->buffer_views[1].data.as_int32[row];
+              if (raw_value < INT32_MIN + kPostgresDateEpoch) {
+                SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 
"%s", col + 1,
+                         "('", bind_schema->children[col]->name, "') Row #", 
row + 1,
+                         "has value which exceeds postgres date limits");
+                return ADBC_STATUS_INVALID_ARGUMENT;
+              }
+
+              const uint32_t value = ToNetworkInt32(raw_value - 
kPostgresDateEpoch);
+              std::memcpy(param_values[col], &value, sizeof(int32_t));
+              break;
+            }
             case ArrowType::NANOARROW_TYPE_TIMESTAMP: {
               int64_t val = 
array_view->children[col]->buffer_views[1].data.as_int64[row];
 
@@ -801,6 +821,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
       case ArrowType::NANOARROW_TYPE_BINARY:
         create += " BYTEA";
         break;
+      case ArrowType::NANOARROW_TYPE_DATE32:
+        create += " DATE";
+        break;
       case ArrowType::NANOARROW_TYPE_TIMESTAMP:
         if (strcmp("", source_schema_fields[i].timezone)) {
           create += " TIMESTAMPTZ";
diff --git a/c/driver/snowflake/snowflake_test.cc 
b/c/driver/snowflake/snowflake_test.cc
index c8e08a56..96edb889 100644
--- a/c/driver/snowflake/snowflake_test.cc
+++ b/c/driver/snowflake/snowflake_test.cc
@@ -178,9 +178,9 @@ class SnowflakeStatementTest : public ::testing::Test,
   }
 
  protected:
-  void ValidateIngestedTemporalData(struct ArrowArrayView* values,
-                                    enum ArrowTimeUnit unit,
-                                    const char* timezone) override {
+  void ValidateIngestedTimestampData(struct ArrowArrayView* values,
+                                     enum ArrowTimeUnit unit,
+                                     const char* timezone) override {
     std::vector<std::optional<int64_t>> expected;
     switch (unit) {
       case NANOARROW_TIME_UNIT_SECOND:
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index 00332066..b03158cc 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -77,6 +77,7 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
       case NANOARROW_TYPE_FLOAT:
       case NANOARROW_TYPE_DOUBLE:
         return NANOARROW_TYPE_DOUBLE;
+      case NANOARROW_TYPE_DATE32:
       case NANOARROW_TYPE_TIMESTAMP:
         return NANOARROW_TYPE_STRING;
       default:
@@ -200,9 +201,9 @@ class SqliteStatementTest : public ::testing::Test,
   }
 
  protected:
-  void ValidateIngestedTemporalData(struct ArrowArrayView* values,
-                                    enum ArrowTimeUnit unit,
-                                    const char* timezone) override {
+  void ValidateIngestedTimestampData(struct ArrowArrayView* values,
+                                     enum ArrowTimeUnit unit,
+                                     const char* timezone) override {
     std::vector<std::optional<std::string>> expected;
     switch (unit) {
       case (NANOARROW_TIME_UNIT_SECOND):
diff --git a/c/driver/sqlite/statement_reader.c 
b/c/driver/sqlite/statement_reader.c
index 366c0fa1..f8483c53 100644
--- a/c/driver/sqlite/statement_reader.c
+++ b/c/driver/sqlite/statement_reader.c
@@ -93,6 +93,59 @@ AdbcStatusCode AdbcSqliteBinderSetArrayStream(struct 
AdbcSqliteBinder* binder,
   return AdbcSqliteBinderSet(binder, error);
 }
 
+#define SECONDS_PER_DAY 86400
+
+/*
+  Allocates to buf on success. Caller is responsible for freeing.
+  On failure sets error and contents of buf are undefined.
+*/
+static AdbcStatusCode ArrowDate32ToIsoString(int32_t value, char** buf,
+                                             struct AdbcError* error) {
+  int strlen = 10;
+
+#if SIZEOF_TIME_T < 8
+  if ((seconds > INT32_MAX / SECONDS_PER_DAY) ||
+      (seconds < INT32_MIN / SECONDS_PER_DAY)) {
+    SetError(error, "Date %" PRId32 " exceeds platform time_t bounds", value);
+
+    return ADBC_STATUS_INVALID_ARGUMENT;
+  }
+  time_t time = (time_t)(value * SECONDS_PER_DAY);
+#else
+  time_t time = value * SECONDS_PER_DAY;
+#endif
+
+  struct tm broken_down_time;
+
+#if defined(_WIN32)
+  if (gmtime_s(&broken_down_time, &time) != 0) {
+    SetError(error, "Could not convert date %" PRId32 " to broken down time", 
value);
+
+    return ADBC_STATUS_INVALID_ARGUMENT;
+  }
+#else
+  if (gmtime_r(&time, &broken_down_time) != &broken_down_time) {
+    SetError(error, "Could not convert date %" PRId32 " to broken down time", 
value);
+
+    return ADBC_STATUS_INVALID_ARGUMENT;
+  }
+#endif
+
+  char* tsstr = malloc(strlen + 1);
+  if (tsstr == NULL) {
+    return ADBC_STATUS_IO;
+  }
+
+  if (strftime(tsstr, strlen + 1, "%Y-%m-%d", &broken_down_time) == 0) {
+    SetError(error, "Call to strftime for date %" PRId32 " with failed", 
value);
+    free(tsstr);
+    return ADBC_STATUS_INVALID_ARGUMENT;
+  }
+
+  *buf = tsstr;
+  return ADBC_STATUS_OK;
+}
+
 /*
   Allocates to buf on success. Caller is responsible for freeing.
   On failure sets error and contents of buf are undefined.
@@ -300,6 +353,28 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct 
AdbcSqliteBinder* binder, sqlite3
                                      SQLITE_STATIC);
           break;
         }
+        case NANOARROW_TYPE_DATE32: {
+          int64_t value =
+              ArrowArrayViewGetIntUnsafe(binder->batch.children[col], 
binder->next_row);
+          char* tsstr;
+
+          if ((value > INT32_MAX) || (value < INT32_MIN)) {
+            SetError(error,
+                     "Column %d has value %" PRId64
+                     " which exceeds the expected range "
+                     "for an Arrow DATE32 type",
+                     col, value);
+            return ADBC_STATUS_INVALID_DATA;
+          }
+
+          RAISE_ADBC(ArrowDate32ToIsoString((int32_t)value, &tsstr, error));
+          // SQLITE_TRANSIENT ensures the value is copied during bind
+          status =
+              sqlite3_bind_text(stmt, col + 1, tsstr, strlen(tsstr), 
SQLITE_TRANSIENT);
+
+          free(tsstr);
+          break;
+        }
         case NANOARROW_TYPE_TIMESTAMP: {
           struct ArrowSchemaView bind_schema_view;
           RAISE_ADBC(ArrowSchemaViewInit(&bind_schema_view, 
binder->schema.children[col],
diff --git a/c/driver_manager/adbc_driver_manager_test.cc 
b/c/driver_manager/adbc_driver_manager_test.cc
index 149da7c6..d3ff6f58 100644
--- a/c/driver_manager/adbc_driver_manager_test.cc
+++ b/c/driver_manager/adbc_driver_manager_test.cc
@@ -226,6 +226,7 @@ class SqliteStatementTest : public ::testing::Test,
 
   void TestSqlIngestUInt64() { GTEST_SKIP() << "Cannot ingest UINT64 (out of 
range)"; }
   void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not 
implemented)"; }
+  void TestSqlIngestDate32() { GTEST_SKIP() << "Cannot ingest DATE (not 
implemented)"; }
   void TestSqlIngestTimestamp() {
     GTEST_SKIP() << "Cannot ingest TIMESTAMP (not implemented)";
   }
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index 54f3981c..8f519b04 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -1037,6 +1037,10 @@ void StatementTest::TestSqlIngestNumericType(ArrowType 
type) {
     // values. Likely a bug on our side, but for now, avoid them.
     values.push_back(static_cast<CType>(-1.5));
     values.push_back(static_cast<CType>(1.5));
+  } else if (type == ArrowType::NANOARROW_TYPE_DATE32) {
+    // Windows does not seem to support negative date values
+    values.push_back(static_cast<CType>(0));
+    values.push_back(static_cast<CType>(42));
   } else {
     values.push_back(std::numeric_limits<CType>::lowest());
     values.push_back(std::numeric_limits<CType>::max());
@@ -1095,8 +1099,12 @@ void StatementTest::TestSqlIngestBinary() {
       NANOARROW_TYPE_BINARY, {std::nullopt, "", "\x00\x01\x02\x04", 
"\xFE\xFF"}));
 }
 
+void StatementTest::TestSqlIngestDate32() {
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<int32_t>(NANOARROW_TYPE_DATE32));
+}
+
 template <enum ArrowTimeUnit TU>
-void StatementTest::TestSqlIngestTemporalType(const char* timezone) {
+void StatementTest::TestSqlIngestTimestampType(const char* timezone) {
   if (!quirks()->supports_bulk_ingest()) {
     GTEST_SKIP();
   }
@@ -1155,7 +1163,7 @@ void StatementTest::TestSqlIngestTemporalType(const char* 
timezone) {
     ASSERT_EQ(values.size(), reader.array->length);
     ASSERT_EQ(1, reader.array->n_children);
 
-    ValidateIngestedTemporalData(reader.array_view->children[0], TU, timezone);
+    ValidateIngestedTimestampData(reader.array_view->children[0], TU, 
timezone);
 
     ASSERT_NO_FATAL_FAILURE(reader.Next());
     ASSERT_EQ(nullptr, reader.array->release);
@@ -1164,33 +1172,34 @@ void StatementTest::TestSqlIngestTemporalType(const 
char* timezone) {
   ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
 }
 
-void StatementTest::ValidateIngestedTemporalData(struct ArrowArrayView* values,
-                                                 enum ArrowTimeUnit unit,
-                                                 const char* timezone) {
-  FAIL() << "ValidateIngestedTemporalData is not implemented in the base 
class";
+void StatementTest::ValidateIngestedTimestampData(struct ArrowArrayView* 
values,
+                                                  enum ArrowTimeUnit unit,
+                                                  const char* timezone) {
+  FAIL() << "ValidateIngestedTimestampData is not implemented in the base 
class";
 }
 
 void StatementTest::TestSqlIngestTimestamp() {
-  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>(nullptr));
-  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>(nullptr));
-  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>(nullptr));
-  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>(nullptr));
+  ASSERT_NO_FATAL_FAILURE(
+      TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_SECOND>(nullptr));
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_MILLI>(nullptr));
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_MICRO>(nullptr));
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_NANO>(nullptr));
 }
 
 void StatementTest::TestSqlIngestTimestampTz() {
-  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>("UTC"));
-  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>("UTC"));
-  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>("UTC"));
-  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>("UTC"));
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_SECOND>("UTC"));
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_MILLI>("UTC"));
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_MICRO>("UTC"));
+  
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_NANO>("UTC"));
 
   ASSERT_NO_FATAL_FAILURE(
-      
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>("America/Los_Angeles"));
+      
TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_SECOND>("America/Los_Angeles"));
   ASSERT_NO_FATAL_FAILURE(
-      
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>("America/Los_Angeles"));
+      
TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_MILLI>("America/Los_Angeles"));
   ASSERT_NO_FATAL_FAILURE(
-      
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>("America/Los_Angeles"));
+      
TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_MICRO>("America/Los_Angeles"));
   ASSERT_NO_FATAL_FAILURE(
-      
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>("America/Los_Angeles"));
+      
TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_NANO>("America/Los_Angeles"));
 }
 
 void StatementTest::TestSqlIngestInterval() {
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index dc5d69c2..23dacb7f 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -230,6 +230,7 @@ class StatementTest {
   void TestSqlIngestBinary();
 
   // Temporal
+  void TestSqlIngestDate32();
   void TestSqlIngestTimestamp();
   void TestSqlIngestTimestampTz();
   void TestSqlIngestInterval();
@@ -277,11 +278,11 @@ class StatementTest {
   void TestSqlIngestNumericType(ArrowType type);
 
   template <enum ArrowTimeUnit TU>
-  void TestSqlIngestTemporalType(const char* timezone);
+  void TestSqlIngestTimestampType(const char* timezone);
 
-  virtual void ValidateIngestedTemporalData(struct ArrowArrayView* values,
-                                            enum ArrowTimeUnit unit,
-                                            const char* timezone);
+  virtual void ValidateIngestedTimestampData(struct ArrowArrayView* values,
+                                             enum ArrowTimeUnit unit,
+                                             const char* timezone);
 };
 
 #define ADBCV_TEST_STATEMENT(FIXTURE)                                          
         \
@@ -301,6 +302,7 @@ class StatementTest {
   TEST_F(FIXTURE, SqlIngestFloat64) { TestSqlIngestFloat64(); }                
         \
   TEST_F(FIXTURE, SqlIngestString) { TestSqlIngestString(); }                  
         \
   TEST_F(FIXTURE, SqlIngestBinary) { TestSqlIngestBinary(); }                  
         \
+  TEST_F(FIXTURE, SqlIngestDate32) { TestSqlIngestDate32(); }                  
         \
   TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); }            
         \
   TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); }        
         \
   TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); }              
         \

Reply via email to