This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ded35dd575 [Relax] Normalize negative concat axis in 
ReorderPermuteDimsAfterConcat (#19588)
ded35dd575 is described below

commit ded35dd5758df2a266ea865d8eeb305ef59142c3
Author: Neo Chien <[email protected]>
AuthorDate: Sun May 24 11:52:36 2026 +0800

    [Relax] Normalize negative concat axis in ReorderPermuteDimsAfterConcat 
(#19588)
    
    Hi Committers,
    
    This PR fixes https://github.com/apache/tvm/issues/19575.
    
    ### Root Cause
    `ReorderPermuteDimsAfterConcat` reads `concat` axis and uses it as an
    index into the permutation axes.
    
    When `concat(axis=-1)` is used, the negative axis was converted directly
    to `size_t` before indexing, which can produce an out-of-range index and
    crash (e.g. `IndexError: Index -1 out of bounds 4`).
    
    ### Solution
    In `src/relax/transform/reorder_permute_dims_after_concat.cc`:
    
    1. Read concat axis as signed integer first.
    2. Normalize negative axis with `axis += ndim`.
    3. Add explicit range checks after normalization.
    4. Use the normalized axis for permutation-axis remapping
    
    This keeps behavior unchanged for non-negative axes and only fixes
    negative-axis handling.
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 .../transform/reorder_permute_dims_after_concat.cc | 15 +++++++++--
 ..._transform_reorder_permute_dims_after_concat.py | 29 ++++++++++++++++++++++
 2 files changed, 42 insertions(+), 2 deletions(-)

diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc 
b/src/relax/transform/reorder_permute_dims_after_concat.cc
index bc542ccf91..01eadc37f3 100644
--- a/src/relax/transform/reorder_permute_dims_after_concat.cc
+++ b/src/relax/transform/reorder_permute_dims_after_concat.cc
@@ -151,8 +151,19 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, 
ffi::Map<DFPattern, Expr>)>>
     auto concat_attrs = concat_call->attrs.as<ConcatAttrs>();
     TVM_FFI_ICHECK(concat_attrs);
 
-    auto old_concat_axis = [&]() -> size_t { return 
concat_attrs->axis.value_or(0); }();
-    Integer new_concat_axis = 
get_permute_dims_axes(all_permute_dims[0])[old_concat_axis];
+    auto permute_dims_axes = get_permute_dims_axes(all_permute_dims[0]);
+
+    int64_t old_concat_axis = concat_attrs->axis.value_or(0);
+    int64_t ndim = static_cast<int64_t>(permute_dims_axes.size());
+    if (old_concat_axis < 0) {
+      old_concat_axis += ndim;
+    }
+    TVM_FFI_ICHECK_GE(old_concat_axis, 0)
+        << "concat axis " << old_concat_axis << " out of range for " << ndim 
<< "-D input";
+    TVM_FFI_ICHECK_LT(old_concat_axis, ndim)
+        << "concat axis " << old_concat_axis << " out of range for " << ndim 
<< "-D input";
+
+    Integer new_concat_axis = 
permute_dims_axes[static_cast<size_t>(old_concat_axis)];
 
     auto new_concat = concat(Tuple(args), new_concat_axis->value);
     auto new_permute_dims = permute_dims(new_concat, permute_axes);
diff --git 
a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py 
b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py
index f93daa4c1e..2da6cfcda9 100644
--- a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py
+++ b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py
@@ -261,5 +261,34 @@ class TestCheckForRewriteBeforeIncompatibleChange(Base):
             return out
 
 
+class TestNegativeConcatAxis(Base):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([1, 4, 8, 8], "float32"),
+            y: R.Tensor([1, 4, 8, 8], "float32"),
+        ):
+            with R.dataflow():
+                xt = R.permute_dims(x, axes=[0, 2, 3, 1])
+                yt = R.permute_dims(y, axes=[0, 2, 3, 1])
+                out = R.concat([xt, yt], axis=-1)
+                R.output(out)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([1, 4, 8, 8], "float32"),
+            y: R.Tensor([1, 4, 8, 8], "float32"),
+        ):
+            with R.dataflow():
+                merged = R.concat([x, y], axis=1)
+                out = R.permute_dims(merged, axes=[0, 2, 3, 1])
+                R.output(out)
+            return out
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to