This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 8ca0719 [ABI] Introduce ShapeView Minimize TensorObj exposure (#67)
8ca0719 is described below
commit 8ca0719f74bef289d80c8704343ed7c1607db8f3
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Sep 27 17:16:33 2025 -0400
[ABI] Introduce ShapeView Minimize TensorObj exposure (#67)
This PR minimizes TensorObj ABI exposure so C++ api only depends on
behavior of the DLTensor field.
We also introduce ShapeView to reduce managed copy of shape. The change
will make future dependencies on C++ side more stable.
We also added a few helper functions such as data_ptr(), ndim(), numel()
to the ffi::Tensor.
---
include/tvm/ffi/container/shape.h | 130 +++++++++++++++++++++++----
include/tvm/ffi/container/tensor.h | 177 ++++++++++++++++++-------------------
pyproject.toml | 2 +-
python/tvm_ffi/__init__.py | 2 +-
tests/cpp/test_shape.cc | 24 +++++
tests/cpp/test_tensor.cc | 8 +-
6 files changed, 230 insertions(+), 113 deletions(-)
diff --git a/include/tvm/ffi/container/shape.h
b/include/tvm/ffi/container/shape.h
index de24a44..690c51b 100644
--- a/include/tvm/ffi/container/shape.h
+++ b/include/tvm/ffi/container/shape.h
@@ -36,6 +36,69 @@
namespace tvm {
namespace ffi {
+/*!
+ * \brief Lightweight view non-owning class for shape.
+ */
+class ShapeView {
+ public:
+ /*! \brief Default constructor. */
+ ShapeView() : cell_{nullptr, 0} {}
+ /*! \brief Copy constructor. */
+ ShapeView(const ShapeView& other) = default;
+ /*! \brief Copy assignment operator. */
+ ShapeView& operator=(const ShapeView& other) = default;
+ /*! \brief Move constructor. */
+ ShapeView(ShapeView&& other) = default;
+ /*! \brief Move assignment operator. */
+ ShapeView& operator=(ShapeView&& other) = default;
+ /*! \brief Constructor from data and size. */
+ ShapeView(const int64_t* data, size_t size) : cell_{data, size} {}
+ /*! \brief Constructor from initializer list. */
+ ShapeView(const std::initializer_list<int64_t>& other) :
cell_{other.begin(), other.size()} {}
+ /*! \brief Get the data pointer. */
+ const int64_t* data() const { return cell_.data; }
+ /*! \brief Get the size of the shape. */
+ size_t size() const { return cell_.size; }
+
+ /*! \brief Get the product of the shape. */
+ int64_t Product() const {
+ int64_t product = 1;
+ for (size_t i = 0; i < cell_.size; ++i) {
+ product *= cell_.data[i];
+ }
+ return product;
+ }
+
+ /*! \brief Get the i-th element of the shape. */
+ int64_t operator[](size_t idx) const { return cell_.data[idx]; }
+
+ /*! \return begin iterator */
+ const int64_t* begin() const { return cell_.data; }
+
+ /*! \return end iterator */
+ const int64_t* end() const { return cell_.data + cell_.size; }
+
+ /*! \return Whether shape tuple is empty */
+ bool empty() const { return size() == 0; }
+
+ /*! \return The first element of the shape tuple */
+ int64_t front() const { return this->at(0); }
+
+ /*! \return The last element of the shape tuple */
+ int64_t back() const { return this->at(this->size() - 1); }
+
+ /*! \brief Get the i-th element of the shape. */
+ int64_t at(size_t idx) const {
+ if (idx >= this->size()) {
+ TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size
" << this->size();
+ }
+ return cell_.data[idx];
+ }
+
+ private:
+ TVMFFIShapeCell cell_;
+};
+
/*! \brief An object representing a shape tuple. */
class ShapeObj : public Object, public TVMFFIShapeCell {
public:
@@ -93,21 +156,41 @@ TVM_FFI_INLINE ObjectPtr<ShapeObj>
MakeInplaceShape(IterType begin, IterType end
return p;
}
-TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(const int64_t* data,
int64_t ndim) {
- int64_t* strides_data;
- ObjectPtr<ShapeObj> strides = details::MakeEmptyShape(ndim, &strides_data);
+/*!
+ * \brief Get the product of a shape.
+ * \param shape The input shape.
+ * \param out_strides The output strides.
+ * \return The product of the shape.
+ */
+TVM_FFI_INLINE void FillStridesFromShape(ShapeView shape, int64_t*
out_strides) {
int64_t stride = 1;
- for (int i = ndim - 1; i >= 0; --i) {
- strides_data[i] = stride;
- stride *= data[i];
+ for (int64_t i = static_cast<int64_t>(shape.size()) - 1; i >= 0; --i) {
+ out_strides[i] = stride;
+ stride *= shape[i];
}
+}
+
+/*!
+ * \brief Make a strides shape from a shape view.
+ * \param shape The input shape.
+ * \return The shape.
+ */
+TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(ShapeView shape) {
+ int64_t* strides_data;
+ ObjectPtr<ShapeObj> strides = details::MakeEmptyShape(shape.size(),
&strides_data);
+ FillStridesFromShape(shape, strides_data);
return strides;
}
} // namespace details
/*!
- * \brief Reference to shape object.
+ * \brief Managed reference to shape object.
+ *
+ * When possible, use ShapeView instead of Shape to reduce memory allocation.
+ * Use Shape when you need to have a managed reference to on-heap allocated
shape.
+ *
+ * \sa ShapeView
*/
class Shape : public ObjectRef {
public:
@@ -149,16 +232,27 @@ class Shape : public ObjectRef {
Shape(std::vector<int64_t> other) // NOLINT(*)
: ObjectRef(make_object<details::ShapeObjStdImpl>(std::move(other))) {}
+ /*!
+ * \brief constructor from shape view.
+ * \param other The shape view.
+ */
+ Shape(ShapeView other) : Shape(other.begin(), other.end()) {} // NOLINT(*)
+
/*!
* \brief Create a strides from a shape.
- * \param data The shape data.
- * \param ndim The number of dimensions.
+ * \param shape The input shape.
* \return The strides.
*/
- static Shape StridesFromShape(const int64_t* data, int64_t ndim) {
- return Shape(details::MakeStridesFromShape(data, ndim));
+ static Shape StridesFromShape(ShapeView shape) {
+ return Shape(details::MakeStridesFromShape(shape));
}
+ /*!
+ * \brief Convert to shape view.
+ * \return The shape view.
+ */
+ operator ShapeView() const { return ShapeView(data(), size()); } //
NOLINT(*)
+
/*!
* \brief Return the data pointer
*
@@ -178,19 +272,19 @@ class Shape : public ObjectRef {
* \param idx The index
* \return the i-th element.
*/
- int64_t operator[](size_t idx) const {
- if (idx >= this->size()) {
- TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size
" << this->size();
- }
- return this->data()[idx];
- }
+ int64_t operator[](size_t idx) const { return this->data()[idx]; }
/*!
* \brief Immutably read i-th element from the shape tuple.
* \param idx The index
* \return the i-th element.
*/
- int64_t at(size_t idx) const { return this->operator[](idx); }
+ int64_t at(size_t idx) const {
+ if (idx >= this->size()) {
+ TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size
" << this->size();
+ }
+ return this->operator[](idx);
+ }
/*! \return Whether shape tuple is empty */
bool empty() const { return size() == 0; }
diff --git a/include/tvm/ffi/container/tensor.h
b/include/tvm/ffi/container/tensor.h
index 9f698de..675de51 100644
--- a/include/tvm/ffi/container/tensor.h
+++ b/include/tvm/ffi/container/tensor.h
@@ -126,15 +126,7 @@ class TensorObj : public Object, public DLTensor {
static constexpr const uint32_t _type_index = TypeIndex::kTVMFFITensor;
TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFITensor, TensorObj,
Object);
/// \endcond
- ~TensorObj() {
- // deleting the cached dl managed tensor versioned
- // need to acquire the value in case it is released by another thread
- DLManagedTensorVersioned* cached =
- cached_dl_managed_tensor_versioned_.load(std::memory_order_acquire);
- if (cached != nullptr) {
- delete cached;
- }
- }
+
/*!
* \brief Move a Tensor to a DLPack managed tensor.
* \return The converted DLPack managed tensor.
@@ -144,7 +136,7 @@ class TensorObj : public Object, public DLTensor {
DLManagedTensor* ret = new DLManagedTensor();
ret->dl_tensor = *static_cast<DLTensor*>(self);
ret->manager_ctx = self;
- ret->deleter = DLManagedTensorDeleter;
+ ret->deleter = DLManagedTensorDeleter<DLManagedTensor>;
details::ObjectUnsafe::IncRefObjectHandle(self);
return ret;
}
@@ -154,69 +146,29 @@ class TensorObj : public Object, public DLTensor {
* \return The converted DLPack managed tensor.
*/
DLManagedTensorVersioned* ToDLPackVersioned() const {
- TensorObj* from = const_cast<TensorObj*>(this);
- // if cache is set, directly return it
- // we need to use acquire to ensure that write to DLManagedTensorVersioned
- // from another thread is visible to this thread.
- DLManagedTensorVersioned* cached =
- cached_dl_managed_tensor_versioned_.load(std::memory_order_acquire);
- // if cache is not set, create a new one
- if (cached == nullptr) {
- DLManagedTensorVersioned* ret = new DLManagedTensorVersioned();
- ret->version.major = DLPACK_MAJOR_VERSION;
- ret->version.minor = DLPACK_MINOR_VERSION;
- ret->dl_tensor = *static_cast<DLTensor*>(from);
- ret->manager_ctx = from;
- ret->deleter = EmbeddedDLManagedTensorVersionedDeleter;
- ret->flags = 0;
- DLManagedTensorVersioned* expected = nullptr;
- // success set must release the new value to all other threads
- // failure set must acquire, since the expected value is now coming
- // from another thread that released this value
- if
(std::atomic_compare_exchange_strong_explicit(&cached_dl_managed_tensor_versioned_,
- &expected, ret,
std::memory_order_release,
-
std::memory_order_acquire)) {
- // set is succes
- cached = ret;
- } else {
- // delete the ret value as another thread raced to set this one first
- delete ret;
- cached = expected;
- }
- // at this point, cached is the value that officially set to the field
- }
- // inc the ref count of the from object
- details::ObjectUnsafe::IncRefObjectHandle(from);
- return cached;
+ TensorObj* self = const_cast<TensorObj*>(this);
+ DLManagedTensorVersioned* ret = new DLManagedTensorVersioned();
+ ret->version.major = DLPACK_MAJOR_VERSION;
+ ret->version.minor = DLPACK_MINOR_VERSION;
+ ret->dl_tensor = *static_cast<DLTensor*>(self);
+ ret->manager_ctx = self;
+ ret->deleter = DLManagedTensorDeleter<DLManagedTensorVersioned>;
+ details::ObjectUnsafe::IncRefObjectHandle(self);
+ return ret;
}
protected:
- /*! \brief Internal data to back returning shape. */
- Optional<Shape> shape_data_;
- /*! \brief Internal data to back returning strides. */
- Optional<Shape> strides_data_;
- /*! \brief cached data to back returning DLManagedTensorVersioned. */
- mutable std::atomic<DLManagedTensorVersioned*>
cached_dl_managed_tensor_versioned_ = nullptr;
-
/*!
* \brief Deleter for DLManagedTensor.
* \param tensor The DLManagedTensor to be deleted.
*/
- static void DLManagedTensorDeleter(DLManagedTensor* tensor) {
+ template <typename TDLManagedTensor>
+ static void DLManagedTensorDeleter(TDLManagedTensor* tensor) {
TensorObj* obj = static_cast<TensorObj*>(tensor->manager_ctx);
details::ObjectUnsafe::DecRefObjectHandle(obj);
delete tensor;
}
- /*!
- * \brief Deleter for DLManagedTensorVersioned.
- * \param tensor The DLManagedTensorVersioned to be deleted.
- */
- static void
EmbeddedDLManagedTensorVersionedDeleter(DLManagedTensorVersioned* tensor) {
- TensorObj* obj = static_cast<TensorObj*>(tensor->manager_ctx);
- details::ObjectUnsafe::DecRefObjectHandle(obj);
- }
-
friend class Tensor;
};
@@ -229,19 +181,22 @@ namespace details {
template <typename TNDAlloc>
class TensorObjFromNDAlloc : public TensorObj {
public:
+ using Self = TensorObjFromNDAlloc<TNDAlloc>;
+
template <typename... ExtraArgs>
- TensorObjFromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType dtype,
DLDevice device,
+ TensorObjFromNDAlloc(TNDAlloc alloc, ffi::ShapeView shape, DLDataType dtype,
DLDevice device,
ExtraArgs&&... extra_args)
: alloc_(alloc) {
this->device = device;
this->ndim = static_cast<int>(shape.size());
this->dtype = dtype;
- this->shape = const_cast<int64_t*>(shape.data());
- Shape strides = Shape::StridesFromShape(this->shape, this->ndim);
- this->strides = const_cast<int64_t*>(strides.data());
this->byte_offset = 0;
- this->shape_data_ = std::move(shape);
- this->strides_data_ = std::move(strides);
+ // inplace alloc shape and strides after data structure
+ this->shape = reinterpret_cast<int64_t*>(reinterpret_cast<char*>(this) +
sizeof(Self));
+ this->strides = this->shape + shape.size();
+ std::copy(shape.begin(), shape.end(), this->shape);
+ details::FillStridesFromShape(shape, this->strides);
+ // call allocator to alloc data
alloc_.AllocData(static_cast<DLTensor*>(this),
std::forward<ExtraArgs>(extra_args)...);
}
@@ -255,12 +210,15 @@ class TensorObjFromNDAlloc : public TensorObj {
template <typename TDLPackManagedTensor>
class TensorObjFromDLPack : public TensorObj {
public:
- explicit TensorObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor)
{
+ using Self = TensorObjFromDLPack<TDLPackManagedTensor>;
+
+ explicit TensorObjFromDLPack(TDLPackManagedTensor* tensor, bool
extra_strides_at_tail)
+ : tensor_(tensor) {
*static_cast<DLTensor*>(this) = tensor_->dl_tensor;
- if (tensor_->dl_tensor.strides == nullptr) {
- Shape strides = Shape::StridesFromShape(tensor_->dl_tensor.shape,
tensor_->dl_tensor.ndim);
- this->strides = const_cast<int64_t*>(strides.data());
- this->strides_data_ = std::move(strides);
+ if (extra_strides_at_tail) {
+ this->strides = reinterpret_cast<int64_t*>(reinterpret_cast<char*>(this)
+ sizeof(Self));
+ details::FillStridesFromShape(ShapeView(tensor_->dl_tensor.shape,
tensor_->dl_tensor.ndim),
+ this->strides);
}
}
@@ -289,25 +247,38 @@ class Tensor : public ObjectRef {
* \brief Get the shape of the Tensor.
* \return The shape of the Tensor.
*/
- tvm::ffi::Shape shape() const {
- TensorObj* obj = get_mutable();
- if (!obj->shape_data_.has_value()) {
- obj->shape_data_ = tvm::ffi::Shape(obj->shape, obj->shape + obj->ndim);
- }
- return *(obj->shape_data_);
+ ShapeView shape() const {
+ const TensorObj* obj = get();
+ return tvm::ffi::ShapeView(obj->shape, obj->ndim);
}
/*!
* \brief Get the strides of the Tensor.
* \return The strides of the Tensor.
*/
- tvm::ffi::Shape strides() const {
- TensorObj* obj = get_mutable();
+ ShapeView strides() const {
+ const TensorObj* obj = get();
TVM_FFI_ICHECK(obj->strides != nullptr);
- if (!obj->strides_data_.has_value()) {
- obj->strides_data_ = tvm::ffi::Shape(obj->strides, obj->strides +
obj->ndim);
- }
- return *(obj->strides_data_);
+ return ShapeView(obj->strides, obj->ndim);
}
+
+ /*!
+ * \brief Get the data pointer of the Tensor.
+ * \return The data pointer of the Tensor.
+ */
+ void* data_ptr() const { return (*this)->data; }
+
+ /*!
+ * \brief Get the number of dimensions in the Tensor.
+ * \return The number of dimensions in the Tensor.
+ */
+ int32_t ndim() const { return (*this)->ndim; }
+
+ /*!
+ * \brief Get the number of elements in the Tensor.
+ * \return The number of elements in the Tensor.
+ */
+ int64_t numel() const { return this->shape().Product(); }
+
/*!
* \brief Get the data type of the Tensor.
* \return The data type of the Tensor.
@@ -336,10 +307,13 @@ class Tensor : public ObjectRef {
* \tparam ExtraArgs Extra arguments to be passed to Alloc.
*/
template <typename TNDAlloc, typename... ExtraArgs>
- static Tensor FromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType
dtype, DLDevice device,
+ static Tensor FromNDAlloc(TNDAlloc alloc, ffi::ShapeView shape, DLDataType
dtype, DLDevice device,
ExtraArgs&&... extra_args) {
- return Tensor(make_object<details::TensorObjFromNDAlloc<TNDAlloc>>(
- alloc, shape, dtype, device, std::forward<ExtraArgs>(extra_args)...));
+ // inplace alloc shape and strides after data structure (as a result why
multiply 2)
+ size_t num_extra_i64_at_tail = shape.size() * 2;
+ return
Tensor(make_inplace_array_object<details::TensorObjFromNDAlloc<TNDAlloc>,
int64_t>(
+ num_extra_i64_at_tail, alloc, shape, dtype, device,
+ std::forward<ExtraArgs>(extra_args)...));
}
/*!
* \brief Create a Tensor from a DLPackTensorAllocator
@@ -393,7 +367,15 @@ class Tensor : public ObjectRef {
throw ffi::Error(error_context.kind, error_context.message,
TVMFFIBacktrace(__FILE__, __LINE__, __func__, 0));
}
- return
Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(tensor));
+ if (tensor->dl_tensor.strides != nullptr) {
+ return
Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(
+ tensor, /*extra_strides_at_tail=*/false));
+ } else {
+ return Tensor(
+
make_inplace_array_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>,
+ int64_t>(tensor->dl_tensor.ndim, tensor,
+ /*extra_strides_at_tail=*/true));
+ }
}
/*!
* \brief Create a Tensor from a DLPack managed tensor, pre v1.0 API.
@@ -412,7 +394,14 @@ class Tensor : public ObjectRef {
if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) {
TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous.";
}
- return
Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensor>>(tensor));
+ if (tensor->dl_tensor.strides != nullptr) {
+ return Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensor>>(
+ tensor, /*extra_strides_at_tail=*/false));
+ } else {
+ return Tensor(
+
make_inplace_array_object<details::TensorObjFromDLPack<DLManagedTensor>,
int64_t>(
+ tensor->dl_tensor.ndim, tensor, /*extra_strides_at_tail=*/true));
+ }
}
/*!
@@ -434,7 +423,15 @@ class Tensor : public ObjectRef {
if (tensor->flags & DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED) {
TVM_FFI_THROW(RuntimeError) << "Subbyte type padded is not yet
supported";
}
- return
Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(tensor));
+ if (tensor->dl_tensor.strides != nullptr) {
+ return
Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(
+ tensor, /*extra_strides_at_tail=*/false));
+ } else {
+ return Tensor(
+
make_inplace_array_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>,
+ int64_t>(tensor->dl_tensor.ndim, tensor,
+ /*extra_strides_at_tail=*/true));
+ }
}
/*!
diff --git a/pyproject.toml b/pyproject.toml
index 66323e6..f734184 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,7 @@
[project]
name = "apache-tvm-ffi"
-version = "0.1.0b9"
+version = "0.1.0b10"
description = "tvm ffi"
authors = [{ name = "TVM FFI team" }]
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 31cda06..807f9a9 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -17,7 +17,7 @@
"""TVM FFI Python package."""
# version
-__version__ = "0.1.0b9"
+__version__ = "0.1.0b10"
# order matters here so we need to skip isort here
# isort: skip_file
diff --git a/tests/cpp/test_shape.cc b/tests/cpp/test_shape.cc
index 0ccba78..cd67464 100644
--- a/tests/cpp/test_shape.cc
+++ b/tests/cpp/test_shape.cc
@@ -69,4 +69,28 @@ TEST(Shape, AnyConvert) {
EXPECT_EQ(shape2[1], 2);
}
+TEST(Shape, ShapeView) {
+ Shape shape = Shape({1, 2, 3});
+ ShapeView shape_view = shape;
+ EXPECT_EQ(shape_view.size(), 3);
+ EXPECT_EQ(shape_view[0], 1);
+ EXPECT_EQ(shape_view[1], 2);
+ EXPECT_EQ(shape_view[2], 3);
+
+ std::vector<int64_t> data = {4, 5, 6};
+ ShapeView view_from_data(data.data(), data.size());
+ EXPECT_EQ(view_from_data.size(), 3);
+ EXPECT_EQ(view_from_data[0], 4);
+ EXPECT_EQ(view_from_data[1], 5);
+ EXPECT_EQ(view_from_data[2], 6);
+
+ std::initializer_list<int64_t> init = {7, 8, 9};
+ ShapeView view_from_init = init;
+ EXPECT_EQ(view_from_init.size(), 3);
+ EXPECT_EQ(view_from_init[0], 7);
+ EXPECT_EQ(view_from_init[1], 8);
+ EXPECT_EQ(view_from_init[2], 9);
+ EXPECT_EQ(view_from_init.Product(), 7 * 8 * 9);
+}
+
} // namespace
diff --git a/tests/cpp/test_tensor.cc b/tests/cpp/test_tensor.cc
index 7c696a3..bb9158a 100644
--- a/tests/cpp/test_tensor.cc
+++ b/tests/cpp/test_tensor.cc
@@ -50,7 +50,7 @@ int TestDLPackTensorAllocatorError(DLTensor* prototype,
DLManagedTensorVersioned
}
TEST(Tensor, Basic) {
- Tensor nd = Empty(Shape({1, 2, 3}), DLDataType({kDLFloat, 32, 1}),
DLDevice({kDLCPU, 0}));
+ Tensor nd = Empty({1, 2, 3}, DLDataType({kDLFloat, 32, 1}),
DLDevice({kDLCPU, 0}));
Shape shape = nd.shape();
Shape strides = nd.strides();
EXPECT_EQ(shape.size(), 3);
@@ -66,10 +66,12 @@ TEST(Tensor, Basic) {
reinterpret_cast<float*>(nd->data)[i] = static_cast<float>(i);
}
+ EXPECT_EQ(nd.numel(), 6);
+ EXPECT_EQ(nd.ndim(), 3);
+ EXPECT_EQ(nd.data_ptr(), nd->data);
+
Any any0 = nd;
Tensor nd2 = any0.as<Tensor>().value();
- EXPECT_EQ(nd2.shape(), shape);
- EXPECT_EQ(nd2.strides(), strides);
EXPECT_EQ(nd2.dtype(), DLDataType({kDLFloat, 32, 1}));
for (int64_t i = 0; i < shape.Product(); ++i) {
EXPECT_EQ(reinterpret_cast<float*>(nd2->data)[i], i);