This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 228e268  ARROW-11162: [C++][Parquet] Fix invalid cast on Decimal256 
Parquet data
228e268 is described below

commit 228e268227d6f67af600678166b673cca51480ea
Author: Antoine Pitrou <[email protected]>
AuthorDate: Tue Jan 12 19:54:44 2021 +0100

    ARROW-11162: [C++][Parquet] Fix invalid cast on Decimal256 Parquet data
    
    The invalid cast would occur when a variable-length bytearray field would be
    decoded as Decimal256 Arrow data.
    
    Should fix the following issue:
    - https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=28750
    
    Found by OSS-Fuzz.
    
    Closes #9125 from pitrou/ARROW-11162-parquet-decimal256-fuzz
    
    Authored-by: Antoine Pitrou <[email protected]>
    Signed-off-by: Antoine Pitrou <[email protected]>
---
 cpp/src/parquet/arrow/arrow_reader_writer_test.cc | 67 +++++++++++++++++++++++
 cpp/src/parquet/arrow/reader_internal.cc          | 16 +++---
 testing                                           |  2 +-
 3 files changed, 76 insertions(+), 9 deletions(-)

diff --git a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc 
b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc
index d7de2b0..1da379c 100644
--- a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc
+++ b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc
@@ -42,6 +42,7 @@
 #include "arrow/testing/random.h"
 #include "arrow/testing/util.h"
 #include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
 #include "arrow/util/decimal.h"
 #include "arrow/util/logging.h"
 #include "arrow/util/range.h"
@@ -77,6 +78,8 @@ using arrow::Status;
 using arrow::Table;
 using arrow::TimeUnit;
 using arrow::compute::DictionaryEncode;
+using arrow::internal::checked_cast;
+using arrow::internal::checked_pointer_cast;
 using arrow::io::BufferReader;
 
 using arrow::randint;
