This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 88849ddf refactor(c/driver/postgresql): hardcode overflow checks 
(#1051)
88849ddf is described below

commit 88849ddfb77a819730e368e49a19f46240e0ca62
Author: David Li <[email protected]>
AuthorDate: Mon Sep 11 14:04:48 2023 -0400

    refactor(c/driver/postgresql): hardcode overflow checks (#1051)
    
    Fixes #1017.
---
 c/driver/postgresql/postgres_copy_reader.h |  29 +++++++-
 c/driver/postgresql/postgresql_test.cc     | 106 +++++++++++++++++++++++++++++
 c/driver/postgresql/statement.cc           |  21 +++---
 r/adbcpostgresql/bootstrap.R               |   9 +--
 4 files changed, 149 insertions(+), 16 deletions(-)

diff --git a/c/driver/postgresql/postgres_copy_reader.h 
b/c/driver/postgresql/postgres_copy_reader.h
index 5c7214dc..5a589700 100644
--- a/c/driver/postgresql/postgres_copy_reader.h
+++ b/c/driver/postgresql/postgres_copy_reader.h
@@ -30,7 +30,6 @@
 
 #include "postgres_type.h"
 #include "postgres_util.h"
-#include "vendor/portable-snippets/safe-math.h"
 
 // R 3.6 / Windows builds on a very old toolchain that does not define ENODATA
 #if defined(_WIN32) && !defined(MSVC) && !defined(ENODATA)
@@ -44,6 +43,30 @@ static int8_t kPgCopyBinarySignature[] = {0x50, 0x47, 0x43, 
0x4F,
                                           0x50, 0x59, 0x0A, 
static_cast<int8_t>(0xFF),
                                           0x0D, 0x0A, 0x00};
 
+// The maximum value in seconds that can be converted into microseconds
+// without overflow
+constexpr int64_t kMaxSafeSecondsToMicros = 9223372036854L;
+
+// The minimum value in seconds that can be converted into microseconds
+// without overflow
+constexpr int64_t kMinSafeSecondsToMicros = -9223372036854L;
+
+// The maximum value in milliseconds that can be converted into microseconds
+// without overflow
+constexpr int64_t kMaxSafeMillisToMicros = 9223372036854775L;
+
+// The minimum value in milliseconds that can be converted into microseconds
+// without overflow
+constexpr int64_t kMinSafeMillisToMicros = -9223372036854775L;
+
+// The maximum value in microseconds that can be converted into nanoseconds
+// without overflow
+constexpr int64_t kMaxSafeMicrosToNanos = 9223372036854775L;
+
+// The minimum value in microseconds that can be converted into nanoseconds
+// without overflow
+constexpr int64_t kMinSafeMicrosToNanos = -9223372036854775L;
+
 // Read a value from the buffer without checking the buffer size. Advances
 // the cursor of data and reduces its size by sizeof(T).
 template <typename T>
@@ -234,7 +257,7 @@ class PostgresCopyIntervalFieldReader : public 
PostgresCopyFieldReader {
     const int64_t time_usec = ReadUnsafe<int64_t>(data);
     int64_t time;
 
-    if (!psnip_safe_int64_mul(&time, time_usec, 1000)) {
+    if (time_usec > kMaxSafeMicrosToNanos || time_usec < 
kMinSafeMicrosToNanos) {
       ArrowErrorSet(error,
                     "[libpq] Interval with time value %" PRId64
                     " usec would overflow when converting to nanoseconds",
@@ -242,6 +265,8 @@ class PostgresCopyIntervalFieldReader : public 
PostgresCopyFieldReader {
       return EINVAL;
     }
 
+    time = time_usec * 1000;
+
     const int32_t days = ReadUnsafe<int32_t>(data);
     const int32_t months = ReadUnsafe<int32_t>(data);
 
diff --git a/c/driver/postgresql/postgresql_test.cc 
b/c/driver/postgresql/postgresql_test.cc
index 3b776991..aabb8133 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -878,6 +878,112 @@ class PostgresStatementTest : public ::testing::Test,
 };
 ADBCV_TEST_STATEMENT(PostgresStatementTest)
 
+TEST_F(PostgresStatementTest, SqlIngestTimestampOverflow) {
+  ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
+
+  {
+    adbc_validation::Handle<struct ArrowSchema> schema;
+    adbc_validation::Handle<struct ArrowArray> batch;
+
+    ArrowSchemaInit(&schema.value);
+    ASSERT_THAT(ArrowSchemaSetTypeStruct(&schema.value, 1), 
adbc_validation::IsOkErrno());
+    ASSERT_THAT(ArrowSchemaSetName(schema->children[0], "$1"),
+                adbc_validation::IsOkErrno());
+    ASSERT_THAT(ArrowSchemaSetTypeDateTime(schema->children[0], 
NANOARROW_TYPE_TIMESTAMP,
+                                           NANOARROW_TIME_UNIT_SECOND, 
nullptr),
+                adbc_validation::IsOkErrno());
+
+    ASSERT_THAT((adbc_validation::MakeBatch<int64_t>(
+                    &schema.value, &batch.value, static_cast<struct 
ArrowError*>(nullptr),
+                    {std::numeric_limits<int64_t>::max()})),
+                adbc_validation::IsOkErrno());
+
+    ASSERT_THAT(
+        AdbcStatementSetSqlQuery(&statement, "SELECT CAST($1 AS TIMESTAMP)", 
&error),
+        IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementBind(&statement, &batch.value, &schema.value, 
&error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, 
&error),
+                IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error));
+    ASSERT_THAT(error.message,
+                ::testing::HasSubstr("Row #1 has value '9223372036854775807' 
which "
+                                     "exceeds PostgreSQL timestamp limits"));
+  }
+
+  {
+    adbc_validation::Handle<struct ArrowSchema> schema;
+    adbc_validation::Handle<struct ArrowArray> batch;
+
+    ArrowSchemaInit(&schema.value);
+    ASSERT_THAT(ArrowSchemaSetTypeStruct(&schema.value, 1), 
adbc_validation::IsOkErrno());
+    ASSERT_THAT(ArrowSchemaSetName(schema->children[0], "$1"),
+                adbc_validation::IsOkErrno());
+    ASSERT_THAT(ArrowSchemaSetTypeDateTime(schema->children[0], 
NANOARROW_TYPE_TIMESTAMP,
+                                           NANOARROW_TIME_UNIT_SECOND, 
nullptr),
+                adbc_validation::IsOkErrno());
+
+    ASSERT_THAT((adbc_validation::MakeBatch<int64_t>(
+                    &schema.value, &batch.value, static_cast<struct 
ArrowError*>(nullptr),
+                    {std::numeric_limits<int64_t>::min()})),
+                adbc_validation::IsOkErrno());
+
+    ASSERT_THAT(
+        AdbcStatementSetSqlQuery(&statement, "SELECT CAST($1 AS TIMESTAMP)", 
&error),
+        IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementBind(&statement, &batch.value, &schema.value, 
&error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, 
&error),
+                IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error));
+    ASSERT_THAT(error.message,
+                ::testing::HasSubstr("Row #1 has value '-9223372036854775808' 
which "
+                                     "exceeds PostgreSQL timestamp limits"));
+  }
+}
+
+TEST_F(PostgresStatementTest, SqlReadIntervalOverflow) {
+  ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
+
+  {
+    ASSERT_THAT(
+        AdbcStatementSetSqlQuery(
+            &statement, "SELECT CAST('P0Y0M0DT2562048H0M0S' AS INTERVAL)", 
&error),
+        IsOkStatus(&error));
+    adbc_validation::StreamReader reader;
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+                                          &reader.rows_affected, &error),
+                IsOkStatus(&error));
+    ASSERT_EQ(reader.rows_affected, -1);
+    ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+    ASSERT_THAT(reader.MaybeNext(),
+                adbc_validation::IsErrno(EINVAL, &reader.stream.value, 
nullptr));
+    ASSERT_THAT(reader.stream->get_last_error(&reader.stream.value),
+                ::testing::HasSubstr("Interval with time value 
9223372800000000 usec "
+                                     "would overflow when converting to 
nanoseconds"));
+    ASSERT_EQ(reader.array->release, nullptr);
+  }
+
+  {
+    ASSERT_THAT(
+        AdbcStatementSetSqlQuery(
+            &statement, "SELECT CAST('P0Y0M0DT-2562048H0M0S' AS INTERVAL)", 
&error),
+        IsOkStatus(&error));
+    adbc_validation::StreamReader reader;
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+                                          &reader.rows_affected, &error),
+                IsOkStatus(&error));
+    ASSERT_EQ(reader.rows_affected, -1);
+    ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+    ASSERT_THAT(reader.MaybeNext(),
+                adbc_validation::IsErrno(EINVAL, &reader.stream.value, 
nullptr));
+    ASSERT_THAT(reader.stream->get_last_error(&reader.stream.value),
+                ::testing::HasSubstr("Interval with time value 
-9223372800000000 usec "
+                                     "would overflow when converting to 
nanoseconds"));
+    ASSERT_EQ(reader.array->release, nullptr);
+  }
+}
+
 TEST_F(PostgresStatementTest, UpdateInExecuteQuery) {
   ASSERT_THAT(quirks()->DropTable(&connection, "adbc_test", &error), 
IsOkStatus(&error));
 
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index c68ceeb2..e38bf8f4 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -37,7 +37,6 @@
 #include "postgres_copy_reader.h"
 #include "postgres_type.h"
 #include "postgres_util.h"
-#include "vendor/portable-snippets/safe-math.h"
 
 namespace adbcpq {
 
@@ -426,16 +425,20 @@ struct BindStream {
 
               // 2000-01-01 00:00:00.000000 in microseconds
               constexpr int64_t kPostgresTimestampEpoch = 946684800000000;
-              psnip_safe_bool overflow_safe = true;
+              bool overflow_safe = true;
 
               auto unit = bind_schema_fields[col].time_unit;
 
               switch (unit) {
                 case NANOARROW_TIME_UNIT_SECOND:
-                  overflow_safe = psnip_safe_int64_mul(&val, val, 1000000);
+                  overflow_safe =
+                      val <= kMaxSafeSecondsToMicros && val >= 
kMinSafeSecondsToMicros;
+                  val *= 1000000;
                   break;
                 case NANOARROW_TIME_UNIT_MILLI:
-                  overflow_safe = psnip_safe_int64_mul(&val, val, 1000);
+                  overflow_safe =
+                      val <= kMaxSafeMillisToMicros && val >= 
kMinSafeMillisToMicros;
+                  val *= 1000;
                   break;
                 case NANOARROW_TIME_UNIT_MICRO:
                   break;
@@ -445,10 +448,12 @@ struct BindStream {
               }
 
               if (!overflow_safe) {
-                SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 
"%s", col + 1,
-                         " (' ", bind_schema->children[col]->name, " ') Row # 
", row + 1,
-                         " has value which exceeds postgres timestamp limits");
-
+                SetError(error,
+                         "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64
+                         " has value '%" PRIi64
+                         "' which exceeds PostgreSQL timestamp limits",
+                         col + 1, bind_schema->children[col]->name, row + 1,
+                         
array_view->children[col]->buffer_views[1].data.as_int64[row]);
                 return ADBC_STATUS_INVALID_ARGUMENT;
               }
 
diff --git a/r/adbcpostgresql/bootstrap.R b/r/adbcpostgresql/bootstrap.R
index 2670f760..d2473a79 100644
--- a/r/adbcpostgresql/bootstrap.R
+++ b/r/adbcpostgresql/bootstrap.R
@@ -35,8 +35,7 @@ files_to_vendor <- c(
   "../../c/driver/common/utils.c",
   "../../c/vendor/nanoarrow/nanoarrow.h",
   "../../c/vendor/nanoarrow/nanoarrow.hpp",
-  "../../c/vendor/nanoarrow/nanoarrow.c",
-  "../../c/vendor/portable-snippets/safe-math.h"
+  "../../c/vendor/nanoarrow/nanoarrow.c"
 )
 
 if (all(file.exists(files_to_vendor))) {
@@ -61,16 +60,14 @@ if (all(file.exists(files_to_vendor))) {
         "src/nanoarrow.h",
         "src/nanoarrow.hpp",
         "src/utils.c",
-        "src/utils.h",
-        "src/safe-math.h"
+        "src/utils.h"
       ),
       c(
         "src/nanoarrow/nanoarrow.c",
         "src/nanoarrow/nanoarrow.h",
         "src/nanoarrow/nanoarrow.hpp",
         "src/common/utils.c",
-        "src/common/utils.h",
-        "src/vendor/portable-snippets/safe-math.h"
+        "src/common/utils.h"
       )
     )
     cat("All files successfully copied to src/\n")

Reply via email to