This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 88849ddf refactor(c/driver/postgresql): hardcode overflow checks
(#1051)
88849ddf is described below
commit 88849ddfb77a819730e368e49a19f46240e0ca62
Author: David Li <[email protected]>
AuthorDate: Mon Sep 11 14:04:48 2023 -0400
refactor(c/driver/postgresql): hardcode overflow checks (#1051)
Fixes #1017.
---
c/driver/postgresql/postgres_copy_reader.h | 29 +++++++-
c/driver/postgresql/postgresql_test.cc | 106 +++++++++++++++++++++++++++++
c/driver/postgresql/statement.cc | 21 +++---
r/adbcpostgresql/bootstrap.R | 9 +--
4 files changed, 149 insertions(+), 16 deletions(-)
diff --git a/c/driver/postgresql/postgres_copy_reader.h
b/c/driver/postgresql/postgres_copy_reader.h
index 5c7214dc..5a589700 100644
--- a/c/driver/postgresql/postgres_copy_reader.h
+++ b/c/driver/postgresql/postgres_copy_reader.h
@@ -30,7 +30,6 @@
#include "postgres_type.h"
#include "postgres_util.h"
-#include "vendor/portable-snippets/safe-math.h"
// R 3.6 / Windows builds on a very old toolchain that does not define ENODATA
#if defined(_WIN32) && !defined(MSVC) && !defined(ENODATA)
@@ -44,6 +43,30 @@ static int8_t kPgCopyBinarySignature[] = {0x50, 0x47, 0x43,
0x4F,
0x50, 0x59, 0x0A,
static_cast<int8_t>(0xFF),
0x0D, 0x0A, 0x00};
+// The maximum value in seconds that can be converted into microseconds
+// without overflow
+constexpr int64_t kMaxSafeSecondsToMicros = 9223372036854L;
+
+// The minimum value in seconds that can be converted into microseconds
+// without overflow
+constexpr int64_t kMinSafeSecondsToMicros = -9223372036854L;
+
+// The maximum value in milliseconds that can be converted into microseconds
+// without overflow
+constexpr int64_t kMaxSafeMillisToMicros = 9223372036854775L;
+
+// The minimum value in milliseconds that can be converted into microseconds
+// without overflow
+constexpr int64_t kMinSafeMillisToMicros = -9223372036854775L;
+
+// The maximum value in microseconds that can be converted into nanoseconds
+// without overflow
+constexpr int64_t kMaxSafeMicrosToNanos = 9223372036854775L;
+
+// The minimum value in microseconds that can be converted into nanoseconds
+// without overflow
+constexpr int64_t kMinSafeMicrosToNanos = -9223372036854775L;
+
// Read a value from the buffer without checking the buffer size. Advances
// the cursor of data and reduces its size by sizeof(T).
template <typename T>
@@ -234,7 +257,7 @@ class PostgresCopyIntervalFieldReader : public
PostgresCopyFieldReader {
const int64_t time_usec = ReadUnsafe<int64_t>(data);
int64_t time;
- if (!psnip_safe_int64_mul(&time, time_usec, 1000)) {
+ if (time_usec > kMaxSafeMicrosToNanos || time_usec <
kMinSafeMicrosToNanos) {
ArrowErrorSet(error,
"[libpq] Interval with time value %" PRId64
" usec would overflow when converting to nanoseconds",
@@ -242,6 +265,8 @@ class PostgresCopyIntervalFieldReader : public
PostgresCopyFieldReader {
return EINVAL;
}
+ time = time_usec * 1000;
+
const int32_t days = ReadUnsafe<int32_t>(data);
const int32_t months = ReadUnsafe<int32_t>(data);
diff --git a/c/driver/postgresql/postgresql_test.cc
b/c/driver/postgresql/postgresql_test.cc
index 3b776991..aabb8133 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -878,6 +878,112 @@ class PostgresStatementTest : public ::testing::Test,
};
ADBCV_TEST_STATEMENT(PostgresStatementTest)
+TEST_F(PostgresStatementTest, SqlIngestTimestampOverflow) {
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));
+
+ {
+ adbc_validation::Handle<struct ArrowSchema> schema;
+ adbc_validation::Handle<struct ArrowArray> batch;
+
+ ArrowSchemaInit(&schema.value);
+ ASSERT_THAT(ArrowSchemaSetTypeStruct(&schema.value, 1),
adbc_validation::IsOkErrno());
+ ASSERT_THAT(ArrowSchemaSetName(schema->children[0], "$1"),
+ adbc_validation::IsOkErrno());
+ ASSERT_THAT(ArrowSchemaSetTypeDateTime(schema->children[0],
NANOARROW_TYPE_TIMESTAMP,
+ NANOARROW_TIME_UNIT_SECOND,
nullptr),
+ adbc_validation::IsOkErrno());
+
+ ASSERT_THAT((adbc_validation::MakeBatch<int64_t>(
+ &schema.value, &batch.value, static_cast<struct
ArrowError*>(nullptr),
+ {std::numeric_limits<int64_t>::max()})),
+ adbc_validation::IsOkErrno());
+
+ ASSERT_THAT(
+ AdbcStatementSetSqlQuery(&statement, "SELECT CAST($1 AS TIMESTAMP)",
&error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementBind(&statement, &batch.value, &schema.value,
&error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr,
&error),
+ IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error));
+ ASSERT_THAT(error.message,
+ ::testing::HasSubstr("Row #1 has value '9223372036854775807'
which "
+ "exceeds PostgreSQL timestamp limits"));
+ }
+
+ {
+ adbc_validation::Handle<struct ArrowSchema> schema;
+ adbc_validation::Handle<struct ArrowArray> batch;
+
+ ArrowSchemaInit(&schema.value);
+ ASSERT_THAT(ArrowSchemaSetTypeStruct(&schema.value, 1),
adbc_validation::IsOkErrno());
+ ASSERT_THAT(ArrowSchemaSetName(schema->children[0], "$1"),
+ adbc_validation::IsOkErrno());
+ ASSERT_THAT(ArrowSchemaSetTypeDateTime(schema->children[0],
NANOARROW_TYPE_TIMESTAMP,
+ NANOARROW_TIME_UNIT_SECOND,
nullptr),
+ adbc_validation::IsOkErrno());
+
+ ASSERT_THAT((adbc_validation::MakeBatch<int64_t>(
+ &schema.value, &batch.value, static_cast<struct
ArrowError*>(nullptr),
+ {std::numeric_limits<int64_t>::min()})),
+ adbc_validation::IsOkErrno());
+
+ ASSERT_THAT(
+ AdbcStatementSetSqlQuery(&statement, "SELECT CAST($1 AS TIMESTAMP)",
&error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementBind(&statement, &batch.value, &schema.value,
&error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr,
&error),
+ IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error));
+ ASSERT_THAT(error.message,
+ ::testing::HasSubstr("Row #1 has value '-9223372036854775808'
which "
+ "exceeds PostgreSQL timestamp limits"));
+ }
+}
+
+TEST_F(PostgresStatementTest, SqlReadIntervalOverflow) {
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));
+
+ {
+ ASSERT_THAT(
+ AdbcStatementSetSqlQuery(
+ &statement, "SELECT CAST('P0Y0M0DT2562048H0M0S' AS INTERVAL)",
&error),
+ IsOkStatus(&error));
+ adbc_validation::StreamReader reader;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+ &reader.rows_affected, &error),
+ IsOkStatus(&error));
+ ASSERT_EQ(reader.rows_affected, -1);
+ ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+ ASSERT_THAT(reader.MaybeNext(),
+ adbc_validation::IsErrno(EINVAL, &reader.stream.value,
nullptr));
+ ASSERT_THAT(reader.stream->get_last_error(&reader.stream.value),
+ ::testing::HasSubstr("Interval with time value
9223372800000000 usec "
+ "would overflow when converting to
nanoseconds"));
+ ASSERT_EQ(reader.array->release, nullptr);
+ }
+
+ {
+ ASSERT_THAT(
+ AdbcStatementSetSqlQuery(
+ &statement, "SELECT CAST('P0Y0M0DT-2562048H0M0S' AS INTERVAL)",
&error),
+ IsOkStatus(&error));
+ adbc_validation::StreamReader reader;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+ &reader.rows_affected, &error),
+ IsOkStatus(&error));
+ ASSERT_EQ(reader.rows_affected, -1);
+ ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+ ASSERT_THAT(reader.MaybeNext(),
+ adbc_validation::IsErrno(EINVAL, &reader.stream.value,
nullptr));
+ ASSERT_THAT(reader.stream->get_last_error(&reader.stream.value),
+ ::testing::HasSubstr("Interval with time value
-9223372800000000 usec "
+ "would overflow when converting to
nanoseconds"));
+ ASSERT_EQ(reader.array->release, nullptr);
+ }
+}
+
TEST_F(PostgresStatementTest, UpdateInExecuteQuery) {
ASSERT_THAT(quirks()->DropTable(&connection, "adbc_test", &error),
IsOkStatus(&error));
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index c68ceeb2..e38bf8f4 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -37,7 +37,6 @@
#include "postgres_copy_reader.h"
#include "postgres_type.h"
#include "postgres_util.h"
-#include "vendor/portable-snippets/safe-math.h"
namespace adbcpq {
@@ -426,16 +425,20 @@ struct BindStream {
// 2000-01-01 00:00:00.000000 in microseconds
constexpr int64_t kPostgresTimestampEpoch = 946684800000000;
- psnip_safe_bool overflow_safe = true;
+ bool overflow_safe = true;
auto unit = bind_schema_fields[col].time_unit;
switch (unit) {
case NANOARROW_TIME_UNIT_SECOND:
- overflow_safe = psnip_safe_int64_mul(&val, val, 1000000);
+ overflow_safe =
+ val <= kMaxSafeSecondsToMicros && val >=
kMinSafeSecondsToMicros;
+ val *= 1000000;
break;
case NANOARROW_TIME_UNIT_MILLI:
- overflow_safe = psnip_safe_int64_mul(&val, val, 1000);
+ overflow_safe =
+ val <= kMaxSafeMillisToMicros && val >=
kMinSafeMillisToMicros;
+ val *= 1000;
break;
case NANOARROW_TIME_UNIT_MICRO:
break;
@@ -445,10 +448,12 @@ struct BindStream {
}
if (!overflow_safe) {
- SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64
"%s", col + 1,
- " (' ", bind_schema->children[col]->name, " ') Row #
", row + 1,
- " has value which exceeds postgres timestamp limits");
-
+ SetError(error,
+ "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64
+ " has value '%" PRIi64
+ "' which exceeds PostgreSQL timestamp limits",
+ col + 1, bind_schema->children[col]->name, row + 1,
+
array_view->children[col]->buffer_views[1].data.as_int64[row]);
return ADBC_STATUS_INVALID_ARGUMENT;
}
diff --git a/r/adbcpostgresql/bootstrap.R b/r/adbcpostgresql/bootstrap.R
index 2670f760..d2473a79 100644
--- a/r/adbcpostgresql/bootstrap.R
+++ b/r/adbcpostgresql/bootstrap.R
@@ -35,8 +35,7 @@ files_to_vendor <- c(
"../../c/driver/common/utils.c",
"../../c/vendor/nanoarrow/nanoarrow.h",
"../../c/vendor/nanoarrow/nanoarrow.hpp",
- "../../c/vendor/nanoarrow/nanoarrow.c",
- "../../c/vendor/portable-snippets/safe-math.h"
+ "../../c/vendor/nanoarrow/nanoarrow.c"
)
if (all(file.exists(files_to_vendor))) {
@@ -61,16 +60,14 @@ if (all(file.exists(files_to_vendor))) {
"src/nanoarrow.h",
"src/nanoarrow.hpp",
"src/utils.c",
- "src/utils.h",
- "src/safe-math.h"
+ "src/utils.h"
),
c(
"src/nanoarrow/nanoarrow.c",
"src/nanoarrow/nanoarrow.h",
"src/nanoarrow/nanoarrow.hpp",
"src/common/utils.c",
- "src/common/utils.h",
- "src/vendor/portable-snippets/safe-math.h"
+ "src/common/utils.h"
)
)
cat("All files successfully copied to src/\n")