This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-cpp.git


The following commit(s) were added to refs/heads/main by this push:
     new bc87f00  feat(avro): extract avro datum from arrow array (#166)
bc87f00 is described below

commit bc87f00f2df9155c3d17e8cc804b04c05210b636
Author: Gang Wu <ust...@gmail.com>
AuthorDate: Tue Aug 19 02:01:26 2025 +0800

    feat(avro): extract avro datum from arrow array (#166)
---
 src/iceberg/avro/avro_data_util.cc         | 230 +++++++++
 src/iceberg/avro/avro_data_util_internal.h |   9 +
 test/avro_data_test.cc                     | 721 +++++++++++++++++++++++++++++
 3 files changed, 960 insertions(+)

diff --git a/src/iceberg/avro/avro_data_util.cc 
b/src/iceberg/avro/avro_data_util.cc
index 16ac41a..8853f74 100644
--- a/src/iceberg/avro/avro_data_util.cc
+++ b/src/iceberg/avro/avro_data_util.cc
@@ -17,10 +17,13 @@
  * under the License.
  */
 
+#include <ranges>
+
 #include <arrow/array/builder_binary.h>
 #include <arrow/array/builder_decimal.h>
 #include <arrow/array/builder_nested.h>
 #include <arrow/array/builder_primitive.h>
+#include <arrow/extension_type.h>
 #include <arrow/json/from_string.h>
 #include <arrow/type.h>
 #include <arrow/util/decimal.h>
@@ -451,4 +454,231 @@ Status AppendDatumToBuilder(const ::avro::NodePtr& 
avro_node,
                                     projected_schema, array_builder);
 }
 
+namespace {
+
+// ToAvroNodeVisitor uses 0 for null branch and 1 for value branch.
+constexpr int64_t kNullBranch = 0;
+constexpr int64_t kValueBranch = 1;
+
+}  // namespace
+
+Status ExtractDatumFromArray(const ::arrow::Array& array, int64_t index,
+                             ::avro::GenericDatum* datum) {
+  if (index < 0 || index >= array.length()) {
+    return InvalidArgument("Cannot extract datum from array at index {} of 
length {}",
+                           index, array.length());
+  }
+
+  if (array.IsNull(index)) {
+    if (!datum->isUnion()) [[unlikely]] {
+      return InvalidSchema("Cannot extract null to non-union type: {}",
+                           ::avro::toString(datum->type()));
+    }
+    datum->selectBranch(kNullBranch);
+    return {};
+  }
+
+  if (datum->isUnion()) {
+    datum->selectBranch(kValueBranch);
+  }
+
+  switch (array.type()->id()) {
+    case ::arrow::Type::BOOL: {
+      const auto& bool_array =
+          internal::checked_cast<const ::arrow::BooleanArray&>(array);
+      datum->value<bool>() = bool_array.Value(index);
+      return {};
+    }
+
+    case ::arrow::Type::INT32: {
+      const auto& int32_array = internal::checked_cast<const 
::arrow::Int32Array&>(array);
+      datum->value<int32_t>() = int32_array.Value(index);
+      return {};
+    }
+
+    case ::arrow::Type::INT64: {
+      const auto& int64_array = internal::checked_cast<const 
::arrow::Int64Array&>(array);
+      datum->value<int64_t>() = int64_array.Value(index);
+      return {};
+    }
+
+    case ::arrow::Type::FLOAT: {
+      const auto& float_array = internal::checked_cast<const 
::arrow::FloatArray&>(array);
+      datum->value<float>() = float_array.Value(index);
+      return {};
+    }
+
+    case ::arrow::Type::DOUBLE: {
+      const auto& double_array =
+          internal::checked_cast<const ::arrow::DoubleArray&>(array);
+      datum->value<double>() = double_array.Value(index);
+      return {};
+    }
+
+    // TODO(gangwu): support LARGE_STRING.
+    case ::arrow::Type::STRING: {
+      const auto& string_array =
+          internal::checked_cast<const ::arrow::StringArray&>(array);
+      datum->value<std::string>() = string_array.GetString(index);
+      return {};
+    }
+
+    // TODO(gangwu): support LARGE_BINARY.
+    case ::arrow::Type::BINARY: {
+      const auto& binary_array =
+          internal::checked_cast<const ::arrow::BinaryArray&>(array);
+      std::string_view value = binary_array.GetView(index);
+      datum->value<std::vector<uint8_t>>().assign(
+          reinterpret_cast<const uint8_t*>(value.data()),
+          reinterpret_cast<const uint8_t*>(value.data()) + value.size());
+      return {};
+    }
+
+    case ::arrow::Type::FIXED_SIZE_BINARY: {
+      const auto& fixed_array =
+          internal::checked_cast<const ::arrow::FixedSizeBinaryArray&>(array);
+      std::string_view value = fixed_array.GetView(index);
+      auto& fixed_datum = datum->value<::avro::GenericFixed>();
+      fixed_datum.value().assign(value.begin(), value.end());
+      return {};
+    }
+
+    case ::arrow::Type::DECIMAL128: {
+      const auto& decimal_array =
+          internal::checked_cast<const ::arrow::Decimal128Array&>(array);
+      std::string_view decimal_value = decimal_array.GetView(index);
+      auto& fixed_datum = datum->value<::avro::GenericFixed>();
+      auto& bytes = fixed_datum.value();
+      bytes.assign(decimal_value.begin(), decimal_value.end());
+      std::ranges::reverse(bytes);
+      return {};
+    }
+
+    case ::arrow::Type::DATE32: {
+      const auto& date_array = internal::checked_cast<const 
::arrow::Date32Array&>(array);
+      datum->value<int32_t>() = date_array.Value(index);
+      return {};
+    }
+
+    case ::arrow::Type::TIME64: {
+      const auto& time_array = internal::checked_cast<const 
::arrow::Time64Array&>(array);
+      datum->value<int64_t>() = time_array.Value(index);
+      return {};
+    }
+
+    // For both timestamp and timestamp_tz with time unit as microsecond.
+    case ::arrow::Type::TIMESTAMP: {
+      const auto& timestamp_array =
+          internal::checked_cast<const ::arrow::TimestampArray&>(array);
+      datum->value<int64_t>() = timestamp_array.Value(index);
+      return {};
+    }
+
+    case ::arrow::Type::EXTENSION: {
+      if (array.type()->name() == "arrow.uuid") {
+        const auto& extension_array =
+            internal::checked_cast<const ::arrow::ExtensionArray&>(array);
+        const auto& fixed_array =
+            internal::checked_cast<const ::arrow::FixedSizeBinaryArray&>(
+                *extension_array.storage());
+        std::string_view value = fixed_array.GetView(index);
+        auto& fixed_datum = datum->value<::avro::GenericFixed>();
+        fixed_datum.value().assign(value.begin(), value.end());
+        return {};
+      }
+
+      return NotSupported("Unsupported Arrow extension type: {}", 
array.type()->name());
+    }
+
+    case ::arrow::Type::STRUCT: {
+      const auto& struct_array =
+          internal::checked_cast<const ::arrow::StructArray&>(array);
+      auto& record = datum->value<::avro::GenericRecord>();
+      for (int i = 0; i < struct_array.num_fields(); ++i) {
+        ICEBERG_RETURN_UNEXPECTED(
+            ExtractDatumFromArray(*struct_array.field(i), index, 
&record.fieldAt(i)));
+      }
+      return {};
+    }
+
+    // TODO(gangwu): support LARGE_LIST.
+    case ::arrow::Type::LIST: {
+      const auto& list_array = internal::checked_cast<const 
::arrow::ListArray&>(array);
+      auto& avro_array = datum->value<::avro::GenericArray>();
+      auto& elements = avro_array.value();
+
+      auto start = list_array.value_offset(index);
+      auto end = list_array.value_offset(index + 1);
+      auto length = end - start;
+
+      auto values = list_array.values();
+      elements.resize(length, 
::avro::GenericDatum(avro_array.schema()->leafAt(0)));
+
+      for (int64_t i = 0; i < length; ++i) {
+        ICEBERG_RETURN_UNEXPECTED(
+            ExtractDatumFromArray(*values, start + i, &elements[i]));
+      }
+      return {};
+    }
+
+    case ::arrow::Type::MAP: {
+      const auto& map_array = internal::checked_cast<const 
::arrow::MapArray&>(array);
+      auto start = map_array.value_offset(index);
+      auto end = map_array.value_offset(index + 1);
+      auto length = end - start;
+
+      auto keys = map_array.keys();
+      auto items = map_array.items();
+
+      if (datum->type() == ::avro::AVRO_MAP) {
+        // Handle regular Avro map
+        auto& avro_map = datum->value<::avro::GenericMap>();
+        auto value_node = avro_map.schema()->leafAt(1);
+
+        auto& map_entries = avro_map.value();
+        map_entries.resize(
+            length, std::make_pair(std::string(), 
::avro::GenericDatum(value_node)));
+
+        const auto& key_array =
+            internal::checked_cast<const ::arrow::StringArray&>(*keys);
+
+        for (int64_t i = 0; i < length; ++i) {
+          auto& map_entry = map_entries[i];
+          map_entry.first = key_array.GetString(start + i);
+          ICEBERG_RETURN_UNEXPECTED(
+              ExtractDatumFromArray(*items, start + i, &map_entry.second));
+        }
+      } else if (datum->type() == ::avro::AVRO_ARRAY) {
+        // Handle array-based map (list<struct<key, value>>)
+        auto& avro_array = datum->value<::avro::GenericArray>();
+        auto record_node = avro_array.schema()->leafAt(0);
+        if (record_node->type() != ::avro::AVRO_RECORD || 
record_node->leaves() != 2) {
+          return InvalidArgument(
+              "Expected Avro record with 2 fields for map value, got: {}",
+              ToString(record_node));
+        }
+
+        auto& elements = avro_array.value();
+        elements.resize(length, ::avro::GenericDatum(record_node));
+
+        for (int64_t i = 0; i < length; ++i) {
+          auto& record = elements[i].value<::avro::GenericRecord>();
+          ICEBERG_RETURN_UNEXPECTED(
+              ExtractDatumFromArray(*keys, start + i, &record.fieldAt(0)));
+          ICEBERG_RETURN_UNEXPECTED(
+              ExtractDatumFromArray(*items, start + i, &record.fieldAt(1)));
+        }
+      } else {
+        return InvalidArgument("Unsupported Avro type for map: {}",
+                               static_cast<int>(datum->type()));
+      }
+      return {};
+    }
+
+    default:
+      return InvalidArgument("Unsupported Arrow array type: {}",
+                             array.type()->ToString());
+  }
+}
+
 }  // namespace iceberg::avro
