This is an automated email from the ASF dual-hosted git repository. zeroshade pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/iceberg-cpp.git
The following commit(s) were added to refs/heads/main by this push: new bc87f00 feat(avro): extract avro datum from arrow array (#166) bc87f00 is described below commit bc87f00f2df9155c3d17e8cc804b04c05210b636 Author: Gang Wu <ust...@gmail.com> AuthorDate: Tue Aug 19 02:01:26 2025 +0800 feat(avro): extract avro datum from arrow array (#166) --- src/iceberg/avro/avro_data_util.cc | 230 +++++++++ src/iceberg/avro/avro_data_util_internal.h | 9 + test/avro_data_test.cc | 721 +++++++++++++++++++++++++++++ 3 files changed, 960 insertions(+) diff --git a/src/iceberg/avro/avro_data_util.cc b/src/iceberg/avro/avro_data_util.cc index 16ac41a..8853f74 100644 --- a/src/iceberg/avro/avro_data_util.cc +++ b/src/iceberg/avro/avro_data_util.cc @@ -17,10 +17,13 @@ * under the License. */ +#include <ranges> + #include <arrow/array/builder_binary.h> #include <arrow/array/builder_decimal.h> #include <arrow/array/builder_nested.h> #include <arrow/array/builder_primitive.h> +#include <arrow/extension_type.h> #include <arrow/json/from_string.h> #include <arrow/type.h> #include <arrow/util/decimal.h> @@ -451,4 +454,231 @@ Status AppendDatumToBuilder(const ::avro::NodePtr& avro_node, projected_schema, array_builder); } +namespace { + +// ToAvroNodeVisitor uses 0 for null branch and 1 for value branch. +constexpr int64_t kNullBranch = 0; +constexpr int64_t kValueBranch = 1; + +} // namespace + +Status ExtractDatumFromArray(const ::arrow::Array& array, int64_t index, + ::avro::GenericDatum* datum) { + if (index < 0 || index >= array.length()) { + return InvalidArgument("Cannot extract datum from array at index {} of length {}", + index, array.length()); + } + + if (array.IsNull(index)) { + if (!datum->isUnion()) [[unlikely]] { + return InvalidSchema("Cannot extract null to non-union type: {}", + ::avro::toString(datum->type())); + } + datum->selectBranch(kNullBranch); + return {}; + } + + if (datum->isUnion()) { + datum->selectBranch(kValueBranch); + } + + switch (array.type()->id()) { + case ::arrow::Type::BOOL: { + const auto& bool_array = + internal::checked_cast<const ::arrow::BooleanArray&>(array); + datum->value<bool>() = bool_array.Value(index); + return {}; + } + + case ::arrow::Type::INT32: { + const auto& int32_array = internal::checked_cast<const ::arrow::Int32Array&>(array); + datum->value<int32_t>() = int32_array.Value(index); + return {}; + } + + case ::arrow::Type::INT64: { + const auto& int64_array = internal::checked_cast<const ::arrow::Int64Array&>(array); + datum->value<int64_t>() = int64_array.Value(index); + return {}; + } + + case ::arrow::Type::FLOAT: { + const auto& float_array = internal::checked_cast<const ::arrow::FloatArray&>(array); + datum->value<float>() = float_array.Value(index); + return {}; + } + + case ::arrow::Type::DOUBLE: { + const auto& double_array = + internal::checked_cast<const ::arrow::DoubleArray&>(array); + datum->value<double>() = double_array.Value(index); + return {}; + } + + // TODO(gangwu): support LARGE_STRING. + case ::arrow::Type::STRING: { + const auto& string_array = + internal::checked_cast<const ::arrow::StringArray&>(array); + datum->value<std::string>() = string_array.GetString(index); + return {}; + } + + // TODO(gangwu): support LARGE_BINARY. + case ::arrow::Type::BINARY: { + const auto& binary_array = + internal::checked_cast<const ::arrow::BinaryArray&>(array); + std::string_view value = binary_array.GetView(index); + datum->value<std::vector<uint8_t>>().assign( + reinterpret_cast<const uint8_t*>(value.data()), + reinterpret_cast<const uint8_t*>(value.data()) + value.size()); + return {}; + } + + case ::arrow::Type::FIXED_SIZE_BINARY: { + const auto& fixed_array = + internal::checked_cast<const ::arrow::FixedSizeBinaryArray&>(array); + std::string_view value = fixed_array.GetView(index); + auto& fixed_datum = datum->value<::avro::GenericFixed>(); + fixed_datum.value().assign(value.begin(), value.end()); + return {}; + } + + case ::arrow::Type::DECIMAL128: { + const auto& decimal_array = + internal::checked_cast<const ::arrow::Decimal128Array&>(array); + std::string_view decimal_value = decimal_array.GetView(index); + auto& fixed_datum = datum->value<::avro::GenericFixed>(); + auto& bytes = fixed_datum.value(); + bytes.assign(decimal_value.begin(), decimal_value.end()); + std::ranges::reverse(bytes); + return {}; + } + + case ::arrow::Type::DATE32: { + const auto& date_array = internal::checked_cast<const ::arrow::Date32Array&>(array); + datum->value<int32_t>() = date_array.Value(index); + return {}; + } + + case ::arrow::Type::TIME64: { + const auto& time_array = internal::checked_cast<const ::arrow::Time64Array&>(array); + datum->value<int64_t>() = time_array.Value(index); + return {}; + } + + // For both timestamp and timestamp_tz with time unit as microsecond. + case ::arrow::Type::TIMESTAMP: { + const auto& timestamp_array = + internal::checked_cast<const ::arrow::TimestampArray&>(array); + datum->value<int64_t>() = timestamp_array.Value(index); + return {}; + } + + case ::arrow::Type::EXTENSION: { + if (array.type()->name() == "arrow.uuid") { + const auto& extension_array = + internal::checked_cast<const ::arrow::ExtensionArray&>(array); + const auto& fixed_array = + internal::checked_cast<const ::arrow::FixedSizeBinaryArray&>( + *extension_array.storage()); + std::string_view value = fixed_array.GetView(index); + auto& fixed_datum = datum->value<::avro::GenericFixed>(); + fixed_datum.value().assign(value.begin(), value.end()); + return {}; + } + + return NotSupported("Unsupported Arrow extension type: {}", array.type()->name()); + } + + case ::arrow::Type::STRUCT: { + const auto& struct_array = + internal::checked_cast<const ::arrow::StructArray&>(array); + auto& record = datum->value<::avro::GenericRecord>(); + for (int i = 0; i < struct_array.num_fields(); ++i) { + ICEBERG_RETURN_UNEXPECTED( + ExtractDatumFromArray(*struct_array.field(i), index, &record.fieldAt(i))); + } + return {}; + } + + // TODO(gangwu): support LARGE_LIST. + case ::arrow::Type::LIST: { + const auto& list_array = internal::checked_cast<const ::arrow::ListArray&>(array); + auto& avro_array = datum->value<::avro::GenericArray>(); + auto& elements = avro_array.value(); + + auto start = list_array.value_offset(index); + auto end = list_array.value_offset(index + 1); + auto length = end - start; + + auto values = list_array.values(); + elements.resize(length, ::avro::GenericDatum(avro_array.schema()->leafAt(0))); + + for (int64_t i = 0; i < length; ++i) { + ICEBERG_RETURN_UNEXPECTED( + ExtractDatumFromArray(*values, start + i, &elements[i])); + } + return {}; + } + + case ::arrow::Type::MAP: { + const auto& map_array = internal::checked_cast<const ::arrow::MapArray&>(array); + auto start = map_array.value_offset(index); + auto end = map_array.value_offset(index + 1); + auto length = end - start; + + auto keys = map_array.keys(); + auto items = map_array.items(); + + if (datum->type() == ::avro::AVRO_MAP) { + // Handle regular Avro map + auto& avro_map = datum->value<::avro::GenericMap>(); + auto value_node = avro_map.schema()->leafAt(1); + + auto& map_entries = avro_map.value(); + map_entries.resize( + length, std::make_pair(std::string(), ::avro::GenericDatum(value_node))); + + const auto& key_array = + internal::checked_cast<const ::arrow::StringArray&>(*keys); + + for (int64_t i = 0; i < length; ++i) { + auto& map_entry = map_entries[i]; + map_entry.first = key_array.GetString(start + i); + ICEBERG_RETURN_UNEXPECTED( + ExtractDatumFromArray(*items, start + i, &map_entry.second)); + } + } else if (datum->type() == ::avro::AVRO_ARRAY) { + // Handle array-based map (list<struct<key, value>>) + auto& avro_array = datum->value<::avro::GenericArray>(); + auto record_node = avro_array.schema()->leafAt(0); + if (record_node->type() != ::avro::AVRO_RECORD || record_node->leaves() != 2) { + return InvalidArgument( + "Expected Avro record with 2 fields for map value, got: {}", + ToString(record_node)); + } + + auto& elements = avro_array.value(); + elements.resize(length, ::avro::GenericDatum(record_node)); + + for (int64_t i = 0; i < length; ++i) { + auto& record = elements[i].value<::avro::GenericRecord>(); + ICEBERG_RETURN_UNEXPECTED( + ExtractDatumFromArray(*keys, start + i, &record.fieldAt(0))); + ICEBERG_RETURN_UNEXPECTED( + ExtractDatumFromArray(*items, start + i, &record.fieldAt(1))); + } + } else { + return InvalidArgument("Unsupported Avro type for map: {}", + static_cast<int>(datum->type())); + } + return {}; + } + + default: + return InvalidArgument("Unsupported Arrow array type: {}", + array.type()->ToString()); + } +} + } // namespace iceberg::avro diff --git a/src/iceberg/avro/avro_data_util_internal.h b/src/iceberg/avro/avro_data_util_internal.h index ad49368..0de383e 100644 --- a/src/iceberg/avro/avro_data_util_internal.h +++ b/src/iceberg/avro/avro_data_util_internal.h @@ -43,4 +43,13 @@ Status AppendDatumToBuilder(const ::avro::NodePtr& avro_node, const Schema& projected_schema, ::arrow::ArrayBuilder* array_builder); +/// \brief Extract an Avro datum from an Arrow array. +/// +/// \param array The Arrow array to extract from. +/// \param index The index of the element to extract. +/// \param datum The Avro datum to extract to. Its Avro type should be consistent with the +/// Arrow type. +Status ExtractDatumFromArray(const ::arrow::Array& array, int64_t index, + ::avro::GenericDatum* datum); + } // namespace iceberg::avro diff --git a/test/avro_data_test.cc b/test/avro_data_test.cc index 45811ff..b5cc1c5 100644 --- a/test/avro_data_test.cc +++ b/test/avro_data_test.cc @@ -21,6 +21,7 @@ #include <arrow/c/bridge.h> #include <arrow/json/from_string.h> +#include <arrow/util/decimal.h> #include <avro/Compiler.hh> #include <avro/Generic.hh> #include <avro/Node.hh> @@ -757,4 +758,724 @@ TEST(AppendDatumToBuilderTest, ListWithMissingOptionalElementFields) { avro_data, expected_json)); } +struct ExtractDatumParam { + std::string name; + std::shared_ptr<Type> iceberg_type; + std::string arrow_json; + std::function<void(const ::avro::GenericDatum&, int)> value_verifier; +}; + +void VerifyExtractDatumFromArray(const ExtractDatumParam& test_case) { + Schema iceberg_schema({SchemaField::MakeRequired( + /*field_id=*/1, /*name=*/"a", test_case.iceberg_type)}); + ::avro::NodePtr avro_node; + ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk()); + + ArrowSchema arrow_c_schema; + ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk()); + auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie(); + auto arrow_struct_type = std::make_shared<::arrow::StructType>(arrow_schema->fields()); + auto arrow_array = + ::arrow::json::ArrayFromJSONString(arrow_struct_type, test_case.arrow_json) + .ValueOrDie(); + + for (int64_t i = 0; i < arrow_array->length(); ++i) { + ::avro::GenericDatum extracted_datum(avro_node); + ASSERT_THAT(ExtractDatumFromArray(*arrow_array, i, &extracted_datum), IsOk()) + << "Failed to extract at index " << i; + test_case.value_verifier(extracted_datum, static_cast<int>(i)); + } +} + +class ExtractDatumFromArrayTest : public ::testing::TestWithParam<ExtractDatumParam> {}; + +TEST_P(ExtractDatumFromArrayTest, PrimitiveType) { + ASSERT_NO_FATAL_FAILURE(VerifyExtractDatumFromArray(GetParam())); +} + +const std::vector<ExtractDatumParam> kExtractDatumTestCases = { + { + .name = "Boolean", + .iceberg_type = boolean(), + .arrow_json = R"([{"a": true}, {"a": false}, {"a": true}])", + .value_verifier = + [](const ::avro::GenericDatum& datum, int i) { + const auto& record = datum.value<::avro::GenericRecord>(); + bool expected = (i % 2 == 0); + EXPECT_EQ(record.fieldAt(0).value<bool>(), expected); + }, + }, + { + .name = "Int", + .iceberg_type = int32(), + .arrow_json = R"([{"a": 0}, {"a": 100}, {"a": 200}])", + .value_verifier = + [](const ::avro::GenericDatum& datum, int i) { + const auto& record = datum.value<::avro::GenericRecord>(); + EXPECT_EQ(record.fieldAt(0).value<int32_t>(), i * 100); + }, + }, + { + .name = "Long", + .iceberg_type = int64(), + .arrow_json = R"([{"a": 0}, {"a": 1000000}, {"a": 2000000}])", + .value_verifier = + [](const ::avro::GenericDatum& datum, int i) { + const auto& record = datum.value<::avro::GenericRecord>(); + EXPECT_EQ(record.fieldAt(0).value<int64_t>(), i * 1000000LL); + }, + }, + { + .name = "Float", + .iceberg_type = float32(), + .arrow_json = R"([{"a": 0.0}, {"a": 3.14}, {"a": 6.28}])", + .value_verifier = + [](const ::avro::GenericDatum& datum, int i) { + const auto& record = datum.value<::avro::GenericRecord>(); + EXPECT_FLOAT_EQ(record.fieldAt(0).value<float>(), i * 3.14f); + }, + }, + { + .name = "Double", + .iceberg_type = float64(), + .arrow_json = R"([{"a": 0.0}, {"a": 1.234567890}, {"a": 2.469135780}])", + .value_verifier = + [](const ::avro::GenericDatum& datum, int i) { + const auto& record = datum.value<::avro::GenericRecord>(); + EXPECT_DOUBLE_EQ(record.fieldAt(0).value<double>(), i * 1.234567890); + }, + }, + { + .name = "String", + .iceberg_type = string(), + .arrow_json = + R"([{"a": "test_string_0"}, {"a": "test_string_1"}, {"a": "test_string_2"}])", + .value_verifier = + [](const ::avro::GenericDatum& datum, int i) { + const auto& record = datum.value<::avro::GenericRecord>(); + std::string expected = "test_string_" + std::to_string(i); + EXPECT_EQ(record.fieldAt(0).value<std::string>(), expected); + }, + }, + { + .name = "Binary", + .iceberg_type = binary(), + .arrow_json = R"([{"a": "abc"}, {"a": "bcd"}, {"a": "cde"}])", + .value_verifier = + [](const ::avro::GenericDatum& datum, int i) { + const auto& record = datum.value<::avro::GenericRecord>(); + const auto& bytes = record.fieldAt(0).value<std::vector<uint8_t>>(); + EXPECT_EQ(bytes.size(), 3); + EXPECT_EQ(bytes[0], static_cast<uint8_t>('a' + i)); + EXPECT_EQ(bytes[1], static_cast<uint8_t>('b' + i)); + EXPECT_EQ(bytes[2], static_cast<uint8_t>('c' + i)); + }, + }, + { + .name = "Fixed", + .iceberg_type = fixed(4), + .arrow_json = R"([{"a": "abcd"}, {"a": "bcde"}, {"a": "cdef"}])", + .value_verifier = + [](const ::avro::GenericDatum& datum, int i) { + const auto& record = datum.value<::avro::GenericRecord>(); + const auto& fixed = record.fieldAt(0).value<::avro::GenericFixed>(); + EXPECT_EQ(fixed.value().size(), 4); + EXPECT_EQ(static_cast<char>(fixed.value()[0]), static_cast<char>('a' + i)); + EXPECT_EQ(static_cast<char>(fixed.value()[1]), static_cast<char>('b' + i)); + EXPECT_EQ(static_cast<char>(fixed.value()[2]), static_cast<char>('c' + i)); + EXPECT_EQ(static_cast<char>(fixed.value()[3]), static_cast<char>('d' + i)); + }, + }, + { + .name = "Decimal", + .iceberg_type = decimal(10, 2), + .arrow_json = R"([{"a": "0.00"}, {"a": "10.01"}, {"a": "20.02"}])", + .value_verifier = + [](const ::avro::GenericDatum& datum, int i) { + const auto& record = datum.value<::avro::GenericRecord>(); + const auto& fixed = record.fieldAt(0).value<::avro::GenericFixed>(); + + const auto& bytes = fixed.value(); + auto decimal = + ::arrow::Decimal128::FromBigEndian( + reinterpret_cast<const uint8_t*>(bytes.data()), bytes.size()) + .ValueOrDie(); + int64_t expected_unscaled = i * 1000 + i; + EXPECT_EQ(decimal.low_bits(), static_cast<uint64_t>(expected_unscaled)); + EXPECT_EQ(decimal.high_bits(), 0); + }, + }, + { + .name = "Date", + .iceberg_type = date(), + .arrow_json = R"([{"a": 18000}, {"a": 18001}, {"a": 18002}])", + .value_verifier = + [](const ::avro::GenericDatum& datum, int i) { + const auto& record = datum.value<::avro::GenericRecord>(); + EXPECT_EQ(record.fieldAt(0).value<int32_t>(), 18000 + i); + }, + }, + { + .name = "Time", + .iceberg_type = time(), + .arrow_json = R"([{"a": 45045123456}, {"a": 45046123456}, {"a": 45047123456}])", + .value_verifier = + [](const ::avro::GenericDatum& datum, int i) { + const auto& record = datum.value<::avro::GenericRecord>(); + EXPECT_EQ(record.fieldAt(0).value<int64_t>(), + 45045123456LL + i * 1000000LL); + }, + }, + { + .name = "Timestamp", + .iceberg_type = timestamp(), + .arrow_json = R"([{"a": 0}, {"a": 1000000}, {"a": 2000000}])", + .value_verifier = + [](const ::avro::GenericDatum& datum, int i) { + const auto& record = datum.value<::avro::GenericRecord>(); + EXPECT_EQ(record.fieldAt(0).value<int64_t>(), i * 1000000LL); + }, + }, + { + .name = "TimestampTz", + .iceberg_type = timestamp_tz(), + .arrow_json = + R"([{"a": 1672531200000000}, {"a": 1672531201000000}, {"a": 1672531202000000}])", + .value_verifier = + [](const ::avro::GenericDatum& datum, int i) { + const auto& record = datum.value<::avro::GenericRecord>(); + EXPECT_EQ(record.fieldAt(0).value<int64_t>(), + 1672531200000000LL + i * 1000000LL); + }, + }, +}; + +INSTANTIATE_TEST_SUITE_P(AllPrimitiveTypes, ExtractDatumFromArrayTest, + ::testing::ValuesIn(kExtractDatumTestCases), + [](const ::testing::TestParamInfo<ExtractDatumParam>& info) { + return info.param.name; + }); + +TEST(ExtractDatumFromArrayTest, StructWithTwoFields) { + Schema iceberg_schema({ + SchemaField::MakeRequired(1, "id", int32()), + SchemaField::MakeRequired(2, "name", string()), + }); + ::avro::NodePtr avro_node; + ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk()); + + ArrowSchema arrow_c_schema; + ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk()); + auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie(); + auto arrow_struct_type = std::make_shared<::arrow::StructType>(arrow_schema->fields()); + + auto arrow_array = ::arrow::json::ArrayFromJSONString(arrow_struct_type, + R"([ + {"id": 42, "name": "Alice"}, + {"id": 43, "name": "Bob"}, + {"id": 44, "name": "Charlie"} + ])") + .ValueOrDie(); + + struct ExpectedData { + int32_t id; + std::string name; + }; + std::vector<ExpectedData> expected = {{.id = 42, .name = "Alice"}, + {.id = 43, .name = "Bob"}, + {.id = 44, .name = "Charlie"}}; + + auto verify_record = [&](int64_t index, const ExpectedData& expected_data) { + ::avro::GenericDatum extracted_datum(avro_node); + ASSERT_THAT(ExtractDatumFromArray(*arrow_array, index, &extracted_datum), IsOk()); + const auto& record = extracted_datum.value<::avro::GenericRecord>(); + EXPECT_EQ(record.fieldAt(0).value<int32_t>(), expected_data.id); + EXPECT_EQ(record.fieldAt(1).value<std::string>(), expected_data.name); + }; + + for (size_t i = 0; i < expected.size(); ++i) { + verify_record(i, expected[i]); + } +} + +TEST(ExtractDatumFromArrayTest, NestedStruct) { + Schema iceberg_schema({ + SchemaField::MakeRequired(1, "id", int32()), + SchemaField::MakeRequired(2, "person", + std::make_shared<StructType>(std::vector<SchemaField>{ + SchemaField::MakeRequired(3, "name", string()), + SchemaField::MakeRequired(4, "age", int32()), + })), + }); + + ::avro::NodePtr avro_node; + ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk()); + + ArrowSchema arrow_c_schema; + ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk()); + auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie(); + auto arrow_struct_type = std::make_shared<::arrow::StructType>(arrow_schema->fields()); + + const std::string arrow_json = R"([ + {"id": 1, "person": {"name": "Alice", "age": 25}}, + {"id": 2, "person": {"name": "Bob", "age": 30}}, + {"id": 3, "person": {"name": "Charlie", "age": 35}} + ])"; + auto arrow_array = + ::arrow::json::ArrayFromJSONString(arrow_struct_type, arrow_json).ValueOrDie(); + + struct ExpectedData { + int32_t id; + std::string name; + int32_t age; + }; + std::vector<ExpectedData> expected = {{.id = 1, .name = "Alice", .age = 25}, + {.id = 2, .name = "Bob", .age = 30}, + {.id = 3, .name = "Charlie", .age = 35}}; + + auto verify_record = [&](int64_t index, const ExpectedData& expected_data) { + ::avro::GenericDatum extracted_datum(avro_node); + ASSERT_THAT(ExtractDatumFromArray(*arrow_array, index, &extracted_datum), IsOk()); + const auto& record = extracted_datum.value<::avro::GenericRecord>(); + EXPECT_EQ(record.fieldAt(0).value<int32_t>(), expected_data.id); + const auto& person_record = record.fieldAt(1).value<::avro::GenericRecord>(); + EXPECT_EQ(person_record.fieldAt(0).value<std::string>(), expected_data.name); + EXPECT_EQ(person_record.fieldAt(1).value<int32_t>(), expected_data.age); + }; + + for (size_t i = 0; i < expected.size(); ++i) { + verify_record(i, expected[i]); + } +} + +TEST(ExtractDatumFromArrayTest, ListOfIntegers) { + Schema iceberg_schema({ + SchemaField::MakeRequired( + 1, "numbers", + std::make_shared<ListType>(SchemaField::MakeRequired(2, "element", int32()))), + }); + + ::avro::NodePtr avro_node; + ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk()); + + ArrowSchema arrow_c_schema; + ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk()); + auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie(); + auto arrow_struct_type = std::make_shared<::arrow::StructType>(arrow_schema->fields()); + + const std::string arrow_json = R"([ + {"numbers": [10, 11, 12]}, + {"numbers": [20, 21]}, + {"numbers": [30, 31, 32, 33]} + ])"; + auto arrow_array = + ::arrow::json::ArrayFromJSONString(arrow_struct_type, arrow_json).ValueOrDie(); + + std::vector<std::vector<int32_t>> expected = {{10, 11, 12}, {20, 21}, {30, 31, 32, 33}}; + + auto verify_record = [&](int64_t index, const std::vector<int32_t>& expected_numbers) { + ::avro::GenericDatum extracted_datum(avro_node); + ASSERT_THAT(ExtractDatumFromArray(*arrow_array, index, &extracted_datum), IsOk()); + const auto& record = extracted_datum.value<::avro::GenericRecord>(); + const auto& array = record.fieldAt(0).value<::avro::GenericArray>(); + const auto& elements = array.value(); + + ASSERT_EQ(elements.size(), expected_numbers.size()); + for (size_t i = 0; i < expected_numbers.size(); ++i) { + EXPECT_EQ(elements[i].value<int32_t>(), expected_numbers[i]); + } + }; + + for (size_t i = 0; i < expected.size(); ++i) { + verify_record(i, expected[i]); + } +} + +TEST(ExtractDatumFromArrayTest, MapStringToInt) { + Schema iceberg_schema({ + SchemaField::MakeRequired( + 1, "scores", + std::make_shared<MapType>(SchemaField::MakeRequired(2, "key", string()), + SchemaField::MakeRequired(3, "value", int32()))), + }); + + ::avro::NodePtr avro_node; + ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk()); + + ArrowSchema arrow_c_schema; + ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk()); + auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie(); + auto arrow_struct_type = std::make_shared<::arrow::StructType>(arrow_schema->fields()); + + const std::string arrow_json = R"([ + {"scores": [["alice", 95], ["bob", 87]]}, + {"scores": [["charlie", 92], ["diana", 98], ["eve", 89]]}, + {"scores": [["frank", 91]]} + ])"; + auto arrow_array = + ::arrow::json::ArrayFromJSONString(arrow_struct_type, arrow_json).ValueOrDie(); + + using MapEntry = std::pair<std::string, int32_t>; + std::vector<std::vector<MapEntry>> expected = { + {{"alice", 95}, {"bob", 87}}, + {{"charlie", 92}, {"diana", 98}, {"eve", 89}}, + {{"frank", 91}}}; + + auto verify_record = [&](int64_t index, const std::vector<MapEntry>& expected_entries) { + ::avro::GenericDatum extracted_datum(avro_node); + ASSERT_THAT(ExtractDatumFromArray(*arrow_array, index, &extracted_datum), IsOk()); + const auto& record = extracted_datum.value<::avro::GenericRecord>(); + const auto& map = record.fieldAt(0).value<::avro::GenericMap>(); + const auto& entries = map.value(); + + ASSERT_EQ(entries.size(), expected_entries.size()); + for (size_t i = 0; i < expected_entries.size(); ++i) { + EXPECT_EQ(entries[i].first, expected_entries[i].first); + EXPECT_EQ(entries[i].second.value<int32_t>(), expected_entries[i].second); + } + }; + + for (size_t i = 0; i < expected.size(); ++i) { + verify_record(i, expected[i]); + } +} + +TEST(ExtractDatumFromArrayTest, ErrorHandling) { + Schema iceberg_schema({SchemaField::MakeRequired(1, "a", int32())}); + ::avro::NodePtr avro_node; + ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk()); + + ArrowSchema arrow_c_schema; + ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk()); + auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie(); + auto arrow_struct_type = std::make_shared<::arrow::StructType>(arrow_schema->fields()); + + auto arrow_array = ::arrow::json::ArrayFromJSONString( + arrow_struct_type, R"([{"a": 1}, {"a": 2}, {"a": 3}])") + .ValueOrDie(); + + ::avro::GenericDatum datum(avro_node); + + // Test negative index + EXPECT_THAT(ExtractDatumFromArray(*arrow_array, -1, &datum), + HasErrorMessage("Cannot extract datum from array at index -1")); + + // Test index beyond array length + EXPECT_THAT(ExtractDatumFromArray(*arrow_array, 3, &datum), + HasErrorMessage("Cannot extract datum from array at index 3")); +} + +TEST(ExtractDatumFromArrayTest, NullHandling) { + Schema iceberg_schema({SchemaField::MakeOptional(1, "a", int32())}); + ::avro::NodePtr avro_node; + ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk()); + + ArrowSchema arrow_c_schema; + ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk()); + auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie(); + auto arrow_struct_type = std::make_shared<::arrow::StructType>(arrow_schema->fields()); + + auto arrow_array = + ::arrow::json::ArrayFromJSONString(arrow_struct_type, R"([{"a": 42}, {"a": null}])") + .ValueOrDie(); + + ::avro::GenericDatum datum(avro_node); + ASSERT_THAT(ExtractDatumFromArray(*arrow_array, 0, &datum), IsOk()); + + const auto& record = datum.value<::avro::GenericRecord>(); + EXPECT_EQ(record.fieldAt(0).unionBranch(), 1); + EXPECT_EQ(record.fieldAt(0).type(), ::avro::AVRO_INT); + EXPECT_EQ(record.fieldAt(0).value<int32_t>(), 42); + + ASSERT_THAT(ExtractDatumFromArray(*arrow_array, 1, &datum), IsOk()); + const auto& record2 = datum.value<::avro::GenericRecord>(); + EXPECT_EQ(record2.fieldAt(0).unionBranch(), 0); + EXPECT_EQ(record2.fieldAt(0).type(), ::avro::AVRO_NULL); +} + +struct RoundTripParam { + std::string name; + Schema iceberg_schema; + std::string arrow_json; +}; + +void VerifyRoundTripConversion(const RoundTripParam& test_case) { + ::avro::NodePtr avro_node; + ASSERT_THAT(ToAvroNodeVisitor{}.Visit(test_case.iceberg_schema, &avro_node), IsOk()); + + ArrowSchema arrow_c_schema; + ASSERT_THAT(ToArrowSchema(test_case.iceberg_schema, &arrow_c_schema), IsOk()); + auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie(); + auto arrow_struct_type = std::make_shared<::arrow::StructType>(arrow_schema->fields()); + + auto original_array = + ::arrow::json::ArrayFromJSONString(arrow_struct_type, test_case.arrow_json) + .ValueOrDie(); + + std::vector<::avro::GenericDatum> extracted_data; + for (int64_t i = 0; i < original_array->length(); ++i) { + ::avro::GenericDatum datum(avro_node); + ASSERT_THAT(ExtractDatumFromArray(*original_array, i, &datum), IsOk()) + << "Failed to extract datum at index " << i; + extracted_data.push_back(datum); + } + + auto projection_result = + Project(test_case.iceberg_schema, avro_node, /*prune_source=*/false); + ASSERT_THAT(projection_result, IsOk()); + auto projection = std::move(projection_result.value()); + + auto builder = ::arrow::MakeBuilder(arrow_struct_type).ValueOrDie(); + for (const auto& datum : extracted_data) { + ASSERT_THAT(AppendDatumToBuilder(avro_node, datum, projection, + test_case.iceberg_schema, builder.get()), + IsOk()); + } + + auto rebuilt_array = builder->Finish().ValueOrDie(); + + ASSERT_TRUE(original_array->Equals(*rebuilt_array)) + << "Round-trip consistency failed!\n" + << "Original array: " << original_array->ToString() << "\n" + << "Rebuilt array: " << rebuilt_array->ToString(); +} + +class AvroRoundTripConversionTest : public ::testing::TestWithParam<RoundTripParam> {}; + +TEST_P(AvroRoundTripConversionTest, ConvertTypes) { + ASSERT_NO_FATAL_FAILURE(VerifyRoundTripConversion(GetParam())); +} + +const std::vector<RoundTripParam> kRoundTripTestCases = { + { + .name = "SimpleStruct", + .iceberg_schema = Schema({ + SchemaField::MakeRequired(1, "id", int32()), + SchemaField::MakeRequired(2, "name", string()), + SchemaField::MakeOptional(3, "age", int32()), + }), + .arrow_json = R"([ + {"id": 100, "name": "Alice", "age": 25}, + {"id": 101, "name": "Bob", "age": null}, + {"id": 102, "name": "Charlie", "age": 35} + ])", + }, + { + .name = "PrimitiveTypes", + .iceberg_schema = Schema({ + SchemaField::MakeRequired(1, "bool_field", boolean()), + SchemaField::MakeRequired(2, "int_field", int32()), + SchemaField::MakeRequired(3, "long_field", int64()), + SchemaField::MakeRequired(4, "float_field", float32()), + SchemaField::MakeRequired(5, "double_field", float64()), + SchemaField::MakeRequired(6, "string_field", string()), + }), + .arrow_json = R"([ + {"bool_field": true, "int_field": 42, "long_field": 1000000, "float_field": 3.14, "double_field": 2.718281828, "string_field": "hello"}, + {"bool_field": false, "int_field": -42, "long_field": -1000000, "float_field": -3.14, "double_field": -2.718281828, "string_field": "world"} + ])", + }, + { + .name = "NestedStruct", + .iceberg_schema = Schema({ + SchemaField::MakeRequired(1, "id", int32()), + SchemaField::MakeRequired( + 2, "person", + std::make_shared<StructType>(std::vector<SchemaField>{ + SchemaField::MakeRequired(3, "name", string()), + SchemaField::MakeRequired(4, "age", int32()), + })), + }), + .arrow_json = R"([ + {"id": 1, "person": {"name": "Alice", "age": 30}}, + {"id": 2, "person": {"name": "Bob", "age": 25}} + ])", + }, + { + .name = "ListOfIntegers", + .iceberg_schema = Schema({ + SchemaField::MakeRequired( + 1, "numbers", + std::make_shared<ListType>( + SchemaField::MakeRequired(2, "element", int32()))), + }), + .arrow_json = R"([ + {"numbers": [1, 2, 3]}, + {"numbers": [10, 20]}, + {"numbers": []} + ])", + }, + { + .name = "MapStringToInt", + .iceberg_schema = Schema({ + SchemaField::MakeRequired( + 1, "scores", + std::make_shared<MapType>( + SchemaField::MakeRequired(2, "key", string()), + SchemaField::MakeRequired(3, "value", int32()))), + }), + .arrow_json = R"([ + {"scores": [["alice", 95], ["bob", 87]]}, + {"scores": [["charlie", 92]]}, + {"scores": []} + ])", + }, + { + .name = "ComplexNested", + .iceberg_schema = Schema({ + SchemaField::MakeRequired( + 1, "data", + std::make_shared<StructType>(std::vector<SchemaField>{ + SchemaField::MakeRequired(2, "id", int32()), + SchemaField::MakeRequired( + 3, "tags", + std::make_shared<ListType>( + SchemaField::MakeRequired(4, "element", string()))), + SchemaField::MakeOptional( + 5, "metadata", + std::make_shared<MapType>( + SchemaField::MakeRequired(6, "key", string()), + SchemaField::MakeRequired(7, "value", string()))), + })), + }), + .arrow_json = R"([ + {"data": {"id": 1, "tags": ["tag1", "tag2"], "metadata": [["key1", "value1"]]}}, + {"data": {"id": 2, "tags": [], "metadata": null}} + ])", + }, + { + .name = "NullablePrimitives", + .iceberg_schema = Schema({ + SchemaField::MakeOptional(1, "optional_bool", boolean()), + SchemaField::MakeOptional(2, "optional_int", int32()), + SchemaField::MakeOptional(3, "optional_long", int64()), + SchemaField::MakeOptional(4, "optional_string", string()), + SchemaField::MakeRequired(5, "required_id", int32()), + }), + .arrow_json = R"([ + {"optional_bool": true, "optional_int": 42, "optional_long": 1000000, "optional_string": "hello", "required_id": 1}, + {"optional_bool": null, "optional_int": null, "optional_long": null, "optional_string": null, "required_id": 2}, + {"optional_bool": false, "optional_int": null, "optional_long": 2000000, "optional_string": null, "required_id": 3}, + {"optional_bool": null, "optional_int": 123, "optional_long": null, "optional_string": "world", "required_id": 4} + ])", + }, + { + .name = "NullableNestedStruct", + .iceberg_schema = Schema({ + SchemaField::MakeRequired(1, "id", int32()), + SchemaField::MakeOptional( + 2, "person", + std::make_shared<StructType>(std::vector<SchemaField>{ + SchemaField::MakeRequired(3, "name", string()), + SchemaField::MakeOptional(4, "age", int32()), + SchemaField::MakeOptional(5, "email", string()), + })), + SchemaField::MakeOptional(6, "department", string()), + }), + .arrow_json = R"([ + {"id": 1, "person": {"name": "Alice", "age": 30, "email": "al...@example.com"}, "department": "Engineering"}, + {"id": 2, "person": null, "department": null}, + {"id": 3, "person": {"name": "Bob", "age": null, "email": null}, "department": "Sales"}, + {"id": 4, "person": {"name": "Charlie", "age": 25, "email": null}, "department": null} + ])", + }, + { + .name = "NullableListElements", + .iceberg_schema = Schema({ + SchemaField::MakeRequired(1, "id", int32()), + SchemaField::MakeOptional( + 2, "numbers", + std::make_shared<ListType>( + SchemaField::MakeOptional(3, "element", int32()))), + SchemaField::MakeRequired( + 4, "tags", + std::make_shared<ListType>( + SchemaField::MakeOptional(5, "element", string()))), + }), + .arrow_json = R"([ + {"id": 1, "numbers": [1, null, 3], "tags": ["tag1", null, "tag3"]}, + {"id": 2, "numbers": null, "tags": ["only_tag"]}, + {"id": 3, "numbers": [null, null], "tags": [null, null, null]}, + {"id": 4, "numbers": [], "tags": []} + ])", + }, + { + .name = "NullableMapValues", + .iceberg_schema = Schema({ + SchemaField::MakeRequired(1, "id", int32()), + SchemaField::MakeOptional( + 2, "scores", + std::make_shared<MapType>( + SchemaField::MakeRequired(3, "key", string()), + SchemaField::MakeOptional(4, "value", int32()))), + SchemaField::MakeRequired( + 5, "metadata", + std::make_shared<MapType>( + SchemaField::MakeRequired(6, "key", string()), + SchemaField::MakeOptional(7, "value", string()))), + }), + .arrow_json = R"([ + {"id": 1, "scores": [["alice", 95], ["bob", null]], "metadata": [["key1", "value1"], ["key2", null]]}, + {"id": 2, "scores": null, "metadata": [["key3", null]]}, + {"id": 3, "scores": [["charlie", null], ["diana", 98]], "metadata": []}, + {"id": 4, "scores": [], "metadata": [["key4", null], ["key5", "value5"]]} + ])", + }, + { + .name = "DeeplyNestedWithNulls", + .iceberg_schema = Schema({ + SchemaField::MakeRequired( + 1, "root", + std::make_shared<StructType>(std::vector<SchemaField>{ + SchemaField::MakeRequired(2, "id", int32()), + SchemaField::MakeOptional( + 3, "nested", + std::make_shared<StructType>(std::vector<SchemaField>{ + SchemaField::MakeOptional(4, "name", string()), + SchemaField::MakeOptional( + 5, "values", + std::make_shared<ListType>( + SchemaField::MakeOptional(6, "element", int32()))), + })), + SchemaField::MakeOptional( + 7, "tags", + std::make_shared<ListType>( + SchemaField::MakeOptional(8, "element", string()))), + })), + }), + .arrow_json = R"([ + {"root": {"id": 1, "nested": {"name": "test", "values": [1, null, 3]}, "tags": ["a", "b"]}}, + {"root": {"id": 2, "nested": null, "tags": null}}, + {"root": {"id": 3, "nested": {"name": null, "values": null}, "tags": [null, "c"]}}, + {"root": {"id": 4, "nested": {"name": "empty", "values": []}, "tags": []}} + ])", + }, + { + .name = "AllNullsVariations", + .iceberg_schema = Schema({ + SchemaField::MakeOptional(1, "always_null", string()), + SchemaField::MakeOptional(2, "sometimes_null", int32()), + SchemaField::MakeOptional( + 3, "nested_struct", + std::make_shared<StructType>(std::vector<SchemaField>{ + SchemaField::MakeOptional(4, "inner_null", string()), + SchemaField::MakeRequired(5, "inner_required", boolean()), + })), + SchemaField::MakeRequired(6, "id", int32()), + }), + .arrow_json = R"([ + {"always_null": null, "sometimes_null": 42, "nested_struct": {"inner_null": "value", "inner_required": true}, "id": 1}, + {"always_null": null, "sometimes_null": null, "nested_struct": null, "id": 2}, + {"always_null": null, "sometimes_null": 123, "nested_struct": {"inner_null": null, "inner_required": false}, "id": 3}, + {"always_null": null, "sometimes_null": null, "nested_struct": {"inner_null": null, "inner_required": true}, "id": 4} + ])", + }, +}; + +INSTANTIATE_TEST_SUITE_P(AllTypes, AvroRoundTripConversionTest, + ::testing::ValuesIn(kRoundTripTestCases), + [](const ::testing::TestParamInfo<RoundTripParam>& info) { + return info.param.name; + }); + } // namespace iceberg::avro