This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 69336ac4a5 [TOPI] Fix strided_slice_with_axes to handle negative axis 
values (#18917)
69336ac4a5 is described below

commit 69336ac4a5f9e244deaaad94a031417c64114b2c
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Mar 20 15:17:36 2026 +0900

    [TOPI] Fix strided_slice_with_axes to handle negative axis values (#18917)
    
    Negative axis values (e.g., `axes=[-1]`) in `strided_slice_with_axes`
    were used directly as array indices without normalization, causing an
    `IndexError` during `LegalizeOps`.
    
    This PR normalizes negative axes to positive equivalents before passing
    them to `StridedSliceCanonicalizeBegin`, `StridedSliceOutputShape`, and
    the compute lambda.
---
 include/tvm/topi/transform.h                       | 29 ++++++++++++++------
 ..._transform_legalize_ops_index_linear_algebra.py | 31 ++++++++++++++++++++++
 2 files changed, 52 insertions(+), 8 deletions(-)

diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 93938c601d..5c3ec5986c 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -904,28 +904,41 @@ inline Tensor strided_slice_with_axes(const Tensor& x, 
const ffi::Array<Integer>
                                       std::string slice_mode = "end",
                                       std::string name = 
"T_strided_slice_with_axes",
                                       std::string tag = kInjective) {
-  const size_t src_tensor_dim = x->shape.size();
-  TVM_FFI_ICHECK(axes.size() <= src_tensor_dim);
+  const int64_t src_tensor_dim = static_cast<int64_t>(x->shape.size());
+  TVM_FFI_ICHECK(static_cast<int64_t>(axes.size()) <= src_tensor_dim);
   TVM_FFI_ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
                  axes.size() == strides.size());
 
+  // Normalize negative axes
+  ffi::Array<Integer> normalized_axes;
+  for (size_t i = 0; i < axes.size(); ++i) {
+    int64_t axis = axes[i].IntValue();
+    if (axis < 0) {
+      axis += src_tensor_dim;
+    }
+    TVM_FFI_ICHECK(axis >= 0 && axis < src_tensor_dim)
+        << "Axis " << axes[i].IntValue() << " is out of bounds for tensor with 
" << src_tensor_dim
+        << " dimensions";
+    normalized_axes.push_back(Integer(axis));
+  }
+
   std::vector<int64_t> begin_vec, end_vec, strides_vec;
   std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, 
strides, slice_mode);
 
-  auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, 
strides_vec, axes,
+  auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, 
strides_vec, normalized_axes,
                                                   begin[0]->dtype, slice_mode);
-  auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, 
strides_vec, axes,
-                                           slice_mode, begin_expr);
+  auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, 
strides_vec,
+                                           normalized_axes, slice_mode, 
begin_expr);
 
   return te::compute(
       out_shape,
       [&](const ffi::Array<tirx::Var>& indices) {
         ffi::Array<PrimExpr> real_indices;
         for (size_t i = 0; i < out_shape.size(); ++i) 
real_indices.push_back(indices[i]);
-        for (size_t i = 0; i < axes.size(); ++i) {
+        for (size_t i = 0; i < normalized_axes.size(); ++i) {
           auto stride = make_const(strides[i].dtype(), strides_vec[i]);
-          PrimExpr ind = indices[axes[i].IntValue()] * stride + begin_expr[i];
-          real_indices.Set(axes[i].IntValue(), ind);
+          PrimExpr ind = indices[normalized_axes[i].IntValue()] * stride + 
begin_expr[i];
+          real_indices.Set(normalized_axes[i].IntValue(), ind);
         }
         return x(real_indices);
       },
diff --git 
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py 
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index 1be7a39781..b8dbe1934b 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -259,6 +259,37 @@ def test_strided_slice_no_strides():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_strided_slice_negative_axes():
+    # fmt: off
+    @tvm.script.ir_module
+    class StridedSlice:
+        @R.function
+        def main(x: R.Tensor((8, 9, 10), "float32")) -> R.Tensor((8, 9, 3), 
"float32"):
+            gv: R.Tensor((8, 9, 3), "float32") = R.strided_slice(x, axes=[-1], 
begin=[2], end=[5])
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((8, 9, 10), dtype="float32")) -> R.Tensor((8, 9, 
3), dtype="float32"):
+            gv = R.call_tir(Expected.strided_slice, (x,), 
out_sinfo=R.Tensor((8, 9, 3), dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def strided_slice(rxplaceholder: T.Buffer((T.int64(8), T.int64(9), 
T.int64(10)), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(8), 
T.int64(9), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for ax0, ax1, ax2 in T.grid(T.int64(8), T.int64(9), T.int64(3)):
+                with T.sblock("T_strided_slice_with_axes"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2 + T.int64(2)])
+                    T.writes(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2])
+                    T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2] = 
rxplaceholder[v_ax0, v_ax1, v_ax2 + T.int64(2)]
+    # fmt: on
+
+    mod = LegalizeOps()(StridedSlice)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_strided_slice_symbolic_sliced_axis():
     # fmt: off
     @tvm.script.ir_module

Reply via email to