This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s2
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 8bac6dd83d3e43aef0cc979599d281df497a71be
Author: tqchen <[email protected]>
AuthorDate: Tue Apr 22 13:29:00 2025 -0400

    [FFI] Introduce NDArray with DLPack support
---
 ffi/CMakeLists.txt                      |   1 +
 ffi/include/tvm/ffi/any.h               |   3 +-
 ffi/include/tvm/ffi/c_api.h             |  44 +++++
 ffi/include/tvm/ffi/container/ndarray.h | 328 ++++++++++++++++++++++++++++++++
 ffi/include/tvm/ffi/container/shape.h   |   4 +-
 ffi/include/tvm/ffi/object.h            |   3 +-
 ffi/include/tvm/ffi/string.h            |   3 +-
 ffi/src/ffi/ndarray.cc                  |  59 ++++++
 ffi/tests/cpp/test_ndarray.cc           | 111 +++++++++++
 9 files changed, 549 insertions(+), 7 deletions(-)

diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt
index 7d07c3bd0b..9fed170014 100644
--- a/ffi/CMakeLists.txt
+++ b/ffi/CMakeLists.txt
@@ -61,6 +61,7 @@ add_library(tvm_ffi_objs OBJECT
   "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback.cc"
   "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/object.cc"
   "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/function.cc"
+  "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/ndarray.cc"
 )
 set_target_properties(
   tvm_ffi_objs PROPERTIES
diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index 518342b3ea..049c07066c 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -469,8 +469,7 @@ struct AnyEqual {
           details::AnyUnsafe::CopyFromAnyStorageAfterCheck<const 
BytesObjBase*>(lhs);
       const BytesObjBase* rhs_str =
           details::AnyUnsafe::CopyFromAnyStorageAfterCheck<const 
BytesObjBase*>(rhs);
-      return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size,
-                            rhs_str->size) == 0;
+      return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, 
rhs_str->size) == 0;
     }
     return false;
   }
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index f0f609e4d1..6ac704070a 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -437,6 +437,50 @@ TVM_FFI_DLL int TVMFFIRegisterTypeField(int32_t 
type_index, const TVMFFIFieldInf
  */
 TVM_FFI_DLL int TVMFFIRegisterTypeMethod(int32_t type_index, const 
TVMFFIMethodInfo* info);
 
+//------------------------------------------------------------
+// Section: DLPack support APIs
+//------------------------------------------------------------
+/*!
+ * \brief Produce a managed NDArray from a DLPack tensor.
+ * \param from The source DLPack tensor.
+ * \param require_alignment The minimum alignment requored of the data + 
byte_offset.
+ * \param require_contiguous Boolean flag indicating if we need to check for 
contiguity.
+ * \param out The output NDArray handle.
+ * \return 0 when success, nonzero when failure happens
+ */
+TVM_FFI_DLL int TVMFFINDArrayFromDLPack(DLManagedTensor* from, int32_t 
require_alignment,
+                                        int32_t require_contiguous, 
TVMFFIObjectHandle* out);
+
+/*!
+ * \brief Produce a DLMangedTensor from the array that shares data memory with 
the array.
+ * \param from The source array.
+ * \param out The DLManagedTensor handle.
+ * \return 0 when success, nonzero when failure happens
+ */
+TVM_FFI_DLL int TVMFFINDArrayToDLPack(TVMFFIObjectHandle from, 
DLManagedTensor** out);
+
+/*!
+ * \brief Produce a managed NDArray from a DLPack tensor.
+ * \param from The source DLPack tensor.
+ * \param require_alignment The minimum alignment requored of the data + 
byte_offset.
+ * \param require_contiguous Boolean flag indicating if we need to check for 
contiguity.
+ * \param out The output NDArray handle.
+ * \return 0 when success, nonzero when failure happens
+ */
+TVM_FFI_DLL int TVMFFINDArrayFromDLPackVersioned(DLManagedTensorVersioned* 
from,
+                                                 int32_t require_alignment,
+                                                 int32_t require_contiguous,
+                                                 TVMFFIObjectHandle* out);
+
+/*!
+ * \brief Produce a DLMangedTensor from the array that shares data memory with 
the array.
+ * \param from The source array.
+ * \param out The DLManagedTensor handle.
+ * \return 0 when success, nonzero when failure happens
+ */
+TVM_FFI_DLL int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle from,
+                                               DLManagedTensorVersioned** out);
+
 //------------------------------------------------------------
 // Section: Backend noexcept functions for internal use
 //
