This is an automated email from the ASF dual-hosted git repository.
rok pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 2fcc3ec9f0 GH-38007: [C++] Add VariableShapeTensor implementation
(#38008)
2fcc3ec9f0 is described below
commit 2fcc3ec9f04cfe73facef5155e831dac19ead853
Author: Rok Mihevc <[email protected]>
AuthorDate: Tue Feb 24 19:42:36 2026 +0100
GH-38007: [C++] Add VariableShapeTensor implementation (#38008)
### Rationale for this change
We want to add VariableShapeTensor extension type definition for arrays
containing tensors with variable shapes.
### What changes are included in this PR?
This adds a C++ implementation.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
This adds a new extension type C++.
* Closes: #38007
* GitHub Issue: #38007
Lead-authored-by: Rok Mihevc <[email protected]>
Co-authored-by: Joris Van den Bossche <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Rok Mihevc <[email protected]>
---
cpp/src/arrow/CMakeLists.txt | 2 +
cpp/src/arrow/extension/CMakeLists.txt | 2 +-
cpp/src/arrow/extension/fixed_shape_tensor.cc | 102 ++----
cpp/src/arrow/extension/fixed_shape_tensor.h | 7 +-
...nsor_test.cc => tensor_extension_array_test.cc} | 349 +++++++++++++++++++--
cpp/src/arrow/extension/tensor_internal.cc | 131 ++++++++
cpp/src/arrow/extension/tensor_internal.h | 35 ++-
cpp/src/arrow/extension/variable_shape_tensor.cc | 325 +++++++++++++++++++
cpp/src/arrow/extension/variable_shape_tensor.h | 111 +++++++
cpp/src/arrow/extension_type.cc | 2 +
cpp/src/arrow/extension_type_test.cc | 4 +-
docs/source/format/CanonicalExtensions.rst | 4 +-
python/pyarrow/tests/test_extension_type.py | 11 +
13 files changed, 944 insertions(+), 141 deletions(-)
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 6e9d76a61e..eee63b11ca 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -995,6 +995,8 @@ if(ARROW_JSON)
arrow_add_object_library(ARROW_JSON
extension/fixed_shape_tensor.cc
extension/opaque.cc
+ extension/tensor_internal.cc
+ extension/variable_shape_tensor.cc
json/options.cc
json/chunked_builder.cc
json/chunker.cc
diff --git a/cpp/src/arrow/extension/CMakeLists.txt
b/cpp/src/arrow/extension/CMakeLists.txt
index 4ab6a35b52..ae52bc32a9 100644
--- a/cpp/src/arrow/extension/CMakeLists.txt
+++ b/cpp/src/arrow/extension/CMakeLists.txt
@@ -18,7 +18,7 @@
set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc)
if(ARROW_JSON)
- list(APPEND CANONICAL_EXTENSION_TESTS fixed_shape_tensor_test.cc
opaque_test.cc)
+ list(APPEND CANONICAL_EXTENSION_TESTS tensor_extension_array_test.cc
opaque_test.cc)
endif()
add_arrow_test(test
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc
b/cpp/src/arrow/extension/fixed_shape_tensor.cc
index bb7082e697..5be855ffcb 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor.cc
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc
@@ -26,7 +26,6 @@
#include "arrow/array/array_primitive.h"
#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
#include "arrow/tensor.h"
-#include "arrow/util/int_util_overflow.h"
#include "arrow/util/logging_internal.h"
#include "arrow/util/print_internal.h"
#include "arrow/util/sort_internal.h"
@@ -37,52 +36,7 @@
namespace rj = arrow::rapidjson;
-namespace arrow {
-
-namespace extension {
-
-namespace {
-
-Status ComputeStrides(const FixedWidthType& type, const std::vector<int64_t>&
shape,
- const std::vector<int64_t>& permutation,
- std::vector<int64_t>* strides) {
- if (permutation.empty()) {
- return internal::ComputeRowMajorStrides(type, shape, strides);
- }
-
- const int byte_width = type.byte_width();
-
- int64_t remaining = 0;
- if (!shape.empty() && shape.front() > 0) {
- remaining = byte_width;
- for (auto i : permutation) {
- if (i > 0) {
- if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
- return Status::Invalid(
- "Strides computed from shape would not fit in 64-bit integer");
- }
- }
- }
- }
-
- if (remaining == 0) {
- strides->assign(shape.size(), byte_width);
- return Status::OK();
- }
-
- strides->push_back(remaining);
- for (auto i : permutation) {
- if (i > 0) {
- remaining /= shape[i];
- strides->push_back(remaining);
- }
- }
- internal::Permute(permutation, strides);
-
- return Status::OK();
-}
-
-} // namespace
+namespace arrow::extension {
bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
if (extension_name() != other.extension_name()) {
@@ -90,18 +44,10 @@ bool FixedShapeTensorType::ExtensionEquals(const
ExtensionType& other) const {
}
const auto& other_ext = internal::checked_cast<const
FixedShapeTensorType&>(other);
- auto is_permutation_trivial = [](const std::vector<int64_t>& permutation) {
- for (size_t i = 1; i < permutation.size(); ++i) {
- if (permutation[i - 1] + 1 != permutation[i]) {
- return false;
- }
- }
- return true;
- };
const bool permutation_equivalent =
- ((permutation_ == other_ext.permutation()) ||
- (permutation_.empty() &&
is_permutation_trivial(other_ext.permutation())) ||
- (is_permutation_trivial(permutation_) &&
other_ext.permutation().empty()));
+ (permutation_ == other_ext.permutation()) ||
+ (internal::IsPermutationTrivial(permutation_) &&
+ internal::IsPermutationTrivial(other_ext.permutation()));
return (storage_type()->Equals(other_ext.storage_type())) &&
(this->shape() == other_ext.shape()) && (dim_names_ ==
other_ext.dim_names()) &&
@@ -167,7 +113,8 @@ Result<std::shared_ptr<DataType>>
FixedShapeTensorType::Deserialize(
internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type();
rj::Document document;
if (document.Parse(serialized_data.data(),
serialized_data.length()).HasParseError() ||
- !document.HasMember("shape") || !document["shape"].IsArray()) {
+ !document.IsObject() || !document.HasMember("shape") ||
+ !document["shape"].IsArray()) {
return Status::Invalid("Invalid serialized JSON data: ", serialized_data);
}
@@ -218,10 +165,6 @@ Result<std::shared_ptr<Tensor>>
FixedShapeTensorType::MakeTensor(
if (array->null_count() > 0) {
return Status::Invalid("Cannot convert data with nulls to Tensor.");
}
- const auto& value_type =
- internal::checked_cast<const FixedWidthType&>(*ext_type.value_type());
- const auto byte_width = value_type.byte_width();
-
std::vector<int64_t> permutation = ext_type.permutation();
if (permutation.empty()) {
permutation.resize(ext_type.ndim());
@@ -236,13 +179,10 @@ Result<std::shared_ptr<Tensor>>
FixedShapeTensorType::MakeTensor(
internal::Permute<std::string>(permutation, &dim_names);
}
- std::vector<int64_t> strides;
- RETURN_NOT_OK(ComputeStrides(value_type, shape, permutation, &strides));
- const auto start_position = array->offset() * byte_width;
- const auto size = std::accumulate(shape.begin(), shape.end(),
static_cast<int64_t>(1),
- std::multiplies<>());
- const auto buffer =
- SliceBuffer(array->data()->buffers[1], start_position, size *
byte_width);
+ ARROW_ASSIGN_OR_RAISE(
+ auto strides, internal::ComputeStrides(ext_type.value_type(), shape,
permutation));
+ ARROW_ASSIGN_OR_RAISE(const auto buffer, internal::SliceTensorBuffer(
+ *array, *ext_type.value_type(),
shape));
return Tensor::Make(ext_type.value_type(), buffer, shape, strides,
dim_names);
}
@@ -304,7 +244,7 @@ Result<std::shared_ptr<FixedShapeTensorArray>>
FixedShapeTensorArray::FromTensor
break;
}
case Type::UINT64: {
- value_array = std::make_shared<Int64Array>(tensor->size(),
tensor->data());
+ value_array = std::make_shared<UInt64Array>(tensor->size(),
tensor->data());
break;
}
case Type::INT64: {
@@ -375,10 +315,8 @@ const Result<std::shared_ptr<Tensor>>
FixedShapeTensorArray::ToTensor() const {
shape.insert(shape.begin(), 1, this->length());
internal::Permute<int64_t>(permutation, &shape);
- std::vector<int64_t> tensor_strides;
- const auto* fw_value_type =
internal::checked_cast<FixedWidthType*>(value_type.get());
- ARROW_RETURN_NOT_OK(
- ComputeStrides(*fw_value_type, shape, permutation, &tensor_strides));
+ ARROW_ASSIGN_OR_RAISE(auto tensor_strides,
+ internal::ComputeStrides(value_type, shape,
permutation));
const auto& raw_buffer = this->storage()->data()->child_data[0]->buffers[1];
ARROW_ASSIGN_OR_RAISE(
@@ -412,11 +350,10 @@ Result<std::shared_ptr<DataType>>
FixedShapeTensorType::Make(
const std::vector<int64_t>& FixedShapeTensorType::strides() {
if (strides_.empty()) {
- auto value_type =
internal::checked_cast<FixedWidthType*>(this->value_type_.get());
- std::vector<int64_t> tensor_strides;
- ARROW_CHECK_OK(
- ComputeStrides(*value_type, this->shape(), this->permutation(),
&tensor_strides));
- strides_ = tensor_strides;
+ auto maybe_strides =
+ internal::ComputeStrides(this->value_type_, this->shape(),
this->permutation());
+ ARROW_CHECK_OK(maybe_strides.status());
+ strides_ = std::move(maybe_strides).MoveValueUnsafe();
}
return strides_;
}
@@ -426,9 +363,8 @@ std::shared_ptr<DataType> fixed_shape_tensor(const
std::shared_ptr<DataType>& va
const std::vector<int64_t>&
permutation,
const std::vector<std::string>&
dim_names) {
auto maybe_type = FixedShapeTensorType::Make(value_type, shape, permutation,
dim_names);
- ARROW_DCHECK_OK(maybe_type.status());
+ ARROW_CHECK_OK(maybe_type.status());
return maybe_type.MoveValueUnsafe();
}
-} // namespace extension
-} // namespace arrow
+} // namespace arrow::extension
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.h
b/cpp/src/arrow/extension/fixed_shape_tensor.h
index 80a602021c..eee44e1c81 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor.h
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.h
@@ -19,8 +19,7 @@
#include "arrow/extension_type.h"
-namespace arrow {
-namespace extension {
+namespace arrow::extension {
class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
public:
@@ -112,7 +111,6 @@ class ARROW_EXPORT FixedShapeTensorType : public
ExtensionType {
const std::vector<std::string>& dim_names = {});
private:
- std::shared_ptr<DataType> storage_type_;
std::shared_ptr<DataType> value_type_;
std::vector<int64_t> shape_;
std::vector<int64_t> strides_;
@@ -126,5 +124,4 @@ ARROW_EXPORT std::shared_ptr<DataType> fixed_shape_tensor(
const std::vector<int64_t>& permutation = {},
const std::vector<std::string>& dim_names = {});
-} // namespace extension
-} // namespace arrow
+} // namespace arrow::extension
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
b/cpp/src/arrow/extension/tensor_extension_array_test.cc
similarity index 66%
rename from cpp/src/arrow/extension/fixed_shape_tensor_test.cc
rename to cpp/src/arrow/extension/tensor_extension_array_test.cc
index 6d4d2de326..5c6dbe2162 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
+++ b/cpp/src/arrow/extension/tensor_extension_array_test.cc
@@ -16,6 +16,7 @@
// under the License.
#include "arrow/extension/fixed_shape_tensor.h"
+#include "arrow/extension/variable_shape_tensor.h"
#include "arrow/testing/matchers.h"
@@ -37,7 +38,11 @@ using arrow::ipc::test::RoundtripBatch;
using extension::fixed_shape_tensor;
using extension::FixedShapeTensorArray;
-class TestExtensionType : public ::testing::Test {
+using VariableShapeTensorType = extension::VariableShapeTensorType;
+using extension::variable_shape_tensor;
+using extension::VariableShapeTensorArray;
+
+class TestFixedShapeTensorType : public ::testing::Test {
public:
void SetUp() override {
shape_ = {3, 3, 4};
@@ -72,13 +77,13 @@ class TestExtensionType : public ::testing::Test {
std::string serialized_;
};
-TEST_F(TestExtensionType, CheckDummyRegistration) {
+TEST_F(TestFixedShapeTensorType, CheckDummyRegistration) {
// We need a registered dummy type at runtime to allow for IPC
deserialization
auto registered_type = GetExtensionType("arrow.fixed_shape_tensor");
- ASSERT_TRUE(registered_type->type_id == Type::EXTENSION);
+ ASSERT_EQ(registered_type->id(), Type::EXTENSION);
}
-TEST_F(TestExtensionType, CreateExtensionType) {
+TEST_F(TestFixedShapeTensorType, CreateExtensionType) {
auto exact_ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
// Test ExtensionType methods
@@ -118,7 +123,7 @@ TEST_F(TestExtensionType, CreateExtensionType) {
FixedShapeTensorType::Make(value_type_, {1, 2, 3}, {0, 1, 1}));
}
-TEST_F(TestExtensionType, EqualsCases) {
+TEST_F(TestFixedShapeTensorType, EqualsCases) {
auto ext_type_permutation_1 = fixed_shape_tensor(int64(), {3, 4}, {0, 1},
{"x", "y"});
auto ext_type_permutation_2 = fixed_shape_tensor(int64(), {3, 4}, {1, 0},
{"x", "y"});
auto ext_type_no_permutation = fixed_shape_tensor(int64(), {3, 4}, {}, {"x",
"y"});
@@ -140,7 +145,7 @@ TEST_F(TestExtensionType, EqualsCases) {
ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_permutation_1));
}
-TEST_F(TestExtensionType, CreateFromArray) {
+TEST_F(TestFixedShapeTensorType, CreateFromArray) {
auto exact_ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
std::vector<std::shared_ptr<Buffer>> buffers = {nullptr,
Buffer::Wrap(values_)};
@@ -152,7 +157,7 @@ TEST_F(TestExtensionType, CreateFromArray) {
ASSERT_EQ(ext_arr->null_count(), 0);
}
-TEST_F(TestExtensionType, MakeArrayCanGetCorrectScalarType) {
+TEST_F(TestFixedShapeTensorType, MakeArrayCanGetCorrectScalarType) {
ASSERT_OK_AND_ASSIGN(auto tensor,
Tensor::Make(value_type_, Buffer::Wrap(values_),
shape_));
@@ -175,23 +180,23 @@ TEST_F(TestExtensionType,
MakeArrayCanGetCorrectScalarType) {
}
void CheckSerializationRoundtrip(const std::shared_ptr<DataType>& ext_type) {
- auto fst_type =
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type);
- auto serialized = fst_type->Serialize();
+ auto type = internal::checked_pointer_cast<ExtensionType>(ext_type);
+ auto serialized = type->Serialize();
ASSERT_OK_AND_ASSIGN(auto deserialized,
- fst_type->Deserialize(fst_type->storage_type(),
serialized));
- ASSERT_TRUE(fst_type->Equals(*deserialized));
+ type->Deserialize(type->storage_type(), serialized));
+ ASSERT_TRUE(type->Equals(*deserialized));
}
-void CheckDeserializationRaises(const std::shared_ptr<DataType>& storage_type,
+void CheckDeserializationRaises(const std::shared_ptr<DataType>&
extension_type,
+ const std::shared_ptr<DataType>& storage_type,
const std::string& serialized,
const std::string& expected_message) {
- auto fst_type = internal::checked_pointer_cast<FixedShapeTensorType>(
- fixed_shape_tensor(int64(), {3, 4}));
+ auto ext_type =
internal::checked_pointer_cast<ExtensionType>(extension_type);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
testing::HasSubstr(expected_message),
- fst_type->Deserialize(storage_type,
serialized));
+ ext_type->Deserialize(storage_type,
serialized));
}
-TEST_F(TestExtensionType, MetadataSerializationRoundtrip) {
+TEST_F(TestFixedShapeTensorType, MetadataSerializationRoundtrip) {
CheckSerializationRoundtrip(ext_type_);
CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {}, {}, {}));
CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {0}, {}, {}));
@@ -202,19 +207,21 @@ TEST_F(TestExtensionType, MetadataSerializationRoundtrip)
{
fixed_shape_tensor(value_type_, {256, 256, 3}, {2, 0, 1}, {"C", "H",
"W"}));
auto storage_type = fixed_size_list(int64(), 12);
- CheckDeserializationRaises(boolean(), R"({"shape":[3,4]})",
+ CheckDeserializationRaises(ext_type_, boolean(), R"({"shape":[3,4]})",
"Expected FixedSizeList storage type, got bool");
- CheckDeserializationRaises(storage_type, R"({"dim_names":["x","y"]})",
+ CheckDeserializationRaises(ext_type_, storage_type,
R"({"dim_names":["x","y"]})",
"Invalid serialized JSON data");
- CheckDeserializationRaises(storage_type, R"({"shape":(3,4)})",
+ CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":(3,4)})",
"Invalid serialized JSON data");
- CheckDeserializationRaises(storage_type,
R"({"shape":[3,4],"permutation":[1,0,2]})",
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"shape":[3,4],"permutation":[1,0,2]})",
"Invalid permutation");
- CheckDeserializationRaises(storage_type,
R"({"shape":[3],"dim_names":["x","y"]})",
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"shape":[3],"dim_names":["x","y"]})",
"Invalid dim_names");
}
-TEST_F(TestExtensionType, RoundtripBatch) {
+TEST_F(TestFixedShapeTensorType, RoundtripBatch) {
auto exact_ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
std::vector<std::shared_ptr<Buffer>> buffers = {nullptr,
Buffer::Wrap(values_)};
@@ -242,7 +249,7 @@ TEST_F(TestExtensionType, RoundtripBatch) {
CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true);
}
-TEST_F(TestExtensionType, CreateFromTensor) {
+TEST_F(TestFixedShapeTensorType, CreateFromTensor) {
std::vector<int64_t> column_major_strides = {8, 24, 72};
std::vector<int64_t> neither_major_strides = {96, 8, 32};
@@ -320,7 +327,7 @@ void CheckFromTensorType(const std::shared_ptr<Tensor>&
tensor,
ASSERT_TRUE(generated_ext_type->Equals(ext_type));
}
-TEST_F(TestExtensionType, TestFromTensorType) {
+TEST_F(TestFixedShapeTensorType, TestFromTensorType) {
auto values = Buffer::Wrap(values_);
auto shapes =
std::vector<std::vector<int64_t>>{{3, 3, 4}, {3, 3, 4}, {3, 4, 3}, {3,
4, 3}};
@@ -379,7 +386,7 @@ void CheckToTensor(const std::vector<T>& values, const
std::shared_ptr<DataType>
ASSERT_TRUE(actual_tensor->Equals(*expected_tensor));
}
-TEST_F(TestExtensionType, ToTensor) {
+TEST_F(TestFixedShapeTensorType, ToTensor) {
std::vector<float_t> float_values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35};
@@ -430,7 +437,7 @@ void CheckTensorRoundtrip(const std::shared_ptr<Tensor>&
tensor) {
ASSERT_TRUE(tensor->Equals(*tensor_from_array));
}
-TEST_F(TestExtensionType, RoundtripTensor) {
+TEST_F(TestFixedShapeTensorType, RoundtripTensor) {
auto values = Buffer::Wrap(values_);
auto shapes = std::vector<std::vector<int64_t>>{
@@ -451,7 +458,7 @@ TEST_F(TestExtensionType, RoundtripTensor) {
}
}
-TEST_F(TestExtensionType, SliceTensor) {
+TEST_F(TestFixedShapeTensorType, SliceTensor) {
ASSERT_OK_AND_ASSIGN(auto tensor,
Tensor::Make(value_type_, Buffer::Wrap(values_),
shape_));
ASSERT_OK_AND_ASSIGN(
@@ -478,7 +485,7 @@ TEST_F(TestExtensionType, SliceTensor) {
ASSERT_EQ(sliced->length(), partial->length());
}
-TEST_F(TestExtensionType, RoundtripBatchFromTensor) {
+TEST_F(TestFixedShapeTensorType, RoundtripBatchFromTensor) {
auto exact_ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_,
Buffer::Wrap(values_),
shape_, {}, {"n", "x", "y"}));
@@ -495,7 +502,7 @@ TEST_F(TestExtensionType, RoundtripBatchFromTensor) {
CompareBatch(*batch, *read_batch, /*compare_metadata=*/true);
}
-TEST_F(TestExtensionType, ComputeStrides) {
+TEST_F(TestFixedShapeTensorType, ComputeStrides) {
auto exact_ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
auto ext_type_1 = internal::checked_pointer_cast<FixedShapeTensorType>(
@@ -529,7 +536,7 @@ TEST_F(TestExtensionType, ComputeStrides) {
ASSERT_EQ(ext_type_7->Serialize(),
R"({"shape":[3,4,7],"permutation":[2,0,1]})");
}
-TEST_F(TestExtensionType, ToString) {
+TEST_F(TestFixedShapeTensorType, FixedShapeTensorToString) {
auto exact_ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
auto ext_type_1 = internal::checked_pointer_cast<FixedShapeTensorType>(
@@ -557,7 +564,7 @@ TEST_F(TestExtensionType, ToString) {
ASSERT_EQ(expected_3, result_3);
}
-TEST_F(TestExtensionType, GetTensor) {
+TEST_F(TestFixedShapeTensorType, GetTensor) {
auto arr = ArrayFromJSON(element_type_,
"[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],"
"[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
23]]");
@@ -649,4 +656,284 @@ TEST_F(TestExtensionType, GetTensor) {
exact_ext_type->MakeTensor(ext_scalar));
}
+class TestVariableShapeTensorType : public ::testing::Test {
+ public:
+ void SetUp() override {
+ ndim_ = 3;
+ value_type_ = int64();
+ data_type_ = list(value_type_);
+ shape_type_ = fixed_size_list(int32(), ndim_);
+ permutation_ = {0, 1, 2};
+ dim_names_ = {"x", "y", "z"};
+ uniform_shape_ = {std::nullopt, std::optional<int64_t>(1), std::nullopt};
+ ext_type_ =
internal::checked_pointer_cast<ExtensionType>(variable_shape_tensor(
+ value_type_, ndim_, permutation_, dim_names_, uniform_shape_));
+ values_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17,
+ 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35};
+ shapes_ = ArrayFromJSON(fixed_size_list(int32(), ndim_),
"[[2,1,3],[2,1,2],[3,1,3]]");
+ data_ = ArrayFromJSON(list(value_type_),
+
"[[0,1,2,3,4,5],[6,7,8,9],[10,11,12,13,14,15,16,17,18]]");
+ serialized_ =
+
R"({"permutation":[0,1,2],"dim_names":["x","y","z"],"uniform_shape":[null,1,null]})";
+ storage_arr_ = ArrayFromJSON(
+ ext_type_->storage_type(),
+
R"([[[0,1,2,3,4,5],[2,3,1]],[[6,7,8,9],[1,2,2]],[[10,11,12,13,14,15,16,17,18],[3,1,3]]])");
+ ext_arr_ = internal::checked_pointer_cast<ExtensionArray>(
+ ExtensionType::WrapArray(ext_type_, storage_arr_));
+ }
+
+ protected:
+ int32_t ndim_;
+ std::shared_ptr<DataType> value_type_;
+ std::shared_ptr<DataType> data_type_;
+ std::shared_ptr<DataType> shape_type_;
+ std::vector<int64_t> permutation_;
+ std::vector<std::optional<int64_t>> uniform_shape_;
+ std::vector<std::string> dim_names_;
+ std::shared_ptr<ExtensionType> ext_type_;
+ std::vector<int64_t> values_;
+ std::shared_ptr<Array> shapes_;
+ std::shared_ptr<Array> data_;
+ std::string serialized_;
+ std::shared_ptr<Array> storage_arr_;
+ std::shared_ptr<ExtensionArray> ext_arr_;
+};
+
+TEST_F(TestVariableShapeTensorType, CheckDummyRegistration) {
+ // We need a registered dummy type at runtime to allow for IPC
deserialization
+ auto registered_type = GetExtensionType("arrow.variable_shape_tensor");
+ ASSERT_EQ(registered_type->id(), Type::EXTENSION);
+}
+
+TEST_F(TestVariableShapeTensorType, CreateExtensionType) {
+ auto exact_ext_type =
+ internal::checked_pointer_cast<VariableShapeTensorType>(ext_type_);
+
+ // Test ExtensionType methods
+ ASSERT_EQ(ext_type_->extension_name(), "arrow.variable_shape_tensor");
+ ASSERT_TRUE(ext_type_->Equals(*exact_ext_type));
+ auto expected_type =
+ struct_({::arrow::field("data", list(value_type_)),
+ ::arrow::field("shape", fixed_size_list(int32(), ndim_))});
+
+ ASSERT_TRUE(ext_type_->storage_type()->Equals(*expected_type));
+ ASSERT_EQ(ext_type_->Serialize(), serialized_);
+ ASSERT_OK_AND_ASSIGN(auto ds,
+ ext_type_->Deserialize(ext_type_->storage_type(),
serialized_));
+ auto deserialized = internal::checked_pointer_cast<ExtensionType>(ds);
+ ASSERT_TRUE(deserialized->Equals(*exact_ext_type));
+ ASSERT_TRUE(deserialized->Equals(*ext_type_));
+
+ // Test VariableShapeTensorType methods
+ ASSERT_EQ(exact_ext_type->id(), Type::EXTENSION);
+ ASSERT_EQ(exact_ext_type->ndim(), ndim_);
+ ASSERT_EQ(exact_ext_type->value_type(), value_type_);
+ ASSERT_EQ(exact_ext_type->permutation(), permutation_);
+ ASSERT_EQ(exact_ext_type->dim_names(), dim_names_);
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("Invalid: permutation size must match ndim. Expected:
3 Got: 1"),
+ VariableShapeTensorType::Make(value_type_, ndim_, {0}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("Invalid: dim_names size must match ndim."),
+ VariableShapeTensorType::Make(value_type_, ndim_, {}, {"x"}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("Invalid: Permutation indices for 3 dimensional
tensors must be "
+ "unique and within [0, 2] range. Got: [2,0,0]"),
+ VariableShapeTensorType::Make(value_type_, 3, {2, 0, 0}, {"C", "H",
"W"}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("Invalid: Permutation indices for 3 dimensional
tensors must be "
+ "unique and within [0, 2] range. Got: [1,2,3]"),
+ VariableShapeTensorType::Make(value_type_, 3, {1, 2, 3}, {"C", "H",
"W"}));
+}
+
+TEST_F(TestVariableShapeTensorType, EqualsCases) {
+ auto ext_type_permutation_1 = variable_shape_tensor(int64(), 2, {0, 1},
{"x", "y"});
+ auto ext_type_permutation_2 = variable_shape_tensor(int64(), 2, {1, 0},
{"x", "y"});
+ auto ext_type_no_permutation = variable_shape_tensor(int64(), 2, {}, {"x",
"y"});
+
+ ASSERT_TRUE(ext_type_permutation_1->Equals(ext_type_permutation_1));
+
+ ASSERT_FALSE(
+ variable_shape_tensor(int32(), 2, {}, {"x",
"y"})->Equals(ext_type_no_permutation));
+ ASSERT_FALSE(variable_shape_tensor(int64(), 2, {}, {})
+ ->Equals(variable_shape_tensor(int64(), 3, {}, {})));
+ ASSERT_FALSE(
+ variable_shape_tensor(int64(), 2, {}, {"H",
"W"})->Equals(ext_type_no_permutation));
+
+ ASSERT_TRUE(ext_type_no_permutation->Equals(ext_type_permutation_1));
+ ASSERT_TRUE(ext_type_permutation_1->Equals(ext_type_no_permutation));
+ ASSERT_FALSE(ext_type_no_permutation->Equals(ext_type_permutation_2));
+ ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_no_permutation));
+ ASSERT_FALSE(ext_type_permutation_1->Equals(ext_type_permutation_2));
+ ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_permutation_1));
+}
+
+TEST_F(TestVariableShapeTensorType, MetadataSerializationRoundtrip) {
+ CheckSerializationRoundtrip(ext_type_);
+ CheckSerializationRoundtrip(
+ variable_shape_tensor(value_type_, 3, {1, 2, 0}, {"x", "y", "z"}));
+ CheckSerializationRoundtrip(variable_shape_tensor(value_type_, 0, {}, {}));
+ CheckSerializationRoundtrip(variable_shape_tensor(value_type_, 1, {0},
{"x"}));
+ CheckSerializationRoundtrip(
+ variable_shape_tensor(value_type_, 3, {0, 1, 2}, {"H", "W", "C"}));
+ CheckSerializationRoundtrip(
+ variable_shape_tensor(value_type_, 3, {2, 0, 1}, {"C", "H", "W"}));
+ CheckSerializationRoundtrip(
+ variable_shape_tensor(value_type_, 3, {2, 0, 1}, {"C", "H", "W"}, {0, 1,
2}));
+
+ auto storage_type = ext_type_->storage_type();
+ CheckDeserializationRaises(ext_type_, boolean(), R"({"shape":[3,4]})",
+ "Expected Struct storage type, got bool");
+ CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":(3,4)})",
+ "Invalid serialized JSON data");
+ CheckDeserializationRaises(ext_type_, storage_type,
R"({"permutation":[1,0]})",
+ "Invalid: permutation");
+ CheckDeserializationRaises(ext_type_, storage_type,
R"({"dim_names":["x","y"]})",
+ "Invalid: dim_names");
+}
+
+TEST_F(TestVariableShapeTensorType, RoundtripBatch) {
+ auto exact_ext_type =
+ internal::checked_pointer_cast<VariableShapeTensorType>(ext_type_);
+
+ // Pass extension array, expect getting back extension array
+ std::shared_ptr<RecordBatch> read_batch;
+ auto ext_field = field(/*name=*/"f0", /*type=*/ext_type_);
+ auto batch = RecordBatch::Make(schema({ext_field}), ext_arr_->length(),
{ext_arr_});
+ ASSERT_OK(RoundtripBatch(batch, &read_batch));
+ CompareBatch(*batch, *read_batch, /*compare_metadata=*/true);
+
+ // Pass extension metadata and storage array, expect getting back extension
array
+ std::shared_ptr<RecordBatch> read_batch2;
+ auto ext_metadata =
+ key_value_metadata({{"ARROW:extension:name",
exact_ext_type->extension_name()},
+ {"ARROW:extension:metadata", serialized_}});
+ ext_field = field(/*name=*/"f0", /*type=*/ext_type_->storage_type(),
/*nullable=*/true,
+ /*metadata=*/ext_metadata);
+ auto batch2 = RecordBatch::Make(schema({ext_field}), ext_arr_->length(),
{ext_arr_});
+ ASSERT_OK(RoundtripBatch(batch2, &read_batch2));
+ CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true);
+}
+
+TEST_F(TestVariableShapeTensorType, ComputeStrides) {
+ auto shape = ArrayFromJSON(shape_type_, "[[2,3,1],[2,1,2],[3,1,3],null]");
+ auto data = ArrayFromJSON(
+ data_type_,
"[[1,1,2,3,4,5],[2,7,8,9],[10,11,12,13,14,15,16,17,18],null]");
+ std::vector<std::shared_ptr<Field>> fields = {field("data", data_type_),
+ field("shape", shape_type_)};
+ ASSERT_OK_AND_ASSIGN(auto storage_arr, StructArray::Make({data, shape},
fields));
+ auto ext_arr = ExtensionType::WrapArray(ext_type_, storage_arr);
+ auto exact_ext_type =
+ internal::checked_pointer_cast<VariableShapeTensorType>(ext_type_);
+ auto ext_array = std::static_pointer_cast<VariableShapeTensorArray>(ext_arr);
+
+ std::shared_ptr<Tensor> t, tensor;
+
+ ASSERT_OK_AND_ASSIGN(auto scalar, ext_array->GetScalar(0));
+ auto ext_scalar = internal::checked_pointer_cast<ExtensionScalar>(scalar);
+ ASSERT_OK_AND_ASSIGN(t, exact_ext_type->MakeTensor(ext_scalar));
+ ASSERT_EQ(t->shape(), (std::vector<int64_t>{2, 3, 1}));
+ ASSERT_EQ(t->strides(), (std::vector<int64_t>{24, 8, 8}));
+
+ std::vector<int64_t> strides = {sizeof(int64_t) * 3, sizeof(int64_t) * 1,
+ sizeof(int64_t) * 1};
+ tensor = TensorFromJSON(int64(), R"([1,1,2,3,4,5])", {2, 3, 1}, strides,
dim_names_);
+
+ ASSERT_TRUE(tensor->Equals(*t));
+
+ ASSERT_OK_AND_ASSIGN(scalar, ext_array->GetScalar(1));
+ ext_scalar = internal::checked_pointer_cast<ExtensionScalar>(scalar);
+ ASSERT_OK_AND_ASSIGN(t, exact_ext_type->MakeTensor(ext_scalar));
+ ASSERT_EQ(t->shape(), (std::vector<int64_t>{2, 1, 2}));
+ ASSERT_EQ(t->strides(), (std::vector<int64_t>{16, 16, 8}));
+
+ ASSERT_OK_AND_ASSIGN(scalar, ext_array->GetScalar(2));
+ ext_scalar = internal::checked_pointer_cast<ExtensionScalar>(scalar);
+ ASSERT_OK_AND_ASSIGN(t, exact_ext_type->MakeTensor(ext_scalar));
+ ASSERT_EQ(t->shape(), (std::vector<int64_t>{3, 1, 3}));
+ ASSERT_EQ(t->strides(), (std::vector<int64_t>{24, 24, 8}));
+
+ strides = {sizeof(int64_t) * 3, sizeof(int64_t) * 3, sizeof(int64_t) * 1};
+ tensor = TensorFromJSON(int64(), R"([10,11,12,13,14,15,16,17,18])", {3, 1,
3}, strides,
+ dim_names_);
+
+ ASSERT_EQ(tensor->strides(), t->strides());
+ ASSERT_EQ(tensor->shape(), t->shape());
+ ASSERT_EQ(tensor->dim_names(), t->dim_names());
+ ASSERT_EQ(tensor->type(), t->type());
+ ASSERT_EQ(tensor->is_contiguous(), t->is_contiguous());
+ ASSERT_EQ(tensor->is_column_major(), t->is_column_major());
+ ASSERT_TRUE(tensor->Equals(*t));
+
+ ASSERT_OK_AND_ASSIGN(auto sc, ext_arr->GetScalar(2));
+ auto s = internal::checked_pointer_cast<ExtensionScalar>(sc);
+ ASSERT_OK_AND_ASSIGN(t, exact_ext_type->MakeTensor(s));
+ ASSERT_EQ(tensor->strides(), t->strides());
+ ASSERT_EQ(tensor->shape(), t->shape());
+ ASSERT_EQ(tensor->dim_names(), t->dim_names());
+ ASSERT_EQ(tensor->type(), t->type());
+ ASSERT_EQ(tensor->is_contiguous(), t->is_contiguous());
+ ASSERT_EQ(tensor->is_column_major(), t->is_column_major());
+ ASSERT_TRUE(tensor->Equals(*t));
+
+ // Null value in VariableShapeTensorArray produces a tensor with shape {0,
0, 0}
+ strides = {sizeof(int64_t), sizeof(int64_t), sizeof(int64_t)};
+ tensor = TensorFromJSON(int64(), R"([10,11,12,13,14,15,16,17,18])", {0, 0,
0}, strides,
+ dim_names_);
+
+ ASSERT_OK_AND_ASSIGN(sc, ext_arr->GetScalar(3));
+ ASSERT_OK_AND_ASSIGN(
+ t,
exact_ext_type->MakeTensor(internal::checked_pointer_cast<ExtensionScalar>(sc)));
+ ASSERT_EQ(tensor->strides(), t->strides());
+ ASSERT_EQ(tensor->shape(), t->shape());
+ ASSERT_EQ(tensor->dim_names(), t->dim_names());
+ ASSERT_EQ(tensor->type(), t->type());
+ ASSERT_EQ(tensor->is_contiguous(), t->is_contiguous());
+ ASSERT_EQ(tensor->is_column_major(), t->is_column_major());
+ ASSERT_TRUE(tensor->Equals(*t));
+}
+
+TEST_F(TestVariableShapeTensorType, ToString) {
+ auto exact_ext_type =
+ internal::checked_pointer_cast<VariableShapeTensorType>(ext_type_);
+
+ auto uniform_shape = std::vector<std::optional<int64_t>>{
+ std::nullopt, std::optional<int64_t>(1), std::nullopt};
+ auto ext_type_1 = internal::checked_pointer_cast<VariableShapeTensorType>(
+ variable_shape_tensor(int16(), 3));
+ auto ext_type_2 = internal::checked_pointer_cast<VariableShapeTensorType>(
+ variable_shape_tensor(int32(), 3, {1, 0, 2}));
+ auto ext_type_3 = internal::checked_pointer_cast<VariableShapeTensorType>(
+ variable_shape_tensor(int64(), 3, {}, {"C", "H", "W"}));
+ auto ext_type_4 = internal::checked_pointer_cast<VariableShapeTensorType>(
+ variable_shape_tensor(int64(), 3, {}, {}, uniform_shape));
+
+ std::string result_1 = ext_type_1->ToString();
+ std::string expected_1 =
+ "extension<arrow.variable_shape_tensor[value_type=int16, ndim=3]>";
+ ASSERT_EQ(expected_1, result_1);
+
+ std::string result_2 = ext_type_2->ToString();
+ std::string expected_2 =
+ "extension<arrow.variable_shape_tensor[value_type=int32, ndim=3, "
+ "permutation=[1,0,2]]>";
+ ASSERT_EQ(expected_2, result_2);
+
+ std::string result_3 = ext_type_3->ToString();
+ std::string expected_3 =
+ "extension<arrow.variable_shape_tensor[value_type=int64, ndim=3, "
+ "dim_names=[C,H,W]]>";
+ ASSERT_EQ(expected_3, result_3);
+
+ std::string result_4 = ext_type_4->ToString();
+ std::string expected_4 =
+ "extension<arrow.variable_shape_tensor[value_type=int64, ndim=3, "
+ "uniform_shape=[null,1,null]]>";
+ ASSERT_EQ(expected_4, result_4);
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/extension/tensor_internal.cc
b/cpp/src/arrow/extension/tensor_internal.cc
new file mode 100644
index 0000000000..37862b7689
--- /dev/null
+++ b/cpp/src/arrow/extension/tensor_internal.cc
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/extension/tensor_internal.h"
+
+#include <numeric>
+
+#include "arrow/array/array_base.h"
+#include "arrow/buffer.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/int_util_overflow.h"
+#include "arrow/util/print_internal.h"
+#include "arrow/util/sort_internal.h"
+
+namespace arrow::internal {
+
+bool IsPermutationTrivial(std::span<const int64_t> permutation) {
+ for (size_t i = 1; i < permutation.size(); ++i) {
+ if (permutation[i - 1] + 1 != permutation[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+Status IsPermutationValid(std::span<const int64_t> permutation) {
+ const auto size = static_cast<int64_t>(permutation.size());
+ std::vector<uint8_t> dim_seen(size, 0);
+
+ for (const auto p : permutation) {
+ if (p < 0 || p >= size || dim_seen[p] != 0) {
+ return Status::Invalid(
+ "Permutation indices for ", size,
+ " dimensional tensors must be unique and within [0, ", size - 1,
+ "] range. Got: ", ::arrow::internal::PrintVector{permutation, ","});
+ }
+ dim_seen[p] = 1;
+ }
+ return Status::OK();
+}
+
+Result<std::vector<int64_t>> ComputeStrides(const std::shared_ptr<DataType>&
value_type,
+ std::span<const int64_t> shape,
+ std::span<const int64_t>
permutation) {
+ const auto ndim = shape.size();
+ const int byte_width = value_type->byte_width();
+
+ // Use identity permutation if none provided
+ std::vector<int64_t> perm;
+ if (permutation.empty()) {
+ perm.resize(ndim);
+ std::iota(perm.begin(), perm.end(), 0);
+ } else {
+ perm.assign(permutation.begin(), permutation.end());
+ }
+
+ int64_t remaining = 0;
+ if (!shape.empty() && shape[0] > 0) {
+ remaining = byte_width;
+ for (auto i : perm) {
+ if (i > 0) {
+ if (MultiplyWithOverflow(remaining, shape[i], &remaining)) {
+ return Status::Invalid(
+ "Strides computed from shape would not fit in 64-bit integer");
+ }
+ }
+ }
+ }
+
+ std::vector<int64_t> strides;
+ if (remaining == 0) {
+ strides.assign(ndim, byte_width);
+ return strides;
+ }
+
+ strides.push_back(remaining);
+ for (auto i : perm) {
+ if (i > 0) {
+ remaining /= shape[i];
+ strides.push_back(remaining);
+ }
+ }
+ Permute(perm, &strides);
+
+ return strides;
+}
+
+Result<std::shared_ptr<Buffer>> SliceTensorBuffer(const Array& data_array,
+ const DataType& value_type,
+ std::span<const int64_t>
shape) {
+ const int64_t byte_width = value_type.byte_width();
+ int64_t size = 1;
+ for (const auto dim : shape) {
+ if (MultiplyWithOverflow(size, dim, &size)) {
+ return Status::Invalid("Tensor size would not fit in 64-bit integer");
+ }
+ }
+ if (size != data_array.length()) {
+ return Status::Invalid("Expected data array of length ", size, ", got ",
+ data_array.length());
+ }
+
+ int64_t start_position = 0;
+ if (MultiplyWithOverflow(data_array.offset(), byte_width, &start_position)) {
+ return Status::Invalid("Data offset in bytes would not fit in 64-bit
integer");
+ }
+ int64_t size_bytes = 0;
+ if (MultiplyWithOverflow(size, byte_width, &size_bytes)) {
+ return Status::Invalid("Tensor byte size would not fit in 64-bit integer");
+ }
+
+ return SliceBufferSafe(data_array.data()->buffers[1], start_position,
size_bytes);
+}
+
+} // namespace arrow::internal
diff --git a/cpp/src/arrow/extension/tensor_internal.h
b/cpp/src/arrow/extension/tensor_internal.h
index 62b1dba614..b5ed5ebe11 100644
--- a/cpp/src/arrow/extension/tensor_internal.h
+++ b/cpp/src/arrow/extension/tensor_internal.h
@@ -18,27 +18,28 @@
#pragma once
#include <cstdint>
+#include <span>
#include <vector>
-#include "arrow/status.h"
-#include "arrow/util/print_internal.h"
+#include "arrow/result.h"
+#include "arrow/type_fwd.h"
namespace arrow::internal {
-inline Status IsPermutationValid(const std::vector<int64_t>& permutation) {
- const auto size = static_cast<int64_t>(permutation.size());
- std::vector<uint8_t> dim_seen(size, 0);
-
- for (const auto p : permutation) {
- if (p < 0 || p >= size || dim_seen[p] != 0) {
- return Status::Invalid(
- "Permutation indices for ", size,
- " dimensional tensors must be unique and within [0, ", size - 1,
- "] range. Got: ", ::arrow::internal::PrintVector{permutation, ","});
- }
- dim_seen[p] = 1;
- }
- return Status::OK();
-}
+ARROW_EXPORT
+bool IsPermutationTrivial(std::span<const int64_t> permutation);
+
+ARROW_EXPORT
+Status IsPermutationValid(std::span<const int64_t> permutation);
+
+ARROW_EXPORT
+Result<std::vector<int64_t>> ComputeStrides(const std::shared_ptr<DataType>&
value_type,
+ std::span<const int64_t> shape,
+ std::span<const int64_t>
permutation);
+
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> SliceTensorBuffer(const Array& data_array,
+ const DataType& value_type,
+ std::span<const int64_t>
shape);
} // namespace arrow::internal
diff --git a/cpp/src/arrow/extension/variable_shape_tensor.cc
b/cpp/src/arrow/extension/variable_shape_tensor.cc
new file mode 100644
index 0000000000..7e27bbdb74
--- /dev/null
+++ b/cpp/src/arrow/extension/variable_shape_tensor.cc
@@ -0,0 +1,325 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+
+#include "arrow/extension/tensor_internal.h"
+#include "arrow/extension/variable_shape_tensor.h"
+
+#include "arrow/array/array_primitive.h"
+#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
+#include "arrow/scalar.h"
+#include "arrow/tensor.h"
+#include "arrow/util/logging_internal.h"
+#include "arrow/util/print_internal.h"
+#include "arrow/util/sort_internal.h"
+#include "arrow/util/string.h"
+
+#include <rapidjson/document.h>
+#include <rapidjson/writer.h>
+
+namespace rj = arrow::rapidjson;
+
+namespace arrow::extension {
+
+bool VariableShapeTensorType::ExtensionEquals(const ExtensionType& other)
const {
+ if (extension_name() != other.extension_name()) {
+ return false;
+ }
+ const auto& other_ext = internal::checked_cast<const
VariableShapeTensorType&>(other);
+ if (this->ndim() != other_ext.ndim()) {
+ return false;
+ }
+
+ const bool permutation_equivalent =
+ (permutation_ == other_ext.permutation()) ||
+ (internal::IsPermutationTrivial(permutation_) &&
+ internal::IsPermutationTrivial(other_ext.permutation()));
+
+ return (storage_type()->Equals(other_ext.storage_type())) &&
+ (dim_names_ == other_ext.dim_names()) &&
+ (uniform_shape_ == other_ext.uniform_shape()) &&
permutation_equivalent;
+}
+
+std::string VariableShapeTensorType::ToString(bool show_metadata) const {
+ std::stringstream ss;
+ ss << "extension<" << this->extension_name()
+ << "[value_type=" << value_type_->ToString(show_metadata) << ", ndim=" <<
ndim_;
+
+ if (!permutation_.empty()) {
+ ss << ", permutation=" << ::arrow::internal::PrintVector{permutation_,
","};
+ }
+ if (!dim_names_.empty()) {
+ ss << ", dim_names=[" << internal::JoinStrings(dim_names_, ",") << "]";
+ }
+ if (!uniform_shape_.empty()) {
+ std::vector<std::string> uniform_shape;
+ for (const auto& v : uniform_shape_) {
+ if (v.has_value()) {
+ uniform_shape.emplace_back(std::to_string(v.value()));
+ } else {
+ uniform_shape.emplace_back("null");
+ }
+ }
+ ss << ", uniform_shape=[" << internal::JoinStrings(uniform_shape, ",") <<
"]";
+ }
+ ss << "]>";
+ return ss.str();
+}
+
+std::string VariableShapeTensorType::Serialize() const {
+ rj::Document document;
+ document.SetObject();
+ rj::Document::AllocatorType& allocator = document.GetAllocator();
+
+ if (!permutation_.empty()) {
+ rj::Value permutation(rj::kArrayType);
+ for (auto v : permutation_) {
+ permutation.PushBack(v, allocator);
+ }
+ document.AddMember(rj::Value("permutation", allocator), permutation,
allocator);
+ }
+
+ if (!dim_names_.empty()) {
+ rj::Value dim_names(rj::kArrayType);
+ for (const std::string& v : dim_names_) {
+ dim_names.PushBack(rj::Value{}.SetString(v.c_str(), allocator),
allocator);
+ }
+ document.AddMember(rj::Value("dim_names", allocator), dim_names,
allocator);
+ }
+
+ if (!uniform_shape_.empty()) {
+ rj::Value uniform_shape(rj::kArrayType);
+ for (auto v : uniform_shape_) {
+ if (v.has_value()) {
+ uniform_shape.PushBack(v.value(), allocator);
+ } else {
+ uniform_shape.PushBack(rj::Value{}.SetNull(), allocator);
+ }
+ }
+ document.AddMember(rj::Value("uniform_shape", allocator), uniform_shape,
allocator);
+ }
+
+ rj::StringBuffer buffer;
+ rj::Writer<rj::StringBuffer> writer(buffer);
+ document.Accept(writer);
+ return buffer.GetString();
+}
+
+Result<std::shared_ptr<DataType>> VariableShapeTensorType::Deserialize(
+ std::shared_ptr<DataType> storage_type, const std::string&
serialized_data) const {
+ if (storage_type->id() != Type::STRUCT) {
+ return Status::Invalid("Expected Struct storage type, got ",
+ storage_type->ToString());
+ }
+ if (storage_type->num_fields() != 2) {
+ return Status::Invalid("Expected Struct storage type with 2 fields, got ",
+ storage_type->num_fields());
+ }
+ if (storage_type->field(0)->type()->id() != Type::LIST) {
+ return Status::Invalid("Expected List storage type, got ",
+ storage_type->field(0)->type()->ToString());
+ }
+ if (storage_type->field(1)->type()->id() != Type::FIXED_SIZE_LIST) {
+ return Status::Invalid("Expected FixedSizeList storage type, got ",
+ storage_type->field(1)->type()->ToString());
+ }
+ if (internal::checked_cast<const
FixedSizeListType&>(*storage_type->field(1)->type())
+ .value_type() != int32()) {
+ return Status::Invalid("Expected FixedSizeList value type int32, got ",
+ storage_type->field(1)->type()->ToString());
+ }
+
+ const auto value_type = storage_type->field(0)->type()->field(0)->type();
+ const int32_t ndim =
+ internal::checked_cast<const
FixedSizeListType&>(*storage_type->field(1)->type())
+ .list_size();
+
+ rj::Document document;
+ if (document.Parse(serialized_data.data(),
serialized_data.length()).HasParseError() ||
+ !document.IsObject()) {
+ return Status::Invalid("Invalid serialized JSON data: ", serialized_data);
+ }
+
+ std::vector<int64_t> permutation;
+ if (document.HasMember("permutation")) {
+ const auto& json_permutation = document["permutation"];
+ if (!json_permutation.IsArray()) {
+ return Status::Invalid("permutation must be an array");
+ }
+ permutation.reserve(ndim);
+ for (const auto& x : json_permutation.GetArray()) {
+ if (!x.IsInt64()) {
+ return Status::Invalid("permutation must contain integers");
+ }
+ permutation.emplace_back(x.GetInt64());
+ }
+ }
+ std::vector<std::string> dim_names;
+ if (document.HasMember("dim_names")) {
+ const auto& json_dim_names = document["dim_names"];
+ if (!json_dim_names.IsArray()) {
+ return Status::Invalid("dim_names must be an array");
+ }
+ dim_names.reserve(ndim);
+ for (const auto& x : json_dim_names.GetArray()) {
+ if (!x.IsString()) {
+ return Status::Invalid("dim_names must contain strings");
+ }
+ dim_names.emplace_back(x.GetString());
+ }
+ }
+
+ std::vector<std::optional<int64_t>> uniform_shape;
+ if (document.HasMember("uniform_shape")) {
+ const auto& json_uniform_shape = document["uniform_shape"];
+ if (!json_uniform_shape.IsArray()) {
+ return Status::Invalid("uniform_shape must be an array");
+ }
+ uniform_shape.reserve(ndim);
+ for (const auto& x : json_uniform_shape.GetArray()) {
+ if (x.IsNull()) {
+ uniform_shape.emplace_back(std::nullopt);
+ } else if (x.IsInt64()) {
+ uniform_shape.emplace_back(x.GetInt64());
+ } else {
+ return Status::Invalid("uniform_shape must contain integers or nulls");
+ }
+ }
+ }
+
+ return VariableShapeTensorType::Make(value_type, ndim, permutation,
dim_names,
+ uniform_shape);
+}
+
+std::shared_ptr<Array> VariableShapeTensorType::MakeArray(
+ std::shared_ptr<ArrayData> data) const {
+ DCHECK_EQ(data->type->id(), Type::EXTENSION);
+ DCHECK_EQ("arrow.variable_shape_tensor",
+ internal::checked_cast<const
ExtensionType&>(*data->type).extension_name());
+ return std::make_shared<VariableShapeTensorArray>(data);
+}
+
+Result<std::shared_ptr<Tensor>> VariableShapeTensorType::MakeTensor(
+ const std::shared_ptr<ExtensionScalar>& scalar) {
+ const auto& tensor_scalar = internal::checked_cast<const
StructScalar&>(*scalar->value);
+ const auto& ext_type =
+ internal::checked_cast<const VariableShapeTensorType&>(*scalar->type);
+
+ if (!tensor_scalar.is_valid) {
+ return Status::Invalid("Cannot convert null scalar to Tensor.");
+ }
+ ARROW_ASSIGN_OR_RAISE(const auto data_scalar, tensor_scalar.field(0));
+ ARROW_ASSIGN_OR_RAISE(const auto shape_scalar, tensor_scalar.field(1));
+ const auto data_array =
+ internal::checked_pointer_cast<BaseListScalar>(data_scalar)->value;
+ const auto shape_array = internal::checked_pointer_cast<Int32Array>(
+
internal::checked_pointer_cast<FixedSizeListScalar>(shape_scalar)->value);
+
+ const auto& value_type =
+ internal::checked_cast<const FixedWidthType&>(*ext_type.value_type());
+
+ if (data_array->null_count() > 0) {
+ return Status::Invalid("Cannot convert data with nulls to Tensor.");
+ }
+
+ auto permutation = ext_type.permutation();
+ if (permutation.empty()) {
+ permutation.resize(ext_type.ndim());
+ std::iota(permutation.begin(), permutation.end(), 0);
+ }
+
+ if (shape_array->length() != ext_type.ndim()) {
+ return Status::Invalid("Expected shape array of length ", ext_type.ndim(),
", got ",
+ shape_array->length());
+ }
+ std::vector<int64_t> shape;
+ shape.reserve(ext_type.ndim());
+ for (int64_t j = 0; j < static_cast<int64_t>(ext_type.ndim()); ++j) {
+ const auto size_value = shape_array->Value(j);
+ if (size_value < 0) {
+ return Status::Invalid("shape must have non-negative values");
+ }
+ shape.push_back(size_value);
+ }
+
+ std::vector<std::string> dim_names = ext_type.dim_names();
+ if (!dim_names.empty()) {
+ internal::Permute<std::string>(permutation, &dim_names);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto strides, internal::ComputeStrides(ext_type.value_type(), shape,
permutation));
+ internal::Permute<int64_t>(permutation, &shape);
+
+ ARROW_ASSIGN_OR_RAISE(const auto buffer,
+ internal::SliceTensorBuffer(*data_array, value_type,
shape));
+
+ return Tensor::Make(ext_type.value_type(), buffer, shape, strides,
dim_names);
+}
+
+Result<std::shared_ptr<DataType>> VariableShapeTensorType::Make(
+ const std::shared_ptr<DataType>& value_type, int32_t ndim,
+ std::vector<int64_t> permutation, std::vector<std::string> dim_names,
+ std::vector<std::optional<int64_t>> uniform_shape) {
+ if (!is_fixed_width(*value_type)) {
+ return Status::Invalid("Cannot convert non-fixed-width values to Tensor.");
+ }
+ if (ndim < 0) {
+ return Status::Invalid("ndim must be non-negative. Got: ", ndim);
+ }
+
+ if (!dim_names.empty() && dim_names.size() != static_cast<size_t>(ndim)) {
+ return Status::Invalid("dim_names size must match ndim. Expected: ", ndim,
+ " Got: ", dim_names.size());
+ }
+ if (!uniform_shape.empty() && uniform_shape.size() !=
static_cast<size_t>(ndim)) {
+ return Status::Invalid("uniform_shape size must match ndim. Expected: ",
ndim,
+ " Got: ", uniform_shape.size());
+ }
+ if (!uniform_shape.empty()) {
+ for (const auto& v : uniform_shape) {
+ if (v.has_value() && v.value() < 0) {
+ return Status::Invalid("uniform_shape must have non-negative values");
+ }
+ }
+ }
+ if (!permutation.empty()) {
+ if (permutation.size() != static_cast<size_t>(ndim)) {
+ return Status::Invalid("permutation size must match ndim. Expected: ",
ndim,
+ " Got: ", permutation.size());
+ }
+ RETURN_NOT_OK(internal::IsPermutationValid(permutation));
+ }
+
+ return std::make_shared<VariableShapeTensorType>(
+ value_type, ndim, std::move(permutation), std::move(dim_names),
+ std::move(uniform_shape));
+}
+
+std::shared_ptr<DataType> variable_shape_tensor(
+ const std::shared_ptr<DataType>& value_type, int32_t ndim,
+ std::vector<int64_t> permutation, std::vector<std::string> dim_names,
+ std::vector<std::optional<int64_t>> uniform_shape) {
+ auto maybe_type =
+ VariableShapeTensorType::Make(value_type, ndim, std::move(permutation),
+ std::move(dim_names),
std::move(uniform_shape));
+ ARROW_CHECK_OK(maybe_type.status());
+ return maybe_type.MoveValueUnsafe();
+}
+
+} // namespace arrow::extension
diff --git a/cpp/src/arrow/extension/variable_shape_tensor.h
b/cpp/src/arrow/extension/variable_shape_tensor.h
new file mode 100644
index 0000000000..eb76a9e27a
--- /dev/null
+++ b/cpp/src/arrow/extension/variable_shape_tensor.h
@@ -0,0 +1,111 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "arrow/extension_type.h"
+
+namespace arrow::extension {
+
+class ARROW_EXPORT VariableShapeTensorArray : public ExtensionArray {
+ public:
+ using ExtensionArray::ExtensionArray;
+};
+
+/// \brief Concrete type class for variable-shape Tensor data.
+/// This is a canonical arrow extension type.
+/// See: https://arrow.apache.org/docs/format/CanonicalExtensions.html
+class ARROW_EXPORT VariableShapeTensorType : public ExtensionType {
+ public:
+ VariableShapeTensorType(const std::shared_ptr<DataType>& value_type, int32_t
ndim,
+ std::vector<int64_t> permutation = {},
+ std::vector<std::string> dim_names = {},
+ std::vector<std::optional<int64_t>> uniform_shape =
{})
+ : ExtensionType(struct_({::arrow::field("data", list(value_type)),
+ ::arrow::field("shape",
fixed_size_list(int32(), ndim))})),
+ value_type_(value_type),
+ ndim_(ndim),
+ permutation_(std::move(permutation)),
+ dim_names_(std::move(dim_names)),
+ uniform_shape_(std::move(uniform_shape)) {}
+
+ std::string extension_name() const override { return
"arrow.variable_shape_tensor"; }
+ std::string ToString(bool show_metadata = false) const override;
+
+ /// Number of dimensions of tensor elements
+ int32_t ndim() const { return ndim_; }
+
+ /// Value type of tensor elements
+ const std::shared_ptr<DataType>& value_type() const { return value_type_; }
+
+ /// Permutation mapping from logical to physical memory layout of tensor
elements
+ const std::vector<int64_t>& permutation() const { return permutation_; }
+
+ /// Dimension names of tensor elements. Dimensions are ordered physically.
+ const std::vector<std::string>& dim_names() const { return dim_names_; }
+
+ /// Shape of uniform dimensions.
+ const std::vector<std::optional<int64_t>>& uniform_shape() const {
+ return uniform_shape_;
+ }
+
+ bool ExtensionEquals(const ExtensionType& other) const override;
+
+ std::string Serialize() const override;
+
+ Result<std::shared_ptr<DataType>> Deserialize(
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized_data) const override;
+
+ /// Create a VariableShapeTensorArray from ArrayData
+ std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const
override;
+
+ /// \brief Convert an ExtensionScalar to a Tensor
+ ///
+ /// This method will return a Tensor from ExtensionScalar with strides
derived
+ /// from shape and permutation stored. Shape and dim_names will be permuted
+ /// according to permutation stored in the VariableShapeTensorType.
+ static Result<std::shared_ptr<Tensor>> MakeTensor(
+ const std::shared_ptr<ExtensionScalar>&);
+
+ /// \brief Create a VariableShapeTensorType instance
+ static Result<std::shared_ptr<DataType>> Make(
+ const std::shared_ptr<DataType>& value_type, int32_t ndim,
+ std::vector<int64_t> permutation = {}, std::vector<std::string>
dim_names = {},
+ std::vector<std::optional<int64_t>> uniform_shape = {});
+
+ private:
+ std::shared_ptr<DataType> value_type_;
+ int32_t ndim_;
+ std::vector<int64_t> permutation_;
+ std::vector<std::string> dim_names_;
+ std::vector<std::optional<int64_t>> uniform_shape_;
+};
+
+/// \brief Return a VariableShapeTensorType instance.
+ARROW_EXPORT std::shared_ptr<DataType> variable_shape_tensor(
+ const std::shared_ptr<DataType>& value_type, int32_t ndim,
+ std::vector<int64_t> permutation = {}, std::vector<std::string> dim_names
= {},
+ std::vector<std::optional<int64_t>> uniform_shape = {});
+
+} // namespace arrow::extension
diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc
index 555ffe0156..ce88c95174 100644
--- a/cpp/src/arrow/extension_type.cc
+++ b/cpp/src/arrow/extension_type.cc
@@ -31,6 +31,7 @@
#ifdef ARROW_JSON
# include "arrow/extension/fixed_shape_tensor.h"
# include "arrow/extension/opaque.h"
+# include "arrow/extension/variable_shape_tensor.h"
#endif
#include "arrow/extension/json.h"
#include "arrow/extension/uuid.h"
@@ -155,6 +156,7 @@ static void CreateGlobalRegistry() {
#ifdef ARROW_JSON
ext_types.push_back(extension::fixed_shape_tensor(int64(), {}));
ext_types.push_back(extension::opaque(null(), "", ""));
+ ext_types.push_back(extension::variable_shape_tensor(int64(), 0));
#endif
for (const auto& ext_type : ext_types) {
diff --git a/cpp/src/arrow/extension_type_test.cc
b/cpp/src/arrow/extension_type_test.cc
index 23c1ff731d..0b256f1b45 100644
--- a/cpp/src/arrow/extension_type_test.cc
+++ b/cpp/src/arrow/extension_type_test.cc
@@ -40,10 +40,10 @@
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging_internal.h"
-namespace arrow {
-
using arrow::ipc::test::RoundtripBatch;
+namespace arrow {
+
class Parametric1Array : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
diff --git a/docs/source/format/CanonicalExtensions.rst
b/docs/source/format/CanonicalExtensions.rst
index 41b94aa0a8..5de0da8354 100644
--- a/docs/source/format/CanonicalExtensions.rst
+++ b/docs/source/format/CanonicalExtensions.rst
@@ -248,8 +248,8 @@ Variable shape tensor
This means the logical tensor has names [z, x, y] and shape [30, 10, 20].
.. note::
- Values inside each **data** tensor element are stored in
row-major/C-contiguous
- order according to the corresponding **shape**.
+ Elements in a variable shape tensor extension array are stored
+ in row-major/C-contiguous order.
.. _json_extension:
diff --git a/python/pyarrow/tests/test_extension_type.py
b/python/pyarrow/tests/test_extension_type.py
index ebac37e862..c947b06e0e 100644
--- a/python/pyarrow/tests/test_extension_type.py
+++ b/python/pyarrow/tests/test_extension_type.py
@@ -1482,6 +1482,17 @@ def test_tensor_class_methods(np_type_str):
assert result.to_tensor().shape == (1, 3, 2, 2)
assert result.to_tensor().strides == (12 * bw, 1 * bw, 6 * bw, 2 * bw)
+ tensor_type = pa.fixed_shape_tensor(arrow_type, [2, 2, 3], permutation=[2,
1, 0])
+ result = pa.ExtensionArray.from_storage(tensor_type, storage)
+ expected = as_strided(flat_arr, shape=(1, 3, 2, 2),
+ strides=(bw * 12, bw, bw * 3, bw * 6))
+ np.testing.assert_array_equal(result.to_numpy_ndarray(), expected)
+
+ assert result.type.permutation == [2, 1, 0]
+ assert result.type.shape == [2, 2, 3]
+ assert result.to_tensor().shape == (1, 3, 2, 2)
+ assert result.to_tensor().strides == (12 * bw, 1 * bw, 3 * bw, 6 * bw)
+
@pytest.mark.numpy
@pytest.mark.parametrize("np_type_str", ("int8", "int64", "float32"))