@@ -521,6 +524,7 @@ class ParquetIOTestBase : public ::testing::Test {
     ASSERT_EQ(1, chunked_out->num_chunks());
     *out = chunked_out->chunk(0);
     ASSERT_NE(nullptr, out->get());
+    ASSERT_OK((*out)->ValidateFull());
   }
 
   void ReadSingleColumnFileStatistics(std::unique_ptr<FileReader> file_reader,
@@ -644,6 +648,69 @@ class ParquetIOTestBase : public ::testing::Test {
   std::shared_ptr<::arrow::io::BufferOutputStream> sink_;
 };
 
+class TestReadDecimals : public ParquetIOTestBase {
+ public:
+  void CheckReadFromByteArrays(const std::shared_ptr<const LogicalType>& 
logical_type,
+                               const std::vector<std::vector<uint8_t>>& values,
+                               const Array& expected) {
+    std::vector<ByteArray> byte_arrays(values.size());
+    std::transform(values.begin(), values.end(), byte_arrays.begin(),
+                   [](const std::vector<uint8_t>& bytes) {
+                     return ByteArray(static_cast<uint32_t>(bytes.size()), 
bytes.data());
+                   });
+
+    auto node = PrimitiveNode::Make("decimals", Repetition::REQUIRED, 
logical_type,
+                                    Type::BYTE_ARRAY);
+    auto schema =
+        GroupNode::Make("schema", Repetition::REQUIRED, 
std::vector<NodePtr>{node});
+
+    auto file_writer = MakeWriter(checked_pointer_cast<GroupNode>(schema));
+    auto column_writer = file_writer->AppendRowGroup()->NextColumn();
+    auto typed_writer = 
checked_cast<TypedColumnWriter<ByteArrayType>*>(column_writer);
+    typed_writer->WriteBatch(static_cast<int64_t>(byte_arrays.size()),
+                             /*def_levels=*/nullptr,
+                             /*rep_levels=*/nullptr, byte_arrays.data());
+    column_writer->Close();
+    file_writer->Close();
+
+    ReadAndCheckSingleColumnFile(expected);
+  }
+};
+
+// The Decimal roundtrip tests always go through the FixedLenByteArray path,
+// check the ByteArray case manually.
+
+TEST_F(TestReadDecimals, Decimal128ByteArray) {
+  const std::vector<std::vector<uint8_t>> big_endian_decimals = {
+      // 123456
+      {1, 226, 64},
+      // 987654
+      {15, 18, 6},
+      // -123456
+      {255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 254, 
29, 192},
+  };
+
+  auto expected =
+      ArrayFromJSON(::arrow::decimal128(6, 3), R"(["123.456", "987.654", 
"-123.456"])");
+  CheckReadFromByteArrays(LogicalType::Decimal(6, 3), big_endian_decimals, 
*expected);
+}
+
+TEST_F(TestReadDecimals, Decimal256ByteArray) {
+  const std::vector<std::vector<uint8_t>> big_endian_decimals = {
+      // 123456
+      {1, 226, 64},
+      // 987654
+      {15, 18, 6},
+      // -123456
+      {255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 
255, 255,
+       255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 254, 
29,  192},
+  };
+
+  auto expected =
+      ArrayFromJSON(::arrow::decimal256(40, 3), R"(["123.456", "987.654", 
"-123.456"])");
+  CheckReadFromByteArrays(LogicalType::Decimal(40, 3), big_endian_decimals, 
*expected);
+}
+
 template <typename TestType>
 class TestParquetIO : public ParquetIOTestBase {
  public:
diff --git a/cpp/src/parquet/arrow/reader_internal.cc 
b/cpp/src/parquet/arrow/reader_internal.cc
index 6c387df..7ec8691 100644
--- a/cpp/src/parquet/arrow/reader_internal.cc
+++ b/cpp/src/parquet/arrow/reader_internal.cc
@@ -413,7 +413,7 @@ struct DecimalConverter<DecimalArrayType, FLBAType> {
 
     // The byte width of each decimal value
     const int32_t type_length =
-        static_cast<const ::arrow::DecimalType&>(*type).byte_width();
+        checked_cast<const ::arrow::DecimalType&>(*type).byte_width();
 
     // number of elements in the entire array
     const int64_t length = fixed_size_binary_array.length();
@@ -462,10 +462,10 @@ struct DecimalConverter<DecimalArrayType, ByteArrayType> {
   static inline Status ConvertToDecimal(const Array& array,
                                         const std::shared_ptr<DataType>& type,
                                         MemoryPool* pool, 
std::shared_ptr<Array>* out) {
-    const auto& binary_array = static_cast<const ::arrow::BinaryArray&>(array);
+    const auto& binary_array = checked_cast<const 
::arrow::BinaryArray&>(array);
     const int64_t length = binary_array.length();
 
-    const auto& decimal_type = static_cast<const 
::arrow::Decimal128Type&>(*type);
+    const auto& decimal_type = checked_cast<const 
::arrow::DecimalType&>(*type);
     const int64_t type_length = decimal_type.byte_width();
 
     ARROW_ASSIGN_OR_RAISE(auto data, ::arrow::AllocateBuffer(length * 
type_length, pool));
@@ -481,7 +481,7 @@ struct DecimalConverter<DecimalArrayType, ByteArrayType> {
       const uint8_t* record_loc = binary_array.GetValue(i, &record_len);
 
       if (record_len < 0 || record_len > type_length) {
-        return Status::Invalid("Invalid BYTE_ARRAY length for Decimal128");
+        return Status::Invalid("Invalid BYTE_ARRAY length for ", 
type->ToString());
       }
 
       auto out_ptr_view = reinterpret_cast<uint64_t*>(out_ptr);
@@ -531,7 +531,7 @@ static Status DecimalIntegerTransfer(RecordReader* reader, 
MemoryPool* pool,
 
   const auto values = reinterpret_cast<const ElementType*>(reader->values());
 
-  const auto& decimal_type = static_cast<const ::arrow::DecimalType&>(*type);
+  const auto& decimal_type = checked_cast<const ::arrow::DecimalType&>(*type);
   const int64_t type_length = decimal_type.byte_width();
 
   ARROW_ASSIGN_OR_RAISE(auto data, ::arrow::AllocateBuffer(length * 
type_length, pool));
@@ -557,10 +557,10 @@ static Status DecimalIntegerTransfer(RecordReader* 
reader, MemoryPool* pool,
   return Status::OK();
 }
 
-/// \brief Convert an arrow::BinaryArray to an arrow::Decimal128Array
+/// \brief Convert an arrow::BinaryArray to an arrow::Decimal{128,256}Array
 /// We do this by:
 /// 1. Creating an arrow::BinaryArray from the RecordReader's builder
-/// 2. Allocating a buffer for the arrow::Decimal128Array
+/// 2. Allocating a buffer for the arrow::Decimal{128,256}Array
 /// 3. Converting the big-endian bytes in each BinaryArray entry to two 
integers
 ///    representing the high and low bits of each decimal value.
 template <typename DecimalArrayType, typename ParquetType>
@@ -677,7 +677,7 @@ Status TransferColumnData(RecordReader* reader, 
std::shared_ptr<DataType> value_
 
     case ::arrow::Type::TIMESTAMP: {
       const ::arrow::TimestampType& timestamp_type =
-          static_cast<::arrow::TimestampType&>(*value_type);
+          checked_cast<::arrow::TimestampType&>(*value_type);
       switch (timestamp_type.unit()) {
         case ::arrow::TimeUnit::MILLI:
         case ::arrow::TimeUnit::MICRO: {
diff --git a/testing b/testing
index d6c4deb..b4eeafd 160000
--- a/testing
+++ b/testing
@@ -1 +1 @@
-Subproject commit d6c4deb22c4b4e9e3247a2f291046e3c671ad235
+Subproject commit b4eeafdec6fb5284c4aaf269f2ebdb3be2c63ed5

Reply via email to