This is an automated email from the ASF dual-hosted git repository.

jorisvandenbossche pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new aff876a572 GH-34796: [C++] Add FromTensor, ToTensor and strides 
methods to FixedShapeTensorArray (#34797)
aff876a572 is described below

commit aff876a572db7a732fbafb0dbbf53f078bc79403
Author: Rok Mihevc <[email protected]>
AuthorDate: Tue Apr 11 18:15:20 2023 +0200

    GH-34796: [C++] Add FromTensor, ToTensor and strides methods to 
FixedShapeTensorArray (#34797)
    
    ### Rationale for this change
    
    We want to enable converting Tensors to FixedShapeTensorArrays and the 
other way around.
    
    ### What changes are included in this PR?
    
    This adds FromTensor, ToTensor to FixedShapeTensorArrays and strides method 
to FixedShapeTensorType.
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    This adds FromTensor, ToTensor and strides are user facing methods.
    * Closes: #34796
    
    Authored-by: Rok Mihevc <[email protected]>
    Signed-off-by: Joris Van den Bossche <[email protected]>
---
 cpp/src/arrow/extension/fixed_shape_tensor.cc      | 182 +++++++++++++++++
 cpp/src/arrow/extension/fixed_shape_tensor.h       |  26 +++
 cpp/src/arrow/extension/fixed_shape_tensor_test.cc | 222 +++++++++++++++++++++
 3 files changed, 430 insertions(+)

diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc 
b/cpp/src/arrow/extension/fixed_shape_tensor.cc
index 8b0ed43df5..1debac0e70 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor.cc
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc
@@ -23,6 +23,7 @@
 #include "arrow/array/array_nested.h"
 #include "arrow/array/array_primitive.h"
 #include "arrow/json/rapidjson_defs.h"  // IWYU pragma: keep
+#include "arrow/tensor.h"
 #include "arrow/util/int_util_overflow.h"
 #include "arrow/util/logging.h"
 #include "arrow/util/sort.h"
@@ -33,8 +34,52 @@
 namespace rj = arrow::rapidjson;
 
 namespace arrow {
+
 namespace extension {
 
+namespace {
+
+Status ComputeStrides(const FixedWidthType& type, const std::vector<int64_t>& 
shape,
+                      const std::vector<int64_t>& permutation,
+                      std::vector<int64_t>* strides) {
+  if (permutation.empty()) {
+    return internal::ComputeRowMajorStrides(type, shape, strides);
+  }
+
+  const int byte_width = type.byte_width();
+
+  int64_t remaining = 0;
+  if (!shape.empty() && shape.front() > 0) {
+    remaining = byte_width;
+    for (auto i : permutation) {
+      if (i > 0) {
+        if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
+          return Status::Invalid(
+              "Strides computed from shape would not fit in 64-bit integer");
+        }
+      }
+    }
+  }
+
+  if (remaining == 0) {
+    strides->assign(shape.size(), byte_width);
+    return Status::OK();
+  }
+
+  strides->push_back(remaining);
+  for (auto i : permutation) {
+    if (i > 0) {
+      remaining /= shape[i];
+      strides->push_back(remaining);
+    }
+  }
+  internal::Permute(permutation, strides);
+
+  return Status::OK();
+}
+
+}  // namespace
+
 bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
   if (extension_name() != other.extension_name()) {
     return false;
@@ -140,6 +185,132 @@ std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
   return std::make_shared<ExtensionArray>(data);
 }
 
+Result<std::shared_ptr<FixedShapeTensorArray>> 
FixedShapeTensorArray::FromTensor(
+    const std::shared_ptr<Tensor>& tensor) {
+  auto permutation = internal::ArgSort(tensor->strides(), std::greater<>());
+  if (permutation[0] != 0) {
+    return Status::Invalid(
+        "Only first-major tensors can be zero-copy converted to arrays");
+  }
+  permutation.erase(permutation.begin());
+
+  std::vector<int64_t> cell_shape;
+  for (auto i : permutation) {
+    cell_shape.emplace_back(tensor->shape()[i]);
+  }
+
+  std::vector<std::string> dim_names;
+  if (!tensor->dim_names().empty()) {
+    for (auto i : permutation) {
+      dim_names.emplace_back(tensor->dim_names()[i]);
+    }
+  }
+
+  for (int64_t& i : permutation) {
+    --i;
+  }
+
+  auto ext_type = internal::checked_pointer_cast<ExtensionType>(
+      fixed_shape_tensor(tensor->type(), cell_shape, permutation, dim_names));
+
+  std::shared_ptr<Array> value_array;
+  switch (tensor->type_id()) {
+    case Type::UINT8: {
+      value_array = std::make_shared<UInt8Array>(tensor->size(), 
tensor->data());
+      break;
+    }
+    case Type::INT8: {
+      value_array = std::make_shared<Int8Array>(tensor->size(), 
tensor->data());
+      break;
+    }
+    case Type::UINT16: {
+      value_array = std::make_shared<UInt16Array>(tensor->size(), 
tensor->data());
+      break;
+    }
+    case Type::INT16: {
+      value_array = std::make_shared<Int16Array>(tensor->size(), 
tensor->data());
+      break;
+    }
+    case Type::UINT32: {
+      value_array = std::make_shared<UInt32Array>(tensor->size(), 
tensor->data());
+      break;
+    }
+    case Type::INT32: {
+      value_array = std::make_shared<Int32Array>(tensor->size(), 
tensor->data());
+      break;
+    }
+    case Type::UINT64: {
+      value_array = std::make_shared<Int64Array>(tensor->size(), 
tensor->data());
+      break;
+    }
+    case Type::INT64: {
+      value_array = std::make_shared<Int64Array>(tensor->size(), 
tensor->data());
+      break;
+    }
+    case Type::HALF_FLOAT: {
+      value_array = std::make_shared<HalfFloatArray>(tensor->size(), 
tensor->data());
+      break;
+    }
+    case Type::FLOAT: {
+      value_array = std::make_shared<FloatArray>(tensor->size(), 
tensor->data());
+      break;
+    }
+    case Type::DOUBLE: {
+      value_array = std::make_shared<DoubleArray>(tensor->size(), 
tensor->data());
+      break;
+    }
+    default: {
+      return Status::NotImplemented("Unsupported tensor type: ",
+                                    tensor->type()->ToString());
+    }
+  }
+  auto cell_size = static_cast<int32_t>(tensor->size() / tensor->shape()[0]);
+  ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> arr,
+                        FixedSizeListArray::FromArrays(value_array, 
cell_size));
+  std::shared_ptr<Array> ext_arr = ExtensionType::WrapArray(ext_type, arr);
+  return std::reinterpret_pointer_cast<FixedShapeTensorArray>(ext_arr);
+}
+
+const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
+  // To convert an array of n dimensional tensors to a n+1 dimensional tensor 
we
+  // interpret the array's length as the first dimension the new tensor.
+
+  auto ext_arr = 
internal::checked_pointer_cast<FixedSizeListArray>(this->storage());
+  auto ext_type = 
internal::checked_pointer_cast<FixedShapeTensorType>(this->type());
+  ARROW_RETURN_IF(!is_fixed_width(*ext_arr->value_type()),
+                  Status::Invalid(ext_arr->value_type()->ToString(),
+                                  " is not valid data type for a tensor"));
+  auto permutation = ext_type->permutation();
+
+  std::vector<std::string> dim_names;
+  if (!ext_type->dim_names().empty()) {
+    for (auto i : permutation) {
+      dim_names.emplace_back(ext_type->dim_names()[i]);
+    }
+    dim_names.insert(dim_names.begin(), 1, "");
+  } else {
+    dim_names = {};
+  }
+
+  std::vector<int64_t> shape;
+  for (int64_t& i : permutation) {
+    shape.emplace_back(ext_type->shape()[i]);
+    ++i;
+  }
+  shape.insert(shape.begin(), 1, this->length());
+  permutation.insert(permutation.begin(), 1, 0);
+
+  std::vector<int64_t> tensor_strides;
+  auto value_type = 
internal::checked_pointer_cast<FixedWidthType>(ext_arr->value_type());
+  ARROW_RETURN_NOT_OK(
+      ComputeStrides(*value_type.get(), shape, permutation, &tensor_strides));
+  ARROW_ASSIGN_OR_RAISE(auto buffers, ext_arr->Flatten());
+  ARROW_ASSIGN_OR_RAISE(
+      auto tensor, Tensor::Make(ext_arr->value_type(), 
buffers->data()->buffers[1], shape,
+                                tensor_strides, dim_names));
+  return tensor;
+}
+
 Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
     const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& 
shape,
     const std::vector<int64_t>& permutation, const std::vector<std::string>& 
dim_names) {
@@ -157,6 +328,17 @@ Result<std::shared_ptr<DataType>> 
FixedShapeTensorType::Make(
                                                 shape, permutation, dim_names);
 }
 