diff --git a/src/iceberg/avro/avro_data_util_internal.h 
b/src/iceberg/avro/avro_data_util_internal.h
index ad49368..0de383e 100644
--- a/src/iceberg/avro/avro_data_util_internal.h
+++ b/src/iceberg/avro/avro_data_util_internal.h
@@ -43,4 +43,13 @@ Status AppendDatumToBuilder(const ::avro::NodePtr& avro_node,
                             const Schema& projected_schema,
                             ::arrow::ArrayBuilder* array_builder);
 
+/// \brief Extract an Avro datum from an Arrow array.
+///
+/// \param array The Arrow array to extract from.
+/// \param index The index of the element to extract.
+/// \param datum The Avro datum to extract to. Its Avro type should be 
consistent with the
+/// Arrow type.
+Status ExtractDatumFromArray(const ::arrow::Array& array, int64_t index,
+                             ::avro::GenericDatum* datum);
+
 }  // namespace iceberg::avro
diff --git a/test/avro_data_test.cc b/test/avro_data_test.cc
index 45811ff..b5cc1c5 100644
--- a/test/avro_data_test.cc
+++ b/test/avro_data_test.cc
@@ -21,6 +21,7 @@
 
 #include <arrow/c/bridge.h>
 #include <arrow/json/from_string.h>
