This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 53ffe5e  [Fix] Restrict stride normalization to 1D tensors on export 
(#23)
53ffe5e is described below

commit 53ffe5ecd8b4f36506ee04f90d95bc34e456af78
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Thu Sep 18 12:14:59 2025 -0700

    [Fix] Restrict stride normalization to 1D tensors on export (#23)
    
    see [issue](https://github.com/pytorch/pytorch/issues/163274) on this
    
    Co-authored-by: Kathryn-cat <[email protected]>
---
 python/tvm_ffi/_optional_torch_c_dlpack.py | 21 ++-------------------
 1 file changed, 2 insertions(+), 19 deletions(-)

diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index b227c46..2501607 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -210,30 +210,13 @@ template <class T>
 T* toDLPackImpl(const Tensor& src) {
   auto view = src;
 
-  bool need_normalize_strides = false;
-  int64_t expected_stride = 1;
-  for (int i = src.dim() - 1; i >= 0; i--) {
-    // detect if we do not meet continuous pattern
-    // and the size is 1, so there is opportunity to normalize
-    if (src.stride(i) != expected_stride && src.size(i) == 1) {
-      need_normalize_strides = true;
-      break;
-    }
-    expected_stride *= src.size(i);
-  }
-
+  bool need_normalize_strides = src.ndim() == 1 && src.size(0) == 1 && 
src.stride(0) != 1;
   // less common case, try normalizing the strides
   if (need_normalize_strides) {
     // create a new tensor with possibly normalized strides
     // gh-83069
     auto shape = src.sizes();
-    auto strides = src.strides().vec();
-    for (int i = 0; i < src.dim(); i++) {
-      if (shape[i] < 2) {
-        strides[i] = 1;
-      }
-    }
-    view = src.as_strided(shape, strides, src.storage_offset());
+    view = src.as_strided(shape, {1}, src.storage_offset());
   }
 
   ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);

Reply via email to