diff --git a/ffi/include/tvm/ffi/container/ndarray.h 
b/ffi/include/tvm/ffi/container/ndarray.h
new file mode 100644
index 0000000000..642825655f
--- /dev/null
+++ b/ffi/include/tvm/ffi/container/ndarray.h
@@ -0,0 +1,328 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ffi/ndarray.h
+ * \brief Container to store an NDArray.
+ */
+#ifndef TVM_FFI_CONTAINER_NDARRAY_H_
+#define TVM_FFI_CONTAINER_NDARRAY_H_
+
+#include <tvm/ffi/container/shape.h>
+#include <tvm/ffi/dtype.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/type_traits.h>
+
+namespace tvm {
+namespace ffi {
+
+/*!
+ * \brief check if a DLTensor is contiguous.
+ * \param arr The input DLTensor.
+ * \return The check result.
+ */
+inline bool IsContiguous(const DLTensor& arr) {
+  if (arr.strides == nullptr) return true;
+  int64_t expected_stride = 1;
+  for (int32_t i = arr.ndim; i != 0; --i) {
+    int32_t k = i - 1;
+    if (arr.shape[k] == 1) {
+      // Skip stride check if shape[k] is 1, where the dimension is contiguous
+      // regardless of the value of stride.
+      //
+      // For example, PyTorch will normalize stride to 1 if shape is 1 when 
exporting
+      // to DLPack.
+      // More context: https://github.com/pytorch/pytorch/pull/83158
+      continue;
+    }
+    if (arr.strides[k] != expected_stride) return false;
+    expected_stride *= arr.shape[k];
+  }
+  return true;
+}
+
+/**
+ * \brief Check if the data in the DLTensor is aligned to the given alignment.
+ * \param arr The input DLTensor.
+ * \param alignment The alignment to check.
+ * \return True if the data is aligned to the given alignment, false otherwise.
+ */
+inline bool IsAligned(const DLTensor& arr, size_t alignment) {
+  // whether the device uses direct address mapping instead of indirect buffer
+  bool direct_address = arr.device.device_type <= kDLCUDAHost ||
+                        arr.device.device_type == kDLCUDAManaged ||
+                        arr.device.device_type == kDLROCM || 
arr.device.device_type == kDLROCMHost;
+  if (direct_address) {
+    return (reinterpret_cast<size_t>(static_cast<char*>(arr.data) + 
arr.byte_offset) % alignment ==
+            0);
+  } else {
+    return arr.byte_offset % alignment == 0;
+  }
+}
+
+/*!
+ * \brief return the total number bytes needs to store packed data
+ *
+ * \param numel the number of elements in the array
+ * \param dtype the data type of the array
+ * \return the total number bytes needs to store packed data
+ */
+inline size_t GetPackedDataSize(int64_t numel, DLDataType dtype) {
+  return (numel * dtype.bits * dtype.lanes + 7) / 8;
+}
+
+/*!
+ * \brief return the size of data the DLTensor hold, in term of number of bytes
+ *
+ *  \param arr the input DLTensor
+ *  \return number of  bytes of data in the DLTensor.
+ */
+inline size_t GetDataSize(const DLTensor& arr) {
+  size_t size = 1;
+  for (int i = 0; i < arr.ndim; ++i) {
+    size *= static_cast<size_t>(arr.shape[i]);
+  }
+  return GetPackedDataSize(size, arr.dtype);
+}
+
+/*! \brief An object representing an NDArray. */
+class NDArrayObj : public Object, public DLTensor {
+ public:
+  static constexpr const uint32_t _type_index = TypeIndex::kTVMFFINDArray;
+  static constexpr const char* _type_key = StaticTypeKey::kTVMFFINDArray;
+  TVM_FFI_DECLARE_STATIC_OBJECT_INFO(NDArrayObj, Object);
+
+  /*!
+   * \brief Move NDArray to a DLPack managed tensor.
+   * \return The converted DLPack managed tensor.
+   */
+  DLManagedTensor* ToDLPack() const {
+    DLManagedTensor* ret = new DLManagedTensor();
+    NDArrayObj* from = const_cast<NDArrayObj*>(this);
+    ret->dl_tensor = *static_cast<DLTensor*>(from);
+    ret->manager_ctx = from;
+    ret->deleter = DLManagedTensorDeleter;
+    details::ObjectUnsafe::IncRefObjectHandle(from);
+    return ret;
+  }
+
+  /*!
+   * \brief Move  NDArray to a DLPack managed tensor.
+   * \return The converted DLPack managed tensor.
+   */
+  DLManagedTensorVersioned* ToDLPackVersioned() const {
+    DLManagedTensorVersioned* ret = new DLManagedTensorVersioned();
+    NDArrayObj* from = const_cast<NDArrayObj*>(this);
+    ret->version.major = DLPACK_MAJOR_VERSION;
+    ret->version.minor = DLPACK_MINOR_VERSION;
+    ret->dl_tensor = *static_cast<DLTensor*>(from);
+    ret->manager_ctx = from;
+    ret->deleter = DLManagedTensorVersionedDeleter;
+    ret->flags = 0;
+    details::ObjectUnsafe::IncRefObjectHandle(from);
+    return ret;
+  }
+
+ protected:
+  // backs up the shape of the NDArray
+  Optional<Shape> shape_data_;
+
+  static void DLManagedTensorDeleter(DLManagedTensor* tensor) {
+    NDArrayObj* obj = static_cast<NDArrayObj*>(tensor->manager_ctx);
+    details::ObjectUnsafe::DecRefObjectHandle(obj);
+    delete tensor;
+  }
+
+  static void DLManagedTensorVersionedDeleter(DLManagedTensorVersioned* 
tensor) {
+    NDArrayObj* obj = static_cast<NDArrayObj*>(tensor->manager_ctx);
+    details::ObjectUnsafe::DecRefObjectHandle(obj);
+    delete tensor;
+  }
+
+  friend class NDArray;
+};
+
+namespace details {
+/*!
+ *\brief Helper class to create an NDArrayObj from an NDAllocator
+ *
+ * The underlying allocator needs to be implemented by user.
+ */
+template <typename TNDAlloc>
+class NDArrayObjFromNDAlloc : public NDArrayObj {
+ public:
+  template <typename... ExtraArgs>
+  NDArrayObjFromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType dtype, 
DLDevice device,
+                        ExtraArgs&&... extra_args)
+      : alloc_(alloc) {
+    this->device = device;
+    this->ndim = shape.size();
+    this->dtype = dtype;
+    this->shape = const_cast<int64_t*>(shape.data());
+    this->strides = nullptr;
+    this->byte_offset = 0;
+    this->shape_data_ = std::move(shape);
+    alloc_.AllocData(static_cast<DLTensor*>(this), 
std::forward<ExtraArgs>(extra_args)...);
+  }
+
+  ~NDArrayObjFromNDAlloc() { alloc_.FreeData(static_cast<DLTensor*>(this)); }
+
+ private:
+  TNDAlloc alloc_;
+};
+
+/*! \brief helper class to import from DLPack legacy DLManagedTensor */
+template <typename TDLPackManagedTensor>
+class NDArrayObjFromDLPack : public NDArrayObj {
+ public:
+  NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) {
+    *static_cast<DLTensor*>(this) = tensor_->dl_tensor;
+    // set strides to nullptr if the tensor is contiguous.
+    if (IsContiguous(tensor->dl_tensor)) {
+      this->strides = nullptr;
+    }
+  }
+
+  ~NDArrayObjFromDLPack() {
+    // run DLPack deleter if needed.
+    if (tensor_->deleter != nullptr) {
+      (*tensor_->deleter)(tensor_);
+    }
+  }
+
+ private:
+  TDLPackManagedTensor* tensor_;
+};
+}  // namespace details
+
+/*!
+ * \brief Managed NDArray.
+ *  The array is backed by reference counted blocks.
+ *
+ * \note This class can be subclassed to implement downstream customized
+ *       NDArray types that are backed by the same NDArrayObj storage type.
+ */
+class NDArray : public ObjectRef {
+ public:
+  /*!
+   * \brief Get the shape of the NDArray.
+   * \return The shape of the NDArray.
+   */
+  tvm::ffi::Shape shape() const {
+    NDArrayObj* obj = get_mutable();
+    if (!obj->shape_data_.has_value()) {
+      obj->shape_data_ = tvm::ffi::Shape(obj->shape, obj->shape + obj->ndim);
+    }
+    return *(obj->shape_data_);
+  }
+  /*!
+   * \brief Get the data type of the NDArray.
+   * \return The data type of the NDArray.
+   */
+  DLDataType dtype() const { return (*this)->dtype; }
+  /*!
+   * \brief Check if the NDArray is contiguous.
+   * \return True if the NDArray is contiguous, false otherwise.
+   */
+  bool IsContiguous() const { return tvm::ffi::IsContiguous(*get()); }
+  /*!
+   * \brief Create a NDArray from a NDAllocator.
+   * \param alloc The NDAllocator.
+   * \param shape The shape of the NDArray.
+   * \param dtype The data type of the NDArray.
+   * \param device The device of the NDArray.
+   * \return The created NDArray.
+   * \tparam TNDAlloc The type of the NDAllocator, impelments Alloc and Free.
+   * \tparam ExtraArgs Extra arguments to be passed to Alloc.
+   */
+  template <typename TNDAlloc, typename... ExtraArgs>
+  static NDArray FromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType 
dtype, DLDevice device,
+                             ExtraArgs&&... extra_args) {
+    return NDArray(make_object<details::NDArrayObjFromNDAlloc<TNDAlloc>>(
+        alloc, shape, dtype, device, std::forward<ExtraArgs>(extra_args)...));
+  }
+
+  /*!
+   * \brief Create a NDArray from a DLPack managed tensor, pre v1.0 API.
+   * \param tensor The input DLPack managed tensor.
+   * \param require_alignment The minimum alignment requored of the data + 
byte_offset.
+   * \param require_contiguous Boolean flag indicating if we need to check for 
contiguity.
+   * \note This function will not run any checks on flags.
+   * \return The created NDArray.
+   */
+  static NDArray FromDLPack(DLManagedTensor* tensor, size_t require_alignment 
= 0,
+                            bool require_contiguous = false) {
+    if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, 
require_alignment)) {
+      TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << 
require_alignment
+                                  << " bytes.";
+    }
+    if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) {
+      TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous.";
+    }
+    return 
NDArray(make_object<details::NDArrayObjFromDLPack<DLManagedTensor>>(tensor));
+  }
+
+  /*!
+   * \brief Create a NDArray from a DLPack managed tensor, post v1.0 API.
+   * \param tensor The input DLPack managed tensor.
+   * \param require_alignment The minimum alignment requored of the data + 
byte_offset.
+   * \param require_contiguous Boolean flag indicating if we need to check for 
contiguity.
+   * \return The created NDArray.
+   */
+  static NDArray FromDLPackVersioned(DLManagedTensorVersioned* tensor, size_t 
require_alignment = 0,
+                                     bool require_contiguous = false) {
+    if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, 
require_alignment)) {
+      TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << 
require_alignment
+                                  << " bytes.";
+    }
+    if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) {
+      TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous.";
+    }
+    if (tensor->flags & DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED) {
+      TVM_FFI_THROW(RuntimeError) << "Subbyte type padded is not yet 
supported";
+    }
+    return 
NDArray(make_object<details::NDArrayObjFromDLPack<DLManagedTensorVersioned>>(tensor));
+  }
+
+  /*!
+   * \brief Convert the NDArray to a DLPack managed tensor.
+   * \return The converted DLPack managed tensor.
+   */
+  DLManagedTensor* ToDLPack() const { return get_mutable()->ToDLPack(); }
+
+  /*!
+   * \brief Convert the NDArray to a DLPack managed tensor.
+   * \return The converted DLPack managed tensor.
+   */
+  DLManagedTensorVersioned* ToDLPackVersioned() const { return 
get_mutable()->ToDLPackVersioned(); }
+
+  TVM_FFI_DEFINE_OBJECT_REF_METHODS(NDArray, ObjectRef, NDArrayObj);
+
+ protected:
+  /*!
+   * \brief Get mutable internal container pointer.
+   * \return a mutable container pointer.
+   */
+  NDArrayObj* get_mutable() const { return const_cast<NDArrayObj*>(get()); }
+};
+
+}  // namespace ffi
+}  // namespace tvm
+
+#endif  // TVM_FFI_CONTAINER_NDARRAY_H_
diff --git a/ffi/include/tvm/ffi/container/shape.h 
b/ffi/include/tvm/ffi/container/shape.h
index 5fbb1b5ccb..25db67c2ea 100644
--- a/ffi/include/tvm/ffi/container/shape.h
+++ b/ffi/include/tvm/ffi/container/shape.h
@@ -28,10 +28,10 @@
 #include <tvm/ffi/error.h>
 #include <tvm/ffi/type_traits.h>
 