+const std::vector<int64_t>& FixedShapeTensorType::strides() {
+  if (strides_.empty()) {
+    auto value_type = 
internal::checked_pointer_cast<FixedWidthType>(this->value_type_);
+    std::vector<int64_t> tensor_strides;
+    ARROW_CHECK_OK(ComputeStrides(*value_type.get(), this->shape(), 
this->permutation(),
+                                  &tensor_strides));
+    strides_ = tensor_strides;
+  }
+  return strides_;
+}
+
 std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& 
value_type,
                                              const std::vector<int64_t>& shape,
                                              const std::vector<int64_t>& 
permutation,
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.h 
b/cpp/src/arrow/extension/fixed_shape_tensor.h
index 4ee2b894ee..93837f1300 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor.h
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.h
@@ -23,6 +23,26 @@ namespace extension {
 class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
  public:
   using ExtensionArray::ExtensionArray;
+
+  /// \brief Create a FixedShapeTensorArray from a Tensor
+  ///
+  /// This method will create a FixedShapeTensorArray from a Tensor, taking 
its first
+  /// dimension as the number of elements in the resulting array and the 
remaining
+  /// dimensions as the shape of the individual tensors. If Tensor provides 
strides,
+  /// they will be used to determine dimension permutation. Otherwise, 
row-major layout
+  /// (i.e. no permutation) will be assumed.
+  ///
+  /// \param[in] tensor The Tensor to convert to a FixedShapeTensorArray
+  static Result<std::shared_ptr<FixedShapeTensorArray>> FromTensor(
+      const std::shared_ptr<Tensor>& tensor);
+
+  /// \brief Create a Tensor from FixedShapeTensorArray
+  ///
+  /// This method will create a Tensor from a FixedShapeTensorArray, setting 
its first
+  /// dimension as length equal to the FixedShapeTensorArray's length and the 
remaining
+  /// dimensions as the FixedShapeTensorType's shape. Shape and dim_names will 
be
+  /// permuted according to permutation stored in the FixedShapeTensorType 
metadata.
+  const Result<std::shared_ptr<Tensor>> ToTensor() const;
 };
 
 /// \brief Concrete type class for constant-size Tensor data.
