kszucs commented on code in PR #45360: URL: https://github.com/apache/arrow/pull/45360#discussion_r2085268278
########## cpp/src/parquet/chunker_internal_test.cc: ########## @@ -0,0 +1,1687 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <algorithm> +#include <iostream> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include <gtest/gtest.h> + +#include "arrow/table.h" +#include "arrow/testing/extension_type.h" +#include "arrow/testing/generator.h" +#include "arrow/type_fwd.h" +#include "arrow/util/float16.h" +#include "parquet/arrow/reader.h" +#include "parquet/arrow/reader_internal.h" +#include "parquet/arrow/test_util.h" +#include "parquet/arrow/writer.h" +#include "parquet/chunker_internal.h" +#include "parquet/column_writer.h" +#include "parquet/file_writer.h" + +namespace parquet::internal { + +using ::arrow::Array; +using ::arrow::ChunkedArray; +using ::arrow::ConcatenateTables; +using ::arrow::DataType; +using ::arrow::default_memory_pool; +using ::arrow::Field; +using ::arrow::Result; +using ::arrow::Schema; +using ::arrow::Table; +using ::arrow::internal::checked_cast; +using ::arrow::io::BufferReader; +using ::parquet::arrow::FileReader; +using ::parquet::arrow::FileReaderBuilder; +using ::parquet::arrow::MakeSimpleTable; +using ::parquet::arrow::NonNullArray; +using ::parquet::arrow::WriteTable; + +using ::testing::Bool; +using ::testing::Combine; +using ::testing::Values; + +// generate determinisic and platform-independent data +inline uint64_t hash(uint64_t seed, uint64_t index) { + uint64_t h = (index + seed) * 0xc4ceb9fe1a85ec53ull; + h ^= h >> 33; + h *= 0xff51afd7ed558ccdull; + h ^= h >> 33; + h *= 0xc4ceb9fe1a85ec53ull; + h ^= h >> 33; + return h; +} + +template <typename BuilderType, typename ValueFunc> +Result<std::shared_ptr<Array>> GenerateArray(const std::shared_ptr<DataType>& type, + bool nullable, int64_t length, uint64_t seed, + ValueFunc value_func) { + BuilderType builder(type, default_memory_pool()); + + if (nullable) { + for (int64_t i = 0; i < length; ++i) { + uint64_t val = hash(seed, i); + if (val % 10 == 0) { + RETURN_NOT_OK(builder.AppendNull()); + } else { + RETURN_NOT_OK(builder.Append(value_func(val))); + } + } + } else { + for (int64_t i = 0; i < length; ++i) { + uint64_t val = hash(seed, i); + RETURN_NOT_OK(builder.Append(value_func(val))); + } + } + + std::shared_ptr<Array> array; + RETURN_NOT_OK(builder.Finish(&array)); + RETURN_NOT_OK(array->ValidateFull()); + return array; +} + +#define GENERATE_CASE(TYPE_ID, BUILDER_TYPE, VALUE_EXPR) \ + case ::arrow::Type::TYPE_ID: { \ + auto value_func = [](uint64_t val) { return VALUE_EXPR; }; \ + return GenerateArray<BUILDER_TYPE>(type, nullable, length, seed, value_func); \ + } + +Result<std::shared_ptr<Array>> GenerateArray(const std::shared_ptr<Field>& field, + int64_t length, int64_t seed) { + const std::shared_ptr<DataType>& type = field->type(); + bool nullable = field->nullable(); + + switch (type->id()) { + GENERATE_CASE(BOOL, ::arrow::BooleanBuilder, (val % 2 == 0)) + + // Numeric types. + GENERATE_CASE(INT8, ::arrow::Int8Builder, static_cast<int8_t>(val)) + GENERATE_CASE(INT16, ::arrow::Int16Builder, static_cast<int16_t>(val)) + GENERATE_CASE(INT32, ::arrow::Int32Builder, static_cast<int32_t>(val)) + GENERATE_CASE(INT64, ::arrow::Int64Builder, static_cast<int64_t>(val)) + GENERATE_CASE(UINT8, ::arrow::UInt8Builder, static_cast<uint8_t>(val)) + GENERATE_CASE(UINT16, ::arrow::UInt16Builder, static_cast<uint16_t>(val)) + GENERATE_CASE(UINT32, ::arrow::UInt32Builder, static_cast<uint32_t>(val)) + GENERATE_CASE(UINT64, ::arrow::UInt64Builder, static_cast<uint64_t>(val)) + GENERATE_CASE(HALF_FLOAT, ::arrow::HalfFloatBuilder, + static_cast<uint16_t>(val % 1000)) + GENERATE_CASE(FLOAT, ::arrow::FloatBuilder, static_cast<float>(val % 1000) / 1000.0f) + GENERATE_CASE(DOUBLE, ::arrow::DoubleBuilder, + static_cast<double>(val % 100000) / 1000.0) + case ::arrow::Type::DECIMAL128: { + const auto& decimal_type = checked_cast<const ::arrow::Decimal128Type&>(*type); + // Limit the value to fit within the specified precision + int32_t max_exponent = decimal_type.precision() - decimal_type.scale(); + int64_t max_value = static_cast<int64_t>(std::pow(10, max_exponent) - 1); + auto value_func = [&](uint64_t val) { + return ::arrow::Decimal128(val % max_value); + }; + return GenerateArray<::arrow::Decimal128Builder>(type, nullable, length, seed, + value_func); + } + case ::arrow::Type::DECIMAL256: { + const auto& decimal_type = checked_cast<const ::arrow::Decimal256Type&>(*type); + // Limit the value to fit within the specified precision, capped at 9 to avoid + // int64_t overflow + int32_t max_exponent = std::min(9, decimal_type.precision() - decimal_type.scale()); + int64_t max_value = static_cast<int64_t>(std::pow(10, max_exponent) - 1); + auto value_func = [&](uint64_t val) { + return ::arrow::Decimal256(val % max_value); + }; + return GenerateArray<::arrow::Decimal256Builder>(type, nullable, length, seed, + value_func); + } + + // Temporal types + GENERATE_CASE(DATE32, ::arrow::Date32Builder, static_cast<int32_t>(val)) + GENERATE_CASE(TIME32, ::arrow::Time32Builder, + std::abs(static_cast<int32_t>(val) % 86400000)) + GENERATE_CASE(TIME64, ::arrow::Time64Builder, + std::abs(static_cast<int64_t>(val) % 86400000000)) + GENERATE_CASE(TIMESTAMP, ::arrow::TimestampBuilder, static_cast<int64_t>(val)) + GENERATE_CASE(DURATION, ::arrow::DurationBuilder, static_cast<int64_t>(val)) + + // Binary and string types. + GENERATE_CASE(STRING, ::arrow::StringBuilder, + std::string("str_") + std::to_string(val)) + GENERATE_CASE(LARGE_STRING, ::arrow::LargeStringBuilder, + std::string("str_") + std::to_string(val)) + GENERATE_CASE(BINARY, ::arrow::BinaryBuilder, + std::string("bin_") + std::to_string(val)) + GENERATE_CASE(LARGE_BINARY, ::arrow::LargeBinaryBuilder, + std::string("bin_") + std::to_string(val)) + case ::arrow::Type::FIXED_SIZE_BINARY: { + auto size = + checked_cast<const ::arrow::FixedSizeBinaryType*>(type.get())->byte_width(); + auto value_func = [size](uint64_t val) { + return std::string("bin_") + std::to_string(val).substr(0, size - 4); + }; + return GenerateArray<::arrow::FixedSizeBinaryBuilder>(type, nullable, length, seed, + value_func); + } + + case ::arrow::Type::STRUCT: { + auto struct_type = checked_cast<const ::arrow::StructType*>(type.get()); + std::vector<std::shared_ptr<Array>> child_arrays; + for (auto i = 0; i < struct_type->num_fields(); i++) { + ARROW_ASSIGN_OR_RAISE(auto child_array, GenerateArray(struct_type->field(i), + length, seed + i * 10)); + child_arrays.push_back(child_array); + } + auto struct_array = + std::make_shared<::arrow::StructArray>(type, length, child_arrays); + return struct_array; + } + + case ::arrow::Type::LIST: { + // Repeat the same pattern in the list array: + // null, empty list, list of 1 element, list of 3 elements + if (length % 4 != 0) { + return Status::Invalid( + "Length must be divisible by 4 when generating list arrays, but got: ", + length); + } + auto values_array_length = length * 4; + auto list_type = checked_cast<const ::arrow::ListType*>(type.get()); + auto value_field = ::arrow::field("item", list_type->value_type()); + ARROW_ASSIGN_OR_RAISE(auto values_array, + GenerateArray(value_field, values_array_length, seed)); + auto offset_builder = ::arrow::Int32Builder(); + auto bitmap_builder = ::arrow::TypedBufferBuilder<bool>(); + + RETURN_NOT_OK(offset_builder.Reserve(length + 1)); + RETURN_NOT_OK(bitmap_builder.Reserve(length)); + + int32_t num_nulls = 0; + RETURN_NOT_OK(offset_builder.Append(0)); + for (auto offset = 0; offset < length; offset += 4) { + if (nullable) { + // add a null + RETURN_NOT_OK(bitmap_builder.Append(false)); + RETURN_NOT_OK(offset_builder.Append(offset)); + num_nulls += 1; + } else { + // add an empty list + RETURN_NOT_OK(bitmap_builder.Append(true)); + RETURN_NOT_OK(offset_builder.Append(offset)); + } + // add an empty list + RETURN_NOT_OK(bitmap_builder.Append(true)); + RETURN_NOT_OK(offset_builder.Append(offset)); + // add a list of 1 element + RETURN_NOT_OK(bitmap_builder.Append(true)); + RETURN_NOT_OK(offset_builder.Append(offset + 1)); + // add a list of 3 elements + RETURN_NOT_OK(bitmap_builder.Append(true)); + RETURN_NOT_OK(offset_builder.Append(offset + 4)); + } + + std::shared_ptr<Array> offsets_array; + RETURN_NOT_OK(offset_builder.Finish(&offsets_array)); + std::shared_ptr<Buffer> bitmap_buffer; + RETURN_NOT_OK(bitmap_builder.Finish(&bitmap_buffer)); + ARROW_ASSIGN_OR_RAISE( + auto list_array, ::arrow::ListArray::FromArrays( + type, *offsets_array, *values_array, default_memory_pool(), + bitmap_buffer, num_nulls)); + RETURN_NOT_OK(list_array->ValidateFull()); + return list_array; + } + + case ::arrow::Type::EXTENSION: { + auto extension_type = checked_cast<const ::arrow::ExtensionType*>(type.get()); + auto storage_type = extension_type->storage_type(); + auto storage_field = ::arrow::field("storage", storage_type, true); + ARROW_ASSIGN_OR_RAISE(auto storage_array, + GenerateArray(storage_field, length, seed)); + return ::arrow::ExtensionType::WrapArray(type, storage_array); + } + + default: + return ::arrow::Status::NotImplemented("Unsupported data type " + type->ToString()); + } +} + +TEST(TestGenerateArray, Integer) { + auto field = ::arrow::field("a", ::arrow::int32()); + ASSERT_OK_AND_ASSIGN(auto array, GenerateArray(field, /*length=*/10, /*seed=*/0)); + ASSERT_OK(array->ValidateFull()); + ASSERT_EQ(array->length(), 10); + ASSERT_TRUE(array->type()->Equals(::arrow::int32())); + ASSERT_EQ(array->null_count(), 1); +} + +TEST(TestGenerateArray, ListOfInteger) { + auto field = ::arrow::field("a", ::arrow::list(::arrow::int32())); + auto length = 12; + ASSERT_OK_AND_ASSIGN(auto array, GenerateArray(field, length, /*seed=*/0)); + ASSERT_OK(array->ValidateFull()); + ASSERT_EQ(array->length(), length); + + for (size_t i = 0; i < 12; i += 4) { + // Assert the first element is null + ASSERT_TRUE(array->IsNull(i)); + + // Assert the second element is an empty list + ASSERT_TRUE(array->IsValid(i + 1)); + auto list_array = std::static_pointer_cast<::arrow::ListArray>(array); + ASSERT_EQ(list_array->value_length(i + 1), 0); + + // Assert the third element has length 1 + ASSERT_TRUE(array->IsValid(i + 2)); + ASSERT_EQ(list_array->value_length(i + 2), 1); + + // Assert the fourth element has length 3 + ASSERT_TRUE(array->IsValid(i + 3)); + ASSERT_EQ(list_array->value_length(i + 3), 3); + } + + ASSERT_NOT_OK(GenerateArray(field, 3, /*seed=*/0)); + ASSERT_OK(GenerateArray(field, 8, /*seed=*/0)); +} + +Result<std::shared_ptr<Table>> GenerateTable( + const std::shared_ptr<::arrow::Schema>& schema, int64_t size, int64_t seed = 0) { + std::vector<std::shared_ptr<Array>> arrays; + for (const auto& field : schema->fields()) { + ARROW_ASSIGN_OR_RAISE(auto array, GenerateArray(field, size, ++seed)); + arrays.push_back(array); + } + return Table::Make(schema, arrays, size); +} + +Result<std::shared_ptr<Table>> ConcatAndCombine( + const std::vector<std::shared_ptr<Table>>& parts) { + // Concatenate and combine chunks so the table doesn't carry information about + // the modification points + ARROW_ASSIGN_OR_RAISE(auto table, ConcatenateTables(parts)); + return table->CombineChunks(); +} + +Result<std::shared_ptr<Table>> ReadTableFromBuffer(const std::shared_ptr<Buffer>& data) { + std::shared_ptr<Table> result; + FileReaderBuilder builder; + std::unique_ptr<FileReader> reader; + auto props = default_arrow_reader_properties(); + + RETURN_NOT_OK(builder.Open(std::make_shared<BufferReader>(data))); + RETURN_NOT_OK(builder.memory_pool(::arrow::default_memory_pool()) + ->properties(props) + ->Build(&reader)); + RETURN_NOT_OK(reader->ReadTable(&result)); + return result; +} + +Result<std::shared_ptr<Buffer>> WriteTableToBuffer( + const std::shared_ptr<Table>& table, int64_t min_chunk_size, int64_t max_chunk_size, + int64_t row_group_length = 1024 * 1024, bool enable_dictionary = false, + ParquetDataPageVersion data_page_version = ParquetDataPageVersion::V1) { + auto sink = CreateOutputStream(); + + auto builder = WriterProperties::Builder(); + builder.enable_content_defined_chunking()->content_defined_chunking_options( + {min_chunk_size, max_chunk_size, /*norm_level=*/0}); + builder.data_page_version(data_page_version); + if (enable_dictionary) { + builder.enable_dictionary(); + } else { + builder.disable_dictionary(); + } + auto write_props = builder.build(); + auto arrow_props = ArrowWriterProperties::Builder().store_schema()->build(); + RETURN_NOT_OK(WriteTable(*table, default_memory_pool(), sink, row_group_length, + write_props, arrow_props)); + ARROW_ASSIGN_OR_RAISE(auto buffer, sink->Finish()); + + // check whether the schema has extension types, if not we can easily ensure that + // the parquet seralization is roundtripable with CDC enabled + bool validate_roundtrip = true; + for (const auto& field : table->schema()->fields()) { + if (field->type()->id() == ::arrow::Type::EXTENSION) { + validate_roundtrip = false; Review Comment: The test case used `UUID` which isn't supported by the parquet arrow reader: ```cpp /// Enable Parquet-supported Arrow extension types. /// /// When enabled, Parquet logical types will be mapped to their corresponding Arrow /// extension types at read time, if such exist. Currently only arrow::extension::json() /// extension type is supported. Columns whose LogicalType is JSON will be interpreted /// as arrow::extension::json(), with storage type inferred from the serialized Arrow /// schema if present, or `utf8` by default. void set_arrow_extensions_enabled(bool extensions_enabled) { arrow_extensions_enabled_ = extensions_enabled; } ``` I changed the test case to use `arrow::extension::json()` instead, it tests the same functionality. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org