-#include <vector>
-#include <ostream>
 #include <algorithm>
+#include <ostream>
 #include <utility>
+#include <vector>
 
 namespace tvm {
 namespace ffi {
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index e3cabaef3e..9056b90da6 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -52,9 +52,10 @@ struct StaticTypeKey {
   static constexpr const char* kTVMFFIRawStr = "const char*";
   static constexpr const char* kTVMFFIByteArrayPtr = "TVMFFIByteArray*";
   static constexpr const char* kTVMFFIObjectRValueRef = "ObjectRValueRef";
-  static constexpr const char* kTVMFFIShape = "object.Shape";
   static constexpr const char* kTVMFFIBytes = "object.Bytes";
   static constexpr const char* kTVMFFIStr = "object.String";
+  static constexpr const char* kTVMFFIShape = "object.Shape";
+  static constexpr const char* kTVMFFINDArray = "object.NDArray";
 };
 
 /*!
diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h
index da9142a292..2fbbb551a1 100644
--- a/ffi/include/tvm/ffi/string.h
+++ b/ffi/include/tvm/ffi/string.h
@@ -48,8 +48,7 @@ namespace tvm {
 namespace ffi {
 
 /*! \brief Base class for bytes and string. */
-class BytesObjBase : public Object, public TVMFFIByteArray {
-};
+class BytesObjBase : public Object, public TVMFFIByteArray {};
 
 /*!
  * \brief An object representing bytes.
diff --git a/ffi/src/ffi/ndarray.cc b/ffi/src/ffi/ndarray.cc
new file mode 100644
index 0000000000..93656a6919
--- /dev/null
+++ b/ffi/src/ffi/ndarray.cc
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * \file src/ffi/ndarray.cc
+ * \brief NDArray C API implementation
+ */
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/container/ndarray.h>
+#include <tvm/ffi/function.h>
+
+int TVMFFINDArrayFromDLPack(DLManagedTensor* from, int64_t min_alignment,
+                            int64_t require_contiguous, TVMFFIObjectHandle* 
out) {
+  TVM_FFI_SAFE_CALL_BEGIN();
+  tvm::ffi::NDArray nd =
+      tvm::ffi::NDArray::FromDLPack(from, static_cast<size_t>(min_alignment), 
require_contiguous);
+  *out = 
tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(nd));
+  TVM_FFI_SAFE_CALL_END();
+}
+
+int TVMFFINDArrayFromDLPackVersioned(DLManagedTensorVersioned* from, int32_t 
min_alignment,
+                                     int32_t require_contiguous, 
TVMFFIObjectHandle* out) {
+  TVM_FFI_SAFE_CALL_BEGIN();
+  tvm::ffi::NDArray nd = tvm::ffi::NDArray::FromDLPackVersioned(
+      from, static_cast<size_t>(min_alignment), require_contiguous);
+  *out = 
tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(nd));
+  TVM_FFI_SAFE_CALL_END();
+}
+
+int TVMFFINDArrayToDLPack(TVMFFIObjectHandle from, DLManagedTensor** out) {
+  TVM_FFI_SAFE_CALL_BEGIN();
+  *out = 
tvm::ffi::details::ObjectUnsafe::RawObjectPtrFromUnowned<tvm::ffi::NDArrayObj>(
+             static_cast<TVMFFIObject*>(from))
+             ->ToDLPack();
+  TVM_FFI_SAFE_CALL_END();
+}
+
+int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle from, 
DLManagedTensorVersioned** out) {
+  TVM_FFI_SAFE_CALL_BEGIN();
+  *out = 
tvm::ffi::details::ObjectUnsafe::RawObjectPtrFromUnowned<tvm::ffi::NDArrayObj>(
+             static_cast<TVMFFIObject*>(from))
+             ->ToDLPackVersioned();
+  TVM_FFI_SAFE_CALL_END();
+}
diff --git a/ffi/tests/cpp/test_ndarray.cc b/ffi/tests/cpp/test_ndarray.cc
new file mode 100644
index 0000000000..0284ceb818
--- /dev/null
+++ b/ffi/tests/cpp/test_ndarray.cc
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/container/ndarray.h>
+
+namespace {
+
+using namespace tvm::ffi;
+
+struct CPUNDAlloc {
+  void AllocData(DLTensor* tensor) { tensor->data = 
malloc(GetDataSize(*tensor)); }
+  void FreeData(DLTensor* tensor) { free(tensor->data); }
+};
+
+inline NDArray Empty(Shape shape, DLDataType dtype, DLDevice device) {
+  return NDArray::FromNDAlloc(CPUNDAlloc(), shape, dtype, device);
+}
+
+TEST(NDArray, Basic) {
+  NDArray nd = Empty(Shape({1, 2, 3}), DLDataType({kDLFloat, 32, 1}), 
DLDevice({kDLCPU, 0}));
+  Shape shape = nd.shape();
+  EXPECT_EQ(shape.size(), 3);
+  EXPECT_EQ(shape[0], 1);
+  EXPECT_EQ(shape[1], 2);
+  EXPECT_EQ(shape[2], 3);
+  EXPECT_EQ(nd.dtype(), DLDataType({kDLFloat, 32, 1}));
+  for (int64_t i = 0; i < shape.Product(); ++i) {
+    reinterpret_cast<float*>(nd->data)[i] = i;
+  }
+
+  Any any0 = nd;
+  NDArray nd2 = any0.as<NDArray>().value();
+  EXPECT_EQ(nd2.shape(), shape);
+  EXPECT_EQ(nd2.dtype(), DLDataType({kDLFloat, 32, 1}));
+  for (int64_t i = 0; i < shape.Product(); ++i) {
+    EXPECT_EQ(reinterpret_cast<float*>(nd2->data)[i], i);
+  }
+
+  EXPECT_EQ(nd.IsContiguous(), true);
+  EXPECT_EQ(nd2.use_count(), 3);
+}
+
+TEST(NDArray, DLPack) {
+  NDArray nd = Empty({1, 2, 3}, DLDataType({kDLInt, 16, 1}), DLDevice({kDLCPU, 
0}));
+  DLManagedTensor* dlpack = nd.ToDLPack();
+  EXPECT_EQ(dlpack->dl_tensor.ndim, 3);
+  EXPECT_EQ(dlpack->dl_tensor.shape[0], 1);
+  EXPECT_EQ(dlpack->dl_tensor.shape[1], 2);
+  EXPECT_EQ(dlpack->dl_tensor.shape[2], 3);
+  EXPECT_EQ(dlpack->dl_tensor.dtype.code, kDLInt);
+  EXPECT_EQ(dlpack->dl_tensor.dtype.bits, 16);
+  EXPECT_EQ(dlpack->dl_tensor.dtype.lanes, 1);
+  EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU);
+  EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0);
+  EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0);
+  EXPECT_EQ(dlpack->dl_tensor.strides, nullptr);
+  EXPECT_EQ(nd.use_count(), 2);
+  {
+    NDArray nd2 = NDArray::FromDLPack(dlpack);
+    EXPECT_EQ(nd2.use_count(), 1);
+    EXPECT_EQ(nd2->data, nd->data);
+    EXPECT_EQ(nd.use_count(), 2);
+    EXPECT_EQ(nd2.use_count(), 1);
+  }
+  EXPECT_EQ(nd.use_count(), 1);
+}
+
+TEST(NDArray, DLPackVersioned) {
+  DLDataType dtype = DLDataType({kDLFloat4_e2m1fn, 4, 1});
+  EXPECT_EQ(GetPackedDataSize(2, dtype), 2 * 4 / 8);
+  NDArray nd = Empty({2}, dtype, DLDevice({kDLCPU, 0}));
+  DLManagedTensorVersioned* dlpack = nd.ToDLPackVersioned();
+  EXPECT_EQ(dlpack->version.major, DLPACK_MAJOR_VERSION);
+  EXPECT_EQ(dlpack->version.minor, DLPACK_MINOR_VERSION);
+  EXPECT_EQ(dlpack->dl_tensor.ndim, 1);
+  EXPECT_EQ(dlpack->dl_tensor.shape[0], 2);
+  EXPECT_EQ(dlpack->dl_tensor.dtype.code, kDLFloat4_e2m1fn);
+  EXPECT_EQ(dlpack->dl_tensor.dtype.bits, 4);
+  EXPECT_EQ(dlpack->dl_tensor.dtype.lanes, 1);
+  EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU);
+  EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0);
+  EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0);
+  EXPECT_EQ(dlpack->dl_tensor.strides, nullptr);
+
+  EXPECT_EQ(nd.use_count(), 2);
+  {
+    NDArray nd2 = NDArray::FromDLPackVersioned(dlpack);
+    EXPECT_EQ(nd2.use_count(), 1);
+    EXPECT_EQ(nd2->data, nd->data);
+    EXPECT_EQ(nd.use_count(), 2);
+    EXPECT_EQ(nd2.use_count(), 1);
+  }
+  EXPECT_EQ(nd.use_count(), 1);
+}
+}  // namespace

Reply via email to