This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 228e268 ARROW-11162: [C++][Parquet] Fix invalid cast on Decimal256
Parquet data
228e268 is described below
commit 228e268227d6f67af600678166b673cca51480ea
Author: Antoine Pitrou <[email protected]>
AuthorDate: Tue Jan 12 19:54:44 2021 +0100
ARROW-11162: [C++][Parquet] Fix invalid cast on Decimal256 Parquet data
The invalid cast would occur when a variable-length bytearray field would be
decoded as Decimal256 Arrow data.
Should fix the following issue:
- https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=28750
Found by OSS-Fuzz.
Closes #9125 from pitrou/ARROW-11162-parquet-decimal256-fuzz
Authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
cpp/src/parquet/arrow/arrow_reader_writer_test.cc | 67 +++++++++++++++++++++++
cpp/src/parquet/arrow/reader_internal.cc | 16 +++---
testing | 2 +-
3 files changed, 76 insertions(+), 9 deletions(-)
diff --git a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc
b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc
index d7de2b0..1da379c 100644
--- a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc
+++ b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc
@@ -42,6 +42,7 @@
#include "arrow/testing/random.h"
#include "arrow/testing/util.h"
#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
#include "arrow/util/decimal.h"
#include "arrow/util/logging.h"
#include "arrow/util/range.h"
@@ -77,6 +78,8 @@ using arrow::Status;
using arrow::Table;
using arrow::TimeUnit;
using arrow::compute::DictionaryEncode;
+using arrow::internal::checked_cast;
+using arrow::internal::checked_pointer_cast;
using arrow::io::BufferReader;
using arrow::randint;
@@ -521,6 +524,7 @@ class ParquetIOTestBase : public ::testing::Test {
ASSERT_EQ(1, chunked_out->num_chunks());
*out = chunked_out->chunk(0);
ASSERT_NE(nullptr, out->get());
+ ASSERT_OK((*out)->ValidateFull());
}
void ReadSingleColumnFileStatistics(std::unique_ptr<FileReader> file_reader,
@@ -644,6 +648,69 @@ class ParquetIOTestBase : public ::testing::Test {
std::shared_ptr<::arrow::io::BufferOutputStream> sink_;
};
+class TestReadDecimals : public ParquetIOTestBase {
+ public:
+ void CheckReadFromByteArrays(const std::shared_ptr<const LogicalType>&
logical_type,
+ const std::vector<std::vector<uint8_t>>& values,
+ const Array& expected) {
+ std::vector<ByteArray> byte_arrays(values.size());
+ std::transform(values.begin(), values.end(), byte_arrays.begin(),
+ [](const std::vector<uint8_t>& bytes) {
+ return ByteArray(static_cast<uint32_t>(bytes.size()),
bytes.data());
+ });
+
+ auto node = PrimitiveNode::Make("decimals", Repetition::REQUIRED,
logical_type,
+ Type::BYTE_ARRAY);
+ auto schema =
+ GroupNode::Make("schema", Repetition::REQUIRED,
std::vector<NodePtr>{node});
+
+ auto file_writer = MakeWriter(checked_pointer_cast<GroupNode>(schema));
+ auto column_writer = file_writer->AppendRowGroup()->NextColumn();
+ auto typed_writer =
checked_cast<TypedColumnWriter<ByteArrayType>*>(column_writer);
+ typed_writer->WriteBatch(static_cast<int64_t>(byte_arrays.size()),
+ /*def_levels=*/nullptr,
+ /*rep_levels=*/nullptr, byte_arrays.data());
+ column_writer->Close();
+ file_writer->Close();
+
+ ReadAndCheckSingleColumnFile(expected);
+ }
+};
+
+// The Decimal roundtrip tests always go through the FixedLenByteArray path,
+// check the ByteArray case manually.
+
+TEST_F(TestReadDecimals, Decimal128ByteArray) {
+ const std::vector<std::vector<uint8_t>> big_endian_decimals = {
+ // 123456
+ {1, 226, 64},
+ // 987654
+ {15, 18, 6},
+ // -123456
+ {255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 254,
29, 192},
+ };
+
+ auto expected =
+ ArrayFromJSON(::arrow::decimal128(6, 3), R"(["123.456", "987.654",
"-123.456"])");
+ CheckReadFromByteArrays(LogicalType::Decimal(6, 3), big_endian_decimals,
*expected);
+}
+
+TEST_F(TestReadDecimals, Decimal256ByteArray) {
+ const std::vector<std::vector<uint8_t>> big_endian_decimals = {
+ // 123456
+ {1, 226, 64},
+ // 987654
+ {15, 18, 6},
+ // -123456
+ {255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 254,
29, 192},
+ };
+
+ auto expected =
+ ArrayFromJSON(::arrow::decimal256(40, 3), R"(["123.456", "987.654",
"-123.456"])");
+ CheckReadFromByteArrays(LogicalType::Decimal(40, 3), big_endian_decimals,
*expected);
+}
+
template <typename TestType>
class TestParquetIO : public ParquetIOTestBase {
public:
diff --git a/cpp/src/parquet/arrow/reader_internal.cc
b/cpp/src/parquet/arrow/reader_internal.cc
index 6c387df..7ec8691 100644
--- a/cpp/src/parquet/arrow/reader_internal.cc
+++ b/cpp/src/parquet/arrow/reader_internal.cc
@@ -413,7 +413,7 @@ struct DecimalConverter<DecimalArrayType, FLBAType> {
// The byte width of each decimal value
const int32_t type_length =
- static_cast<const ::arrow::DecimalType&>(*type).byte_width();
+ checked_cast<const ::arrow::DecimalType&>(*type).byte_width();
// number of elements in the entire array
const int64_t length = fixed_size_binary_array.length();
@@ -462,10 +462,10 @@ struct DecimalConverter<DecimalArrayType, ByteArrayType> {
static inline Status ConvertToDecimal(const Array& array,
const std::shared_ptr<DataType>& type,
MemoryPool* pool,
std::shared_ptr<Array>* out) {
- const auto& binary_array = static_cast<const ::arrow::BinaryArray&>(array);
+ const auto& binary_array = checked_cast<const
::arrow::BinaryArray&>(array);
const int64_t length = binary_array.length();
- const auto& decimal_type = static_cast<const
::arrow::Decimal128Type&>(*type);
+ const auto& decimal_type = checked_cast<const
::arrow::DecimalType&>(*type);
const int64_t type_length = decimal_type.byte_width();
ARROW_ASSIGN_OR_RAISE(auto data, ::arrow::AllocateBuffer(length *
type_length, pool));
@@ -481,7 +481,7 @@ struct DecimalConverter<DecimalArrayType, ByteArrayType> {
const uint8_t* record_loc = binary_array.GetValue(i, &record_len);
if (record_len < 0 || record_len > type_length) {
- return Status::Invalid("Invalid BYTE_ARRAY length for Decimal128");
+ return Status::Invalid("Invalid BYTE_ARRAY length for ",
type->ToString());
}
auto out_ptr_view = reinterpret_cast<uint64_t*>(out_ptr);
@@ -531,7 +531,7 @@ static Status DecimalIntegerTransfer(RecordReader* reader,
MemoryPool* pool,
const auto values = reinterpret_cast<const ElementType*>(reader->values());
- const auto& decimal_type = static_cast<const ::arrow::DecimalType&>(*type);
+ const auto& decimal_type = checked_cast<const ::arrow::DecimalType&>(*type);
const int64_t type_length = decimal_type.byte_width();
ARROW_ASSIGN_OR_RAISE(auto data, ::arrow::AllocateBuffer(length *
type_length, pool));
@@ -557,10 +557,10 @@ static Status DecimalIntegerTransfer(RecordReader*
reader, MemoryPool* pool,
return Status::OK();
}
-/// \brief Convert an arrow::BinaryArray to an arrow::Decimal128Array
+/// \brief Convert an arrow::BinaryArray to an arrow::Decimal{128,256}Array
/// We do this by:
/// 1. Creating an arrow::BinaryArray from the RecordReader's builder
-/// 2. Allocating a buffer for the arrow::Decimal128Array
+/// 2. Allocating a buffer for the arrow::Decimal{128,256}Array
/// 3. Converting the big-endian bytes in each BinaryArray entry to two
integers
/// representing the high and low bits of each decimal value.
template <typename DecimalArrayType, typename ParquetType>
@@ -677,7 +677,7 @@ Status TransferColumnData(RecordReader* reader,
std::shared_ptr<DataType> value_
case ::arrow::Type::TIMESTAMP: {
const ::arrow::TimestampType& timestamp_type =
- static_cast<::arrow::TimestampType&>(*value_type);
+ checked_cast<::arrow::TimestampType&>(*value_type);
switch (timestamp_type.unit()) {
case ::arrow::TimeUnit::MILLI:
case ::arrow::TimeUnit::MICRO: {
diff --git a/testing b/testing
index d6c4deb..b4eeafd 160000
--- a/testing
+++ b/testing
@@ -1 +1 @@
-Subproject commit d6c4deb22c4b4e9e3247a2f291046e3c671ad235
+Subproject commit b4eeafdec6fb5284c4aaf269f2ebdb3be2c63ed5