This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 53ffe5e [Fix] Restrict stride normalization to 1D tensors on export
(#23)
53ffe5e is described below
commit 53ffe5ecd8b4f36506ee04f90d95bc34e456af78
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Thu Sep 18 12:14:59 2025 -0700
[Fix] Restrict stride normalization to 1D tensors on export (#23)
see [issue](https://github.com/pytorch/pytorch/issues/163274) on this
Co-authored-by: Kathryn-cat <[email protected]>
---
python/tvm_ffi/_optional_torch_c_dlpack.py | 21 ++-------------------
1 file changed, 2 insertions(+), 19 deletions(-)
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index b227c46..2501607 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -210,30 +210,13 @@ template <class T>
T* toDLPackImpl(const Tensor& src) {
auto view = src;
- bool need_normalize_strides = false;
- int64_t expected_stride = 1;
- for (int i = src.dim() - 1; i >= 0; i--) {
- // detect if we do not meet continuous pattern
- // and the size is 1, so there is opportunity to normalize
- if (src.stride(i) != expected_stride && src.size(i) == 1) {
- need_normalize_strides = true;
- break;
- }
- expected_stride *= src.size(i);
- }
-
+ bool need_normalize_strides = src.ndim() == 1 && src.size(0) == 1 &&
src.stride(0) != 1;
// less common case, try normalizing the strides
if (need_normalize_strides) {
// create a new tensor with possibly normalized strides
// gh-83069
auto shape = src.sizes();
- auto strides = src.strides().vec();
- for (int i = 0; i < src.dim(); i++) {
- if (shape[i] < 2) {
- strides[i] = 1;
- }
- }
- view = src.as_strided(shape, strides, src.storage_offset());
+ view = src.as_strided(shape, {1}, src.storage_offset());
}
ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);