This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 8bac6dd83d3e43aef0cc979599d281df497a71be Author: tqchen <[email protected]> AuthorDate: Tue Apr 22 13:29:00 2025 -0400 [FFI] Introduce NDArray with DLPack support --- ffi/CMakeLists.txt | 1 + ffi/include/tvm/ffi/any.h | 3 +- ffi/include/tvm/ffi/c_api.h | 44 +++++ ffi/include/tvm/ffi/container/ndarray.h | 328 ++++++++++++++++++++++++++++++++ ffi/include/tvm/ffi/container/shape.h | 4 +- ffi/include/tvm/ffi/object.h | 3 +- ffi/include/tvm/ffi/string.h | 3 +- ffi/src/ffi/ndarray.cc | 59 ++++++ ffi/tests/cpp/test_ndarray.cc | 111 +++++++++++ 9 files changed, 549 insertions(+), 7 deletions(-) diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index 7d07c3bd0b..9fed170014 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -61,6 +61,7 @@ add_library(tvm_ffi_objs OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/object.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/function.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/ndarray.cc" ) set_target_properties( tvm_ffi_objs PROPERTIES diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 518342b3ea..049c07066c 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -469,8 +469,7 @@ struct AnyEqual { details::AnyUnsafe::CopyFromAnyStorageAfterCheck<const BytesObjBase*>(lhs); const BytesObjBase* rhs_str = details::AnyUnsafe::CopyFromAnyStorageAfterCheck<const BytesObjBase*>(rhs); - return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, - rhs_str->size) == 0; + return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size) == 0; } return false; } diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index f0f609e4d1..6ac704070a 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -437,6 +437,50 @@ TVM_FFI_DLL int TVMFFIRegisterTypeField(int32_t type_index, const TVMFFIFieldInf */ TVM_FFI_DLL int TVMFFIRegisterTypeMethod(int32_t type_index, const TVMFFIMethodInfo* info); +//------------------------------------------------------------ +// Section: DLPack support APIs +//------------------------------------------------------------ +/*! + * \brief Produce a managed NDArray from a DLPack tensor. + * \param from The source DLPack tensor. + * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \param out The output NDArray handle. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFINDArrayFromDLPack(DLManagedTensor* from, int32_t require_alignment, + int32_t require_contiguous, TVMFFIObjectHandle* out); + +/*! + * \brief Produce a DLMangedTensor from the array that shares data memory with the array. + * \param from The source array. + * \param out The DLManagedTensor handle. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFINDArrayToDLPack(TVMFFIObjectHandle from, DLManagedTensor** out); + +/*! + * \brief Produce a managed NDArray from a DLPack tensor. + * \param from The source DLPack tensor. + * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \param out The output NDArray handle. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFINDArrayFromDLPackVersioned(DLManagedTensorVersioned* from, + int32_t require_alignment, + int32_t require_contiguous, + TVMFFIObjectHandle* out); + +/*! + * \brief Produce a DLMangedTensor from the array that shares data memory with the array. + * \param from The source array. + * \param out The DLManagedTensor handle. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle from, + DLManagedTensorVersioned** out); + //------------------------------------------------------------ // Section: Backend noexcept functions for internal use // diff --git a/ffi/include/tvm/ffi/container/ndarray.h b/ffi/include/tvm/ffi/container/ndarray.h new file mode 100644 index 0000000000..642825655f --- /dev/null +++ b/ffi/include/tvm/ffi/container/ndarray.h @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/ndarray.h + * \brief Container to store an NDArray. + */ +#ifndef TVM_FFI_CONTAINER_NDARRAY_H_ +#define TVM_FFI_CONTAINER_NDARRAY_H_ + +#include <tvm/ffi/container/shape.h> +#include <tvm/ffi/dtype.h> +#include <tvm/ffi/error.h> +#include <tvm/ffi/type_traits.h> + +namespace tvm { +namespace ffi { + +/*! + * \brief check if a DLTensor is contiguous. + * \param arr The input DLTensor. + * \return The check result. + */ +inline bool IsContiguous(const DLTensor& arr) { + if (arr.strides == nullptr) return true; + int64_t expected_stride = 1; + for (int32_t i = arr.ndim; i != 0; --i) { + int32_t k = i - 1; + if (arr.shape[k] == 1) { + // Skip stride check if shape[k] is 1, where the dimension is contiguous + // regardless of the value of stride. + // + // For example, PyTorch will normalize stride to 1 if shape is 1 when exporting + // to DLPack. + // More context: https://github.com/pytorch/pytorch/pull/83158 + continue; + } + if (arr.strides[k] != expected_stride) return false; + expected_stride *= arr.shape[k]; + } + return true; +} + +/** + * \brief Check if the data in the DLTensor is aligned to the given alignment. + * \param arr The input DLTensor. + * \param alignment The alignment to check. + * \return True if the data is aligned to the given alignment, false otherwise. + */ +inline bool IsAligned(const DLTensor& arr, size_t alignment) { + // whether the device uses direct address mapping instead of indirect buffer + bool direct_address = arr.device.device_type <= kDLCUDAHost || + arr.device.device_type == kDLCUDAManaged || + arr.device.device_type == kDLROCM || arr.device.device_type == kDLROCMHost; + if (direct_address) { + return (reinterpret_cast<size_t>(static_cast<char*>(arr.data) + arr.byte_offset) % alignment == + 0); + } else { + return arr.byte_offset % alignment == 0; + } +} + +/*! + * \brief return the total number bytes needs to store packed data + * + * \param numel the number of elements in the array + * \param dtype the data type of the array + * \return the total number bytes needs to store packed data + */ +inline size_t GetPackedDataSize(int64_t numel, DLDataType dtype) { + return (numel * dtype.bits * dtype.lanes + 7) / 8; +} + +/*! + * \brief return the size of data the DLTensor hold, in term of number of bytes + * + * \param arr the input DLTensor + * \return number of bytes of data in the DLTensor. + */ +inline size_t GetDataSize(const DLTensor& arr) { + size_t size = 1; + for (int i = 0; i < arr.ndim; ++i) { + size *= static_cast<size_t>(arr.shape[i]); + } + return GetPackedDataSize(size, arr.dtype); +} + +/*! \brief An object representing an NDArray. */ +class NDArrayObj : public Object, public DLTensor { + public: + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFINDArray; + static constexpr const char* _type_key = StaticTypeKey::kTVMFFINDArray; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(NDArrayObj, Object); + + /*! + * \brief Move NDArray to a DLPack managed tensor. + * \return The converted DLPack managed tensor. + */ + DLManagedTensor* ToDLPack() const { + DLManagedTensor* ret = new DLManagedTensor(); + NDArrayObj* from = const_cast<NDArrayObj*>(this); + ret->dl_tensor = *static_cast<DLTensor*>(from); + ret->manager_ctx = from; + ret->deleter = DLManagedTensorDeleter; + details::ObjectUnsafe::IncRefObjectHandle(from); + return ret; + } + + /*! + * \brief Move NDArray to a DLPack managed tensor. + * \return The converted DLPack managed tensor. + */ + DLManagedTensorVersioned* ToDLPackVersioned() const { + DLManagedTensorVersioned* ret = new DLManagedTensorVersioned(); + NDArrayObj* from = const_cast<NDArrayObj*>(this); + ret->version.major = DLPACK_MAJOR_VERSION; + ret->version.minor = DLPACK_MINOR_VERSION; + ret->dl_tensor = *static_cast<DLTensor*>(from); + ret->manager_ctx = from; + ret->deleter = DLManagedTensorVersionedDeleter; + ret->flags = 0; + details::ObjectUnsafe::IncRefObjectHandle(from); + return ret; + } + + protected: + // backs up the shape of the NDArray + Optional<Shape> shape_data_; + + static void DLManagedTensorDeleter(DLManagedTensor* tensor) { + NDArrayObj* obj = static_cast<NDArrayObj*>(tensor->manager_ctx); + details::ObjectUnsafe::DecRefObjectHandle(obj); + delete tensor; + } + + static void DLManagedTensorVersionedDeleter(DLManagedTensorVersioned* tensor) { + NDArrayObj* obj = static_cast<NDArrayObj*>(tensor->manager_ctx); + details::ObjectUnsafe::DecRefObjectHandle(obj); + delete tensor; + } + + friend class NDArray; +}; + +namespace details { +/*! + *\brief Helper class to create an NDArrayObj from an NDAllocator + * + * The underlying allocator needs to be implemented by user. + */ +template <typename TNDAlloc> +class NDArrayObjFromNDAlloc : public NDArrayObj { + public: + template <typename... ExtraArgs> + NDArrayObjFromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType dtype, DLDevice device, + ExtraArgs&&... extra_args) + : alloc_(alloc) { + this->device = device; + this->ndim = shape.size(); + this->dtype = dtype; + this->shape = const_cast<int64_t*>(shape.data()); + this->strides = nullptr; + this->byte_offset = 0; + this->shape_data_ = std::move(shape); + alloc_.AllocData(static_cast<DLTensor*>(this), std::forward<ExtraArgs>(extra_args)...); + } + + ~NDArrayObjFromNDAlloc() { alloc_.FreeData(static_cast<DLTensor*>(this)); } + + private: + TNDAlloc alloc_; +}; + +/*! \brief helper class to import from DLPack legacy DLManagedTensor */ +template <typename TDLPackManagedTensor> +class NDArrayObjFromDLPack : public NDArrayObj { + public: + NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) { + *static_cast<DLTensor*>(this) = tensor_->dl_tensor; + // set strides to nullptr if the tensor is contiguous. + if (IsContiguous(tensor->dl_tensor)) { + this->strides = nullptr; + } + } + + ~NDArrayObjFromDLPack() { + // run DLPack deleter if needed. + if (tensor_->deleter != nullptr) { + (*tensor_->deleter)(tensor_); + } + } + + private: + TDLPackManagedTensor* tensor_; +}; +} // namespace details + +/*! + * \brief Managed NDArray. + * The array is backed by reference counted blocks. + * + * \note This class can be subclassed to implement downstream customized + * NDArray types that are backed by the same NDArrayObj storage type. + */ +class NDArray : public ObjectRef { + public: + /*! + * \brief Get the shape of the NDArray. + * \return The shape of the NDArray. + */ + tvm::ffi::Shape shape() const { + NDArrayObj* obj = get_mutable(); + if (!obj->shape_data_.has_value()) { + obj->shape_data_ = tvm::ffi::Shape(obj->shape, obj->shape + obj->ndim); + } + return *(obj->shape_data_); + } + /*! + * \brief Get the data type of the NDArray. + * \return The data type of the NDArray. + */ + DLDataType dtype() const { return (*this)->dtype; } + /*! + * \brief Check if the NDArray is contiguous. + * \return True if the NDArray is contiguous, false otherwise. + */ + bool IsContiguous() const { return tvm::ffi::IsContiguous(*get()); } + /*! + * \brief Create a NDArray from a NDAllocator. + * \param alloc The NDAllocator. + * \param shape The shape of the NDArray. + * \param dtype The data type of the NDArray. + * \param device The device of the NDArray. + * \return The created NDArray. + * \tparam TNDAlloc The type of the NDAllocator, impelments Alloc and Free. + * \tparam ExtraArgs Extra arguments to be passed to Alloc. + */ + template <typename TNDAlloc, typename... ExtraArgs> + static NDArray FromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType dtype, DLDevice device, + ExtraArgs&&... extra_args) { + return NDArray(make_object<details::NDArrayObjFromNDAlloc<TNDAlloc>>( + alloc, shape, dtype, device, std::forward<ExtraArgs>(extra_args)...)); + } + + /*! + * \brief Create a NDArray from a DLPack managed tensor, pre v1.0 API. + * \param tensor The input DLPack managed tensor. + * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \note This function will not run any checks on flags. + * \return The created NDArray. + */ + static NDArray FromDLPack(DLManagedTensor* tensor, size_t require_alignment = 0, + bool require_contiguous = false) { + if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) { + TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment + << " bytes."; + } + if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) { + TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous."; + } + return NDArray(make_object<details::NDArrayObjFromDLPack<DLManagedTensor>>(tensor)); + } + + /*! + * \brief Create a NDArray from a DLPack managed tensor, post v1.0 API. + * \param tensor The input DLPack managed tensor. + * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \return The created NDArray. + */ + static NDArray FromDLPackVersioned(DLManagedTensorVersioned* tensor, size_t require_alignment = 0, + bool require_contiguous = false) { + if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) { + TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment + << " bytes."; + } + if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) { + TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous."; + } + if (tensor->flags & DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED) { + TVM_FFI_THROW(RuntimeError) << "Subbyte type padded is not yet supported"; + } + return NDArray(make_object<details::NDArrayObjFromDLPack<DLManagedTensorVersioned>>(tensor)); + } + + /*! + * \brief Convert the NDArray to a DLPack managed tensor. + * \return The converted DLPack managed tensor. + */ + DLManagedTensor* ToDLPack() const { return get_mutable()->ToDLPack(); } + + /*! + * \brief Convert the NDArray to a DLPack managed tensor. + * \return The converted DLPack managed tensor. + */ + DLManagedTensorVersioned* ToDLPackVersioned() const { return get_mutable()->ToDLPackVersioned(); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS(NDArray, ObjectRef, NDArrayObj); + + protected: + /*! + * \brief Get mutable internal container pointer. + * \return a mutable container pointer. + */ + NDArrayObj* get_mutable() const { return const_cast<NDArrayObj*>(get()); } +}; + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_CONTAINER_NDARRAY_H_ diff --git a/ffi/include/tvm/ffi/container/shape.h b/ffi/include/tvm/ffi/container/shape.h index 5fbb1b5ccb..25db67c2ea 100644 --- a/ffi/include/tvm/ffi/container/shape.h +++ b/ffi/include/tvm/ffi/container/shape.h @@ -28,10 +28,10 @@ #include <tvm/ffi/error.h> #include <tvm/ffi/type_traits.h> -#include <vector> -#include <ostream> #include <algorithm> +#include <ostream> #include <utility> +#include <vector> namespace tvm { namespace ffi { diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index e3cabaef3e..9056b90da6 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -52,9 +52,10 @@ struct StaticTypeKey { static constexpr const char* kTVMFFIRawStr = "const char*"; static constexpr const char* kTVMFFIByteArrayPtr = "TVMFFIByteArray*"; static constexpr const char* kTVMFFIObjectRValueRef = "ObjectRValueRef"; - static constexpr const char* kTVMFFIShape = "object.Shape"; static constexpr const char* kTVMFFIBytes = "object.Bytes"; static constexpr const char* kTVMFFIStr = "object.String"; + static constexpr const char* kTVMFFIShape = "object.Shape"; + static constexpr const char* kTVMFFINDArray = "object.NDArray"; }; /*! diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h index da9142a292..2fbbb551a1 100644 --- a/ffi/include/tvm/ffi/string.h +++ b/ffi/include/tvm/ffi/string.h @@ -48,8 +48,7 @@ namespace tvm { namespace ffi { /*! \brief Base class for bytes and string. */ -class BytesObjBase : public Object, public TVMFFIByteArray { -}; +class BytesObjBase : public Object, public TVMFFIByteArray {}; /*! * \brief An object representing bytes. diff --git a/ffi/src/ffi/ndarray.cc b/ffi/src/ffi/ndarray.cc new file mode 100644 index 0000000000..93656a6919 --- /dev/null +++ b/ffi/src/ffi/ndarray.cc @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/ndarray.cc + * \brief NDArray C API implementation + */ +#include <tvm/ffi/c_api.h> +#include <tvm/ffi/container/ndarray.h> +#include <tvm/ffi/function.h> + +int TVMFFINDArrayFromDLPack(DLManagedTensor* from, int64_t min_alignment, + int64_t require_contiguous, TVMFFIObjectHandle* out) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::NDArray nd = + tvm::ffi::NDArray::FromDLPack(from, static_cast<size_t>(min_alignment), require_contiguous); + *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(nd)); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFINDArrayFromDLPackVersioned(DLManagedTensorVersioned* from, int32_t min_alignment, + int32_t require_contiguous, TVMFFIObjectHandle* out) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::NDArray nd = tvm::ffi::NDArray::FromDLPackVersioned( + from, static_cast<size_t>(min_alignment), require_contiguous); + *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(nd)); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFINDArrayToDLPack(TVMFFIObjectHandle from, DLManagedTensor** out) { + TVM_FFI_SAFE_CALL_BEGIN(); + *out = tvm::ffi::details::ObjectUnsafe::RawObjectPtrFromUnowned<tvm::ffi::NDArrayObj>( + static_cast<TVMFFIObject*>(from)) + ->ToDLPack(); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle from, DLManagedTensorVersioned** out) { + TVM_FFI_SAFE_CALL_BEGIN(); + *out = tvm::ffi::details::ObjectUnsafe::RawObjectPtrFromUnowned<tvm::ffi::NDArrayObj>( + static_cast<TVMFFIObject*>(from)) + ->ToDLPackVersioned(); + TVM_FFI_SAFE_CALL_END(); +} diff --git a/ffi/tests/cpp/test_ndarray.cc b/ffi/tests/cpp/test_ndarray.cc new file mode 100644 index 0000000000..0284ceb818 --- /dev/null +++ b/ffi/tests/cpp/test_ndarray.cc @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <gtest/gtest.h> +#include <tvm/ffi/container/ndarray.h> + +namespace { + +using namespace tvm::ffi; + +struct CPUNDAlloc { + void AllocData(DLTensor* tensor) { tensor->data = malloc(GetDataSize(*tensor)); } + void FreeData(DLTensor* tensor) { free(tensor->data); } +}; + +inline NDArray Empty(Shape shape, DLDataType dtype, DLDevice device) { + return NDArray::FromNDAlloc(CPUNDAlloc(), shape, dtype, device); +} + +TEST(NDArray, Basic) { + NDArray nd = Empty(Shape({1, 2, 3}), DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); + Shape shape = nd.shape(); + EXPECT_EQ(shape.size(), 3); + EXPECT_EQ(shape[0], 1); + EXPECT_EQ(shape[1], 2); + EXPECT_EQ(shape[2], 3); + EXPECT_EQ(nd.dtype(), DLDataType({kDLFloat, 32, 1})); + for (int64_t i = 0; i < shape.Product(); ++i) { + reinterpret_cast<float*>(nd->data)[i] = i; + } + + Any any0 = nd; + NDArray nd2 = any0.as<NDArray>().value(); + EXPECT_EQ(nd2.shape(), shape); + EXPECT_EQ(nd2.dtype(), DLDataType({kDLFloat, 32, 1})); + for (int64_t i = 0; i < shape.Product(); ++i) { + EXPECT_EQ(reinterpret_cast<float*>(nd2->data)[i], i); + } + + EXPECT_EQ(nd.IsContiguous(), true); + EXPECT_EQ(nd2.use_count(), 3); +} + +TEST(NDArray, DLPack) { + NDArray nd = Empty({1, 2, 3}, DLDataType({kDLInt, 16, 1}), DLDevice({kDLCPU, 0})); + DLManagedTensor* dlpack = nd.ToDLPack(); + EXPECT_EQ(dlpack->dl_tensor.ndim, 3); + EXPECT_EQ(dlpack->dl_tensor.shape[0], 1); + EXPECT_EQ(dlpack->dl_tensor.shape[1], 2); + EXPECT_EQ(dlpack->dl_tensor.shape[2], 3); + EXPECT_EQ(dlpack->dl_tensor.dtype.code, kDLInt); + EXPECT_EQ(dlpack->dl_tensor.dtype.bits, 16); + EXPECT_EQ(dlpack->dl_tensor.dtype.lanes, 1); + EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU); + EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0); + EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0); + EXPECT_EQ(dlpack->dl_tensor.strides, nullptr); + EXPECT_EQ(nd.use_count(), 2); + { + NDArray nd2 = NDArray::FromDLPack(dlpack); + EXPECT_EQ(nd2.use_count(), 1); + EXPECT_EQ(nd2->data, nd->data); + EXPECT_EQ(nd.use_count(), 2); + EXPECT_EQ(nd2.use_count(), 1); + } + EXPECT_EQ(nd.use_count(), 1); +} + +TEST(NDArray, DLPackVersioned) { + DLDataType dtype = DLDataType({kDLFloat4_e2m1fn, 4, 1}); + EXPECT_EQ(GetPackedDataSize(2, dtype), 2 * 4 / 8); + NDArray nd = Empty({2}, dtype, DLDevice({kDLCPU, 0})); + DLManagedTensorVersioned* dlpack = nd.ToDLPackVersioned(); + EXPECT_EQ(dlpack->version.major, DLPACK_MAJOR_VERSION); + EXPECT_EQ(dlpack->version.minor, DLPACK_MINOR_VERSION); + EXPECT_EQ(dlpack->dl_tensor.ndim, 1); + EXPECT_EQ(dlpack->dl_tensor.shape[0], 2); + EXPECT_EQ(dlpack->dl_tensor.dtype.code, kDLFloat4_e2m1fn); + EXPECT_EQ(dlpack->dl_tensor.dtype.bits, 4); + EXPECT_EQ(dlpack->dl_tensor.dtype.lanes, 1); + EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU); + EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0); + EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0); + EXPECT_EQ(dlpack->dl_tensor.strides, nullptr); + + EXPECT_EQ(nd.use_count(), 2); + { + NDArray nd2 = NDArray::FromDLPackVersioned(dlpack); + EXPECT_EQ(nd2.use_count(), 1); + EXPECT_EQ(nd2->data, nd->data); + EXPECT_EQ(nd.use_count(), 2); + EXPECT_EQ(nd2.use_count(), 1); + } + EXPECT_EQ(nd.use_count(), 1); +} +} // namespace