+#include <arrow/util/decimal.h>
 #include <avro/Compiler.hh>
 #include <avro/Generic.hh>
 #include <avro/Node.hh>
@@ -757,4 +758,724 @@ TEST(AppendDatumToBuilderTest, 
ListWithMissingOptionalElementFields) {
                                                      avro_data, 
expected_json));
 }
 
+struct ExtractDatumParam {
+  std::string name;
+  std::shared_ptr<Type> iceberg_type;
+  std::string arrow_json;
+  std::function<void(const ::avro::GenericDatum&, int)> value_verifier;
+};
+
+void VerifyExtractDatumFromArray(const ExtractDatumParam& test_case) {
+  Schema iceberg_schema({SchemaField::MakeRequired(
+      /*field_id=*/1, /*name=*/"a", test_case.iceberg_type)});
+  ::avro::NodePtr avro_node;
+  ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk());
+
+  ArrowSchema arrow_c_schema;
+  ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk());
+  auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie();
+  auto arrow_struct_type = 
std::make_shared<::arrow::StructType>(arrow_schema->fields());
+  auto arrow_array =
+      ::arrow::json::ArrayFromJSONString(arrow_struct_type, 
test_case.arrow_json)
+          .ValueOrDie();
+
+  for (int64_t i = 0; i < arrow_array->length(); ++i) {
+    ::avro::GenericDatum extracted_datum(avro_node);
+    ASSERT_THAT(ExtractDatumFromArray(*arrow_array, i, &extracted_datum), 
IsOk())
+        << "Failed to extract at index " << i;
+    test_case.value_verifier(extracted_datum, static_cast<int>(i));
+  }
+}
+
+class ExtractDatumFromArrayTest : public 
::testing::TestWithParam<ExtractDatumParam> {};
+
+TEST_P(ExtractDatumFromArrayTest, PrimitiveType) {
+  ASSERT_NO_FATAL_FAILURE(VerifyExtractDatumFromArray(GetParam()));
+}
+
+const std::vector<ExtractDatumParam> kExtractDatumTestCases = {
+    {
+        .name = "Boolean",
+        .iceberg_type = boolean(),
+        .arrow_json = R"([{"a": true}, {"a": false}, {"a": true}])",
+        .value_verifier =
+            [](const ::avro::GenericDatum& datum, int i) {
+              const auto& record = datum.value<::avro::GenericRecord>();
+              bool expected = (i % 2 == 0);
+              EXPECT_EQ(record.fieldAt(0).value<bool>(), expected);
+            },
+    },
+    {
+        .name = "Int",
+        .iceberg_type = int32(),
+        .arrow_json = R"([{"a": 0}, {"a": 100}, {"a": 200}])",
+        .value_verifier =
+            [](const ::avro::GenericDatum& datum, int i) {
+              const auto& record = datum.value<::avro::GenericRecord>();
+              EXPECT_EQ(record.fieldAt(0).value<int32_t>(), i * 100);
+            },
+    },
+    {
+        .name = "Long",
+        .iceberg_type = int64(),
+        .arrow_json = R"([{"a": 0}, {"a": 1000000}, {"a": 2000000}])",
+        .value_verifier =
+            [](const ::avro::GenericDatum& datum, int i) {
+              const auto& record = datum.value<::avro::GenericRecord>();
+              EXPECT_EQ(record.fieldAt(0).value<int64_t>(), i * 1000000LL);
+            },
+    },
+    {
+        .name = "Float",
+        .iceberg_type = float32(),
+        .arrow_json = R"([{"a": 0.0}, {"a": 3.14}, {"a": 6.28}])",
+        .value_verifier =
+            [](const ::avro::GenericDatum& datum, int i) {
+              const auto& record = datum.value<::avro::GenericRecord>();
+              EXPECT_FLOAT_EQ(record.fieldAt(0).value<float>(), i * 3.14f);
+            },
+    },
+    {
+        .name = "Double",
+        .iceberg_type = float64(),
+        .arrow_json = R"([{"a": 0.0}, {"a": 1.234567890}, {"a": 
2.469135780}])",
+        .value_verifier =
+            [](const ::avro::GenericDatum& datum, int i) {
+              const auto& record = datum.value<::avro::GenericRecord>();
+              EXPECT_DOUBLE_EQ(record.fieldAt(0).value<double>(), i * 
1.234567890);
+            },
+    },
+    {
+        .name = "String",
+        .iceberg_type = string(),
+        .arrow_json =
+            R"([{"a": "test_string_0"}, {"a": "test_string_1"}, {"a": 
"test_string_2"}])",
+        .value_verifier =
+            [](const ::avro::GenericDatum& datum, int i) {
+              const auto& record = datum.value<::avro::GenericRecord>();
+              std::string expected = "test_string_" + std::to_string(i);
+              EXPECT_EQ(record.fieldAt(0).value<std::string>(), expected);
+            },
+    },
+    {
+        .name = "Binary",
+        .iceberg_type = binary(),
+        .arrow_json = R"([{"a": "abc"}, {"a": "bcd"}, {"a": "cde"}])",
+        .value_verifier =
+            [](const ::avro::GenericDatum& datum, int i) {
+              const auto& record = datum.value<::avro::GenericRecord>();
+              const auto& bytes = 
record.fieldAt(0).value<std::vector<uint8_t>>();
+              EXPECT_EQ(bytes.size(), 3);
+              EXPECT_EQ(bytes[0], static_cast<uint8_t>('a' + i));
+              EXPECT_EQ(bytes[1], static_cast<uint8_t>('b' + i));
+              EXPECT_EQ(bytes[2], static_cast<uint8_t>('c' + i));
+            },
+    },
+    {
+        .name = "Fixed",
+        .iceberg_type = fixed(4),
+        .arrow_json = R"([{"a": "abcd"}, {"a": "bcde"}, {"a": "cdef"}])",
+        .value_verifier =
+            [](const ::avro::GenericDatum& datum, int i) {
+              const auto& record = datum.value<::avro::GenericRecord>();
+              const auto& fixed = 
record.fieldAt(0).value<::avro::GenericFixed>();
+              EXPECT_EQ(fixed.value().size(), 4);
+              EXPECT_EQ(static_cast<char>(fixed.value()[0]), 
static_cast<char>('a' + i));
+              EXPECT_EQ(static_cast<char>(fixed.value()[1]), 
static_cast<char>('b' + i));
+              EXPECT_EQ(static_cast<char>(fixed.value()[2]), 
static_cast<char>('c' + i));
+              EXPECT_EQ(static_cast<char>(fixed.value()[3]), 
static_cast<char>('d' + i));
+            },
+    },
+    {
+        .name = "Decimal",
+        .iceberg_type = decimal(10, 2),
+        .arrow_json = R"([{"a": "0.00"}, {"a": "10.01"}, {"a": "20.02"}])",
+        .value_verifier =
+            [](const ::avro::GenericDatum& datum, int i) {
+              const auto& record = datum.value<::avro::GenericRecord>();
+              const auto& fixed = 
record.fieldAt(0).value<::avro::GenericFixed>();
+
+              const auto& bytes = fixed.value();
+              auto decimal =
+                  ::arrow::Decimal128::FromBigEndian(
+                      reinterpret_cast<const uint8_t*>(bytes.data()), 
bytes.size())
+                      .ValueOrDie();
+              int64_t expected_unscaled = i * 1000 + i;
+              EXPECT_EQ(decimal.low_bits(), 
static_cast<uint64_t>(expected_unscaled));
+              EXPECT_EQ(decimal.high_bits(), 0);
+            },
+    },
+    {
+        .name = "Date",
+        .iceberg_type = date(),
+        .arrow_json = R"([{"a": 18000}, {"a": 18001}, {"a": 18002}])",
+        .value_verifier =
+            [](const ::avro::GenericDatum& datum, int i) {
+              const auto& record = datum.value<::avro::GenericRecord>();
+              EXPECT_EQ(record.fieldAt(0).value<int32_t>(), 18000 + i);
+            },
+    },
+    {
+        .name = "Time",
+        .iceberg_type = time(),
+        .arrow_json = R"([{"a": 45045123456}, {"a": 45046123456}, {"a": 
45047123456}])",
+        .value_verifier =
+            [](const ::avro::GenericDatum& datum, int i) {
+              const auto& record = datum.value<::avro::GenericRecord>();
+              EXPECT_EQ(record.fieldAt(0).value<int64_t>(),
+                        45045123456LL + i * 1000000LL);
+            },
+    },
+    {
+        .name = "Timestamp",
+        .iceberg_type = timestamp(),
+        .arrow_json = R"([{"a": 0}, {"a": 1000000}, {"a": 2000000}])",
+        .value_verifier =
+            [](const ::avro::GenericDatum& datum, int i) {
+              const auto& record = datum.value<::avro::GenericRecord>();
+              EXPECT_EQ(record.fieldAt(0).value<int64_t>(), i * 1000000LL);
+            },
+    },
+    {
+        .name = "TimestampTz",
+        .iceberg_type = timestamp_tz(),
+        .arrow_json =
+            R"([{"a": 1672531200000000}, {"a": 1672531201000000}, {"a": 
1672531202000000}])",
+        .value_verifier =
+            [](const ::avro::GenericDatum& datum, int i) {
+              const auto& record = datum.value<::avro::GenericRecord>();
+              EXPECT_EQ(record.fieldAt(0).value<int64_t>(),
+                        1672531200000000LL + i * 1000000LL);
+            },
+    },
+};
+
+INSTANTIATE_TEST_SUITE_P(AllPrimitiveTypes, ExtractDatumFromArrayTest,
+                         ::testing::ValuesIn(kExtractDatumTestCases),
+                         [](const ::testing::TestParamInfo<ExtractDatumParam>& 
info) {
+                           return info.param.name;
+                         });
+
+TEST(ExtractDatumFromArrayTest, StructWithTwoFields) {
+  Schema iceberg_schema({
+      SchemaField::MakeRequired(1, "id", int32()),
+      SchemaField::MakeRequired(2, "name", string()),
+  });
+  ::avro::NodePtr avro_node;
+  ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk());
+
+  ArrowSchema arrow_c_schema;
+  ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk());
+  auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie();
+  auto arrow_struct_type = 
std::make_shared<::arrow::StructType>(arrow_schema->fields());
+
+  auto arrow_array = ::arrow::json::ArrayFromJSONString(arrow_struct_type,
+                                                        R"([
+                                                          {"id": 42, "name": 
"Alice"},
+                                                          {"id": 43, "name": 
"Bob"},
+                                                          {"id": 44, "name": 
"Charlie"}
+                                                        ])")
+                         .ValueOrDie();
+
+  struct ExpectedData {
+    int32_t id;
+    std::string name;
+  };
+  std::vector<ExpectedData> expected = {{.id = 42, .name = "Alice"},
+                                        {.id = 43, .name = "Bob"},
+                                        {.id = 44, .name = "Charlie"}};
+
+  auto verify_record = [&](int64_t index, const ExpectedData& expected_data) {
+    ::avro::GenericDatum extracted_datum(avro_node);
+    ASSERT_THAT(ExtractDatumFromArray(*arrow_array, index, &extracted_datum), 
IsOk());
+    const auto& record = extracted_datum.value<::avro::GenericRecord>();
+    EXPECT_EQ(record.fieldAt(0).value<int32_t>(), expected_data.id);
+    EXPECT_EQ(record.fieldAt(1).value<std::string>(), expected_data.name);
+  };
+
+  for (size_t i = 0; i < expected.size(); ++i) {
+    verify_record(i, expected[i]);
+  }
+}
+
+TEST(ExtractDatumFromArrayTest, NestedStruct) {
+  Schema iceberg_schema({
+      SchemaField::MakeRequired(1, "id", int32()),
+      SchemaField::MakeRequired(2, "person",
+                                
std::make_shared<StructType>(std::vector<SchemaField>{
+                                    SchemaField::MakeRequired(3, "name", 
string()),
+                                    SchemaField::MakeRequired(4, "age", 
int32()),
+                                })),
+  });
+
+  ::avro::NodePtr avro_node;
+  ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk());
+
+  ArrowSchema arrow_c_schema;
+  ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk());
+  auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie();
+  auto arrow_struct_type = 
std::make_shared<::arrow::StructType>(arrow_schema->fields());
+
+  const std::string arrow_json = R"([
+    {"id": 1, "person": {"name": "Alice", "age": 25}},
+    {"id": 2, "person": {"name": "Bob", "age": 30}},
+    {"id": 3, "person": {"name": "Charlie", "age": 35}}
+  ])";
+  auto arrow_array =
+      ::arrow::json::ArrayFromJSONString(arrow_struct_type, 
arrow_json).ValueOrDie();
+
+  struct ExpectedData {
+    int32_t id;
+    std::string name;
+    int32_t age;
+  };
+  std::vector<ExpectedData> expected = {{.id = 1, .name = "Alice", .age = 25},
+                                        {.id = 2, .name = "Bob", .age = 30},
+                                        {.id = 3, .name = "Charlie", .age = 
35}};
+
+  auto verify_record = [&](int64_t index, const ExpectedData& expected_data) {
+    ::avro::GenericDatum extracted_datum(avro_node);
+    ASSERT_THAT(ExtractDatumFromArray(*arrow_array, index, &extracted_datum), 
IsOk());
+    const auto& record = extracted_datum.value<::avro::GenericRecord>();
+    EXPECT_EQ(record.fieldAt(0).value<int32_t>(), expected_data.id);
+    const auto& person_record = 
record.fieldAt(1).value<::avro::GenericRecord>();
+    EXPECT_EQ(person_record.fieldAt(0).value<std::string>(), 
expected_data.name);
+    EXPECT_EQ(person_record.fieldAt(1).value<int32_t>(), expected_data.age);
+  };
+
+  for (size_t i = 0; i < expected.size(); ++i) {
+    verify_record(i, expected[i]);
+  }
+}
+
+TEST(ExtractDatumFromArrayTest, ListOfIntegers) {
+  Schema iceberg_schema({
+      SchemaField::MakeRequired(
+          1, "numbers",
+          std::make_shared<ListType>(SchemaField::MakeRequired(2, "element", 
int32()))),
+  });
+
+  ::avro::NodePtr avro_node;
+  ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk());
+
+  ArrowSchema arrow_c_schema;
+  ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk());
+  auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie();
+  auto arrow_struct_type = 
std::make_shared<::arrow::StructType>(arrow_schema->fields());
+
+  const std::string arrow_json = R"([
+    {"numbers": [10, 11, 12]},
+    {"numbers": [20, 21]},
+    {"numbers": [30, 31, 32, 33]}
+  ])";
+  auto arrow_array =
+      ::arrow::json::ArrayFromJSONString(arrow_struct_type, 
arrow_json).ValueOrDie();
+
+  std::vector<std::vector<int32_t>> expected = {{10, 11, 12}, {20, 21}, {30, 
31, 32, 33}};
+
+  auto verify_record = [&](int64_t index, const std::vector<int32_t>& 
expected_numbers) {
+    ::avro::GenericDatum extracted_datum(avro_node);
+    ASSERT_THAT(ExtractDatumFromArray(*arrow_array, index, &extracted_datum), 
IsOk());
+    const auto& record = extracted_datum.value<::avro::GenericRecord>();
+    const auto& array = record.fieldAt(0).value<::avro::GenericArray>();
+    const auto& elements = array.value();
+
+    ASSERT_EQ(elements.size(), expected_numbers.size());
+    for (size_t i = 0; i < expected_numbers.size(); ++i) {
+      EXPECT_EQ(elements[i].value<int32_t>(), expected_numbers[i]);
+    }
+  };
+
+  for (size_t i = 0; i < expected.size(); ++i) {
+    verify_record(i, expected[i]);
+  }
+}
+
+TEST(ExtractDatumFromArrayTest, MapStringToInt) {
+  Schema iceberg_schema({
+      SchemaField::MakeRequired(
+          1, "scores",
+          std::make_shared<MapType>(SchemaField::MakeRequired(2, "key", 
string()),
+                                    SchemaField::MakeRequired(3, "value", 
int32()))),
+  });
+
+  ::avro::NodePtr avro_node;
+  ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk());
+
+  ArrowSchema arrow_c_schema;
+  ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk());
+  auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie();
+  auto arrow_struct_type = 
std::make_shared<::arrow::StructType>(arrow_schema->fields());
+
+  const std::string arrow_json = R"([
+    {"scores": [["alice", 95], ["bob", 87]]},
+    {"scores": [["charlie", 92], ["diana", 98], ["eve", 89]]},
+    {"scores": [["frank", 91]]}
+  ])";
+  auto arrow_array =
+      ::arrow::json::ArrayFromJSONString(arrow_struct_type, 
arrow_json).ValueOrDie();
+
+  using MapEntry = std::pair<std::string, int32_t>;
+  std::vector<std::vector<MapEntry>> expected = {
+      {{"alice", 95}, {"bob", 87}},
+      {{"charlie", 92}, {"diana", 98}, {"eve", 89}},
+      {{"frank", 91}}};
+
+  auto verify_record = [&](int64_t index, const std::vector<MapEntry>& 
expected_entries) {
+    ::avro::GenericDatum extracted_datum(avro_node);
+    ASSERT_THAT(ExtractDatumFromArray(*arrow_array, index, &extracted_datum), 
IsOk());
+    const auto& record = extracted_datum.value<::avro::GenericRecord>();
+    const auto& map = record.fieldAt(0).value<::avro::GenericMap>();
+    const auto& entries = map.value();
+
+    ASSERT_EQ(entries.size(), expected_entries.size());
+    for (size_t i = 0; i < expected_entries.size(); ++i) {
+      EXPECT_EQ(entries[i].first, expected_entries[i].first);
+      EXPECT_EQ(entries[i].second.value<int32_t>(), 
expected_entries[i].second);
+    }
+  };
+
+  for (size_t i = 0; i < expected.size(); ++i) {
+    verify_record(i, expected[i]);
+  }
+}
+
+TEST(ExtractDatumFromArrayTest, ErrorHandling) {
+  Schema iceberg_schema({SchemaField::MakeRequired(1, "a", int32())});
+  ::avro::NodePtr avro_node;
+  ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk());
+
+  ArrowSchema arrow_c_schema;
+  ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk());
+  auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie();
+  auto arrow_struct_type = 
std::make_shared<::arrow::StructType>(arrow_schema->fields());
+
+  auto arrow_array = ::arrow::json::ArrayFromJSONString(
+                         arrow_struct_type, R"([{"a": 1}, {"a": 2}, {"a": 
3}])")
+                         .ValueOrDie();
+
+  ::avro::GenericDatum datum(avro_node);
+
+  // Test negative index
+  EXPECT_THAT(ExtractDatumFromArray(*arrow_array, -1, &datum),
+              HasErrorMessage("Cannot extract datum from array at index -1"));
+
+  // Test index beyond array length
+  EXPECT_THAT(ExtractDatumFromArray(*arrow_array, 3, &datum),
+              HasErrorMessage("Cannot extract datum from array at index 3"));
+}
+
+TEST(ExtractDatumFromArrayTest, NullHandling) {
+  Schema iceberg_schema({SchemaField::MakeOptional(1, "a", int32())});
+  ::avro::NodePtr avro_node;
+  ASSERT_THAT(ToAvroNodeVisitor{}.Visit(iceberg_schema, &avro_node), IsOk());
+
+  ArrowSchema arrow_c_schema;
+  ASSERT_THAT(ToArrowSchema(iceberg_schema, &arrow_c_schema), IsOk());
+  auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie();
+  auto arrow_struct_type = 
std::make_shared<::arrow::StructType>(arrow_schema->fields());
+
+  auto arrow_array =
+      ::arrow::json::ArrayFromJSONString(arrow_struct_type, R"([{"a": 42}, 
{"a": null}])")
+          .ValueOrDie();
+
+  ::avro::GenericDatum datum(avro_node);
+  ASSERT_THAT(ExtractDatumFromArray(*arrow_array, 0, &datum), IsOk());
+
+  const auto& record = datum.value<::avro::GenericRecord>();
+  EXPECT_EQ(record.fieldAt(0).unionBranch(), 1);
+  EXPECT_EQ(record.fieldAt(0).type(), ::avro::AVRO_INT);
+  EXPECT_EQ(record.fieldAt(0).value<int32_t>(), 42);
+
+  ASSERT_THAT(ExtractDatumFromArray(*arrow_array, 1, &datum), IsOk());
+  const auto& record2 = datum.value<::avro::GenericRecord>();
+  EXPECT_EQ(record2.fieldAt(0).unionBranch(), 0);
+  EXPECT_EQ(record2.fieldAt(0).type(), ::avro::AVRO_NULL);
+}
+
+struct RoundTripParam {
+  std::string name;
+  Schema iceberg_schema;
+  std::string arrow_json;
+};
+
+void VerifyRoundTripConversion(const RoundTripParam& test_case) {
+  ::avro::NodePtr avro_node;
+  ASSERT_THAT(ToAvroNodeVisitor{}.Visit(test_case.iceberg_schema, &avro_node), 
IsOk());
+
+  ArrowSchema arrow_c_schema;
+  ASSERT_THAT(ToArrowSchema(test_case.iceberg_schema, &arrow_c_schema), 
IsOk());
+  auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie();
+  auto arrow_struct_type = 
std::make_shared<::arrow::StructType>(arrow_schema->fields());
+
+  auto original_array =
+      ::arrow::json::ArrayFromJSONString(arrow_struct_type, 
test_case.arrow_json)
+          .ValueOrDie();
+
+  std::vector<::avro::GenericDatum> extracted_data;
+  for (int64_t i = 0; i < original_array->length(); ++i) {
+    ::avro::GenericDatum datum(avro_node);
+    ASSERT_THAT(ExtractDatumFromArray(*original_array, i, &datum), IsOk())
+        << "Failed to extract datum at index " << i;
+    extracted_data.push_back(datum);
+  }
+
+  auto projection_result =
+      Project(test_case.iceberg_schema, avro_node, /*prune_source=*/false);
+  ASSERT_THAT(projection_result, IsOk());
+  auto projection = std::move(projection_result.value());
+
+  auto builder = ::arrow::MakeBuilder(arrow_struct_type).ValueOrDie();
+  for (const auto& datum : extracted_data) {
+    ASSERT_THAT(AppendDatumToBuilder(avro_node, datum, projection,
+                                     test_case.iceberg_schema, builder.get()),
+                IsOk());
+  }
+
+  auto rebuilt_array = builder->Finish().ValueOrDie();
+
+  ASSERT_TRUE(original_array->Equals(*rebuilt_array))
+      << "Round-trip consistency failed!\n"
+      << "Original array: " << original_array->ToString() << "\n"
+      << "Rebuilt array:  " << rebuilt_array->ToString();
+}
+
+class AvroRoundTripConversionTest : public 
::testing::TestWithParam<RoundTripParam> {};
+
+TEST_P(AvroRoundTripConversionTest, ConvertTypes) {
+  ASSERT_NO_FATAL_FAILURE(VerifyRoundTripConversion(GetParam()));
+}
+
+const std::vector<RoundTripParam> kRoundTripTestCases = {
+    {
+        .name = "SimpleStruct",
+        .iceberg_schema = Schema({
+            SchemaField::MakeRequired(1, "id", int32()),
+            SchemaField::MakeRequired(2, "name", string()),
+            SchemaField::MakeOptional(3, "age", int32()),
+        }),
+        .arrow_json = R"([
+          {"id": 100, "name": "Alice", "age": 25},
+          {"id": 101, "name": "Bob", "age": null},
+          {"id": 102, "name": "Charlie", "age": 35}
+        ])",
+    },
+    {
+        .name = "PrimitiveTypes",
+        .iceberg_schema = Schema({
+            SchemaField::MakeRequired(1, "bool_field", boolean()),
+            SchemaField::MakeRequired(2, "int_field", int32()),
+            SchemaField::MakeRequired(3, "long_field", int64()),
+            SchemaField::MakeRequired(4, "float_field", float32()),
+            SchemaField::MakeRequired(5, "double_field", float64()),
+            SchemaField::MakeRequired(6, "string_field", string()),
+        }),
+        .arrow_json = R"([
+          {"bool_field": true, "int_field": 42, "long_field": 1000000, 
"float_field": 3.14, "double_field": 2.718281828, "string_field": "hello"},
+          {"bool_field": false, "int_field": -42, "long_field": -1000000, 
"float_field": -3.14, "double_field": -2.718281828, "string_field": "world"}
+        ])",
+    },
+    {
+        .name = "NestedStruct",
+        .iceberg_schema = Schema({
+            SchemaField::MakeRequired(1, "id", int32()),
+            SchemaField::MakeRequired(
+                2, "person",
+                std::make_shared<StructType>(std::vector<SchemaField>{
+                    SchemaField::MakeRequired(3, "name", string()),
+                    SchemaField::MakeRequired(4, "age", int32()),
+                })),
+        }),
+        .arrow_json = R"([
+          {"id": 1, "person": {"name": "Alice", "age": 30}},
+          {"id": 2, "person": {"name": "Bob", "age": 25}}
+        ])",
+    },
+    {
+        .name = "ListOfIntegers",
+        .iceberg_schema = Schema({
+            SchemaField::MakeRequired(
+                1, "numbers",
+                std::make_shared<ListType>(
+                    SchemaField::MakeRequired(2, "element", int32()))),
+        }),
+        .arrow_json = R"([
+          {"numbers": [1, 2, 3]},
+          {"numbers": [10, 20]},
+          {"numbers": []}
+        ])",
+    },
+    {
+        .name = "MapStringToInt",
+        .iceberg_schema = Schema({
+            SchemaField::MakeRequired(
+                1, "scores",
+                std::make_shared<MapType>(
+                    SchemaField::MakeRequired(2, "key", string()),
+                    SchemaField::MakeRequired(3, "value", int32()))),
+        }),
+        .arrow_json = R"([
+          {"scores": [["alice", 95], ["bob", 87]]},
+          {"scores": [["charlie", 92]]},
+          {"scores": []}
+        ])",
+    },
+    {
+        .name = "ComplexNested",
+        .iceberg_schema = Schema({
+            SchemaField::MakeRequired(
+                1, "data",
+                std::make_shared<StructType>(std::vector<SchemaField>{
+                    SchemaField::MakeRequired(2, "id", int32()),
+                    SchemaField::MakeRequired(
+                        3, "tags",
+                        std::make_shared<ListType>(
+                            SchemaField::MakeRequired(4, "element", 
string()))),
+                    SchemaField::MakeOptional(
+                        5, "metadata",
+                        std::make_shared<MapType>(
+                            SchemaField::MakeRequired(6, "key", string()),
+                            SchemaField::MakeRequired(7, "value", string()))),
+                })),
+        }),
+        .arrow_json = R"([
+           {"data": {"id": 1, "tags": ["tag1", "tag2"], "metadata": [["key1", 
"value1"]]}},
+           {"data": {"id": 2, "tags": [], "metadata": null}}
+         ])",
+    },
+    {
+        .name = "NullablePrimitives",
+        .iceberg_schema = Schema({
+            SchemaField::MakeOptional(1, "optional_bool", boolean()),
+            SchemaField::MakeOptional(2, "optional_int", int32()),
+            SchemaField::MakeOptional(3, "optional_long", int64()),
+            SchemaField::MakeOptional(4, "optional_string", string()),
+            SchemaField::MakeRequired(5, "required_id", int32()),
+        }),
+        .arrow_json = R"([
+           {"optional_bool": true, "optional_int": 42, "optional_long": 
1000000, "optional_string": "hello", "required_id": 1},
+           {"optional_bool": null, "optional_int": null, "optional_long": 
null, "optional_string": null, "required_id": 2},
+           {"optional_bool": false, "optional_int": null, "optional_long": 
2000000, "optional_string": null, "required_id": 3},
+           {"optional_bool": null, "optional_int": 123, "optional_long": null, 
"optional_string": "world", "required_id": 4}
+         ])",
+    },
+    {
+        .name = "NullableNestedStruct",
+        .iceberg_schema = Schema({
+            SchemaField::MakeRequired(1, "id", int32()),
+            SchemaField::MakeOptional(
+                2, "person",
+                std::make_shared<StructType>(std::vector<SchemaField>{
+                    SchemaField::MakeRequired(3, "name", string()),
+                    SchemaField::MakeOptional(4, "age", int32()),
+                    SchemaField::MakeOptional(5, "email", string()),
+                })),
+            SchemaField::MakeOptional(6, "department", string()),
+        }),
+        .arrow_json = R"([
+           {"id": 1, "person": {"name": "Alice", "age": 30, "email": 
"al...@example.com"}, "department": "Engineering"},
+           {"id": 2, "person": null, "department": null},
+           {"id": 3, "person": {"name": "Bob", "age": null, "email": null}, 
"department": "Sales"},
+           {"id": 4, "person": {"name": "Charlie", "age": 25, "email": null}, 
"department": null}
+         ])",
+    },
+    {
+        .name = "NullableListElements",
+        .iceberg_schema = Schema({
+            SchemaField::MakeRequired(1, "id", int32()),
+            SchemaField::MakeOptional(
+                2, "numbers",
+                std::make_shared<ListType>(
+                    SchemaField::MakeOptional(3, "element", int32()))),
+            SchemaField::MakeRequired(
+                4, "tags",
+                std::make_shared<ListType>(
+                    SchemaField::MakeOptional(5, "element", string()))),
+        }),
+        .arrow_json = R"([
+           {"id": 1, "numbers": [1, null, 3], "tags": ["tag1", null, "tag3"]},
+           {"id": 2, "numbers": null, "tags": ["only_tag"]},
+           {"id": 3, "numbers": [null, null], "tags": [null, null, null]},
+           {"id": 4, "numbers": [], "tags": []}
+         ])",
+    },
+    {
+        .name = "NullableMapValues",
+        .iceberg_schema = Schema({
+            SchemaField::MakeRequired(1, "id", int32()),
+            SchemaField::MakeOptional(
+                2, "scores",
+                std::make_shared<MapType>(
+                    SchemaField::MakeRequired(3, "key", string()),
+                    SchemaField::MakeOptional(4, "value", int32()))),
+            SchemaField::MakeRequired(
+                5, "metadata",
+                std::make_shared<MapType>(
+                    SchemaField::MakeRequired(6, "key", string()),
+                    SchemaField::MakeOptional(7, "value", string()))),
+        }),
+        .arrow_json = R"([
+           {"id": 1, "scores": [["alice", 95], ["bob", null]], "metadata": 
[["key1", "value1"], ["key2", null]]},
+           {"id": 2, "scores": null, "metadata": [["key3", null]]},
+           {"id": 3, "scores": [["charlie", null], ["diana", 98]], "metadata": 
[]},
+           {"id": 4, "scores": [], "metadata": [["key4", null], ["key5", 
"value5"]]}
+         ])",
+    },
+    {
+        .name = "DeeplyNestedWithNulls",
+        .iceberg_schema = Schema({
+            SchemaField::MakeRequired(
+                1, "root",
+                std::make_shared<StructType>(std::vector<SchemaField>{
+                    SchemaField::MakeRequired(2, "id", int32()),
+                    SchemaField::MakeOptional(
+                        3, "nested",
+                        std::make_shared<StructType>(std::vector<SchemaField>{
+                            SchemaField::MakeOptional(4, "name", string()),
+                            SchemaField::MakeOptional(
+                                5, "values",
+                                std::make_shared<ListType>(
+                                    SchemaField::MakeOptional(6, "element", 
int32()))),
+                        })),
+                    SchemaField::MakeOptional(
+                        7, "tags",
+                        std::make_shared<ListType>(
+                            SchemaField::MakeOptional(8, "element", 
string()))),
+                })),
+        }),
+        .arrow_json = R"([
+           {"root": {"id": 1, "nested": {"name": "test", "values": [1, null, 
3]}, "tags": ["a", "b"]}},
+           {"root": {"id": 2, "nested": null, "tags": null}},
+           {"root": {"id": 3, "nested": {"name": null, "values": null}, 
"tags": [null, "c"]}},
+           {"root": {"id": 4, "nested": {"name": "empty", "values": []}, 
"tags": []}}
+         ])",
+    },
+    {
+        .name = "AllNullsVariations",
+        .iceberg_schema = Schema({
+            SchemaField::MakeOptional(1, "always_null", string()),
+            SchemaField::MakeOptional(2, "sometimes_null", int32()),
+            SchemaField::MakeOptional(
+                3, "nested_struct",
+                std::make_shared<StructType>(std::vector<SchemaField>{
+                    SchemaField::MakeOptional(4, "inner_null", string()),
+                    SchemaField::MakeRequired(5, "inner_required", boolean()),
+                })),
+            SchemaField::MakeRequired(6, "id", int32()),
+        }),
+        .arrow_json = R"([
+           {"always_null": null, "sometimes_null": 42, "nested_struct": 
{"inner_null": "value", "inner_required": true}, "id": 1},
+           {"always_null": null, "sometimes_null": null, "nested_struct": 
null, "id": 2},
+           {"always_null": null, "sometimes_null": 123, "nested_struct": 
{"inner_null": null, "inner_required": false}, "id": 3},
+           {"always_null": null, "sometimes_null": null, "nested_struct": 
{"inner_null": null, "inner_required": true}, "id": 4}
+         ])",
+    },
+};
+
+INSTANTIATE_TEST_SUITE_P(AllTypes, AvroRoundTripConversionTest,
+                         ::testing::ValuesIn(kRoundTripTestCases),
+                         [](const ::testing::TestParamInfo<RoundTripParam>& 
info) {
+                           return info.param.name;
+                         });
+
 }  // namespace iceberg::avro


Reply via email to