This is an automated email from the ASF dual-hosted git repository.

raulcd pushed a commit to branch maint-15.0.x
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit 91be098b56021b1f9569986b038bd46c3ed53701
Author: Antoine Pitrou <[email protected]>
AuthorDate: Mon Feb 5 17:15:44 2024 +0100

    GH-39865: [C++] Strip extension metadata when importing a registered 
extension (#39866)
    
    ### Rationale for this change
    
    When importing an extension type from the C Data Interface and the 
extension type is registered, we would still leave the extension-related 
metadata on the storage type.
    
    ### What changes are included in this PR?
    
    Strip extension-related metadata on the storage type if we succeed in 
recreating the extension type.
    This matches the behavior of the IPC layer and allows for more exact 
roundtripping.
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    No, unless people mistakingly rely on the presence of said metadata.
    * Closes: #39865
    
    Authored-by: Antoine Pitrou <[email protected]>
    Signed-off-by: Antoine Pitrou <[email protected]>
---
 cpp/src/arrow/c/bridge.cc                |  6 ++++
 cpp/src/arrow/c/bridge_test.cc           | 48 +++++++++++++++++++++-----------
 cpp/src/arrow/util/key_value_metadata.cc | 18 ++++++------
 cpp/src/arrow/util/key_value_metadata.h  | 11 ++++----
 4 files changed, 52 insertions(+), 31 deletions(-)

diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc
index 238afb0328..4751f65632 100644
--- a/cpp/src/arrow/c/bridge.cc
+++ b/cpp/src/arrow/c/bridge.cc
@@ -914,6 +914,8 @@ struct DecodedMetadata {
   std::shared_ptr<KeyValueMetadata> metadata;
   std::string extension_name;
   std::string extension_serialized;
+  int extension_name_index = -1;        // index of extension_name in metadata
+  int extension_serialized_index = -1;  // index of extension_serialized in 
metadata
 };
 
 Result<DecodedMetadata> DecodeMetadata(const char* metadata) {
@@ -956,8 +958,10 @@ Result<DecodedMetadata> DecodeMetadata(const char* 
metadata) {
     RETURN_NOT_OK(read_string(&values[i]));
     if (keys[i] == kExtensionTypeKeyName) {
       decoded.extension_name = values[i];
+      decoded.extension_name_index = i;
     } else if (keys[i] == kExtensionMetadataKeyName) {
       decoded.extension_serialized = values[i];
+      decoded.extension_serialized_index = i;
     }
   }
   decoded.metadata = key_value_metadata(std::move(keys), std::move(values));
@@ -1046,6 +1050,8 @@ struct SchemaImporter {
         ARROW_ASSIGN_OR_RAISE(
             type_, registered_ext_type->Deserialize(std::move(type_),
                                                     
metadata_.extension_serialized));
+        RETURN_NOT_OK(metadata_.metadata->DeleteMany(
+            {metadata_.extension_name_index, 
metadata_.extension_serialized_index}));
       }
     }
 
diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc
index 58bbc9282c..5dcb38185f 100644
--- a/cpp/src/arrow/c/bridge_test.cc
+++ b/cpp/src/arrow/c/bridge_test.cc
@@ -1870,7 +1870,7 @@ class TestSchemaImport : public ::testing::Test, public 
SchemaStructBuilder {
     ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
     Reset();            // for further tests
     cb.AssertCalled();  // was released
-    AssertTypeEqual(*expected, *type);
+    AssertTypeEqual(*expected, *type, /*check_metadata=*/true);
   }
 
   void CheckImport(const std::shared_ptr<Field>& expected) {
@@ -1890,7 +1890,7 @@ class TestSchemaImport : public ::testing::Test, public 
SchemaStructBuilder {
     ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
     Reset();            // for further tests
     cb.AssertCalled();  // was released
-    AssertSchemaEqual(*expected, *schema);
+    AssertSchemaEqual(*expected, *schema, /*check_metadata=*/true);
   }
 
   void CheckImportError() {
@@ -3569,7 +3569,7 @@ class TestSchemaRoundtrip : public ::testing::Test {
     // Recreate the type
     ASSERT_OK_AND_ASSIGN(actual, ImportType(&c_schema));
     type = factory_expected();
-    AssertTypeEqual(*type, *actual);
+    AssertTypeEqual(*type, *actual, /*check_metadata=*/true);
     type.reset();
     actual.reset();
 
@@ -3600,7 +3600,7 @@ class TestSchemaRoundtrip : public ::testing::Test {
     // Recreate the schema
     ASSERT_OK_AND_ASSIGN(actual, ImportSchema(&c_schema));
     schema = factory();
-    AssertSchemaEqual(*schema, *actual);
+    AssertSchemaEqual(*schema, *actual, /*check_metadata=*/true);
     schema.reset();
     actual.reset();
 
@@ -3693,13 +3693,27 @@ TEST_F(TestSchemaRoundtrip, Dictionary) {
   }
 }
 
+// Given an extension type, return a field of its storage type + the
+// serialized extension metadata.
+std::shared_ptr<Field> GetStorageWithMetadata(const std::string& field_name,
+                                              const std::shared_ptr<DataType>& 
type) {
+  const auto& ext_type = checked_cast<const ExtensionType&>(*type);
+  auto storage_type = ext_type.storage_type();
+  auto md = KeyValueMetadata::Make({kExtensionTypeKeyName, 
kExtensionMetadataKeyName},
+                                   {ext_type.extension_name(), 
ext_type.Serialize()});
+  return field(field_name, storage_type, /*nullable=*/true, md);
+}
+
 TEST_F(TestSchemaRoundtrip, UnregisteredExtension) {
   TestWithTypeFactory(uuid, []() { return fixed_size_binary(16); });
   TestWithTypeFactory(dict_extension_type, []() { return dictionary(int8(), 
utf8()); });
 
-  // Inside nested type
-  TestWithTypeFactory([]() { return list(dict_extension_type()); },
-                      []() { return list(dictionary(int8(), utf8())); });
+  // Inside nested type.
+  // When an extension type is not known by the importer, it is imported
+  // as its storage type and the extension metadata is preserved on the field.
+  TestWithTypeFactory(
+      []() { return list(dict_extension_type()); },
+      []() { return list(GetStorageWithMetadata("item", 
dict_extension_type())); });
 }
 
 TEST_F(TestSchemaRoundtrip, RegisteredExtension) {
@@ -3708,7 +3722,9 @@ TEST_F(TestSchemaRoundtrip, RegisteredExtension) {
   TestWithTypeFactory(dict_extension_type);
   TestWithTypeFactory(complex128);
 
-  // Inside nested type
+  // Inside nested type.
+  // When the extension type is registered, the extension metadata is removed
+  // from the storage type's field to ensure roundtripping (GH-39865).
   TestWithTypeFactory([]() { return list(uuid()); });
   TestWithTypeFactory([]() { return list(dict_extension_type()); });
   TestWithTypeFactory([]() { return list(complex128()); });
@@ -3808,7 +3824,7 @@ class TestArrayRoundtrip : public ::testing::Test {
     {
       std::shared_ptr<Array> expected;
       ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected()));
-      AssertTypeEqual(*expected->type(), *array->type());
+      AssertTypeEqual(*expected->type(), *array->type(), 
/*check_metadata=*/true);
       AssertArraysEqual(*expected, *array, true);
     }
     array.reset();
@@ -3848,7 +3864,7 @@ class TestArrayRoundtrip : public ::testing::Test {
     {
       std::shared_ptr<RecordBatch> expected;
       ASSERT_OK_AND_ASSIGN(expected, ToResult(factory()));
-      AssertSchemaEqual(*expected->schema(), *batch->schema());
+      AssertSchemaEqual(*expected->schema(), *batch->schema(), 
/*check_metadata=*/true);
       AssertBatchesEqual(*expected, *batch);
     }
     batch.reset();
@@ -4228,7 +4244,7 @@ class TestDeviceArrayRoundtrip : public ::testing::Test {
     {
       std::shared_ptr<Array> expected;
       ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected()));
-      AssertTypeEqual(*expected->type(), *array->type());
+      AssertTypeEqual(*expected->type(), *array->type(), 
/*check_metadata=*/true);
       AssertArraysEqual(*expected, *array, true);
     }
     array.reset();
@@ -4274,7 +4290,7 @@ class TestDeviceArrayRoundtrip : public ::testing::Test {
     {
       std::shared_ptr<RecordBatch> expected;
       ASSERT_OK_AND_ASSIGN(expected, ToResult(factory()));
-      AssertSchemaEqual(*expected->schema(), *batch->schema());
+      AssertSchemaEqual(*expected->schema(), *batch->schema(), 
/*check_metadata=*/true);
       AssertBatchesEqual(*expected, *batch);
     }
     batch.reset();
@@ -4351,7 +4367,7 @@ class TestArrayStreamExport : public BaseArrayStreamTest {
     SchemaExportGuard schema_guard(&c_schema);
     ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
     ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
-    AssertSchemaEqual(expected, *schema);
+    AssertSchemaEqual(expected, *schema, /*check_metadata=*/true);
   }
 
   void AssertStreamEnd(struct ArrowArrayStream* c_stream) {
@@ -4435,7 +4451,7 @@ TEST_F(TestArrayStreamExport, ArrayLifetime) {
   {
     SchemaExportGuard schema_guard(&c_schema);
     ASSERT_OK_AND_ASSIGN(auto got_schema, ImportSchema(&c_schema));
-    AssertSchemaEqual(*schema, *got_schema);
+    AssertSchemaEqual(*schema, *got_schema, /*check_metadata=*/true);
   }
 
   ASSERT_GT(pool_->bytes_allocated(), orig_allocated_);
@@ -4460,7 +4476,7 @@ TEST_F(TestArrayStreamExport, Errors) {
   {
     SchemaExportGuard schema_guard(&c_schema);
     ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
-    AssertSchemaEqual(schema, arrow::schema({}));
+    AssertSchemaEqual(schema, arrow::schema({}), /*check_metadata=*/true);
   }
 
   struct ArrowArray c_array;
@@ -4537,7 +4553,7 @@ TEST_F(TestArrayStreamRoundtrip, Simple) {
   ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(batches, 
orig_schema));
 
   Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>& 
reader) {
-    AssertSchemaEqual(*orig_schema, *reader->schema());
+    AssertSchemaEqual(*orig_schema, *reader->schema(), 
/*check_metadata=*/true);
     AssertReaderNext(reader, *batches[0]);
     AssertReaderNext(reader, *batches[1]);
     AssertReaderEnd(reader);
diff --git a/cpp/src/arrow/util/key_value_metadata.cc 
b/cpp/src/arrow/util/key_value_metadata.cc
index bc48ae76c2..002e8b0975 100644
--- a/cpp/src/arrow/util/key_value_metadata.cc
+++ b/cpp/src/arrow/util/key_value_metadata.cc
@@ -90,7 +90,7 @@ void KeyValueMetadata::Append(std::string key, std::string 
value) {
   values_.push_back(std::move(value));
 }
 
-Result<std::string> KeyValueMetadata::Get(const std::string& key) const {
+Result<std::string> KeyValueMetadata::Get(std::string_view key) const {
   auto index = FindKey(key);
   if (index < 0) {
     return Status::KeyError(key);
@@ -129,7 +129,7 @@ Status KeyValueMetadata::DeleteMany(std::vector<int64_t> 
indices) {
   return Status::OK();
 }
 
-Status KeyValueMetadata::Delete(const std::string& key) {
+Status KeyValueMetadata::Delete(std::string_view key) {
   auto index = FindKey(key);
   if (index < 0) {
     return Status::KeyError(key);
@@ -138,20 +138,18 @@ Status KeyValueMetadata::Delete(const std::string& key) {
   }
 }
 
-Status KeyValueMetadata::Set(const std::string& key, const std::string& value) 
{
+Status KeyValueMetadata::Set(std::string key, std::string value) {
   auto index = FindKey(key);
   if (index < 0) {
-    Append(key, value);
+    Append(std::move(key), std::move(value));
   } else {
-    keys_[index] = key;
-    values_[index] = value;
+    keys_[index] = std::move(key);
+    values_[index] = std::move(value);
   }
   return Status::OK();
 }
 
-bool KeyValueMetadata::Contains(const std::string& key) const {
-  return FindKey(key) >= 0;
-}
+bool KeyValueMetadata::Contains(std::string_view key) const { return 
FindKey(key) >= 0; }
 
 void KeyValueMetadata::reserve(int64_t n) {
   DCHECK_GE(n, 0);
@@ -188,7 +186,7 @@ std::vector<std::pair<std::string, std::string>> 
KeyValueMetadata::sorted_pairs(
   return pairs;
 }
 
-int KeyValueMetadata::FindKey(const std::string& key) const {
+int KeyValueMetadata::FindKey(std::string_view key) const {
   for (size_t i = 0; i < keys_.size(); ++i) {
     if (keys_[i] == key) {
       return static_cast<int>(i);
diff --git a/cpp/src/arrow/util/key_value_metadata.h 
b/cpp/src/arrow/util/key_value_metadata.h
index 8702ce73a6..57ade11e75 100644
--- a/cpp/src/arrow/util/key_value_metadata.h
+++ b/cpp/src/arrow/util/key_value_metadata.h
@@ -20,6 +20,7 @@
 #include <cstdint>
 #include <memory>
 #include <string>
+#include <string_view>
 #include <unordered_map>
 #include <utility>
 #include <vector>
@@ -44,13 +45,13 @@ class ARROW_EXPORT KeyValueMetadata {
   void ToUnorderedMap(std::unordered_map<std::string, std::string>* out) const;
   void Append(std::string key, std::string value);
 
-  Result<std::string> Get(const std::string& key) const;
-  bool Contains(const std::string& key) const;
+  Result<std::string> Get(std::string_view key) const;
+  bool Contains(std::string_view key) const;
   // Note that deleting may invalidate known indices
-  Status Delete(const std::string& key);
+  Status Delete(std::string_view key);
   Status Delete(int64_t index);
   Status DeleteMany(std::vector<int64_t> indices);
-  Status Set(const std::string& key, const std::string& value);
+  Status Set(std::string key, std::string value);
 
   void reserve(int64_t n);
 
@@ -63,7 +64,7 @@ class ARROW_EXPORT KeyValueMetadata {
   std::vector<std::pair<std::string, std::string>> sorted_pairs() const;
 
   /// \brief Perform linear search for key, returning -1 if not found
-  int FindKey(const std::string& key) const;
+  int FindKey(std::string_view key) const;
 
   std::shared_ptr<KeyValueMetadata> Copy() const;
 

Reply via email to