This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c1707d277 [FFI] Construct NDArray.strides by default (#18272)
5c1707d277 is described below

commit 5c1707d2779fa22070689824e826bb1a16a0841d
Author: Yaxing Cai <[email protected]>
AuthorDate: Sat Sep 6 04:39:33 2025 -0700

    [FFI] Construct NDArray.strides by default (#18272)
    
    This PR updates NDArray.strides to construct strides by default
---
 ffi/include/tvm/ffi/container/ndarray.h        | 12 ++++++++----
 ffi/include/tvm/ffi/container/shape.h          | 11 +++++++++++
 ffi/tests/cpp/test_ndarray.cc                  |  6 ++++--
 include/tvm/runtime/ndarray.h                  |  2 +-
 src/relax/transform/fold_constant.cc           |  2 +-
 src/runtime/contrib/coreml/coreml_runtime.mm   |  2 +-
 src/runtime/contrib/dnnl/dnnl_json_runtime.cc  |  2 +-
 src/runtime/contrib/mps/conv.mm                |  6 +++---
 src/runtime/contrib/mps/gemm.mm                |  6 +++---
 src/runtime/contrib/random/mt_random_engine.cc |  4 ++--
 src/runtime/contrib/random/random.cc           |  2 +-
 src/runtime/contrib/rocblas/rocblas.cc         |  6 +++---
 src/runtime/contrib/tflite/tflite_runtime.cc   |  2 +-
 src/runtime/minrpc/rpc_reference.h             |  4 +++-
 src/runtime/vm/rnn_state.cc                    |  4 +++-
 15 files changed, 46 insertions(+), 25 deletions(-)

diff --git a/ffi/include/tvm/ffi/container/ndarray.h 
b/ffi/include/tvm/ffi/container/ndarray.h
index 6acdbc3a26..f65e386c06 100644
--- a/ffi/include/tvm/ffi/container/ndarray.h
+++ b/ffi/include/tvm/ffi/container/ndarray.h
@@ -151,6 +151,7 @@ class NDArrayObj : public Object, public DLTensor {
  protected:
   // backs up the shape of the NDArray
   Optional<Shape> shape_data_;
+  Optional<Shape> stride_data_;
 
   static void DLManagedTensorDeleter(DLManagedTensor* tensor) {
     NDArrayObj* obj = static_cast<NDArrayObj*>(tensor->manager_ctx);
@@ -184,9 +185,11 @@ class NDArrayObjFromNDAlloc : public NDArrayObj {
     this->ndim = static_cast<int>(shape.size());
     this->dtype = dtype;
     this->shape = const_cast<int64_t*>(shape.data());
-    this->strides = nullptr;
+    Shape strides = Shape(details::MakeStridesFromShape(this->ndim, 
this->shape));
+    this->strides = const_cast<int64_t*>(strides.data());
     this->byte_offset = 0;
     this->shape_data_ = std::move(shape);
+    this->stride_data_ = std::move(strides);
     alloc_.AllocData(static_cast<DLTensor*>(this), 
std::forward<ExtraArgs>(extra_args)...);
   }
 
@@ -202,9 +205,10 @@ class NDArrayObjFromDLPack : public NDArrayObj {
  public:
   explicit NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) : 
tensor_(tensor) {
     *static_cast<DLTensor*>(this) = tensor_->dl_tensor;
-    // set strides to nullptr if the tensor is contiguous.
-    if (IsContiguous(tensor->dl_tensor)) {
-      this->strides = nullptr;
+    if (tensor_->dl_tensor.strides == nullptr) {
+      Shape strides = Shape(details::MakeStridesFromShape(ndim, shape));
+      this->strides = const_cast<int64_t*>(strides.data());
+      this->stride_data_ = std::move(strides);
     }
   }
 
diff --git a/ffi/include/tvm/ffi/container/shape.h 
b/ffi/include/tvm/ffi/container/shape.h
index 2fccc028a5..6360fcd1e3 100644
--- a/ffi/include/tvm/ffi/container/shape.h
+++ b/ffi/include/tvm/ffi/container/shape.h
@@ -91,6 +91,17 @@ TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeInplaceShape(IterType 
begin, IterType end
   return p;
 }
 
+TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(int64_t ndim, int64_t* 
shape) {
+  int64_t* strides_data;
+  ObjectPtr<ShapeObj> strides = details::MakeEmptyShape(ndim, &strides_data);
+  int64_t stride = 1;
+  for (int i = ndim - 1; i >= 0; --i) {
+    strides_data[i] = stride;
+    stride *= shape[i];
+  }
+  return strides;
+}
+
 }  // namespace details
 
 /*!
diff --git a/ffi/tests/cpp/test_ndarray.cc b/ffi/tests/cpp/test_ndarray.cc
index 3d7b00cd33..0196bfc4fb 100644
--- a/ffi/tests/cpp/test_ndarray.cc
+++ b/ffi/tests/cpp/test_ndarray.cc
@@ -69,7 +69,9 @@ TEST(NDArray, DLPack) {
   EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU);
   EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0);
   EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0);
-  EXPECT_EQ(dlpack->dl_tensor.strides, nullptr);
+  EXPECT_EQ(dlpack->dl_tensor.strides[0], 6);
+  EXPECT_EQ(dlpack->dl_tensor.strides[1], 3);
+  EXPECT_EQ(dlpack->dl_tensor.strides[2], 1);
   EXPECT_EQ(nd.use_count(), 2);
   {
     NDArray nd2 = NDArray::FromDLPack(dlpack);
@@ -96,7 +98,7 @@ TEST(NDArray, DLPackVersioned) {
   EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU);
   EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0);
   EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0);
-  EXPECT_EQ(dlpack->dl_tensor.strides, nullptr);
+  EXPECT_EQ(dlpack->dl_tensor.strides[0], 1);
 
   EXPECT_EQ(nd.use_count(), 2);
   {
diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h
index 6eebe49ff1..9a295e491e 100644
--- a/include/tvm/runtime/ndarray.h
+++ b/include/tvm/runtime/ndarray.h
@@ -239,7 +239,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const 
DLTensor* tensor) {
   strm->Write(data_byte_size);
 
   if (DMLC_IO_NO_ENDIAN_SWAP && tensor->device.device_type == kDLCPU &&
-      tensor->strides == nullptr && tensor->byte_offset == 0) {
+      ffi::IsContiguous(*tensor) && tensor->byte_offset == 0) {
     // quick path
     strm->Write(tensor->data, data_byte_size);
   } else {
diff --git a/src/relax/transform/fold_constant.cc 
b/src/relax/transform/fold_constant.cc
index c1aee73cc2..33e077d726 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -270,7 +270,7 @@ class ConstantFolder : public ExprMutator {
           Constant constant = Downcast<Constant>(arg);
           runtime::NDArray ndarray = constant->data;
           ICHECK_EQ(ndarray->device.device_type, kDLCPU);
-          ICHECK(ndarray->strides == nullptr);
+          ICHECK(ffi::IsContiguous(*ndarray.get()));
           ICHECK_EQ(ndarray->byte_offset, 0);
           ICHECK_EQ(ndarray->ndim, 1);
           const int64_t* data = static_cast<const int64_t*>(ndarray->data);
diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm 
b/src/runtime/contrib/coreml/coreml_runtime.mm
index 8e0b2542b4..fb5faa8621 100644
--- a/src/runtime/contrib/coreml/coreml_runtime.mm
+++ b/src/runtime/contrib/coreml/coreml_runtime.mm
@@ -60,7 +60,7 @@ void CoreMLModel::SetInput(const std::string& key, DLTensor* 
data_in) {
 
   MLMultiArray* dest = [[MLMultiArray alloc] initWithShape:shape 
dataType:dataType error:nil];
 
-  ICHECK(data_in->strides == NULL);
+  ICHECK(ffi::IsContiguous(*data_in));
   memcpy(dest.dataPointer, data_in->data, size);
 
   NSString* nsKey = [NSString stringWithUTF8String:key.c_str()];
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc 
b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index 686a8048c7..59b162e765 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -821,7 +821,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     TensorRequisite res;
     if (const_dl_tensor) {
       ICHECK(const_dl_tensor->data);
-      ICHECK(const_dl_tensor->strides == nullptr);
+      ICHECK(ffi::IsContiguous(*const_dl_tensor));
       auto mem = dnnl::memory(desc, engine_, const_dl_tensor->data);
       res = TensorRequisite::AsIs(mem, eid);
     } else {
diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm
index dfc98388d3..2bf38796fd 100644
--- a/src/runtime/contrib/mps/conv.mm
+++ b/src/runtime/contrib/mps/conv.mm
@@ -91,9 +91,9 @@ TVM_FFI_STATIC_INIT_BLOCK({
         ICHECK_EQ(data->ndim, 4);
         ICHECK_EQ(weight->ndim, 4);
         ICHECK_EQ(output->ndim, 4);
-        ICHECK(output->strides == nullptr);
-        ICHECK(weight->strides == nullptr);
-        ICHECK(data->strides == nullptr);
+        ICHECK(ffi::IsContiguous(*output));
+        ICHECK(ffi::IsContiguous(*weight));
+        ICHECK(ffi::IsContiguous(*data));
 
         ICHECK_EQ(data->shape[0], 1);
         ICHECK_EQ(output->shape[0], 1);
diff --git a/src/runtime/contrib/mps/gemm.mm b/src/runtime/contrib/mps/gemm.mm
index 9f5270f38f..7f386172f6 100644
--- a/src/runtime/contrib/mps/gemm.mm
+++ b/src/runtime/contrib/mps/gemm.mm
@@ -37,9 +37,9 @@ TVM_FFI_STATIC_INIT_BLOCK({
     ICHECK_EQ(A->ndim, 2);
     ICHECK_EQ(B->ndim, 2);
     ICHECK_EQ(C->ndim, 2);
-    ICHECK(C->strides == nullptr);
-    ICHECK(B->strides == nullptr);
-    ICHECK(A->strides == nullptr);
+    ICHECK(ffi::IsContiguous(*C));
+    ICHECK(ffi::IsContiguous(*B));
+    ICHECK(ffi::IsContiguous(*A));
     ICHECK(TypeMatch(A->dtype, kDLFloat, 32));
     ICHECK(TypeMatch(B->dtype, kDLFloat, 32));
     ICHECK(TypeMatch(C->dtype, kDLFloat, 32));
diff --git a/src/runtime/contrib/random/mt_random_engine.cc 
b/src/runtime/contrib/random/mt_random_engine.cc
index 04b53d74b4..3ab0309630 100644
--- a/src/runtime/contrib/random/mt_random_engine.cc
+++ b/src/runtime/contrib/random/mt_random_engine.cc
@@ -75,7 +75,7 @@ class RandomEngine {
    */
   void SampleUniform(DLTensor* data, float low, float high) {
     ICHECK_GT(high, low) << "high must be bigger than low";
-    ICHECK(data->strides == nullptr);
+    ICHECK(ffi::IsContiguous(*data));
 
     DLDataType dtype = data->dtype;
     int64_t size = 1;
@@ -99,7 +99,7 @@ class RandomEngine {
    */
   void SampleNormal(DLTensor* data, float loc, float scale) {
     ICHECK_GT(scale, 0) << "standard deviation must be positive";
-    ICHECK(data->strides == nullptr);
+    ICHECK(ffi::IsContiguous(*data));
 
     DLDataType dtype = data->dtype;
     int64_t size = 1;
diff --git a/src/runtime/contrib/random/random.cc 
b/src/runtime/contrib/random/random.cc
index 580ed1073a..b7ca1f8fd7 100644
--- a/src/runtime/contrib/random/random.cc
+++ b/src/runtime/contrib/random/random.cc
@@ -80,7 +80,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
                     int64_t high = args[1].cast<int64_t>();
                     auto out = args[2].cast<DLTensor*>();
                     ICHECK_GT(high, low) << "high must be bigger than low";
-                    ICHECK(out->strides == nullptr);
+                    ICHECK(ffi::IsContiguous(*out));
 
                     DLDataType dtype = out->dtype;
                     int64_t size = 1;
diff --git a/src/runtime/contrib/rocblas/rocblas.cc 
b/src/runtime/contrib/rocblas/rocblas.cc
index 8fdce7e43b..be3c49e121 100644
--- a/src/runtime/contrib/rocblas/rocblas.cc
+++ b/src/runtime/contrib/rocblas/rocblas.cc
@@ -81,9 +81,9 @@ TVM_FFI_STATIC_INIT_BLOCK({
             ICHECK_EQ(A->ndim, 2);
             ICHECK_EQ(B->ndim, 2);
             ICHECK_EQ(C->ndim, 2);
-            ICHECK(C->strides == nullptr);
-            ICHECK(B->strides == nullptr);
-            ICHECK(A->strides == nullptr);
+            ICHECK(ffi::IsContiguous(*C));
+            ICHECK(ffi::IsContiguous(*B));
+            ICHECK(ffi::IsContiguous(*A));
             ICHECK(TypeMatch(A->dtype, kDLFloat, 32));
             ICHECK(TypeMatch(B->dtype, kDLFloat, 32));
             ICHECK(TypeMatch(C->dtype, kDLFloat, 32));
diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc 
b/src/runtime/contrib/tflite/tflite_runtime.cc
index c35af35eae..d65f2ad65b 100644
--- a/src/runtime/contrib/tflite/tflite_runtime.cc
+++ b/src/runtime/contrib/tflite/tflite_runtime.cc
@@ -118,7 +118,7 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) {
   TVM_DTYPE_DISPATCH(dtype, DType, {
     DType* dest = interpreter_->typed_input_tensor<DType>(index);
     DType* src = static_cast<DType*>(data_in->data);
-    ICHECK(data_in->strides == NULL);
+    ICHECK(ffi::IsContiguous(*data_in));
     int64_t size = 1;
     for (int64_t i = 0; i < data_in->ndim; ++i) {
       size *= data_in->shape[i];
diff --git a/src/runtime/minrpc/rpc_reference.h 
b/src/runtime/minrpc/rpc_reference.h
index b5f1e6995f..dfca27c8c3 100644
--- a/src/runtime/minrpc/rpc_reference.h
+++ b/src/runtime/minrpc/rpc_reference.h
@@ -24,6 +24,8 @@
 #ifndef TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_
 #define TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_
 
+#include <tvm/ffi/container/ndarray.h>
+
 namespace tvm {
 namespace ffi {
 // Forward declare TVM Object to use `Object*` in RPC protocol.
@@ -255,7 +257,7 @@ struct RPCReference {
     channel->Write(arr->ndim);
     channel->Write(arr->dtype);
     channel->WriteArray(arr->shape, arr->ndim);
-    if (arr->strides != nullptr) {
+    if (!ffi::IsContiguous(*arr)) {
       channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldStride);
     }
     channel->Write(arr->byte_offset);
diff --git a/src/runtime/vm/rnn_state.cc b/src/runtime/vm/rnn_state.cc
index 8963df0652..085860348e 100644
--- a/src/runtime/vm/rnn_state.cc
+++ b/src/runtime/vm/rnn_state.cc
@@ -396,6 +396,7 @@ class RNNStateImpObj : public RNNStateObj {
     _state.byte_offset = elem_offset * state->dtype.bits / 8;
     _state.ndim = state->ndim - 2;
     _state.shape = const_cast<int64_t*>(_state.shape + 2);
+    _state.strides = const_cast<int64_t*>(_state.strides + 2);
     return _state;
   }
 
@@ -411,6 +412,7 @@ class RNNStateImpObj : public RNNStateObj {
     _state.byte_offset = elem_offset * state->dtype.bits / 8;
     _state.ndim = state->ndim - 1;
     _state.shape = const_cast<int64_t*>(_state.shape + 1);
+    _state.strides = const_cast<int64_t*>(_state.strides + 1);
     return _state;
   }
 
@@ -428,7 +430,7 @@ class RNNStateImpObj : public RNNStateObj {
       copy_src.ndim = 1;
       copy_src.dtype = array->dtype;
       copy_src.shape = array->shape;
-      copy_src.strides = nullptr;
+      copy_src.strides = array->strides;
       copy_src.byte_offset = 0;
       NDArray::CopyFromTo(&copy_src, &copy_dst);
     };

Reply via email to