@@ -51,6 +71,11 @@ class ARROW_EXPORT FixedShapeTensorType : public 
ExtensionType {
   /// Value type of tensor elements
   const std::shared_ptr<DataType> value_type() const { return value_type_; }
 
+  /// Strides of tensor elements. Strides state offset in bytes between 
adjacent
+  /// elements along each dimension. In case permutation is non-empty strides 
are
+  /// computed from permuted tensor element's shape.
+  const std::vector<int64_t>& strides();
+
   /// Permutation mapping from logical to physical memory layout of tensor 
elements
   const std::vector<int64_t>& permutation() const { return permutation_; }
 
@@ -78,6 +103,7 @@ class ARROW_EXPORT FixedShapeTensorType : public 
ExtensionType {
   std::shared_ptr<DataType> storage_type_;
   std::shared_ptr<DataType> value_type_;
   std::vector<int64_t> shape_;
+  std::vector<int64_t> strides_;
   std::vector<int64_t> permutation_;
   std::vector<std::string> dim_names_;
 };
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc 
b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
index 16ba9d2014..50132e25fb 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
+++ b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
@@ -47,17 +47,26 @@ class TestExtensionType : public ::testing::Test {
         fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_));
     values_ = {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 
16, 17,
                18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 
34, 35};
+    values_partial_ = {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
+                       12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23};
+    shape_partial_ = {2, 3, 4};
+    tensor_strides_ = {96, 32, 8};
+    cell_strides_ = {32, 8};
     serialized_ = R"({"shape":[3,4],"dim_names":["x","y"]})";
   }
 
  protected:
   std::vector<int64_t> shape_;
