This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5c1707d277 [FFI] Construct NDArray.strides by default (#18272)
5c1707d277 is described below
commit 5c1707d2779fa22070689824e826bb1a16a0841d
Author: Yaxing Cai <[email protected]>
AuthorDate: Sat Sep 6 04:39:33 2025 -0700
[FFI] Construct NDArray.strides by default (#18272)
This PR updates NDArray.strides to construct strides by default
---
ffi/include/tvm/ffi/container/ndarray.h | 12 ++++++++----
ffi/include/tvm/ffi/container/shape.h | 11 +++++++++++
ffi/tests/cpp/test_ndarray.cc | 6 ++++--
include/tvm/runtime/ndarray.h | 2 +-
src/relax/transform/fold_constant.cc | 2 +-
src/runtime/contrib/coreml/coreml_runtime.mm | 2 +-
src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 2 +-
src/runtime/contrib/mps/conv.mm | 6 +++---
src/runtime/contrib/mps/gemm.mm | 6 +++---
src/runtime/contrib/random/mt_random_engine.cc | 4 ++--
src/runtime/contrib/random/random.cc | 2 +-
src/runtime/contrib/rocblas/rocblas.cc | 6 +++---
src/runtime/contrib/tflite/tflite_runtime.cc | 2 +-
src/runtime/minrpc/rpc_reference.h | 4 +++-
src/runtime/vm/rnn_state.cc | 4 +++-
15 files changed, 46 insertions(+), 25 deletions(-)
diff --git a/ffi/include/tvm/ffi/container/ndarray.h
b/ffi/include/tvm/ffi/container/ndarray.h
index 6acdbc3a26..f65e386c06 100644
--- a/ffi/include/tvm/ffi/container/ndarray.h
+++ b/ffi/include/tvm/ffi/container/ndarray.h
@@ -151,6 +151,7 @@ class NDArrayObj : public Object, public DLTensor {
protected:
// backs up the shape of the NDArray
Optional<Shape> shape_data_;
+ Optional<Shape> stride_data_;
static void DLManagedTensorDeleter(DLManagedTensor* tensor) {
NDArrayObj* obj = static_cast<NDArrayObj*>(tensor->manager_ctx);
@@ -184,9 +185,11 @@ class NDArrayObjFromNDAlloc : public NDArrayObj {
this->ndim = static_cast<int>(shape.size());
this->dtype = dtype;
this->shape = const_cast<int64_t*>(shape.data());
- this->strides = nullptr;
+ Shape strides = Shape(details::MakeStridesFromShape(this->ndim,
this->shape));
+ this->strides = const_cast<int64_t*>(strides.data());
this->byte_offset = 0;
this->shape_data_ = std::move(shape);
+ this->stride_data_ = std::move(strides);
alloc_.AllocData(static_cast<DLTensor*>(this),
std::forward<ExtraArgs>(extra_args)...);
}
@@ -202,9 +205,10 @@ class NDArrayObjFromDLPack : public NDArrayObj {
public:
explicit NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) :
tensor_(tensor) {
*static_cast<DLTensor*>(this) = tensor_->dl_tensor;
- // set strides to nullptr if the tensor is contiguous.
- if (IsContiguous(tensor->dl_tensor)) {
- this->strides = nullptr;
+ if (tensor_->dl_tensor.strides == nullptr) {
+ Shape strides = Shape(details::MakeStridesFromShape(ndim, shape));
+ this->strides = const_cast<int64_t*>(strides.data());
+ this->stride_data_ = std::move(strides);
}
}
diff --git a/ffi/include/tvm/ffi/container/shape.h
b/ffi/include/tvm/ffi/container/shape.h
index 2fccc028a5..6360fcd1e3 100644
--- a/ffi/include/tvm/ffi/container/shape.h
+++ b/ffi/include/tvm/ffi/container/shape.h
@@ -91,6 +91,17 @@ TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeInplaceShape(IterType
begin, IterType end
return p;
}
+TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(int64_t ndim, int64_t*
shape) {
+ int64_t* strides_data;
+ ObjectPtr<ShapeObj> strides = details::MakeEmptyShape(ndim, &strides_data);
+ int64_t stride = 1;
+ for (int i = ndim - 1; i >= 0; --i) {
+ strides_data[i] = stride;
+ stride *= shape[i];
+ }
+ return strides;
+}
+
} // namespace details
/*!
diff --git a/ffi/tests/cpp/test_ndarray.cc b/ffi/tests/cpp/test_ndarray.cc
index 3d7b00cd33..0196bfc4fb 100644
--- a/ffi/tests/cpp/test_ndarray.cc
+++ b/ffi/tests/cpp/test_ndarray.cc
@@ -69,7 +69,9 @@ TEST(NDArray, DLPack) {
EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU);
EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0);
EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0);
- EXPECT_EQ(dlpack->dl_tensor.strides, nullptr);
+ EXPECT_EQ(dlpack->dl_tensor.strides[0], 6);
+ EXPECT_EQ(dlpack->dl_tensor.strides[1], 3);
+ EXPECT_EQ(dlpack->dl_tensor.strides[2], 1);
EXPECT_EQ(nd.use_count(), 2);
{
NDArray nd2 = NDArray::FromDLPack(dlpack);
@@ -96,7 +98,7 @@ TEST(NDArray, DLPackVersioned) {
EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU);
EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0);
EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0);
- EXPECT_EQ(dlpack->dl_tensor.strides, nullptr);
+ EXPECT_EQ(dlpack->dl_tensor.strides[0], 1);
EXPECT_EQ(nd.use_count(), 2);
{
diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h
index 6eebe49ff1..9a295e491e 100644
--- a/include/tvm/runtime/ndarray.h
+++ b/include/tvm/runtime/ndarray.h
@@ -239,7 +239,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const
DLTensor* tensor) {
strm->Write(data_byte_size);
if (DMLC_IO_NO_ENDIAN_SWAP && tensor->device.device_type == kDLCPU &&
- tensor->strides == nullptr && tensor->byte_offset == 0) {
+ ffi::IsContiguous(*tensor) && tensor->byte_offset == 0) {
// quick path
strm->Write(tensor->data, data_byte_size);
} else {
diff --git a/src/relax/transform/fold_constant.cc
b/src/relax/transform/fold_constant.cc
index c1aee73cc2..33e077d726 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -270,7 +270,7 @@ class ConstantFolder : public ExprMutator {
Constant constant = Downcast<Constant>(arg);
runtime::NDArray ndarray = constant->data;
ICHECK_EQ(ndarray->device.device_type, kDLCPU);
- ICHECK(ndarray->strides == nullptr);
+ ICHECK(ffi::IsContiguous(*ndarray.get()));
ICHECK_EQ(ndarray->byte_offset, 0);
ICHECK_EQ(ndarray->ndim, 1);
const int64_t* data = static_cast<const int64_t*>(ndarray->data);
diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm
b/src/runtime/contrib/coreml/coreml_runtime.mm
index 8e0b2542b4..fb5faa8621 100644
--- a/src/runtime/contrib/coreml/coreml_runtime.mm
+++ b/src/runtime/contrib/coreml/coreml_runtime.mm
@@ -60,7 +60,7 @@ void CoreMLModel::SetInput(const std::string& key, DLTensor*
data_in) {
MLMultiArray* dest = [[MLMultiArray alloc] initWithShape:shape
dataType:dataType error:nil];
- ICHECK(data_in->strides == NULL);
+ ICHECK(ffi::IsContiguous(*data_in));
memcpy(dest.dataPointer, data_in->data, size);
NSString* nsKey = [NSString stringWithUTF8String:key.c_str()];
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index 686a8048c7..59b162e765 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -821,7 +821,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
TensorRequisite res;
if (const_dl_tensor) {
ICHECK(const_dl_tensor->data);
- ICHECK(const_dl_tensor->strides == nullptr);
+ ICHECK(ffi::IsContiguous(*const_dl_tensor));
auto mem = dnnl::memory(desc, engine_, const_dl_tensor->data);
res = TensorRequisite::AsIs(mem, eid);
} else {
diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm
index dfc98388d3..2bf38796fd 100644
--- a/src/runtime/contrib/mps/conv.mm
+++ b/src/runtime/contrib/mps/conv.mm
@@ -91,9 +91,9 @@ TVM_FFI_STATIC_INIT_BLOCK({
ICHECK_EQ(data->ndim, 4);
ICHECK_EQ(weight->ndim, 4);
ICHECK_EQ(output->ndim, 4);
- ICHECK(output->strides == nullptr);
- ICHECK(weight->strides == nullptr);
- ICHECK(data->strides == nullptr);
+ ICHECK(ffi::IsContiguous(*output));
+ ICHECK(ffi::IsContiguous(*weight));
+ ICHECK(ffi::IsContiguous(*data));
ICHECK_EQ(data->shape[0], 1);
ICHECK_EQ(output->shape[0], 1);
diff --git a/src/runtime/contrib/mps/gemm.mm b/src/runtime/contrib/mps/gemm.mm
index 9f5270f38f..7f386172f6 100644
--- a/src/runtime/contrib/mps/gemm.mm
+++ b/src/runtime/contrib/mps/gemm.mm
@@ -37,9 +37,9 @@ TVM_FFI_STATIC_INIT_BLOCK({
ICHECK_EQ(A->ndim, 2);
ICHECK_EQ(B->ndim, 2);
ICHECK_EQ(C->ndim, 2);
- ICHECK(C->strides == nullptr);
- ICHECK(B->strides == nullptr);
- ICHECK(A->strides == nullptr);
+ ICHECK(ffi::IsContiguous(*C));
+ ICHECK(ffi::IsContiguous(*B));
+ ICHECK(ffi::IsContiguous(*A));
ICHECK(TypeMatch(A->dtype, kDLFloat, 32));
ICHECK(TypeMatch(B->dtype, kDLFloat, 32));
ICHECK(TypeMatch(C->dtype, kDLFloat, 32));
diff --git a/src/runtime/contrib/random/mt_random_engine.cc
b/src/runtime/contrib/random/mt_random_engine.cc
index 04b53d74b4..3ab0309630 100644
--- a/src/runtime/contrib/random/mt_random_engine.cc
+++ b/src/runtime/contrib/random/mt_random_engine.cc
@@ -75,7 +75,7 @@ class RandomEngine {
*/
void SampleUniform(DLTensor* data, float low, float high) {
ICHECK_GT(high, low) << "high must be bigger than low";
- ICHECK(data->strides == nullptr);
+ ICHECK(ffi::IsContiguous(*data));
DLDataType dtype = data->dtype;
int64_t size = 1;
@@ -99,7 +99,7 @@ class RandomEngine {
*/
void SampleNormal(DLTensor* data, float loc, float scale) {
ICHECK_GT(scale, 0) << "standard deviation must be positive";
- ICHECK(data->strides == nullptr);
+ ICHECK(ffi::IsContiguous(*data));
DLDataType dtype = data->dtype;
int64_t size = 1;
diff --git a/src/runtime/contrib/random/random.cc
b/src/runtime/contrib/random/random.cc
index 580ed1073a..b7ca1f8fd7 100644
--- a/src/runtime/contrib/random/random.cc
+++ b/src/runtime/contrib/random/random.cc
@@ -80,7 +80,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
int64_t high = args[1].cast<int64_t>();
auto out = args[2].cast<DLTensor*>();
ICHECK_GT(high, low) << "high must be bigger than low";
- ICHECK(out->strides == nullptr);
+ ICHECK(ffi::IsContiguous(*out));
DLDataType dtype = out->dtype;
int64_t size = 1;
diff --git a/src/runtime/contrib/rocblas/rocblas.cc
b/src/runtime/contrib/rocblas/rocblas.cc
index 8fdce7e43b..be3c49e121 100644
--- a/src/runtime/contrib/rocblas/rocblas.cc
+++ b/src/runtime/contrib/rocblas/rocblas.cc
@@ -81,9 +81,9 @@ TVM_FFI_STATIC_INIT_BLOCK({
ICHECK_EQ(A->ndim, 2);
ICHECK_EQ(B->ndim, 2);
ICHECK_EQ(C->ndim, 2);
- ICHECK(C->strides == nullptr);
- ICHECK(B->strides == nullptr);
- ICHECK(A->strides == nullptr);
+ ICHECK(ffi::IsContiguous(*C));
+ ICHECK(ffi::IsContiguous(*B));
+ ICHECK(ffi::IsContiguous(*A));
ICHECK(TypeMatch(A->dtype, kDLFloat, 32));
ICHECK(TypeMatch(B->dtype, kDLFloat, 32));
ICHECK(TypeMatch(C->dtype, kDLFloat, 32));
diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc
b/src/runtime/contrib/tflite/tflite_runtime.cc
index c35af35eae..d65f2ad65b 100644
--- a/src/runtime/contrib/tflite/tflite_runtime.cc
+++ b/src/runtime/contrib/tflite/tflite_runtime.cc
@@ -118,7 +118,7 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) {
TVM_DTYPE_DISPATCH(dtype, DType, {
DType* dest = interpreter_->typed_input_tensor<DType>(index);
DType* src = static_cast<DType*>(data_in->data);
- ICHECK(data_in->strides == NULL);
+ ICHECK(ffi::IsContiguous(*data_in));
int64_t size = 1;
for (int64_t i = 0; i < data_in->ndim; ++i) {
size *= data_in->shape[i];
diff --git a/src/runtime/minrpc/rpc_reference.h
b/src/runtime/minrpc/rpc_reference.h
index b5f1e6995f..dfca27c8c3 100644
--- a/src/runtime/minrpc/rpc_reference.h
+++ b/src/runtime/minrpc/rpc_reference.h
@@ -24,6 +24,8 @@
#ifndef TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_
#define TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_
+#include <tvm/ffi/container/ndarray.h>
+
namespace tvm {
namespace ffi {
// Forward declare TVM Object to use `Object*` in RPC protocol.
@@ -255,7 +257,7 @@ struct RPCReference {
channel->Write(arr->ndim);
channel->Write(arr->dtype);
channel->WriteArray(arr->shape, arr->ndim);
- if (arr->strides != nullptr) {
+ if (!ffi::IsContiguous(*arr)) {
channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldStride);
}
channel->Write(arr->byte_offset);
diff --git a/src/runtime/vm/rnn_state.cc b/src/runtime/vm/rnn_state.cc
index 8963df0652..085860348e 100644
--- a/src/runtime/vm/rnn_state.cc
+++ b/src/runtime/vm/rnn_state.cc
@@ -396,6 +396,7 @@ class RNNStateImpObj : public RNNStateObj {
_state.byte_offset = elem_offset * state->dtype.bits / 8;
_state.ndim = state->ndim - 2;
_state.shape = const_cast<int64_t*>(_state.shape + 2);
+ _state.strides = const_cast<int64_t*>(_state.strides + 2);
return _state;
}
@@ -411,6 +412,7 @@ class RNNStateImpObj : public RNNStateObj {
_state.byte_offset = elem_offset * state->dtype.bits / 8;
_state.ndim = state->ndim - 1;
_state.shape = const_cast<int64_t*>(_state.shape + 1);
+ _state.strides = const_cast<int64_t*>(_state.strides + 1);
return _state;
}
@@ -428,7 +430,7 @@ class RNNStateImpObj : public RNNStateObj {
copy_src.ndim = 1;
copy_src.dtype = array->dtype;
copy_src.shape = array->shape;
- copy_src.strides = nullptr;
+ copy_src.strides = array->strides;
copy_src.byte_offset = 0;
NDArray::CopyFromTo(©_src, ©_dst);
};