This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 8ca0719  [ABI] Introduce ShapeView Minimize TensorObj exposure (#67)
8ca0719 is described below

commit 8ca0719f74bef289d80c8704343ed7c1607db8f3
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Sep 27 17:16:33 2025 -0400

    [ABI] Introduce ShapeView Minimize TensorObj exposure (#67)
    
    This PR minimizes TensorObj ABI exposure so C++ api only depends on
    behavior of the DLTensor field.
    We also introduce ShapeView to reduce managed copy of shape. The change
    will make future dependencies on C++ side more stable.
    
    We also added a few helper functions such as data_ptr(), ndim(), numel()
    to the ffi::Tensor.
---
 include/tvm/ffi/container/shape.h  | 130 +++++++++++++++++++++++----
 include/tvm/ffi/container/tensor.h | 177 ++++++++++++++++++-------------------
 pyproject.toml                     |   2 +-
 python/tvm_ffi/__init__.py         |   2 +-
 tests/cpp/test_shape.cc            |  24 +++++
 tests/cpp/test_tensor.cc           |   8 +-
 6 files changed, 230 insertions(+), 113 deletions(-)

diff --git a/include/tvm/ffi/container/shape.h 
b/include/tvm/ffi/container/shape.h
index de24a44..690c51b 100644
--- a/include/tvm/ffi/container/shape.h
+++ b/include/tvm/ffi/container/shape.h
@@ -36,6 +36,69 @@
 namespace tvm {
 namespace ffi {
 
+/*!
+ * \brief Lightweight view non-owning class for shape.
+ */
+class ShapeView {
+ public:
+  /*! \brief Default constructor. */
+  ShapeView() : cell_{nullptr, 0} {}
+  /*! \brief Copy constructor. */
+  ShapeView(const ShapeView& other) = default;
+  /*! \brief Copy assignment operator. */
+  ShapeView& operator=(const ShapeView& other) = default;
+  /*! \brief Move constructor. */
+  ShapeView(ShapeView&& other) = default;
+  /*! \brief Move assignment operator. */
+  ShapeView& operator=(ShapeView&& other) = default;
+  /*! \brief Constructor from data and size. */
+  ShapeView(const int64_t* data, size_t size) : cell_{data, size} {}
+  /*! \brief Constructor from initializer list. */
+  ShapeView(const std::initializer_list<int64_t>& other) : 
cell_{other.begin(), other.size()} {}
+  /*! \brief Get the data pointer. */
+  const int64_t* data() const { return cell_.data; }
+  /*! \brief Get the size of the shape. */
+  size_t size() const { return cell_.size; }
+
+  /*! \brief Get the product of the shape. */
+  int64_t Product() const {
+    int64_t product = 1;
+    for (size_t i = 0; i < cell_.size; ++i) {
+      product *= cell_.data[i];
+    }
+    return product;
+  }
+
+  /*! \brief Get the i-th element of the shape. */
+  int64_t operator[](size_t idx) const { return cell_.data[idx]; }
+
+  /*! \return begin iterator */
+  const int64_t* begin() const { return cell_.data; }
+
+  /*! \return end iterator */
+  const int64_t* end() const { return cell_.data + cell_.size; }
+
+  /*! \return Whether shape tuple is empty */
+  bool empty() const { return size() == 0; }
+
+  /*! \return The first element of the shape tuple */
+  int64_t front() const { return this->at(0); }
+
+  /*! \return The last element of the shape tuple */
+  int64_t back() const { return this->at(this->size() - 1); }
+
+  /*! \brief Get the i-th element of the shape. */
+  int64_t at(size_t idx) const {
+    if (idx >= this->size()) {
+      TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size 
" << this->size();
+    }
+    return cell_.data[idx];
+  }
+
+ private:
+  TVMFFIShapeCell cell_;
+};
+
 /*! \brief An object representing a shape tuple. */
 class ShapeObj : public Object, public TVMFFIShapeCell {
  public:
@@ -93,21 +156,41 @@ TVM_FFI_INLINE ObjectPtr<ShapeObj> 
MakeInplaceShape(IterType begin, IterType end
   return p;
 }
 
-TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(const int64_t* data, 
int64_t ndim) {
-  int64_t* strides_data;
-  ObjectPtr<ShapeObj> strides = details::MakeEmptyShape(ndim, &strides_data);
+/*!
+ * \brief Get the product of a shape.
+ * \param shape The input shape.
+ * \param out_strides The output strides.
+ * \return The product of the shape.
+ */
+TVM_FFI_INLINE void FillStridesFromShape(ShapeView shape, int64_t* 
out_strides) {
   int64_t stride = 1;
-  for (int i = ndim - 1; i >= 0; --i) {
-    strides_data[i] = stride;
-    stride *= data[i];
+  for (int64_t i = static_cast<int64_t>(shape.size()) - 1; i >= 0; --i) {
+    out_strides[i] = stride;
+    stride *= shape[i];
   }
+}
+
+/*!
+ * \brief Make a strides shape from a shape view.
+ * \param shape The input shape.
+ * \return The shape.
+ */
+TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(ShapeView shape) {
+  int64_t* strides_data;
+  ObjectPtr<ShapeObj> strides = details::MakeEmptyShape(shape.size(), 
&strides_data);
+  FillStridesFromShape(shape, strides_data);
   return strides;
 }
 
 }  // namespace details
 
 /*!
- * \brief Reference to shape object.
+ * \brief Managed reference to shape object.
+ *
+ * When possible, use ShapeView instead of Shape to reduce memory allocation.
+ * Use Shape when you need to have a managed reference to on-heap allocated 
shape.
+ *
+ * \sa ShapeView
  */
 class Shape : public ObjectRef {
  public:
@@ -149,16 +232,27 @@ class Shape : public ObjectRef {
   Shape(std::vector<int64_t> other)  // NOLINT(*)
       : ObjectRef(make_object<details::ShapeObjStdImpl>(std::move(other))) {}
 
+  /*!
+   * \brief constructor from shape view.
+   * \param other The shape view.
+   */
+  Shape(ShapeView other) : Shape(other.begin(), other.end()) {}  // NOLINT(*)
+
   /*!
    * \brief Create a strides from a shape.
-   * \param data The shape data.
-   * \param ndim The number of dimensions.
+   * \param shape The input shape.
    * \return The strides.
    */
-  static Shape StridesFromShape(const int64_t* data, int64_t ndim) {
-    return Shape(details::MakeStridesFromShape(data, ndim));
+  static Shape StridesFromShape(ShapeView shape) {
+    return Shape(details::MakeStridesFromShape(shape));
   }
 
+  /*!
+   * \brief Convert to shape view.
+   * \return The shape view.
+   */
+  operator ShapeView() const { return ShapeView(data(), size()); }  // 
NOLINT(*)
+
   /*!
    * \brief Return the data pointer
    *
@@ -178,19 +272,19 @@ class Shape : public ObjectRef {
    * \param idx The index
    * \return the i-th element.
    */
-  int64_t operator[](size_t idx) const {
-    if (idx >= this->size()) {
-      TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size 
" << this->size();
-    }
-    return this->data()[idx];
-  }
+  int64_t operator[](size_t idx) const { return this->data()[idx]; }
 
   /*!
    * \brief Immutably read i-th element from the shape tuple.
    * \param idx The index
    * \return the i-th element.
    */
-  int64_t at(size_t idx) const { return this->operator[](idx); }
+  int64_t at(size_t idx) const {
+    if (idx >= this->size()) {
+      TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size 
" << this->size();
+    }
+    return this->operator[](idx);
+  }
 
   /*! \return Whether shape tuple is empty */
   bool empty() const { return size() == 0; }
diff --git a/include/tvm/ffi/container/tensor.h 
b/include/tvm/ffi/container/tensor.h
index 9f698de..675de51 100644
--- a/include/tvm/ffi/container/tensor.h
+++ b/include/tvm/ffi/container/tensor.h
@@ -126,15 +126,7 @@ class TensorObj : public Object, public DLTensor {
   static constexpr const uint32_t _type_index = TypeIndex::kTVMFFITensor;
   TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFITensor, TensorObj, 
Object);
   /// \endcond
-  ~TensorObj() {
-    // deleting the cached dl managed tensor versioned
-    // need to acquire the value in case it is released by another thread
-    DLManagedTensorVersioned* cached =
-        cached_dl_managed_tensor_versioned_.load(std::memory_order_acquire);
-    if (cached != nullptr) {
-      delete cached;
-    }
-  }
+
   /*!
    * \brief Move a Tensor to a DLPack managed tensor.
    * \return The converted DLPack managed tensor.
@@ -144,7 +136,7 @@ class TensorObj : public Object, public DLTensor {
     DLManagedTensor* ret = new DLManagedTensor();
     ret->dl_tensor = *static_cast<DLTensor*>(self);
     ret->manager_ctx = self;
-    ret->deleter = DLManagedTensorDeleter;
+    ret->deleter = DLManagedTensorDeleter<DLManagedTensor>;
     details::ObjectUnsafe::IncRefObjectHandle(self);
     return ret;
   }
@@ -154,69 +146,29 @@ class TensorObj : public Object, public DLTensor {
    * \return The converted DLPack managed tensor.
    */
   DLManagedTensorVersioned* ToDLPackVersioned() const {
-    TensorObj* from = const_cast<TensorObj*>(this);
-    // if cache is set, directly return it
-    // we need to use acquire to ensure that write to DLManagedTensorVersioned
-    // from another thread is visible to this thread.
-    DLManagedTensorVersioned* cached =
-        cached_dl_managed_tensor_versioned_.load(std::memory_order_acquire);
-    // if cache is not set, create a new one
-    if (cached == nullptr) {
-      DLManagedTensorVersioned* ret = new DLManagedTensorVersioned();
-      ret->version.major = DLPACK_MAJOR_VERSION;
-      ret->version.minor = DLPACK_MINOR_VERSION;
-      ret->dl_tensor = *static_cast<DLTensor*>(from);
-      ret->manager_ctx = from;
-      ret->deleter = EmbeddedDLManagedTensorVersionedDeleter;
-      ret->flags = 0;
-      DLManagedTensorVersioned* expected = nullptr;
-      // success set must release the new value to all other threads
-      // failure set must acquire, since the expected value is now coming
-      // from another thread that released this value
-      if 
(std::atomic_compare_exchange_strong_explicit(&cached_dl_managed_tensor_versioned_,
-                                                       &expected, ret, 
std::memory_order_release,
-                                                       
std::memory_order_acquire)) {
-        // set is succes
-        cached = ret;
-      } else {
-        // delete the ret value as another thread raced to set this one first
-        delete ret;
-        cached = expected;
-      }
-      // at this point, cached is the value that officially set to the field
-    }
-    // inc the ref count of the from object
-    details::ObjectUnsafe::IncRefObjectHandle(from);
-    return cached;
+    TensorObj* self = const_cast<TensorObj*>(this);
+    DLManagedTensorVersioned* ret = new DLManagedTensorVersioned();
+    ret->version.major = DLPACK_MAJOR_VERSION;
+    ret->version.minor = DLPACK_MINOR_VERSION;
+    ret->dl_tensor = *static_cast<DLTensor*>(self);
+    ret->manager_ctx = self;
+    ret->deleter = DLManagedTensorDeleter<DLManagedTensorVersioned>;
+    details::ObjectUnsafe::IncRefObjectHandle(self);
+    return ret;
   }
 
  protected:
-  /*! \brief Internal data to back returning shape. */
-  Optional<Shape> shape_data_;
-  /*! \brief Internal data to back returning strides. */
-  Optional<Shape> strides_data_;
-  /*! \brief cached data to back returning DLManagedTensorVersioned. */
-  mutable std::atomic<DLManagedTensorVersioned*> 
cached_dl_managed_tensor_versioned_ = nullptr;
-
   /*!
    * \brief Deleter for DLManagedTensor.
    * \param tensor The DLManagedTensor to be deleted.
    */
-  static void DLManagedTensorDeleter(DLManagedTensor* tensor) {
+  template <typename TDLManagedTensor>
+  static void DLManagedTensorDeleter(TDLManagedTensor* tensor) {
     TensorObj* obj = static_cast<TensorObj*>(tensor->manager_ctx);
     details::ObjectUnsafe::DecRefObjectHandle(obj);
     delete tensor;
   }
 
-  /*!
-   * \brief Deleter for DLManagedTensorVersioned.
-   * \param tensor The DLManagedTensorVersioned to be deleted.
-   */
-  static void 
EmbeddedDLManagedTensorVersionedDeleter(DLManagedTensorVersioned* tensor) {
-    TensorObj* obj = static_cast<TensorObj*>(tensor->manager_ctx);
-    details::ObjectUnsafe::DecRefObjectHandle(obj);
-  }
-
   friend class Tensor;
 };
 
@@ -229,19 +181,22 @@ namespace details {
 template <typename TNDAlloc>
 class TensorObjFromNDAlloc : public TensorObj {
  public:
+  using Self = TensorObjFromNDAlloc<TNDAlloc>;
+
   template <typename... ExtraArgs>
-  TensorObjFromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType dtype, 
DLDevice device,
+  TensorObjFromNDAlloc(TNDAlloc alloc, ffi::ShapeView shape, DLDataType dtype, 
DLDevice device,
                        ExtraArgs&&... extra_args)
       : alloc_(alloc) {
     this->device = device;
     this->ndim = static_cast<int>(shape.size());
     this->dtype = dtype;
-    this->shape = const_cast<int64_t*>(shape.data());
-    Shape strides = Shape::StridesFromShape(this->shape, this->ndim);
-    this->strides = const_cast<int64_t*>(strides.data());
     this->byte_offset = 0;
-    this->shape_data_ = std::move(shape);
-    this->strides_data_ = std::move(strides);
+    // inplace alloc shape and strides after data structure
+    this->shape = reinterpret_cast<int64_t*>(reinterpret_cast<char*>(this) + 
sizeof(Self));
+    this->strides = this->shape + shape.size();
+    std::copy(shape.begin(), shape.end(), this->shape);
+    details::FillStridesFromShape(shape, this->strides);
+    // call allocator to alloc data
     alloc_.AllocData(static_cast<DLTensor*>(this), 
std::forward<ExtraArgs>(extra_args)...);
   }
 
@@ -255,12 +210,15 @@ class TensorObjFromNDAlloc : public TensorObj {
 template <typename TDLPackManagedTensor>
 class TensorObjFromDLPack : public TensorObj {
  public:
-  explicit TensorObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) 
{
+  using Self = TensorObjFromDLPack<TDLPackManagedTensor>;
+
+  explicit TensorObjFromDLPack(TDLPackManagedTensor* tensor, bool 
extra_strides_at_tail)
+      : tensor_(tensor) {
     *static_cast<DLTensor*>(this) = tensor_->dl_tensor;
-    if (tensor_->dl_tensor.strides == nullptr) {
-      Shape strides = Shape::StridesFromShape(tensor_->dl_tensor.shape, 
tensor_->dl_tensor.ndim);
-      this->strides = const_cast<int64_t*>(strides.data());
-      this->strides_data_ = std::move(strides);
+    if (extra_strides_at_tail) {
+      this->strides = reinterpret_cast<int64_t*>(reinterpret_cast<char*>(this) 
+ sizeof(Self));
+      details::FillStridesFromShape(ShapeView(tensor_->dl_tensor.shape, 
tensor_->dl_tensor.ndim),
+                                    this->strides);
     }
   }
 
@@ -289,25 +247,38 @@ class Tensor : public ObjectRef {
    * \brief Get the shape of the Tensor.
    * \return The shape of the Tensor.
    */
-  tvm::ffi::Shape shape() const {
-    TensorObj* obj = get_mutable();
-    if (!obj->shape_data_.has_value()) {
-      obj->shape_data_ = tvm::ffi::Shape(obj->shape, obj->shape + obj->ndim);
-    }
-    return *(obj->shape_data_);
+  ShapeView shape() const {
+    const TensorObj* obj = get();
+    return tvm::ffi::ShapeView(obj->shape, obj->ndim);
   }
   /*!
    * \brief Get the strides of the Tensor.
    * \return The strides of the Tensor.
    */
-  tvm::ffi::Shape strides() const {
-    TensorObj* obj = get_mutable();
+  ShapeView strides() const {
+    const TensorObj* obj = get();
     TVM_FFI_ICHECK(obj->strides != nullptr);
-    if (!obj->strides_data_.has_value()) {
-      obj->strides_data_ = tvm::ffi::Shape(obj->strides, obj->strides + 
obj->ndim);
-    }
-    return *(obj->strides_data_);
+    return ShapeView(obj->strides, obj->ndim);
   }
+
+  /*!
+   * \brief Get the data pointer of the Tensor.
+   * \return The data pointer of the Tensor.
+   */
+  void* data_ptr() const { return (*this)->data; }
+
+  /*!
+   * \brief Get the number of dimensions in the Tensor.
+   * \return The number of dimensions in the Tensor.
+   */
+  int32_t ndim() const { return (*this)->ndim; }
+
+  /*!
+   * \brief Get the number of elements in the Tensor.
+   * \return The number of elements in the Tensor.
+   */
+  int64_t numel() const { return this->shape().Product(); }
+
   /*!
    * \brief Get the data type of the Tensor.
    * \return The data type of the Tensor.
@@ -336,10 +307,13 @@ class Tensor : public ObjectRef {
    * \tparam ExtraArgs Extra arguments to be passed to Alloc.
    */
   template <typename TNDAlloc, typename... ExtraArgs>
-  static Tensor FromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType 
dtype, DLDevice device,
+  static Tensor FromNDAlloc(TNDAlloc alloc, ffi::ShapeView shape, DLDataType 
dtype, DLDevice device,
                             ExtraArgs&&... extra_args) {
-    return Tensor(make_object<details::TensorObjFromNDAlloc<TNDAlloc>>(
-        alloc, shape, dtype, device, std::forward<ExtraArgs>(extra_args)...));
+    // inplace alloc shape and strides after data structure (as a result why 
multiply 2)
+    size_t num_extra_i64_at_tail = shape.size() * 2;
+    return 
Tensor(make_inplace_array_object<details::TensorObjFromNDAlloc<TNDAlloc>, 
int64_t>(
+        num_extra_i64_at_tail, alloc, shape, dtype, device,
+        std::forward<ExtraArgs>(extra_args)...));
   }
   /*!
    * \brief Create a Tensor from a DLPackTensorAllocator
@@ -393,7 +367,15 @@ class Tensor : public ObjectRef {
       throw ffi::Error(error_context.kind, error_context.message,
                        TVMFFIBacktrace(__FILE__, __LINE__, __func__, 0));
     }
-    return 
Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(tensor));
+    if (tensor->dl_tensor.strides != nullptr) {
+      return 
Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(
+          tensor, /*extra_strides_at_tail=*/false));
+    } else {
+      return Tensor(
+          
make_inplace_array_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>,
+                                    int64_t>(tensor->dl_tensor.ndim, tensor,
+                                             /*extra_strides_at_tail=*/true));
+    }
   }
   /*!
    * \brief Create a Tensor from a DLPack managed tensor, pre v1.0 API.
@@ -412,7 +394,14 @@ class Tensor : public ObjectRef {
     if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) {
       TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous.";
     }
-    return 
Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensor>>(tensor));
+    if (tensor->dl_tensor.strides != nullptr) {
+      return Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensor>>(
+          tensor, /*extra_strides_at_tail=*/false));
+    } else {
+      return Tensor(
+          
make_inplace_array_object<details::TensorObjFromDLPack<DLManagedTensor>, 
int64_t>(
+              tensor->dl_tensor.ndim, tensor, /*extra_strides_at_tail=*/true));
+    }
   }
 
   /*!
@@ -434,7 +423,15 @@ class Tensor : public ObjectRef {
     if (tensor->flags & DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED) {
       TVM_FFI_THROW(RuntimeError) << "Subbyte type padded is not yet 
supported";
     }
-    return 
Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(tensor));
+    if (tensor->dl_tensor.strides != nullptr) {
+      return 
Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(
+          tensor, /*extra_strides_at_tail=*/false));
+    } else {
+      return Tensor(
+          
make_inplace_array_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>,
+                                    int64_t>(tensor->dl_tensor.ndim, tensor,
+                                             /*extra_strides_at_tail=*/true));
+    }
   }
 
   /*!
diff --git a/pyproject.toml b/pyproject.toml
index 66323e6..f734184 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,7 @@
 
 [project]
 name = "apache-tvm-ffi"
-version = "0.1.0b9"
+version = "0.1.0b10"
 description = "tvm ffi"
 
 authors = [{ name = "TVM FFI team" }]
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 31cda06..807f9a9 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -17,7 +17,7 @@
 """TVM FFI Python package."""
 
 # version
-__version__ = "0.1.0b9"
+__version__ = "0.1.0b10"
 
 # order matters here so we need to skip isort here
 # isort: skip_file
diff --git a/tests/cpp/test_shape.cc b/tests/cpp/test_shape.cc
index 0ccba78..cd67464 100644
--- a/tests/cpp/test_shape.cc
+++ b/tests/cpp/test_shape.cc
@@ -69,4 +69,28 @@ TEST(Shape, AnyConvert) {
   EXPECT_EQ(shape2[1], 2);
 }
 
+TEST(Shape, ShapeView) {
+  Shape shape = Shape({1, 2, 3});
+  ShapeView shape_view = shape;
+  EXPECT_EQ(shape_view.size(), 3);
+  EXPECT_EQ(shape_view[0], 1);
+  EXPECT_EQ(shape_view[1], 2);
+  EXPECT_EQ(shape_view[2], 3);
+
+  std::vector<int64_t> data = {4, 5, 6};
+  ShapeView view_from_data(data.data(), data.size());
+  EXPECT_EQ(view_from_data.size(), 3);
+  EXPECT_EQ(view_from_data[0], 4);
+  EXPECT_EQ(view_from_data[1], 5);
+  EXPECT_EQ(view_from_data[2], 6);
+
+  std::initializer_list<int64_t> init = {7, 8, 9};
+  ShapeView view_from_init = init;
+  EXPECT_EQ(view_from_init.size(), 3);
+  EXPECT_EQ(view_from_init[0], 7);
+  EXPECT_EQ(view_from_init[1], 8);
+  EXPECT_EQ(view_from_init[2], 9);
+  EXPECT_EQ(view_from_init.Product(), 7 * 8 * 9);
+}
+
 }  // namespace
diff --git a/tests/cpp/test_tensor.cc b/tests/cpp/test_tensor.cc
index 7c696a3..bb9158a 100644
--- a/tests/cpp/test_tensor.cc
+++ b/tests/cpp/test_tensor.cc
@@ -50,7 +50,7 @@ int TestDLPackTensorAllocatorError(DLTensor* prototype, 
DLManagedTensorVersioned
 }
 
 TEST(Tensor, Basic) {
-  Tensor nd = Empty(Shape({1, 2, 3}), DLDataType({kDLFloat, 32, 1}), 
DLDevice({kDLCPU, 0}));
+  Tensor nd = Empty({1, 2, 3}, DLDataType({kDLFloat, 32, 1}), 
DLDevice({kDLCPU, 0}));
   Shape shape = nd.shape();
   Shape strides = nd.strides();
   EXPECT_EQ(shape.size(), 3);
@@ -66,10 +66,12 @@ TEST(Tensor, Basic) {
     reinterpret_cast<float*>(nd->data)[i] = static_cast<float>(i);
   }
 
+  EXPECT_EQ(nd.numel(), 6);
+  EXPECT_EQ(nd.ndim(), 3);
+  EXPECT_EQ(nd.data_ptr(), nd->data);
+
   Any any0 = nd;
   Tensor nd2 = any0.as<Tensor>().value();
-  EXPECT_EQ(nd2.shape(), shape);
-  EXPECT_EQ(nd2.strides(), strides);
   EXPECT_EQ(nd2.dtype(), DLDataType({kDLFloat, 32, 1}));
   for (int64_t i = 0; i < shape.Product(); ++i) {
     EXPECT_EQ(reinterpret_cast<float*>(nd2->data)[i], i);

Reply via email to