jorisvandenbossche commented on code in PR #8510: URL: https://github.com/apache/arrow/pull/8510#discussion_r1124132904
########## cpp/src/arrow/extension/fixed_shape_tensor.cc: ########## @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/fixed_shape_tensor.h" + +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep +#include "arrow/tensor.h" +#include "arrow/util/logging.h" +#include "arrow/util/sort.h" + +#include <rapidjson/document.h> +#include <rapidjson/writer.h> + +namespace rj = arrow::rapidjson; + +namespace arrow { +namespace extension { + +bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const { + if (extension_name() != other.extension_name()) { + return false; + } + const auto& other_ext = static_cast<const FixedShapeTensorType&>(other); + bool equals = storage_type()->Equals(other_ext.storage_type()); + equals &= shape_ == other_ext.shape(); + equals &= permutation_ == other_ext.permutation(); + equals &= dim_names_ == other_ext.dim_names(); Review Comment: Maybe not that important, but if you do this in one expression, it will shortcut? ``` return storage_type()->Equals(other_ext.storage_type()) && shape_ == other_ext.shape() && permutation_ == other_ext.permutation() && dim_names_ == other_ext.dim_names(); ``` ########## cpp/src/arrow/extension/fixed_shape_tensor.h: ########## @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <numeric> +#include <sstream> + +#include "arrow/extension_type.h" + +namespace arrow { +namespace extension { + +const std::shared_ptr<DataType> GetStorageType( + const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape); + +const std::vector<int64_t> ComputeStrides(const std::shared_ptr<DataType>& value_type, + const std::vector<int64_t>& shape, + const std::vector<int64_t>& permutation); + +class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief Concrete type class for constant-size Tensor data. +class ARROW_EXPORT FixedShapeTensorType : public ExtensionType { + public: + FixedShapeTensorType(const std::shared_ptr<DataType>& value_type, + const std::vector<int64_t>& shape, + const std::vector<int64_t>& permutation = {}, + const std::vector<std::string>& dim_names = {}) + : ExtensionType(GetStorageType(value_type, shape)), + value_type_(value_type), + shape_(shape), + strides_(ComputeStrides(value_type, shape, permutation)), + permutation_(permutation), + dim_names_(dim_names) {} + + std::string extension_name() const override { return "arrow.fixed_shape_tensor"; } + + /// Number of dimensions of tensor elements + size_t ndim() { return shape_.size(); } + + /// Shape of tensor elements + const std::vector<int64_t>& shape() const { return shape_; } + + /// Strides of tensor elements. Strides state offset in bytes between adjacent + /// elements along each dimension. + const std::vector<int64_t>& strides() const { return strides_; } + + /// Permutation mapping from logical to physical memory layout of tensor elements + const std::vector<int64_t>& permutation() const { return permutation_; } + + /// Dimension names of tensor elements. Dimensions are ordered logically. + const std::vector<std::string>& dim_names() const { return dim_names_; } + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::string Serialize() const override; + + Result<std::shared_ptr<DataType>> Deserialize( + std::shared_ptr<DataType> storage_type, + const std::string& serialized_data) const override; + + /// Create a FixedShapeTensorArray from ArrayData + std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override; + + /// \brief Create a FixedShapeTensorArray from a Tensor + /// + /// This function will create a FixedShapeTensorArray from a Tensor, taking it's Review Comment: ```suggestion /// This function will create a FixedShapeTensorArray from a Tensor, taking its ``` ########## cpp/src/arrow/extension/fixed_shape_tensor_test.cc: ########## @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/fixed_shape_tensor.h" + +#include "arrow/testing/matchers.h" + +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/tensor.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/key_value_metadata.h" + +namespace arrow { + +using FixedShapeTensorType = extension::FixedShapeTensorType; +using extension::fixed_shape_tensor; + +class TestExtensionType : public ::testing::Test { + public: + void SetUp() override { + shape_ = {3, 3, 4}; + cell_shape_ = {3, 4}; + value_type_ = int64(); + cell_type_ = fixed_size_list(value_type_, 12); + dim_names_ = {"x", "y"}; + ext_type_ = fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_); + values_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}; + values_partial_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + shape_partial_ = {2, 3, 4}; + tensor_strides_ = {96, 32, 8}; + cell_strides_ = {32, 8}; + serialized_ = R"({"shape":[3,4],"dim_names":["x","y"]})"; + } + + protected: + std::vector<int64_t> shape_; + std::vector<int64_t> shape_partial_; + std::vector<int64_t> cell_shape_; + std::shared_ptr<DataType> value_type_; + std::shared_ptr<DataType> cell_type_; + std::vector<std::string> dim_names_; + std::shared_ptr<ExtensionType> ext_type_; + std::vector<int64_t> values_; + std::vector<int64_t> values_partial_; + std::vector<int64_t> tensor_strides_; + std::vector<int64_t> cell_strides_; + std::string serialized_; +}; + +auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch, + std::shared_ptr<RecordBatch>* out) { + ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), + out_stream.get())); + + ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); + + io::BufferReader reader(complete_ipc_stream); + std::shared_ptr<RecordBatchReader> batch_reader; + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); + ASSERT_OK(batch_reader->ReadNext(out)); +}; + +TEST_F(TestExtensionType, CheckDummyRegistration) { + // We need a dummy registration at runtime to allow for IPC deserialization + auto ext_type = fixed_shape_tensor(int64(), {}); + auto registered_type = GetExtensionType(ext_type->extension_name()); + ASSERT_TRUE(registered_type->Equals(*ext_type)); Review Comment: I am not sure if we should test that the registered type is exactly this dummy, as that shouldn't be considered "public" or something to rely upon. Just testing that the result is of the correct type should be sufficient? ########## cpp/src/arrow/extension/fixed_shape_tensor.h: ########## @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <numeric> +#include <sstream> + +#include "arrow/extension_type.h" + +namespace arrow { +namespace extension { + +const std::shared_ptr<DataType> GetStorageType( + const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape); + +const std::vector<int64_t> ComputeStrides(const std::shared_ptr<DataType>& value_type, + const std::vector<int64_t>& shape, + const std::vector<int64_t>& permutation); + +class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief Concrete type class for constant-size Tensor data. +class ARROW_EXPORT FixedShapeTensorType : public ExtensionType { + public: + FixedShapeTensorType(const std::shared_ptr<DataType>& value_type, + const std::vector<int64_t>& shape, + const std::vector<int64_t>& permutation = {}, + const std::vector<std::string>& dim_names = {}) + : ExtensionType(GetStorageType(value_type, shape)), + value_type_(value_type), + shape_(shape), + strides_(ComputeStrides(value_type, shape, permutation)), Review Comment: Since the strides are only needed if someone would actually ask for them (and are not part of the actual parameters of this type), shouldn't we just calculate them on demand? ########## cpp/src/arrow/extension/fixed_shape_tensor.cc: ########## @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/fixed_shape_tensor.h" + +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep +#include "arrow/tensor.h" +#include "arrow/util/logging.h" +#include "arrow/util/sort.h" + +#include <rapidjson/document.h> +#include <rapidjson/writer.h> + +namespace rj = arrow::rapidjson; + +namespace arrow { +namespace extension { + +bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const { + if (extension_name() != other.extension_name()) { + return false; + } + const auto& other_ext = static_cast<const FixedShapeTensorType&>(other); + bool equals = storage_type()->Equals(other_ext.storage_type()); + equals &= shape_ == other_ext.shape(); + equals &= permutation_ == other_ext.permutation(); + equals &= dim_names_ == other_ext.dim_names(); + return equals; +} + +std::string FixedShapeTensorType::Serialize() const { + rj::Document document; + document.SetObject(); + rj::Document::AllocatorType& allocator = document.GetAllocator(); + + rj::Value shape(rj::kArrayType); + for (auto v : shape_) { + shape.PushBack(v, allocator); + } + document.AddMember(rj::Value("shape", allocator), shape, allocator); + + if (!permutation_.empty()) { + rj::Value permutation(rj::kArrayType); + for (auto v : permutation_) { + permutation.PushBack(v, allocator); + } + document.AddMember(rj::Value("permutation", allocator), permutation, allocator); + } + + if (!dim_names_.empty()) { + rj::Value dim_names(rj::kArrayType); + for (std::string v : dim_names_) { + dim_names.PushBack(rj::Value{}.SetString(v.c_str(), allocator), allocator); + } + document.AddMember(rj::Value("dim_names", allocator), dim_names, allocator); + } + + rj::StringBuffer buffer; + rj::Writer<rj::StringBuffer> writer(buffer); + document.Accept(writer); + return buffer.GetString(); +} + +Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize( + std::shared_ptr<DataType> storage_type, const std::string& serialized_data) const { + if (storage_type->id() != Type::FIXED_SIZE_LIST) { + return Status::Invalid("Expected FixedSizeList storage type, got ", + storage_type->ToString()); + } + auto value_type = + internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type(); + rj::Document document; + if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() || + !document.HasMember("shape") || !document["shape"].IsArray()) { + return Status::Invalid("Invalid serialized JSON data: ", serialized_data); + } + + std::vector<int64_t> shape; + for (auto& x : document["shape"].GetArray()) { + shape.emplace_back(x.GetInt64()); + } + std::vector<int64_t> permutation; + if (document.HasMember("permutation")) { + for (auto& x : document["permutation"].GetArray()) { + permutation.emplace_back(x.GetInt64()); + } + if (shape.size() != permutation.size()) { + return Status::Invalid("Invalid permutation"); + } + } + std::vector<std::string> dim_names; + if (document.HasMember("dim_names")) { + for (auto& x : document["dim_names"].GetArray()) { + dim_names.emplace_back(x.GetString()); + } + if (shape.size() != dim_names.size()) { + return Status::Invalid("Invalid dim_names"); + } + } + + return fixed_shape_tensor(value_type, shape, permutation, dim_names); +} + +std::shared_ptr<Array> FixedShapeTensorType::MakeArray( + std::shared_ptr<ArrayData> data) const { + return std::make_shared<ExtensionArray>(data); +} + +Result<std::shared_ptr<Array>> FixedShapeTensorType::MakeArray( + std::shared_ptr<Tensor> tensor) const { + auto permutation = internal::ArgSort(tensor->strides()); + std::reverse(permutation.begin(), permutation.end()); + if (permutation[0] != 0) { + return Status::Invalid( + "Only first-major tensors can be zero-copy converted to arrays"); + } + + auto cell_shape = tensor->shape(); + cell_shape.erase(cell_shape.begin()); + if (cell_shape != shape_) { + return Status::Invalid("Expected cell shape does not match input tensor shape"); + } + + permutation.erase(permutation.begin()); + for (auto& x : permutation) { + x--; + } + + auto ext_type = + fixed_shape_tensor(tensor->type(), cell_shape, permutation, tensor->dim_names()); + + std::shared_ptr<FixedSizeListArray> arr; + std::shared_ptr<Array> value_array; + switch (tensor->type_id()) { + case Type::UINT8: { + value_array = std::make_shared<UInt8Array>(tensor->size(), tensor->data()); + break; + } + case Type::INT8: { + value_array = std::make_shared<Int8Array>(tensor->size(), tensor->data()); + break; + } + case Type::UINT16: { + value_array = std::make_shared<UInt16Array>(tensor->size(), tensor->data()); + break; + } + case Type::INT16: { + value_array = std::make_shared<Int16Array>(tensor->size(), tensor->data()); + break; + } + case Type::UINT32: { + value_array = std::make_shared<UInt32Array>(tensor->size(), tensor->data()); + break; + } + case Type::INT32: { + value_array = std::make_shared<Int32Array>(tensor->size(), tensor->data()); + break; + } + case Type::UINT64: { + value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data()); + break; + } + case Type::INT64: { + value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data()); + break; + } + case Type::HALF_FLOAT: { + value_array = std::make_shared<HalfFloatArray>(tensor->size(), tensor->data()); + break; + } + case Type::FLOAT: { + value_array = std::make_shared<FloatArray>(tensor->size(), tensor->data()); + break; + } + case Type::DOUBLE: { + value_array = std::make_shared<DoubleArray>(tensor->size(), tensor->data()); + break; + } + default: { + return Status::NotImplemented("Unsupported tensor type: ", + tensor->type()->ToString()); + } + } + arr = std::make_shared<FixedSizeListArray>(ext_type->storage_type(), tensor->shape()[0], + value_array); + auto ext_data = arr->data(); + ext_data->type = ext_type; + return MakeArray(ext_data); +} + +Result<std::shared_ptr<Tensor>> FixedShapeTensorType::ToTensor( + std::shared_ptr<Array> arr) const { + // To convert an array of n dimensional tensors to a n+1 dimensional tensor we + // interpret the array's length as the first dimension the new tensor. Further, we + // define n+1 dimensional tensor's strides by front appending a new stride to the n + // dimensional tensor's strides. + + ARROW_DCHECK_EQ(arr->null_count(), 0) << "Null values not supported in tensors."; + auto ext_arr = internal::checked_pointer_cast<FixedSizeListArray>( + internal::checked_pointer_cast<ExtensionArray>(arr)->storage()); + + std::vector<int64_t> shape = shape_; + shape.insert(shape.begin(), 1, arr->length()); + + std::vector<int64_t> tensor_strides = strides(); + tensor_strides.insert(tensor_strides.begin(), 1, arr->length() * tensor_strides[0]); + + std::shared_ptr<Buffer> buffer = ext_arr->values()->data()->buffers[1]; + return *Tensor::Make(ext_arr->value_type(), buffer, shape, tensor_strides, dim_names()); +} + +std::shared_ptr<FixedShapeTensorType> fixed_shape_tensor( + const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape, + const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) { + ARROW_CHECK(is_tensor_supported(value_type->id())); Review Comment: Is it needed to restrict creating this type to just numeric types? (that's certainly the most typical usecase, but it's not necessarily needed to strictly require that? And if we want to require that, that should be part of the spec, I think?) ########## cpp/src/arrow/extension/fixed_shape_tensor_test.cc: ########## @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/fixed_shape_tensor.h" + +#include "arrow/testing/matchers.h" + +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/tensor.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/key_value_metadata.h" + +namespace arrow { + +using FixedShapeTensorType = extension::FixedShapeTensorType; +using extension::fixed_shape_tensor; + +class TestExtensionType : public ::testing::Test { + public: + void SetUp() override { + shape_ = {3, 3, 4}; + cell_shape_ = {3, 4}; + value_type_ = int64(); + cell_type_ = fixed_size_list(value_type_, 12); + dim_names_ = {"x", "y"}; + ext_type_ = fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_); + values_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}; + values_partial_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + shape_partial_ = {2, 3, 4}; + tensor_strides_ = {96, 32, 8}; + cell_strides_ = {32, 8}; + serialized_ = R"({"shape":[3,4],"dim_names":["x","y"]})"; + } + + protected: + std::vector<int64_t> shape_; + std::vector<int64_t> shape_partial_; + std::vector<int64_t> cell_shape_; + std::shared_ptr<DataType> value_type_; + std::shared_ptr<DataType> cell_type_; + std::vector<std::string> dim_names_; + std::shared_ptr<ExtensionType> ext_type_; + std::vector<int64_t> values_; + std::vector<int64_t> values_partial_; + std::vector<int64_t> tensor_strides_; + std::vector<int64_t> cell_strides_; + std::string serialized_; +}; + +auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch, + std::shared_ptr<RecordBatch>* out) { + ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), + out_stream.get())); + + ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); + + io::BufferReader reader(complete_ipc_stream); + std::shared_ptr<RecordBatchReader> batch_reader; + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); + ASSERT_OK(batch_reader->ReadNext(out)); +}; + +TEST_F(TestExtensionType, CheckDummyRegistration) { + // We need a dummy registration at runtime to allow for IPC deserialization + auto ext_type = fixed_shape_tensor(int64(), {}); + auto registered_type = GetExtensionType(ext_type->extension_name()); + ASSERT_TRUE(registered_type->Equals(*ext_type)); +} + +TEST_F(TestExtensionType, CreateExtensionType) { + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + + // Test ExtensionType methods + ASSERT_EQ(ext_type_->extension_name(), "arrow.fixed_shape_tensor"); + ASSERT_TRUE(ext_type_->Equals(*exact_ext_type)); + ASSERT_TRUE(ext_type_->storage_type()->Equals(*cell_type_)); + ASSERT_EQ(ext_type_->Serialize(), serialized_); + ASSERT_OK_AND_ASSIGN(auto ds, + ext_type_->Deserialize(ext_type_->storage_type(), serialized_)); + auto deserialized = std::reinterpret_pointer_cast<ExtensionType>(ds); + ASSERT_TRUE(deserialized->Equals(*ext_type_)); + + // Test FixedShapeTensorType methods + ASSERT_EQ(exact_ext_type->id(), Type::EXTENSION); + ASSERT_EQ(exact_ext_type->ndim(), cell_shape_.size()); + ASSERT_EQ(exact_ext_type->shape(), cell_shape_); + ASSERT_EQ(exact_ext_type->strides(), cell_strides_); + ASSERT_EQ(exact_ext_type->dim_names(), dim_names_); +} + +TEST_F(TestExtensionType, CreateFromArray) { + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + + std::vector<std::shared_ptr<Buffer>> buffers = {nullptr, Buffer::Wrap(values_)}; + auto arr_data = std::make_shared<ArrayData>(value_type_, values_.size(), buffers, 0, 0); + auto arr = std::make_shared<Int64Array>(arr_data); + EXPECT_OK_AND_ASSIGN(auto fsla_arr, FixedSizeListArray::FromArrays(arr, cell_type_)); + auto data = fsla_arr->data(); + data->type = ext_type_; + auto ext_arr = exact_ext_type->MakeArray(data); + ASSERT_EQ(ext_arr->length(), shape_[0]); + ASSERT_EQ(ext_arr->null_count(), 0); +} + +TEST_F(TestExtensionType, CreateFromTensor) { + std::vector<int64_t> column_major_strides = {8, 24, 72}; + std::vector<int64_t> neither_major_strides = {96, 8, 32}; + + ASSERT_OK_AND_ASSIGN(auto tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_)); + + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + EXPECT_OK_AND_ASSIGN(auto ext_arr, exact_ext_type->MakeArray(tensor)); + + ASSERT_OK(ext_arr->ValidateFull()); + ASSERT_TRUE(tensor->is_row_major()); + ASSERT_EQ(tensor->strides(), tensor_strides_); + ASSERT_EQ(ext_arr->length(), shape_[0]); + + auto ext_type_2 = fixed_shape_tensor(int64(), {3, 4}, {0, 1}); + EXPECT_OK_AND_ASSIGN(auto ext_arr_2, ext_type_2->MakeArray(tensor)); + + ASSERT_OK_AND_ASSIGN( + auto column_major_tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_, column_major_strides)); + auto ext_type_3 = fixed_shape_tensor(int64(), {3, 4}, {0, 1}); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + testing::HasSubstr( + "Invalid: Only first-major tensors can be zero-copy converted to arrays"), + ext_type_3->MakeArray(column_major_tensor)); + ASSERT_THAT(ext_type_3->MakeArray(column_major_tensor), Raises(StatusCode::Invalid)); + + auto neither_major_tensor = std::make_shared<Tensor>(value_type_, Buffer::Wrap(values_), + shape_, neither_major_strides); + auto ext_type_4 = fixed_shape_tensor(int64(), {3, 4}, {1, 0}); + ASSERT_OK_AND_ASSIGN(auto ext_arr_4, ext_type_4->MakeArray(neither_major_tensor)); +} + +TEST_F(TestExtensionType, RoundtripTensor) { + ASSERT_OK_AND_ASSIGN(auto tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_)); + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + EXPECT_OK_AND_ASSIGN(auto ext_arr, exact_ext_type->MakeArray(tensor)); + + EXPECT_OK_AND_ASSIGN(auto tensor_from_array, exact_ext_type->ToTensor(ext_arr)); + ASSERT_EQ(tensor_from_array->shape(), tensor->shape()); + ASSERT_EQ(tensor_from_array->strides(), tensor->strides()); + ASSERT_TRUE(tensor->Equals(*tensor_from_array)); +} + +TEST_F(TestExtensionType, SliceTensor) { + ASSERT_OK_AND_ASSIGN(auto tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_)); + ASSERT_OK_AND_ASSIGN( + auto tensor_partial, + Tensor::Make(value_type_, Buffer::Wrap(values_partial_), shape_partial_)); + ASSERT_EQ(tensor->strides(), tensor_strides_); + ASSERT_EQ(tensor_partial->strides(), tensor_strides_); + auto ext_type = fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_); + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + + EXPECT_OK_AND_ASSIGN(auto ext_arr, exact_ext_type->MakeArray(tensor)); + EXPECT_OK_AND_ASSIGN(auto ext_arr_partial, exact_ext_type->MakeArray(tensor_partial)); + ASSERT_OK(ext_arr->ValidateFull()); + ASSERT_OK(ext_arr_partial->ValidateFull()); + + auto sliced = internal::checked_pointer_cast<ExtensionArray>(ext_arr->Slice(0, 2)); + auto partial = internal::checked_pointer_cast<ExtensionArray>(ext_arr_partial); + + ASSERT_TRUE(sliced->Equals(*partial)); + ASSERT_OK(sliced->ValidateFull()); + ASSERT_OK(partial->ValidateFull()); + ASSERT_TRUE(sliced->storage()->Equals(*partial->storage())); + ASSERT_EQ(sliced->length(), partial->length()); +} + +void CheckSerializationRoundtrip(const std::shared_ptr<ExtensionType>& ext_type) { + auto serialized = ext_type->Serialize(); + ASSERT_OK_AND_ASSIGN(auto deserialized, + ext_type->Deserialize(ext_type->storage_type(), serialized)); + ASSERT_TRUE(ext_type->Equals(*deserialized)); +} + +TEST_F(TestExtensionType, MetadataSerializationRoundtrip) { + CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {}, {}, {})); + CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {0}, {}, {})); + CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {1}, {0}, {"x"})); + CheckSerializationRoundtrip( + fixed_shape_tensor(value_type_, {256, 256, 3}, {0, 1, 2}, {"H", "W", "C"})); + CheckSerializationRoundtrip( + fixed_shape_tensor(value_type_, {256, 256, 3}, {2, 0, 1}, {"C", "H", "W"})); + + auto ext_type = fixed_shape_tensor(value_type_, cell_shape_, {0, 1}, dim_names_); + CheckSerializationRoundtrip(ext_type_); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("Invalid: Expected FixedSizeList storage type"), + ext_type->Deserialize(boolean(), serialized_)); +} + +TEST_F(TestExtensionType, RoudtripBatch) { + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + + std::vector<std::shared_ptr<Buffer>> buffers = {nullptr, Buffer::Wrap(values_)}; + auto arr_data = std::make_shared<ArrayData>(value_type_, values_.size(), buffers, 0, 0); + auto arr = std::make_shared<Int64Array>(arr_data); + EXPECT_OK_AND_ASSIGN(auto fsla_arr, FixedSizeListArray::FromArrays(arr, cell_type_)); + auto data = fsla_arr->data(); + data->type = ext_type_; + auto ext_arr = exact_ext_type->MakeArray(data); + + ASSERT_OK(UnregisterExtensionType(ext_type_->extension_name())); + ASSERT_OK(RegisterExtensionType(ext_type_)); + auto ext_metadata = + key_value_metadata({{"ARROW:extension:name", exact_ext_type->extension_name()}, + {"ARROW:extension:metadata", serialized_}}); + auto ext_field = field("f0", exact_ext_type, true, ext_metadata); + auto batch = RecordBatch::Make(schema({ext_field}), ext_arr->length(), {ext_arr}); + std::shared_ptr<RecordBatch> read_batch; + RoundtripBatch(batch, &read_batch); + CompareBatch(*batch, *read_batch, /*compare_metadata=*/true); +} + +TEST_F(TestExtensionType, RoudtripBatchFromTensor) { + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, Buffer::Wrap(values_), + shape_, {}, dim_names_)); + EXPECT_OK_AND_ASSIGN(auto ext_arr, exact_ext_type->MakeArray(tensor)); + ext_arr->data()->type = exact_ext_type; + + ASSERT_OK(UnregisterExtensionType(ext_type_->extension_name())); + ASSERT_OK(RegisterExtensionType(ext_type_)); Review Comment: What's the reason for unregistering/registering again here? ########## cpp/src/arrow/extension/fixed_shape_tensor_test.cc: ########## @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/fixed_shape_tensor.h" + +#include "arrow/testing/matchers.h" + +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/tensor.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/key_value_metadata.h" + +namespace arrow { + +using FixedShapeTensorType = extension::FixedShapeTensorType; +using extension::fixed_shape_tensor; + +class TestExtensionType : public ::testing::Test { + public: + void SetUp() override { + shape_ = {3, 3, 4}; + cell_shape_ = {3, 4}; + value_type_ = int64(); + cell_type_ = fixed_size_list(value_type_, 12); + dim_names_ = {"x", "y"}; + ext_type_ = fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_); + values_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}; + values_partial_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + shape_partial_ = {2, 3, 4}; + tensor_strides_ = {96, 32, 8}; + cell_strides_ = {32, 8}; + serialized_ = R"({"shape":[3,4],"dim_names":["x","y"]})"; + } + + protected: + std::vector<int64_t> shape_; + std::vector<int64_t> shape_partial_; + std::vector<int64_t> cell_shape_; + std::shared_ptr<DataType> value_type_; + std::shared_ptr<DataType> cell_type_; + std::vector<std::string> dim_names_; + std::shared_ptr<ExtensionType> ext_type_; + std::vector<int64_t> values_; + std::vector<int64_t> values_partial_; + std::vector<int64_t> tensor_strides_; + std::vector<int64_t> cell_strides_; + std::string serialized_; +}; + +auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch, + std::shared_ptr<RecordBatch>* out) { + ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), + out_stream.get())); + + ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); + + io::BufferReader reader(complete_ipc_stream); + std::shared_ptr<RecordBatchReader> batch_reader; + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); + ASSERT_OK(batch_reader->ReadNext(out)); +}; + +TEST_F(TestExtensionType, CheckDummyRegistration) { + // We need a dummy registration at runtime to allow for IPC deserialization + auto ext_type = fixed_shape_tensor(int64(), {}); + auto registered_type = GetExtensionType(ext_type->extension_name()); + ASSERT_TRUE(registered_type->Equals(*ext_type)); +} + +TEST_F(TestExtensionType, CreateExtensionType) { + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + + // Test ExtensionType methods + ASSERT_EQ(ext_type_->extension_name(), "arrow.fixed_shape_tensor"); + ASSERT_TRUE(ext_type_->Equals(*exact_ext_type)); + ASSERT_TRUE(ext_type_->storage_type()->Equals(*cell_type_)); + ASSERT_EQ(ext_type_->Serialize(), serialized_); + ASSERT_OK_AND_ASSIGN(auto ds, + ext_type_->Deserialize(ext_type_->storage_type(), serialized_)); + auto deserialized = std::reinterpret_pointer_cast<ExtensionType>(ds); + ASSERT_TRUE(deserialized->Equals(*ext_type_)); + + // Test FixedShapeTensorType methods + ASSERT_EQ(exact_ext_type->id(), Type::EXTENSION); + ASSERT_EQ(exact_ext_type->ndim(), cell_shape_.size()); + ASSERT_EQ(exact_ext_type->shape(), cell_shape_); + ASSERT_EQ(exact_ext_type->strides(), cell_strides_); + ASSERT_EQ(exact_ext_type->dim_names(), dim_names_); +} + +TEST_F(TestExtensionType, CreateFromArray) { + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + + std::vector<std::shared_ptr<Buffer>> buffers = {nullptr, Buffer::Wrap(values_)}; + auto arr_data = std::make_shared<ArrayData>(value_type_, values_.size(), buffers, 0, 0); + auto arr = std::make_shared<Int64Array>(arr_data); + EXPECT_OK_AND_ASSIGN(auto fsla_arr, FixedSizeListArray::FromArrays(arr, cell_type_)); + auto data = fsla_arr->data(); + data->type = ext_type_; + auto ext_arr = exact_ext_type->MakeArray(data); + ASSERT_EQ(ext_arr->length(), shape_[0]); + ASSERT_EQ(ext_arr->null_count(), 0); +} + +TEST_F(TestExtensionType, CreateFromTensor) { + std::vector<int64_t> column_major_strides = {8, 24, 72}; + std::vector<int64_t> neither_major_strides = {96, 8, 32}; + + ASSERT_OK_AND_ASSIGN(auto tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_)); + + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + EXPECT_OK_AND_ASSIGN(auto ext_arr, exact_ext_type->MakeArray(tensor)); + + ASSERT_OK(ext_arr->ValidateFull()); + ASSERT_TRUE(tensor->is_row_major()); + ASSERT_EQ(tensor->strides(), tensor_strides_); + ASSERT_EQ(ext_arr->length(), shape_[0]); + + auto ext_type_2 = fixed_shape_tensor(int64(), {3, 4}, {0, 1}); + EXPECT_OK_AND_ASSIGN(auto ext_arr_2, ext_type_2->MakeArray(tensor)); + + ASSERT_OK_AND_ASSIGN( + auto column_major_tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_, column_major_strides)); + auto ext_type_3 = fixed_shape_tensor(int64(), {3, 4}, {0, 1}); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + testing::HasSubstr( + "Invalid: Only first-major tensors can be zero-copy converted to arrays"), + ext_type_3->MakeArray(column_major_tensor)); + ASSERT_THAT(ext_type_3->MakeArray(column_major_tensor), Raises(StatusCode::Invalid)); + + auto neither_major_tensor = std::make_shared<Tensor>(value_type_, Buffer::Wrap(values_), + shape_, neither_major_strides); + auto ext_type_4 = fixed_shape_tensor(int64(), {3, 4}, {1, 0}); + ASSERT_OK_AND_ASSIGN(auto ext_arr_4, ext_type_4->MakeArray(neither_major_tensor)); +} + +TEST_F(TestExtensionType, RoundtripTensor) { + ASSERT_OK_AND_ASSIGN(auto tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_)); + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + EXPECT_OK_AND_ASSIGN(auto ext_arr, exact_ext_type->MakeArray(tensor)); + + EXPECT_OK_AND_ASSIGN(auto tensor_from_array, exact_ext_type->ToTensor(ext_arr)); + ASSERT_EQ(tensor_from_array->shape(), tensor->shape()); + ASSERT_EQ(tensor_from_array->strides(), tensor->strides()); + ASSERT_TRUE(tensor->Equals(*tensor_from_array)); +} + +TEST_F(TestExtensionType, SliceTensor) { + ASSERT_OK_AND_ASSIGN(auto tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_)); + ASSERT_OK_AND_ASSIGN( + auto tensor_partial, + Tensor::Make(value_type_, Buffer::Wrap(values_partial_), shape_partial_)); + ASSERT_EQ(tensor->strides(), tensor_strides_); + ASSERT_EQ(tensor_partial->strides(), tensor_strides_); + auto ext_type = fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_); + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + + EXPECT_OK_AND_ASSIGN(auto ext_arr, exact_ext_type->MakeArray(tensor)); + EXPECT_OK_AND_ASSIGN(auto ext_arr_partial, exact_ext_type->MakeArray(tensor_partial)); + ASSERT_OK(ext_arr->ValidateFull()); + ASSERT_OK(ext_arr_partial->ValidateFull()); + + auto sliced = internal::checked_pointer_cast<ExtensionArray>(ext_arr->Slice(0, 2)); + auto partial = internal::checked_pointer_cast<ExtensionArray>(ext_arr_partial); + + ASSERT_TRUE(sliced->Equals(*partial)); + ASSERT_OK(sliced->ValidateFull()); + ASSERT_OK(partial->ValidateFull()); + ASSERT_TRUE(sliced->storage()->Equals(*partial->storage())); + ASSERT_EQ(sliced->length(), partial->length()); +} + +void CheckSerializationRoundtrip(const std::shared_ptr<ExtensionType>& ext_type) { + auto serialized = ext_type->Serialize(); + ASSERT_OK_AND_ASSIGN(auto deserialized, + ext_type->Deserialize(ext_type->storage_type(), serialized)); + ASSERT_TRUE(ext_type->Equals(*deserialized)); +} + +TEST_F(TestExtensionType, MetadataSerializationRoundtrip) { + CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {}, {}, {})); + CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {0}, {}, {})); + CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {1}, {0}, {"x"})); + CheckSerializationRoundtrip( + fixed_shape_tensor(value_type_, {256, 256, 3}, {0, 1, 2}, {"H", "W", "C"})); + CheckSerializationRoundtrip( + fixed_shape_tensor(value_type_, {256, 256, 3}, {2, 0, 1}, {"C", "H", "W"})); + + auto ext_type = fixed_shape_tensor(value_type_, cell_shape_, {0, 1}, dim_names_); + CheckSerializationRoundtrip(ext_type_); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("Invalid: Expected FixedSizeList storage type"), + ext_type->Deserialize(boolean(), serialized_)); +} + +TEST_F(TestExtensionType, RoudtripBatch) { + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + + std::vector<std::shared_ptr<Buffer>> buffers = {nullptr, Buffer::Wrap(values_)}; + auto arr_data = std::make_shared<ArrayData>(value_type_, values_.size(), buffers, 0, 0); + auto arr = std::make_shared<Int64Array>(arr_data); + EXPECT_OK_AND_ASSIGN(auto fsla_arr, FixedSizeListArray::FromArrays(arr, cell_type_)); + auto data = fsla_arr->data(); + data->type = ext_type_; + auto ext_arr = exact_ext_type->MakeArray(data); + + ASSERT_OK(UnregisterExtensionType(ext_type_->extension_name())); + ASSERT_OK(RegisterExtensionType(ext_type_)); + auto ext_metadata = + key_value_metadata({{"ARROW:extension:name", exact_ext_type->extension_name()}, + {"ARROW:extension:metadata", serialized_}}); + auto ext_field = field("f0", exact_ext_type, true, ext_metadata); + auto batch = RecordBatch::Make(schema({ext_field}), ext_arr->length(), {ext_arr}); + std::shared_ptr<RecordBatch> read_batch; + RoundtripBatch(batch, &read_batch); + CompareBatch(*batch, *read_batch, /*compare_metadata=*/true); +} + +TEST_F(TestExtensionType, RoudtripBatchFromTensor) { + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, Buffer::Wrap(values_), + shape_, {}, dim_names_)); + EXPECT_OK_AND_ASSIGN(auto ext_arr, exact_ext_type->MakeArray(tensor)); + ext_arr->data()->type = exact_ext_type; + + ASSERT_OK(UnregisterExtensionType(ext_type_->extension_name())); + ASSERT_OK(RegisterExtensionType(ext_type_)); + auto ext_metadata = + key_value_metadata({{"ARROW:extension:name", ext_type_->extension_name()}, + {"ARROW:extension:metadata", serialized_}}); + auto ext_field = field("f0", ext_type_, true, ext_metadata); + auto batch = RecordBatch::Make(schema({ext_field}), ext_arr->length(), {ext_arr}); + std::shared_ptr<RecordBatch> read_batch; + RoundtripBatch(batch, &read_batch); + CompareBatch(*batch, *read_batch, /*compare_metadata=*/true); +} + +TEST_F(TestExtensionType, ComputeStrides) { + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + + auto ext_type_1 = fixed_shape_tensor(int64(), cell_shape_, {}, dim_names_); + auto ext_type_2 = fixed_shape_tensor(int64(), cell_shape_, {}, dim_names_); + auto ext_type_3 = fixed_shape_tensor(int32(), cell_shape_, {}, dim_names_); + ASSERT_TRUE(ext_type_1->Equals(*ext_type_2)); + ASSERT_FALSE(ext_type_1->Equals(*ext_type_3)); + + auto ext_type_4 = fixed_shape_tensor(int64(), {3, 4, 7}, {}, {"x", "y", "z"}); + ASSERT_EQ(ext_type_4->strides(), (std::vector<int64_t>{224, 56, 8})); + ext_type_4 = fixed_shape_tensor(int64(), {3, 4, 7}, {0, 1, 2}, {"x", "y", "z"}); + ASSERT_EQ(ext_type_4->strides(), (std::vector<int64_t>{224, 56, 8})); + + auto ext_type_5 = fixed_shape_tensor(int64(), {3, 4, 7}, {1, 0, 2}); + ASSERT_EQ(ext_type_5->strides(), (std::vector<int64_t>{56, 224, 8})); + ASSERT_EQ(ext_type_5->Serialize(), R"({"shape":[3,4,7],"permutation":[1,0,2]})"); + + auto ext_type_6 = fixed_shape_tensor(int64(), {3, 4, 7}, {1, 2, 0}, {}); + ASSERT_EQ(ext_type_6->strides(), (std::vector<int64_t>{56, 8, 224})); + ASSERT_EQ(ext_type_6->Serialize(), R"({"shape":[3,4,7],"permutation":[1,2,0]})"); + + auto ext_type_7 = fixed_shape_tensor(int64(), {3, 4, 7}, {2, 0, 1}, {}); Review Comment: ```suggestion auto ext_type_7 = fixed_shape_tensor(int32(), {3, 4, 7}, {2, 0, 1}, {}); ``` Maybe use a different bitwidth in one of them, since the strides are documented to be in bytes and not elements ########## cpp/src/arrow/extension/fixed_shape_tensor.h: ########## @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <numeric> +#include <sstream> + +#include "arrow/extension_type.h" + +namespace arrow { +namespace extension { + +const std::shared_ptr<DataType> GetStorageType( + const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape); + +const std::vector<int64_t> ComputeStrides(const std::shared_ptr<DataType>& value_type, + const std::vector<int64_t>& shape, + const std::vector<int64_t>& permutation); + +class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief Concrete type class for constant-size Tensor data. +class ARROW_EXPORT FixedShapeTensorType : public ExtensionType { + public: + FixedShapeTensorType(const std::shared_ptr<DataType>& value_type, + const std::vector<int64_t>& shape, + const std::vector<int64_t>& permutation = {}, + const std::vector<std::string>& dim_names = {}) + : ExtensionType(GetStorageType(value_type, shape)), + value_type_(value_type), + shape_(shape), + strides_(ComputeStrides(value_type, shape, permutation)), + permutation_(permutation), + dim_names_(dim_names) {} + + std::string extension_name() const override { return "arrow.fixed_shape_tensor"; } + + /// Number of dimensions of tensor elements + size_t ndim() { return shape_.size(); } + + /// Shape of tensor elements + const std::vector<int64_t>& shape() const { return shape_; } + + /// Strides of tensor elements. Strides state offset in bytes between adjacent + /// elements along each dimension. + const std::vector<int64_t>& strides() const { return strides_; } + + /// Permutation mapping from logical to physical memory layout of tensor elements + const std::vector<int64_t>& permutation() const { return permutation_; } + + /// Dimension names of tensor elements. Dimensions are ordered logically. Review Comment: ```suggestion /// Dimension names of tensor elements. Dimensions are ordered physically. ``` That's what the spec says ########## cpp/src/arrow/extension/fixed_shape_tensor.h: ########## @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <numeric> +#include <sstream> + +#include "arrow/extension_type.h" + +namespace arrow { +namespace extension { + +const std::shared_ptr<DataType> GetStorageType( + const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape); + +const std::vector<int64_t> ComputeStrides(const std::shared_ptr<DataType>& value_type, + const std::vector<int64_t>& shape, + const std::vector<int64_t>& permutation); + +class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief Concrete type class for constant-size Tensor data. +class ARROW_EXPORT FixedShapeTensorType : public ExtensionType { + public: + FixedShapeTensorType(const std::shared_ptr<DataType>& value_type, + const std::vector<int64_t>& shape, + const std::vector<int64_t>& permutation = {}, + const std::vector<std::string>& dim_names = {}) + : ExtensionType(GetStorageType(value_type, shape)), + value_type_(value_type), + shape_(shape), + strides_(ComputeStrides(value_type, shape, permutation)), + permutation_(permutation), + dim_names_(dim_names) {} + + std::string extension_name() const override { return "arrow.fixed_shape_tensor"; } + + /// Number of dimensions of tensor elements + size_t ndim() { return shape_.size(); } + + /// Shape of tensor elements + const std::vector<int64_t>& shape() const { return shape_; } + + /// Strides of tensor elements. Strides state offset in bytes between adjacent + /// elements along each dimension. + const std::vector<int64_t>& strides() const { return strides_; } + + /// Permutation mapping from logical to physical memory layout of tensor elements + const std::vector<int64_t>& permutation() const { return permutation_; } + + /// Dimension names of tensor elements. Dimensions are ordered logically. + const std::vector<std::string>& dim_names() const { return dim_names_; } + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::string Serialize() const override; + + Result<std::shared_ptr<DataType>> Deserialize( + std::shared_ptr<DataType> storage_type, + const std::string& serialized_data) const override; + + /// Create a FixedShapeTensorArray from ArrayData + std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override; + + /// \brief Create a FixedShapeTensorArray from a Tensor + /// + /// This function will create a FixedShapeTensorArray from a Tensor, taking it's + /// first dimension as the "element dimension" and the remaining dimensions as the + /// "tensor dimensions". The tensor dimensions must match the FixedShapeTensorType's + /// element shape. This function assumes that the tensor's memory layout is + /// row-major. + /// + /// \param[in] tensor The Tensor to convert to a FixedShapeTensorArray + Result<std::shared_ptr<Array>> MakeArray(std::shared_ptr<Tensor> tensor) const; + + /// \brief Create a Tensor from FixedShapeTensorArray + /// + /// This function will create a Tensor from a FixedShapeTensorArray, setting it's Review Comment: ```suggestion /// This function will create a Tensor from a FixedShapeTensorArray, setting its ``` ########## cpp/src/arrow/extension/fixed_shape_tensor_test.cc: ########## @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/fixed_shape_tensor.h" + +#include "arrow/testing/matchers.h" + +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/tensor.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/key_value_metadata.h" + +namespace arrow { + +using FixedShapeTensorType = extension::FixedShapeTensorType; +using extension::fixed_shape_tensor; + +class TestExtensionType : public ::testing::Test { + public: + void SetUp() override { + shape_ = {3, 3, 4}; + cell_shape_ = {3, 4}; + value_type_ = int64(); + cell_type_ = fixed_size_list(value_type_, 12); + dim_names_ = {"x", "y"}; + ext_type_ = fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_); + values_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}; + values_partial_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + shape_partial_ = {2, 3, 4}; + tensor_strides_ = {96, 32, 8}; + cell_strides_ = {32, 8}; + serialized_ = R"({"shape":[3,4],"dim_names":["x","y"]})"; + } + + protected: + std::vector<int64_t> shape_; + std::vector<int64_t> shape_partial_; + std::vector<int64_t> cell_shape_; + std::shared_ptr<DataType> value_type_; + std::shared_ptr<DataType> cell_type_; + std::vector<std::string> dim_names_; + std::shared_ptr<ExtensionType> ext_type_; + std::vector<int64_t> values_; + std::vector<int64_t> values_partial_; + std::vector<int64_t> tensor_strides_; + std::vector<int64_t> cell_strides_; + std::string serialized_; +}; + +auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch, + std::shared_ptr<RecordBatch>* out) { + ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), + out_stream.get())); + + ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); + + io::BufferReader reader(complete_ipc_stream); + std::shared_ptr<RecordBatchReader> batch_reader; + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); + ASSERT_OK(batch_reader->ReadNext(out)); +}; + +TEST_F(TestExtensionType, CheckDummyRegistration) { + // We need a dummy registration at runtime to allow for IPC deserialization + auto ext_type = fixed_shape_tensor(int64(), {}); + auto registered_type = GetExtensionType(ext_type->extension_name()); + ASSERT_TRUE(registered_type->Equals(*ext_type)); +} + +TEST_F(TestExtensionType, CreateExtensionType) { + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + + // Test ExtensionType methods + ASSERT_EQ(ext_type_->extension_name(), "arrow.fixed_shape_tensor"); + ASSERT_TRUE(ext_type_->Equals(*exact_ext_type)); + ASSERT_TRUE(ext_type_->storage_type()->Equals(*cell_type_)); + ASSERT_EQ(ext_type_->Serialize(), serialized_); + ASSERT_OK_AND_ASSIGN(auto ds, + ext_type_->Deserialize(ext_type_->storage_type(), serialized_)); + auto deserialized = std::reinterpret_pointer_cast<ExtensionType>(ds); + ASSERT_TRUE(deserialized->Equals(*ext_type_)); + + // Test FixedShapeTensorType methods + ASSERT_EQ(exact_ext_type->id(), Type::EXTENSION); + ASSERT_EQ(exact_ext_type->ndim(), cell_shape_.size()); + ASSERT_EQ(exact_ext_type->shape(), cell_shape_); + ASSERT_EQ(exact_ext_type->strides(), cell_strides_); + ASSERT_EQ(exact_ext_type->dim_names(), dim_names_); +} + +TEST_F(TestExtensionType, CreateFromArray) { + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + + std::vector<std::shared_ptr<Buffer>> buffers = {nullptr, Buffer::Wrap(values_)}; + auto arr_data = std::make_shared<ArrayData>(value_type_, values_.size(), buffers, 0, 0); + auto arr = std::make_shared<Int64Array>(arr_data); + EXPECT_OK_AND_ASSIGN(auto fsla_arr, FixedSizeListArray::FromArrays(arr, cell_type_)); + auto data = fsla_arr->data(); + data->type = ext_type_; + auto ext_arr = exact_ext_type->MakeArray(data); + ASSERT_EQ(ext_arr->length(), shape_[0]); + ASSERT_EQ(ext_arr->null_count(), 0); +} + +TEST_F(TestExtensionType, CreateFromTensor) { + std::vector<int64_t> column_major_strides = {8, 24, 72}; + std::vector<int64_t> neither_major_strides = {96, 8, 32}; + + ASSERT_OK_AND_ASSIGN(auto tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_)); + + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + EXPECT_OK_AND_ASSIGN(auto ext_arr, exact_ext_type->MakeArray(tensor)); + + ASSERT_OK(ext_arr->ValidateFull()); + ASSERT_TRUE(tensor->is_row_major()); + ASSERT_EQ(tensor->strides(), tensor_strides_); + ASSERT_EQ(ext_arr->length(), shape_[0]); + + auto ext_type_2 = fixed_shape_tensor(int64(), {3, 4}, {0, 1}); + EXPECT_OK_AND_ASSIGN(auto ext_arr_2, ext_type_2->MakeArray(tensor)); + + ASSERT_OK_AND_ASSIGN( + auto column_major_tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_, column_major_strides)); + auto ext_type_3 = fixed_shape_tensor(int64(), {3, 4}, {0, 1}); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + testing::HasSubstr( + "Invalid: Only first-major tensors can be zero-copy converted to arrays"), + ext_type_3->MakeArray(column_major_tensor)); + ASSERT_THAT(ext_type_3->MakeArray(column_major_tensor), Raises(StatusCode::Invalid)); + + auto neither_major_tensor = std::make_shared<Tensor>(value_type_, Buffer::Wrap(values_), + shape_, neither_major_strides); + auto ext_type_4 = fixed_shape_tensor(int64(), {3, 4}, {1, 0}); + ASSERT_OK_AND_ASSIGN(auto ext_arr_4, ext_type_4->MakeArray(neither_major_tensor)); +} + +TEST_F(TestExtensionType, RoundtripTensor) { + ASSERT_OK_AND_ASSIGN(auto tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_)); + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + EXPECT_OK_AND_ASSIGN(auto ext_arr, exact_ext_type->MakeArray(tensor)); + + EXPECT_OK_AND_ASSIGN(auto tensor_from_array, exact_ext_type->ToTensor(ext_arr)); + ASSERT_EQ(tensor_from_array->shape(), tensor->shape()); + ASSERT_EQ(tensor_from_array->strides(), tensor->strides()); + ASSERT_TRUE(tensor->Equals(*tensor_from_array)); +} + +TEST_F(TestExtensionType, SliceTensor) { + ASSERT_OK_AND_ASSIGN(auto tensor, + Tensor::Make(value_type_, Buffer::Wrap(values_), shape_)); + ASSERT_OK_AND_ASSIGN( + auto tensor_partial, + Tensor::Make(value_type_, Buffer::Wrap(values_partial_), shape_partial_)); + ASSERT_EQ(tensor->strides(), tensor_strides_); + ASSERT_EQ(tensor_partial->strides(), tensor_strides_); + auto ext_type = fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_); + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + + EXPECT_OK_AND_ASSIGN(auto ext_arr, exact_ext_type->MakeArray(tensor)); + EXPECT_OK_AND_ASSIGN(auto ext_arr_partial, exact_ext_type->MakeArray(tensor_partial)); + ASSERT_OK(ext_arr->ValidateFull()); + ASSERT_OK(ext_arr_partial->ValidateFull()); + + auto sliced = internal::checked_pointer_cast<ExtensionArray>(ext_arr->Slice(0, 2)); + auto partial = internal::checked_pointer_cast<ExtensionArray>(ext_arr_partial); + + ASSERT_TRUE(sliced->Equals(*partial)); + ASSERT_OK(sliced->ValidateFull()); + ASSERT_OK(partial->ValidateFull()); + ASSERT_TRUE(sliced->storage()->Equals(*partial->storage())); + ASSERT_EQ(sliced->length(), partial->length()); +} + +void CheckSerializationRoundtrip(const std::shared_ptr<ExtensionType>& ext_type) { + auto serialized = ext_type->Serialize(); + ASSERT_OK_AND_ASSIGN(auto deserialized, + ext_type->Deserialize(ext_type->storage_type(), serialized)); + ASSERT_TRUE(ext_type->Equals(*deserialized)); +} + +TEST_F(TestExtensionType, MetadataSerializationRoundtrip) { + CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {}, {}, {})); + CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {0}, {}, {})); + CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {1}, {0}, {"x"})); + CheckSerializationRoundtrip( + fixed_shape_tensor(value_type_, {256, 256, 3}, {0, 1, 2}, {"H", "W", "C"})); + CheckSerializationRoundtrip( + fixed_shape_tensor(value_type_, {256, 256, 3}, {2, 0, 1}, {"C", "H", "W"})); + + auto ext_type = fixed_shape_tensor(value_type_, cell_shape_, {0, 1}, dim_names_); + CheckSerializationRoundtrip(ext_type_); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("Invalid: Expected FixedSizeList storage type"), + ext_type->Deserialize(boolean(), serialized_)); Review Comment: Can you also add some tests calling `Deserialize` with wrong metadata? (eg missing required key, mismatching length between keys, .., i.e. the things that raise an error in the implementation) ########## cpp/src/arrow/extension/fixed_shape_tensor_test.cc: ########## @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/fixed_shape_tensor.h" + +#include "arrow/testing/matchers.h" + +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/tensor.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/key_value_metadata.h" + +namespace arrow { + +using FixedShapeTensorType = extension::FixedShapeTensorType; +using extension::fixed_shape_tensor; + +class TestExtensionType : public ::testing::Test { + public: + void SetUp() override { + shape_ = {3, 3, 4}; + cell_shape_ = {3, 4}; + value_type_ = int64(); + cell_type_ = fixed_size_list(value_type_, 12); + dim_names_ = {"x", "y"}; + ext_type_ = fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_); + values_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}; + values_partial_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + shape_partial_ = {2, 3, 4}; + tensor_strides_ = {96, 32, 8}; + cell_strides_ = {32, 8}; + serialized_ = R"({"shape":[3,4],"dim_names":["x","y"]})"; + } + + protected: + std::vector<int64_t> shape_; + std::vector<int64_t> shape_partial_; + std::vector<int64_t> cell_shape_; + std::shared_ptr<DataType> value_type_; + std::shared_ptr<DataType> cell_type_; + std::vector<std::string> dim_names_; + std::shared_ptr<ExtensionType> ext_type_; + std::vector<int64_t> values_; + std::vector<int64_t> values_partial_; + std::vector<int64_t> tensor_strides_; + std::vector<int64_t> cell_strides_; + std::string serialized_; +}; + +auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch, + std::shared_ptr<RecordBatch>* out) { + ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), + out_stream.get())); + + ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); + + io::BufferReader reader(complete_ipc_stream); + std::shared_ptr<RecordBatchReader> batch_reader; + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); + ASSERT_OK(batch_reader->ReadNext(out)); +}; + +TEST_F(TestExtensionType, CheckDummyRegistration) { + // We need a dummy registration at runtime to allow for IPC deserialization + auto ext_type = fixed_shape_tensor(int64(), {}); + auto registered_type = GetExtensionType(ext_type->extension_name()); + ASSERT_TRUE(registered_type->Equals(*ext_type)); +} + +TEST_F(TestExtensionType, CreateExtensionType) { + auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_); + + // Test ExtensionType methods + ASSERT_EQ(ext_type_->extension_name(), "arrow.fixed_shape_tensor"); + ASSERT_TRUE(ext_type_->Equals(*exact_ext_type)); Review Comment: ```suggestion ASSERT_TRUE(ext_type_->Equals(*exact_ext_type)); ASSERT_FALSE(ext_type_->Equals(*cell_type_)); ``` ########## cpp/src/arrow/extension/fixed_shape_tensor.h: ########## @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <numeric> +#include <sstream> + +#include "arrow/extension_type.h" + +namespace arrow { +namespace extension { + +const std::shared_ptr<DataType> GetStorageType( + const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape); + +const std::vector<int64_t> ComputeStrides(const std::shared_ptr<DataType>& value_type, + const std::vector<int64_t>& shape, + const std::vector<int64_t>& permutation); + +class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief Concrete type class for constant-size Tensor data. +class ARROW_EXPORT FixedShapeTensorType : public ExtensionType { + public: + FixedShapeTensorType(const std::shared_ptr<DataType>& value_type, + const std::vector<int64_t>& shape, + const std::vector<int64_t>& permutation = {}, + const std::vector<std::string>& dim_names = {}) + : ExtensionType(GetStorageType(value_type, shape)), + value_type_(value_type), + shape_(shape), + strides_(ComputeStrides(value_type, shape, permutation)), + permutation_(permutation), + dim_names_(dim_names) {} + + std::string extension_name() const override { return "arrow.fixed_shape_tensor"; } + + /// Number of dimensions of tensor elements + size_t ndim() { return shape_.size(); } + + /// Shape of tensor elements + const std::vector<int64_t>& shape() const { return shape_; } + + /// Strides of tensor elements. Strides state offset in bytes between adjacent + /// elements along each dimension. + const std::vector<int64_t>& strides() const { return strides_; } + + /// Permutation mapping from logical to physical memory layout of tensor elements + const std::vector<int64_t>& permutation() const { return permutation_; } + + /// Dimension names of tensor elements. Dimensions are ordered logically. + const std::vector<std::string>& dim_names() const { return dim_names_; } + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::string Serialize() const override; + + Result<std::shared_ptr<DataType>> Deserialize( + std::shared_ptr<DataType> storage_type, + const std::string& serialized_data) const override; + + /// Create a FixedShapeTensorArray from ArrayData + std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override; + + /// \brief Create a FixedShapeTensorArray from a Tensor + /// + /// This function will create a FixedShapeTensorArray from a Tensor, taking it's + /// first dimension as the "element dimension" and the remaining dimensions as the + /// "tensor dimensions". The tensor dimensions must match the FixedShapeTensorType's + /// element shape. This function assumes that the tensor's memory layout is + /// row-major. + /// + /// \param[in] tensor The Tensor to convert to a FixedShapeTensorArray + Result<std::shared_ptr<Array>> MakeArray(std::shared_ptr<Tensor> tensor) const; + + /// \brief Create a Tensor from FixedShapeTensorArray + /// + /// This function will create a Tensor from a FixedShapeTensorArray, setting it's + /// first dimension as length equal to the FixedShapeTensorArray's length and the + /// remaining dimensions as the FixedShapeTensorType's element shape. + /// + /// \param[in] arr The FixedShapeTensorArray to convert to a Tensor + Result<std::shared_ptr<Tensor>> ToTensor(std::shared_ptr<Array> arr) const; + + private: + std::shared_ptr<DataType> storage_type_; + std::shared_ptr<DataType> value_type_; + std::vector<int64_t> shape_; + std::vector<int64_t> strides_; + std::vector<int64_t> permutation_; + std::vector<std::string> dim_names_; +}; + +/// \brief Return a FixedShapeTensorType instance. +ARROW_EXPORT std::shared_ptr<FixedShapeTensorType> fixed_shape_tensor( Review Comment: Should this return a DataType instead of FixedShapeTensorType? Our other type factory functions do so. ########## cpp/src/arrow/extension/fixed_shape_tensor.cc: ########## @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/fixed_shape_tensor.h" + +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep +#include "arrow/tensor.h" +#include "arrow/util/logging.h" +#include "arrow/util/sort.h" + +#include <rapidjson/document.h> +#include <rapidjson/writer.h> + +namespace rj = arrow::rapidjson; + +namespace arrow { +namespace extension { + +bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const { + if (extension_name() != other.extension_name()) { + return false; + } + const auto& other_ext = static_cast<const FixedShapeTensorType&>(other); + bool equals = storage_type()->Equals(other_ext.storage_type()); + equals &= shape_ == other_ext.shape(); + equals &= permutation_ == other_ext.permutation(); + equals &= dim_names_ == other_ext.dim_names(); + return equals; +} + +std::string FixedShapeTensorType::Serialize() const { + rj::Document document; + document.SetObject(); + rj::Document::AllocatorType& allocator = document.GetAllocator(); + + rj::Value shape(rj::kArrayType); + for (auto v : shape_) { + shape.PushBack(v, allocator); + } + document.AddMember(rj::Value("shape", allocator), shape, allocator); + + if (!permutation_.empty()) { + rj::Value permutation(rj::kArrayType); + for (auto v : permutation_) { + permutation.PushBack(v, allocator); + } + document.AddMember(rj::Value("permutation", allocator), permutation, allocator); + } + + if (!dim_names_.empty()) { + rj::Value dim_names(rj::kArrayType); + for (std::string v : dim_names_) { + dim_names.PushBack(rj::Value{}.SetString(v.c_str(), allocator), allocator); + } + document.AddMember(rj::Value("dim_names", allocator), dim_names, allocator); + } + + rj::StringBuffer buffer; + rj::Writer<rj::StringBuffer> writer(buffer); + document.Accept(writer); + return buffer.GetString(); +} + +Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize( + std::shared_ptr<DataType> storage_type, const std::string& serialized_data) const { + if (storage_type->id() != Type::FIXED_SIZE_LIST) { + return Status::Invalid("Expected FixedSizeList storage type, got ", + storage_type->ToString()); + } + auto value_type = + internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type(); + rj::Document document; + if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() || + !document.HasMember("shape") || !document["shape"].IsArray()) { + return Status::Invalid("Invalid serialized JSON data: ", serialized_data); + } + + std::vector<int64_t> shape; + for (auto& x : document["shape"].GetArray()) { + shape.emplace_back(x.GetInt64()); + } + std::vector<int64_t> permutation; + if (document.HasMember("permutation")) { + for (auto& x : document["permutation"].GetArray()) { + permutation.emplace_back(x.GetInt64()); + } + if (shape.size() != permutation.size()) { + return Status::Invalid("Invalid permutation"); + } + } + std::vector<std::string> dim_names; + if (document.HasMember("dim_names")) { + for (auto& x : document["dim_names"].GetArray()) { + dim_names.emplace_back(x.GetString()); + } + if (shape.size() != dim_names.size()) { + return Status::Invalid("Invalid dim_names"); + } + } + + return fixed_shape_tensor(value_type, shape, permutation, dim_names); +} + +std::shared_ptr<Array> FixedShapeTensorType::MakeArray( + std::shared_ptr<ArrayData> data) const { + return std::make_shared<ExtensionArray>(data); +} + +Result<std::shared_ptr<Array>> FixedShapeTensorType::MakeArray( + std::shared_ptr<Tensor> tensor) const { + auto permutation = internal::ArgSort(tensor->strides()); + std::reverse(permutation.begin(), permutation.end()); + if (permutation[0] != 0) { + return Status::Invalid( + "Only first-major tensors can be zero-copy converted to arrays"); + } + + auto cell_shape = tensor->shape(); + cell_shape.erase(cell_shape.begin()); + if (cell_shape != shape_) { + return Status::Invalid("Expected cell shape does not match input tensor shape"); + } + + permutation.erase(permutation.begin()); + for (auto& x : permutation) { + x--; + } + + auto ext_type = + fixed_shape_tensor(tensor->type(), cell_shape, permutation, tensor->dim_names()); + + std::shared_ptr<FixedSizeListArray> arr; + std::shared_ptr<Array> value_array; + switch (tensor->type_id()) { + case Type::UINT8: { + value_array = std::make_shared<UInt8Array>(tensor->size(), tensor->data()); + break; + } + case Type::INT8: { + value_array = std::make_shared<Int8Array>(tensor->size(), tensor->data()); + break; + } + case Type::UINT16: { + value_array = std::make_shared<UInt16Array>(tensor->size(), tensor->data()); + break; + } + case Type::INT16: { + value_array = std::make_shared<Int16Array>(tensor->size(), tensor->data()); + break; + } + case Type::UINT32: { + value_array = std::make_shared<UInt32Array>(tensor->size(), tensor->data()); + break; + } + case Type::INT32: { + value_array = std::make_shared<Int32Array>(tensor->size(), tensor->data()); + break; + } + case Type::UINT64: { + value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data()); + break; + } + case Type::INT64: { + value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data()); + break; + } + case Type::HALF_FLOAT: { + value_array = std::make_shared<HalfFloatArray>(tensor->size(), tensor->data()); + break; + } + case Type::FLOAT: { + value_array = std::make_shared<FloatArray>(tensor->size(), tensor->data()); + break; + } + case Type::DOUBLE: { + value_array = std::make_shared<DoubleArray>(tensor->size(), tensor->data()); + break; + } + default: { + return Status::NotImplemented("Unsupported tensor type: ", + tensor->type()->ToString()); + } + } + arr = std::make_shared<FixedSizeListArray>(ext_type->storage_type(), tensor->shape()[0], + value_array); + auto ext_data = arr->data(); + ext_data->type = ext_type; + return MakeArray(ext_data); +} + +Result<std::shared_ptr<Tensor>> FixedShapeTensorType::ToTensor( + std::shared_ptr<Array> arr) const { + // To convert an array of n dimensional tensors to a n+1 dimensional tensor we + // interpret the array's length as the first dimension the new tensor. Further, we + // define n+1 dimensional tensor's strides by front appending a new stride to the n + // dimensional tensor's strides. + + ARROW_DCHECK_EQ(arr->null_count(), 0) << "Null values not supported in tensors."; + auto ext_arr = internal::checked_pointer_cast<FixedSizeListArray>( + internal::checked_pointer_cast<ExtensionArray>(arr)->storage()); + + std::vector<int64_t> shape = shape_; + shape.insert(shape.begin(), 1, arr->length()); + + std::vector<int64_t> tensor_strides = strides(); + tensor_strides.insert(tensor_strides.begin(), 1, arr->length() * tensor_strides[0]); + + std::shared_ptr<Buffer> buffer = ext_arr->values()->data()->buffers[1]; + return *Tensor::Make(ext_arr->value_type(), buffer, shape, tensor_strides, dim_names()); +} + +std::shared_ptr<FixedShapeTensorType> fixed_shape_tensor( + const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape, + const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) { + ARROW_CHECK(is_tensor_supported(value_type->id())); + + if (!permutation.empty()) { + ARROW_CHECK_EQ(shape.size(), permutation.size()) + << "permutation.size() == " << permutation.size() + << " must be empty or have the same length as shape.size() " << shape.size(); + } + if (!dim_names.empty()) { + ARROW_CHECK_EQ(shape.size(), dim_names.size()) Review Comment: Also check that all elements are strings? ########## cpp/src/arrow/extension/fixed_shape_tensor.cc: ########## @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/fixed_shape_tensor.h" + +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep +#include "arrow/tensor.h" +#include "arrow/util/logging.h" +#include "arrow/util/sort.h" + +#include <rapidjson/document.h> +#include <rapidjson/writer.h> + +namespace rj = arrow::rapidjson; + +namespace arrow { +namespace extension { + +bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const { + if (extension_name() != other.extension_name()) { + return false; + } + const auto& other_ext = static_cast<const FixedShapeTensorType&>(other); + bool equals = storage_type()->Equals(other_ext.storage_type()); + equals &= shape_ == other_ext.shape(); + equals &= permutation_ == other_ext.permutation(); + equals &= dim_names_ == other_ext.dim_names(); + return equals; +} + +std::string FixedShapeTensorType::Serialize() const { + rj::Document document; + document.SetObject(); + rj::Document::AllocatorType& allocator = document.GetAllocator(); + + rj::Value shape(rj::kArrayType); + for (auto v : shape_) { + shape.PushBack(v, allocator); + } + document.AddMember(rj::Value("shape", allocator), shape, allocator); + + if (!permutation_.empty()) { + rj::Value permutation(rj::kArrayType); + for (auto v : permutation_) { + permutation.PushBack(v, allocator); + } + document.AddMember(rj::Value("permutation", allocator), permutation, allocator); + } + + if (!dim_names_.empty()) { + rj::Value dim_names(rj::kArrayType); + for (std::string v : dim_names_) { + dim_names.PushBack(rj::Value{}.SetString(v.c_str(), allocator), allocator); + } + document.AddMember(rj::Value("dim_names", allocator), dim_names, allocator); + } + + rj::StringBuffer buffer; + rj::Writer<rj::StringBuffer> writer(buffer); + document.Accept(writer); + return buffer.GetString(); +} + +Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize( + std::shared_ptr<DataType> storage_type, const std::string& serialized_data) const { + if (storage_type->id() != Type::FIXED_SIZE_LIST) { + return Status::Invalid("Expected FixedSizeList storage type, got ", + storage_type->ToString()); + } + auto value_type = + internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type(); + rj::Document document; + if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() || + !document.HasMember("shape") || !document["shape"].IsArray()) { + return Status::Invalid("Invalid serialized JSON data: ", serialized_data); + } + + std::vector<int64_t> shape; + for (auto& x : document["shape"].GetArray()) { + shape.emplace_back(x.GetInt64()); + } + std::vector<int64_t> permutation; + if (document.HasMember("permutation")) { + for (auto& x : document["permutation"].GetArray()) { + permutation.emplace_back(x.GetInt64()); + } + if (shape.size() != permutation.size()) { + return Status::Invalid("Invalid permutation"); + } + } + std::vector<std::string> dim_names; + if (document.HasMember("dim_names")) { + for (auto& x : document["dim_names"].GetArray()) { + dim_names.emplace_back(x.GetString()); + } + if (shape.size() != dim_names.size()) { + return Status::Invalid("Invalid dim_names"); + } + } + + return fixed_shape_tensor(value_type, shape, permutation, dim_names); +} + +std::shared_ptr<Array> FixedShapeTensorType::MakeArray( + std::shared_ptr<ArrayData> data) const { + return std::make_shared<ExtensionArray>(data); +} + +Result<std::shared_ptr<Array>> FixedShapeTensorType::MakeArray( + std::shared_ptr<Tensor> tensor) const { + auto permutation = internal::ArgSort(tensor->strides()); + std::reverse(permutation.begin(), permutation.end()); + if (permutation[0] != 0) { + return Status::Invalid( + "Only first-major tensors can be zero-copy converted to arrays"); + } + + auto cell_shape = tensor->shape(); + cell_shape.erase(cell_shape.begin()); + if (cell_shape != shape_) { + return Status::Invalid("Expected cell shape does not match input tensor shape"); + } + + permutation.erase(permutation.begin()); + for (auto& x : permutation) { + x--; + } + + auto ext_type = + fixed_shape_tensor(tensor->type(), cell_shape, permutation, tensor->dim_names()); + + std::shared_ptr<FixedSizeListArray> arr; + std::shared_ptr<Array> value_array; + switch (tensor->type_id()) { + case Type::UINT8: { + value_array = std::make_shared<UInt8Array>(tensor->size(), tensor->data()); + break; + } + case Type::INT8: { + value_array = std::make_shared<Int8Array>(tensor->size(), tensor->data()); + break; + } + case Type::UINT16: { + value_array = std::make_shared<UInt16Array>(tensor->size(), tensor->data()); + break; + } + case Type::INT16: { + value_array = std::make_shared<Int16Array>(tensor->size(), tensor->data()); + break; + } + case Type::UINT32: { + value_array = std::make_shared<UInt32Array>(tensor->size(), tensor->data()); + break; + } + case Type::INT32: { + value_array = std::make_shared<Int32Array>(tensor->size(), tensor->data()); + break; + } + case Type::UINT64: { + value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data()); + break; + } + case Type::INT64: { + value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data()); + break; + } + case Type::HALF_FLOAT: { + value_array = std::make_shared<HalfFloatArray>(tensor->size(), tensor->data()); + break; + } + case Type::FLOAT: { + value_array = std::make_shared<FloatArray>(tensor->size(), tensor->data()); + break; + } + case Type::DOUBLE: { + value_array = std::make_shared<DoubleArray>(tensor->size(), tensor->data()); + break; + } + default: { + return Status::NotImplemented("Unsupported tensor type: ", + tensor->type()->ToString()); + } + } + arr = std::make_shared<FixedSizeListArray>(ext_type->storage_type(), tensor->shape()[0], + value_array); + auto ext_data = arr->data(); + ext_data->type = ext_type; + return MakeArray(ext_data); +} + +Result<std::shared_ptr<Tensor>> FixedShapeTensorType::ToTensor( + std::shared_ptr<Array> arr) const { + // To convert an array of n dimensional tensors to a n+1 dimensional tensor we + // interpret the array's length as the first dimension the new tensor. Further, we + // define n+1 dimensional tensor's strides by front appending a new stride to the n + // dimensional tensor's strides. + + ARROW_DCHECK_EQ(arr->null_count(), 0) << "Null values not supported in tensors."; + auto ext_arr = internal::checked_pointer_cast<FixedSizeListArray>( + internal::checked_pointer_cast<ExtensionArray>(arr)->storage()); + + std::vector<int64_t> shape = shape_; + shape.insert(shape.begin(), 1, arr->length()); + + std::vector<int64_t> tensor_strides = strides(); + tensor_strides.insert(tensor_strides.begin(), 1, arr->length() * tensor_strides[0]); + + std::shared_ptr<Buffer> buffer = ext_arr->values()->data()->buffers[1]; + return *Tensor::Make(ext_arr->value_type(), buffer, shape, tensor_strides, dim_names()); +} + +std::shared_ptr<FixedShapeTensorType> fixed_shape_tensor( + const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape, + const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) { + ARROW_CHECK(is_tensor_supported(value_type->id())); + + if (!permutation.empty()) { + ARROW_CHECK_EQ(shape.size(), permutation.size()) + << "permutation.size() == " << permutation.size() + << " must be empty or have the same length as shape.size() " << shape.size(); + } + if (!dim_names.empty()) { + ARROW_CHECK_EQ(shape.size(), dim_names.size()) Review Comment: And are those errors already tested? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
