This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 56951fee35 GH-39865: [C++] Strip extension metadata when importing a
registered extension (#39866)
56951fee35 is described below
commit 56951fee35c920ac898c2515896ff3bd752dde97
Author: Antoine Pitrou <[email protected]>
AuthorDate: Mon Feb 5 17:15:44 2024 +0100
GH-39865: [C++] Strip extension metadata when importing a registered
extension (#39866)
### Rationale for this change
When importing an extension type from the C Data Interface and the
extension type is registered, we would still leave the extension-related
metadata on the storage type.
### What changes are included in this PR?
Strip extension-related metadata on the storage type if we succeed in
recreating the extension type.
This matches the behavior of the IPC layer and allows for more exact
roundtripping.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
No, unless people mistakingly rely on the presence of said metadata.
* Closes: #39865
Authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
cpp/src/arrow/c/bridge.cc | 6 ++++
cpp/src/arrow/c/bridge_test.cc | 48 +++++++++++++++++++++-----------
cpp/src/arrow/util/key_value_metadata.cc | 18 ++++++------
cpp/src/arrow/util/key_value_metadata.h | 11 ++++----
4 files changed, 52 insertions(+), 31 deletions(-)
diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc
index 172ed8962c..9b165a10a6 100644
--- a/cpp/src/arrow/c/bridge.cc
+++ b/cpp/src/arrow/c/bridge.cc
@@ -914,6 +914,8 @@ struct DecodedMetadata {
std::shared_ptr<KeyValueMetadata> metadata;
std::string extension_name;
std::string extension_serialized;
+ int extension_name_index = -1; // index of extension_name in metadata
+ int extension_serialized_index = -1; // index of extension_serialized in
metadata
};
Result<DecodedMetadata> DecodeMetadata(const char* metadata) {
@@ -956,8 +958,10 @@ Result<DecodedMetadata> DecodeMetadata(const char*
metadata) {
RETURN_NOT_OK(read_string(&values[i]));
if (keys[i] == kExtensionTypeKeyName) {
decoded.extension_name = values[i];
+ decoded.extension_name_index = i;
} else if (keys[i] == kExtensionMetadataKeyName) {
decoded.extension_serialized = values[i];
+ decoded.extension_serialized_index = i;
}
}
decoded.metadata = key_value_metadata(std::move(keys), std::move(values));
@@ -1046,6 +1050,8 @@ struct SchemaImporter {
ARROW_ASSIGN_OR_RAISE(
type_, registered_ext_type->Deserialize(std::move(type_),
metadata_.extension_serialized));
+ RETURN_NOT_OK(metadata_.metadata->DeleteMany(
+ {metadata_.extension_name_index,
metadata_.extension_serialized_index}));
}
}
diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc
index 321ec36c38..8b67027454 100644
--- a/cpp/src/arrow/c/bridge_test.cc
+++ b/cpp/src/arrow/c/bridge_test.cc
@@ -1872,7 +1872,7 @@ class TestSchemaImport : public ::testing::Test, public
SchemaStructBuilder {
ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
Reset(); // for further tests
cb.AssertCalled(); // was released
- AssertTypeEqual(*expected, *type);
+ AssertTypeEqual(*expected, *type, /*check_metadata=*/true);
}
void CheckImport(const std::shared_ptr<Field>& expected) {
@@ -1892,7 +1892,7 @@ class TestSchemaImport : public ::testing::Test, public
SchemaStructBuilder {
ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
Reset(); // for further tests
cb.AssertCalled(); // was released
- AssertSchemaEqual(*expected, *schema);
+ AssertSchemaEqual(*expected, *schema, /*check_metadata=*/true);
}
void CheckImportError() {
@@ -3571,7 +3571,7 @@ class TestSchemaRoundtrip : public ::testing::Test {
// Recreate the type
ASSERT_OK_AND_ASSIGN(actual, ImportType(&c_schema));
type = factory_expected();
- AssertTypeEqual(*type, *actual);
+ AssertTypeEqual(*type, *actual, /*check_metadata=*/true);
type.reset();
actual.reset();
@@ -3602,7 +3602,7 @@ class TestSchemaRoundtrip : public ::testing::Test {
// Recreate the schema
ASSERT_OK_AND_ASSIGN(actual, ImportSchema(&c_schema));
schema = factory();
- AssertSchemaEqual(*schema, *actual);
+ AssertSchemaEqual(*schema, *actual, /*check_metadata=*/true);
schema.reset();
actual.reset();
@@ -3695,13 +3695,27 @@ TEST_F(TestSchemaRoundtrip, Dictionary) {
}
}
+// Given an extension type, return a field of its storage type + the
+// serialized extension metadata.
+std::shared_ptr<Field> GetStorageWithMetadata(const std::string& field_name,
+ const std::shared_ptr<DataType>&
type) {
+ const auto& ext_type = checked_cast<const ExtensionType&>(*type);
+ auto storage_type = ext_type.storage_type();
+ auto md = KeyValueMetadata::Make({kExtensionTypeKeyName,
kExtensionMetadataKeyName},
+ {ext_type.extension_name(),
ext_type.Serialize()});
+ return field(field_name, storage_type, /*nullable=*/true, md);
+}
+
TEST_F(TestSchemaRoundtrip, UnregisteredExtension) {
TestWithTypeFactory(uuid, []() { return fixed_size_binary(16); });
TestWithTypeFactory(dict_extension_type, []() { return dictionary(int8(),
utf8()); });
- // Inside nested type
- TestWithTypeFactory([]() { return list(dict_extension_type()); },
- []() { return list(dictionary(int8(), utf8())); });
+ // Inside nested type.
+ // When an extension type is not known by the importer, it is imported
+ // as its storage type and the extension metadata is preserved on the field.
+ TestWithTypeFactory(
+ []() { return list(dict_extension_type()); },
+ []() { return list(GetStorageWithMetadata("item",
dict_extension_type())); });
}
TEST_F(TestSchemaRoundtrip, RegisteredExtension) {
@@ -3710,7 +3724,9 @@ TEST_F(TestSchemaRoundtrip, RegisteredExtension) {
TestWithTypeFactory(dict_extension_type);
TestWithTypeFactory(complex128);
- // Inside nested type
+ // Inside nested type.
+ // When the extension type is registered, the extension metadata is removed
+ // from the storage type's field to ensure roundtripping (GH-39865).
TestWithTypeFactory([]() { return list(uuid()); });
TestWithTypeFactory([]() { return list(dict_extension_type()); });
TestWithTypeFactory([]() { return list(complex128()); });
@@ -3810,7 +3826,7 @@ class TestArrayRoundtrip : public ::testing::Test {
{
std::shared_ptr<Array> expected;
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected()));
- AssertTypeEqual(*expected->type(), *array->type());
+ AssertTypeEqual(*expected->type(), *array->type(),
/*check_metadata=*/true);
AssertArraysEqual(*expected, *array, true);
}
array.reset();
@@ -3850,7 +3866,7 @@ class TestArrayRoundtrip : public ::testing::Test {
{
std::shared_ptr<RecordBatch> expected;
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory()));
- AssertSchemaEqual(*expected->schema(), *batch->schema());
+ AssertSchemaEqual(*expected->schema(), *batch->schema(),
/*check_metadata=*/true);
AssertBatchesEqual(*expected, *batch);
}
batch.reset();
@@ -4230,7 +4246,7 @@ class TestDeviceArrayRoundtrip : public ::testing::Test {
{
std::shared_ptr<Array> expected;
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected()));
- AssertTypeEqual(*expected->type(), *array->type());
+ AssertTypeEqual(*expected->type(), *array->type(),
/*check_metadata=*/true);
AssertArraysEqual(*expected, *array, true);
}
array.reset();
@@ -4276,7 +4292,7 @@ class TestDeviceArrayRoundtrip : public ::testing::Test {
{
std::shared_ptr<RecordBatch> expected;
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory()));
- AssertSchemaEqual(*expected->schema(), *batch->schema());
+ AssertSchemaEqual(*expected->schema(), *batch->schema(),
/*check_metadata=*/true);
AssertBatchesEqual(*expected, *batch);
}
batch.reset();
@@ -4353,7 +4369,7 @@ class TestArrayStreamExport : public BaseArrayStreamTest {
SchemaExportGuard schema_guard(&c_schema);
ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
- AssertSchemaEqual(expected, *schema);
+ AssertSchemaEqual(expected, *schema, /*check_metadata=*/true);
}
void AssertStreamEnd(struct ArrowArrayStream* c_stream) {
@@ -4437,7 +4453,7 @@ TEST_F(TestArrayStreamExport, ArrayLifetime) {
{
SchemaExportGuard schema_guard(&c_schema);
ASSERT_OK_AND_ASSIGN(auto got_schema, ImportSchema(&c_schema));
- AssertSchemaEqual(*schema, *got_schema);
+ AssertSchemaEqual(*schema, *got_schema, /*check_metadata=*/true);
}
ASSERT_GT(pool_->bytes_allocated(), orig_allocated_);
@@ -4462,7 +4478,7 @@ TEST_F(TestArrayStreamExport, Errors) {
{
SchemaExportGuard schema_guard(&c_schema);
ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
- AssertSchemaEqual(schema, arrow::schema({}));
+ AssertSchemaEqual(schema, arrow::schema({}), /*check_metadata=*/true);
}
struct ArrowArray c_array;
@@ -4539,7 +4555,7 @@ TEST_F(TestArrayStreamRoundtrip, Simple) {
ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(batches,
orig_schema));
Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>&
reader) {
- AssertSchemaEqual(*orig_schema, *reader->schema());
+ AssertSchemaEqual(*orig_schema, *reader->schema(),
/*check_metadata=*/true);
AssertReaderNext(reader, *batches[0]);
AssertReaderNext(reader, *batches[1]);
AssertReaderEnd(reader);
diff --git a/cpp/src/arrow/util/key_value_metadata.cc
b/cpp/src/arrow/util/key_value_metadata.cc
index bc48ae76c2..002e8b0975 100644
--- a/cpp/src/arrow/util/key_value_metadata.cc
+++ b/cpp/src/arrow/util/key_value_metadata.cc
@@ -90,7 +90,7 @@ void KeyValueMetadata::Append(std::string key, std::string
value) {
values_.push_back(std::move(value));
}
-Result<std::string> KeyValueMetadata::Get(const std::string& key) const {
+Result<std::string> KeyValueMetadata::Get(std::string_view key) const {
auto index = FindKey(key);
if (index < 0) {
return Status::KeyError(key);
@@ -129,7 +129,7 @@ Status KeyValueMetadata::DeleteMany(std::vector<int64_t>
indices) {
return Status::OK();
}
-Status KeyValueMetadata::Delete(const std::string& key) {
+Status KeyValueMetadata::Delete(std::string_view key) {
auto index = FindKey(key);
if (index < 0) {
return Status::KeyError(key);
@@ -138,20 +138,18 @@ Status KeyValueMetadata::Delete(const std::string& key) {
}
}
-Status KeyValueMetadata::Set(const std::string& key, const std::string& value)
{
+Status KeyValueMetadata::Set(std::string key, std::string value) {
auto index = FindKey(key);
if (index < 0) {
- Append(key, value);
+ Append(std::move(key), std::move(value));
} else {
- keys_[index] = key;
- values_[index] = value;
+ keys_[index] = std::move(key);
+ values_[index] = std::move(value);
}
return Status::OK();
}
-bool KeyValueMetadata::Contains(const std::string& key) const {
- return FindKey(key) >= 0;
-}
+bool KeyValueMetadata::Contains(std::string_view key) const { return
FindKey(key) >= 0; }
void KeyValueMetadata::reserve(int64_t n) {
DCHECK_GE(n, 0);
@@ -188,7 +186,7 @@ std::vector<std::pair<std::string, std::string>>
KeyValueMetadata::sorted_pairs(
return pairs;
}
-int KeyValueMetadata::FindKey(const std::string& key) const {
+int KeyValueMetadata::FindKey(std::string_view key) const {
for (size_t i = 0; i < keys_.size(); ++i) {
if (keys_[i] == key) {
return static_cast<int>(i);
diff --git a/cpp/src/arrow/util/key_value_metadata.h
b/cpp/src/arrow/util/key_value_metadata.h
index 8702ce73a6..57ade11e75 100644
--- a/cpp/src/arrow/util/key_value_metadata.h
+++ b/cpp/src/arrow/util/key_value_metadata.h
@@ -20,6 +20,7 @@
#include <cstdint>
#include <memory>
#include <string>
+#include <string_view>
#include <unordered_map>
#include <utility>
#include <vector>
@@ -44,13 +45,13 @@ class ARROW_EXPORT KeyValueMetadata {
void ToUnorderedMap(std::unordered_map<std::string, std::string>* out) const;
void Append(std::string key, std::string value);
- Result<std::string> Get(const std::string& key) const;
- bool Contains(const std::string& key) const;
+ Result<std::string> Get(std::string_view key) const;
+ bool Contains(std::string_view key) const;
// Note that deleting may invalidate known indices
- Status Delete(const std::string& key);
+ Status Delete(std::string_view key);
Status Delete(int64_t index);
Status DeleteMany(std::vector<int64_t> indices);
- Status Set(const std::string& key, const std::string& value);
+ Status Set(std::string key, std::string value);
void reserve(int64_t n);
@@ -63,7 +64,7 @@ class ARROW_EXPORT KeyValueMetadata {
std::vector<std::pair<std::string, std::string>> sorted_pairs() const;
/// \brief Perform linear search for key, returning -1 if not found
- int FindKey(const std::string& key) const;
+ int FindKey(std::string_view key) const;
std::shared_ptr<KeyValueMetadata> Copy() const;