+  std::vector<int64_t> shape_partial_;
   std::vector<int64_t> cell_shape_;
   std::shared_ptr<DataType> value_type_;
   std::shared_ptr<DataType> cell_type_;
   std::vector<std::string> dim_names_;
   std::shared_ptr<ExtensionType> ext_type_;
   std::vector<int64_t> values_;
+  std::vector<int64_t> values_partial_;
+  std::vector<int64_t> tensor_strides_;
+  std::vector<int64_t> cell_strides_;
   std::string serialized_;
 };
 
@@ -100,6 +109,7 @@ TEST_F(TestExtensionType, CreateExtensionType) {
   ASSERT_EQ(exact_ext_type->ndim(), cell_shape_.size());
   ASSERT_EQ(exact_ext_type->shape(), cell_shape_);
   ASSERT_EQ(exact_ext_type->value_type(), value_type_);
+  ASSERT_EQ(exact_ext_type->strides(), cell_strides_);
   ASSERT_EQ(exact_ext_type->dim_names(), dim_names_);
 
   EXPECT_RAISES_WITH_MESSAGE_THAT(
@@ -212,4 +222,216 @@ TEST_F(TestExtensionType, RoudtripBatch) {
   CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true);
 }
 
+TEST_F(TestExtensionType, CreateFromTensor) {
+  std::vector<int64_t> column_major_strides = {8, 24, 72};
+  std::vector<int64_t> neither_major_strides = {96, 8, 32};
+
+  ASSERT_OK_AND_ASSIGN(auto tensor,
+                       Tensor::Make(value_type_, Buffer::Wrap(values_), 
shape_));
+
+  auto exact_ext_type = 
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+  ASSERT_OK_AND_ASSIGN(auto ext_arr, 
FixedShapeTensorArray::FromTensor(tensor));
+
+  ASSERT_OK(ext_arr->ValidateFull());
+  ASSERT_TRUE(tensor->is_row_major());
+  ASSERT_EQ(tensor->strides(), tensor_strides_);
+  ASSERT_EQ(ext_arr->length(), shape_[0]);
+
+  auto ext_type_2 = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(int64(), {3, 4}, {0, 1}));
+  ASSERT_OK_AND_ASSIGN(auto ext_arr_2, 
FixedShapeTensorArray::FromTensor(tensor));
+
+  ASSERT_OK_AND_ASSIGN(
+      auto column_major_tensor,
+      Tensor::Make(value_type_, Buffer::Wrap(values_), shape_, 
column_major_strides));
+  auto ext_type_3 = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(int64(), {3, 4}, {0, 1}));
+  EXPECT_RAISES_WITH_MESSAGE_THAT(
+      Invalid,
+      testing::HasSubstr(
+          "Invalid: Only first-major tensors can be zero-copy converted to 
arrays"),
+      FixedShapeTensorArray::FromTensor(column_major_tensor));
+  ASSERT_THAT(FixedShapeTensorArray::FromTensor(column_major_tensor),
+              Raises(StatusCode::Invalid));
+
+  auto neither_major_tensor = std::make_shared<Tensor>(value_type_, 
Buffer::Wrap(values_),
+                                                       shape_, 
neither_major_strides);
+  auto ext_type_4 = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(int64(), {3, 4}, {1, 0}));
+  ASSERT_OK_AND_ASSIGN(auto ext_arr_4,
+                       
FixedShapeTensorArray::FromTensor(neither_major_tensor));
+
+  auto ext_type_5 = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(binary(), {1, 3}));
+  auto arr = ArrayFromJSON(binary(), R"(["abc", "def"])");
+
+  ASSERT_OK_AND_ASSIGN(auto fsla_arr,
+                       FixedSizeListArray::FromArrays(arr, 
fixed_size_list(binary(), 2)));
+  auto ext_arr_5 = std::reinterpret_pointer_cast<FixedShapeTensorArray>(
+      ExtensionType::WrapArray(ext_type_5, fsla_arr));
+  EXPECT_RAISES_WITH_MESSAGE_THAT(
+      Invalid, testing::HasSubstr("binary is not valid data type for a 
tensor"),
+      ext_arr_5->ToTensor());
+
+  auto ext_type_6 = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(int64(), {1, 2}));
+  auto arr_with_null = ArrayFromJSON(int64(), "[1, 0, null, null, 1, 2]");
+  ASSERT_OK_AND_ASSIGN(auto fsla_arr_6, FixedSizeListArray::FromArrays(
+                                            arr_with_null, 
fixed_size_list(int64(), 2)));
+}
+
+void CheckFromTensorType(const std::shared_ptr<Tensor>& tensor,
+                         std::shared_ptr<DataType> expected_ext_type) {
+  auto ext_type = 
internal::checked_pointer_cast<FixedShapeTensorType>(expected_ext_type);
+  ASSERT_OK_AND_ASSIGN(auto ext_arr, 
FixedShapeTensorArray::FromTensor(tensor));
+  auto generated_ext_type =
+      internal::checked_cast<const 
FixedShapeTensorType*>(ext_arr->extension_type());
+
+  // Check that generated type is equal to the expected type
+  ASSERT_EQ(generated_ext_type->type_name(), ext_type->type_name());
+  ASSERT_EQ(generated_ext_type->shape(), ext_type->shape());
+  ASSERT_EQ(generated_ext_type->dim_names(), ext_type->dim_names());
+  ASSERT_EQ(generated_ext_type->permutation(), ext_type->permutation());
+  
ASSERT_TRUE(generated_ext_type->storage_type()->Equals(*ext_type->storage_type()));
+  ASSERT_TRUE(generated_ext_type->Equals(ext_type));
+}
+
+TEST_F(TestExtensionType, TestFromTensorType) {
+  auto values = Buffer::Wrap(values_);
+  auto shapes =
+      std::vector<std::vector<int64_t>>{{3, 3, 4}, {3, 3, 4}, {3, 4, 3}, {3, 
4, 3}};
+  auto strides = std::vector<std::vector<int64_t>>{
+      {96, 32, 8}, {96, 8, 24}, {96, 24, 8}, {96, 8, 32}};
+  auto tensor_dim_names = std::vector<std::vector<std::string>>{
+      {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"},
+      {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}};
+  auto dim_names = std::vector<std::vector<std::string>>{
+      {"y", "z"}, {"z", "y"}, {"y", "z"}, {"z", "y"},
+      {"y", "z"}, {"y", "z"}, {"y", "z"}, {"y", "z"}};
+  auto cell_shapes = std::vector<std::vector<int64_t>>{{3, 4}, {4, 3}, {4, 3}, 
{3, 4}};
+  auto permutations = std::vector<std::vector<int64_t>>{{0, 1}, {1, 0}, {0, 
1}, {1, 0}};
+
+  for (size_t i = 0; i < shapes.size(); i++) {
+    ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, values, 
shapes[i],
+                                                   strides[i], 
tensor_dim_names[i]));
+    ASSERT_OK_AND_ASSIGN(auto ext_arr, 
FixedShapeTensorArray::FromTensor(tensor));
+    auto ext_type =
+        fixed_shape_tensor(value_type_, cell_shapes[i], permutations[i], 
dim_names[i]);
+    CheckFromTensorType(tensor, ext_type);
+  }
+}
+
+void CheckTensorRoundtrip(const std::shared_ptr<Tensor>& tensor) {
+  ASSERT_OK_AND_ASSIGN(auto ext_arr, 
FixedShapeTensorArray::FromTensor(tensor));
+  ASSERT_OK_AND_ASSIGN(auto tensor_from_array, ext_arr->ToTensor());
+
+  ASSERT_EQ(tensor->type(), tensor_from_array->type());
+  ASSERT_EQ(tensor->shape(), tensor_from_array->shape());
+  for (size_t i = 1; i < tensor->dim_names().size(); i++) {
+    ASSERT_EQ(tensor->dim_names()[i], tensor_from_array->dim_names()[i]);
+  }
+  ASSERT_EQ(tensor->strides(), tensor_from_array->strides());
+  ASSERT_TRUE(tensor->data()->Equals(*tensor_from_array->data()));
+  ASSERT_TRUE(tensor->Equals(*tensor_from_array));
+}
+
+TEST_F(TestExtensionType, RoundtripTensor) {
+  auto values = Buffer::Wrap(values_);
+
+  auto shapes = std::vector<std::vector<int64_t>>{
+      {3, 3, 4}, {3, 4, 3}, {3, 4, 3}, {3, 3, 4},    {6, 2, 3},
+      {6, 3, 2}, {2, 3, 6}, {2, 6, 3}, {2, 3, 2, 3}, {2, 3, 2, 3}};
+  auto strides = std::vector<std::vector<int64_t>>{
+      {96, 32, 8}, {96, 8, 32},  {96, 24, 8},  {96, 8, 24},      {48, 24, 8},
+      {48, 8, 24}, {144, 48, 8}, {144, 8, 48}, {144, 48, 24, 8}, {144, 8, 24, 
48}};
+  auto tensor_dim_names = std::vector<std::vector<std::string>>{
+      {"x", "y", "z"},      {"x", "y", "z"},     {"x", "y", "z"}, {"x", "y", 
"z"},
+      {"x", "y", "z"},      {"x", "y", "z"},     {"x", "y", "z"}, {"x", "y", 
"z"},
+      {"N", "H", "W", "C"}, {"N", "H", "W", "C"}};
+
+  for (size_t i = 0; i < shapes.size(); i++) {
+    ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, values, 
shapes[i],
+                                                   strides[i], 
tensor_dim_names[i]));
+    CheckTensorRoundtrip(tensor);
+  }
+}
+
+TEST_F(TestExtensionType, SliceTensor) {
+  ASSERT_OK_AND_ASSIGN(auto tensor,
+                       Tensor::Make(value_type_, Buffer::Wrap(values_), 
shape_));
+  ASSERT_OK_AND_ASSIGN(
+      auto tensor_partial,
+      Tensor::Make(value_type_, Buffer::Wrap(values_partial_), 
shape_partial_));
+  ASSERT_EQ(tensor->strides(), tensor_strides_);
+  ASSERT_EQ(tensor_partial->strides(), tensor_strides_);
+  auto ext_type = fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_);
+  auto exact_ext_type = 
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+
+  ASSERT_OK_AND_ASSIGN(auto ext_arr, 
FixedShapeTensorArray::FromTensor(tensor));
+  ASSERT_OK_AND_ASSIGN(auto ext_arr_partial,
+                       FixedShapeTensorArray::FromTensor(tensor_partial));
+  ASSERT_OK(ext_arr->ValidateFull());
+  ASSERT_OK(ext_arr_partial->ValidateFull());
+
+  auto sliced = 
internal::checked_pointer_cast<ExtensionArray>(ext_arr->Slice(0, 2));
+  auto partial = 
internal::checked_pointer_cast<ExtensionArray>(ext_arr_partial);
+
+  ASSERT_TRUE(sliced->Equals(*partial));
+  ASSERT_OK(sliced->ValidateFull());
+  ASSERT_OK(partial->ValidateFull());
+  ASSERT_TRUE(sliced->storage()->Equals(*partial->storage()));
+  ASSERT_EQ(sliced->length(), partial->length());
+}
+
+TEST_F(TestExtensionType, RoudtripBatchFromTensor) {
+  auto exact_ext_type = 
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+  ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, 
Buffer::Wrap(values_),
+                                                 shape_, {}, {"n", "x", "y"}));
+  ASSERT_OK_AND_ASSIGN(auto ext_arr, 
FixedShapeTensorArray::FromTensor(tensor));
+  ext_arr->data()->type = exact_ext_type;
+
+  auto ext_metadata =
+      key_value_metadata({{"ARROW:extension:name", 
ext_type_->extension_name()},
+                          {"ARROW:extension:metadata", serialized_}});
+  auto ext_field = field("f0", ext_type_, true, ext_metadata);
+  auto batch = RecordBatch::Make(schema({ext_field}), ext_arr->length(), 
{ext_arr});
+  std::shared_ptr<RecordBatch> read_batch;
+  RoundtripBatch(batch, &read_batch);
+  CompareBatch(*batch, *read_batch, /*compare_metadata=*/true);
+}
+
+TEST_F(TestExtensionType, ComputeStrides) {
+  auto exact_ext_type = 
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+
+  auto ext_type_1 = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(int64(), cell_shape_, {}, dim_names_));
+  auto ext_type_2 = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(int64(), cell_shape_, {}, dim_names_));
+  auto ext_type_3 = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(int32(), cell_shape_, {}, dim_names_));
+  ASSERT_TRUE(ext_type_1->Equals(*ext_type_2));
+  ASSERT_FALSE(ext_type_1->Equals(*ext_type_3));
+
+  auto ext_type_4 = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(int64(), {3, 4, 7}, {}, {"x", "y", "z"}));
+  ASSERT_EQ(ext_type_4->strides(), (std::vector<int64_t>{224, 56, 8}));
+  ext_type_4 = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(int64(), {3, 4, 7}, {0, 1, 2}, {"x", "y", "z"}));
+  ASSERT_EQ(ext_type_4->strides(), (std::vector<int64_t>{224, 56, 8}));
+
+  auto ext_type_5 = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(int64(), {3, 4, 7}, {1, 0, 2}));
+  ASSERT_EQ(ext_type_5->strides(), (std::vector<int64_t>{56, 224, 8}));
+  ASSERT_EQ(ext_type_5->Serialize(), 
R"({"shape":[3,4,7],"permutation":[1,0,2]})");
+
+  auto ext_type_6 = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(int64(), {3, 4, 7}, {1, 2, 0}, {}));
+  ASSERT_EQ(ext_type_6->strides(), (std::vector<int64_t>{56, 8, 224}));
+  ASSERT_EQ(ext_type_6->Serialize(), 
R"({"shape":[3,4,7],"permutation":[1,2,0]})");
+  auto ext_type_7 = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(int32(), {3, 4, 7}, {2, 0, 1}, {}));
+  ASSERT_EQ(ext_type_7->strides(), (std::vector<int64_t>{4, 112, 16}));
+  ASSERT_EQ(ext_type_7->Serialize(), 
R"({"shape":[3,4,7],"permutation":[2,0,1]})");
+}
+
 }  // namespace arrow

Reply via email to