jorisvandenbossche commented on code in PR #8510:
URL: https://github.com/apache/arrow/pull/8510#discussion_r1130680231
##########
cpp/src/arrow/extension/fixed_shape_tensor.cc:
##########
@@ -49,11 +49,23 @@ bool FixedShapeTensorType::ExtensionEquals(const
ExtensionType& other) const {
return false;
}
const auto& other_ext = static_cast<const FixedShapeTensorType&>(other);
- bool equals = storage_type()->Equals(other_ext.storage_type());
- equals &= shape_ == other_ext.shape();
- equals &= permutation_ == other_ext.permutation();
- equals &= dim_names_ == other_ext.dim_names();
- return equals;
+
+ auto is_permutation_trivial = [](const std::vector<int64_t>& permutation) {
+ for (size_t i = 1; i < permutation.size(); ++i) {
+ if (permutation[i - 1] + 1 != permutation[i]) {
+ return false;
+ }
+ }
+ return true;
+ };
+ const bool permutation_equivalent =
+ (permutation_ == other_ext.permutation()) ||
+ ((permutation_.empty() &&
is_permutation_trivial(other_ext.permutation())) &&
+ (is_permutation_trivial(permutation_) ||
other_ext.permutation().empty()));
Review Comment:
Did you add tests for this? (don't see it in the commit that added this)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]