This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ded35dd575 [Relax] Normalize negative concat axis in
ReorderPermuteDimsAfterConcat (#19588)
ded35dd575 is described below
commit ded35dd5758df2a266ea865d8eeb305ef59142c3
Author: Neo Chien <[email protected]>
AuthorDate: Sun May 24 11:52:36 2026 +0800
[Relax] Normalize negative concat axis in ReorderPermuteDimsAfterConcat
(#19588)
Hi Committers,
This PR fixes https://github.com/apache/tvm/issues/19575.
### Root Cause
`ReorderPermuteDimsAfterConcat` reads `concat` axis and uses it as an
index into the permutation axes.
When `concat(axis=-1)` is used, the negative axis was converted directly
to `size_t` before indexing, which can produce an out-of-range index and
crash (e.g. `IndexError: Index -1 out of bounds 4`).
### Solution
In `src/relax/transform/reorder_permute_dims_after_concat.cc`:
1. Read concat axis as signed integer first.
2. Normalize negative axis with `axis += ndim`.
3. Add explicit range checks after normalization.
4. Use the normalized axis for permutation-axis remapping
This keeps behavior unchanged for non-negative axes and only fixes
negative-axis handling.
---------
Co-authored-by: cchung100m <[email protected]>
---
.../transform/reorder_permute_dims_after_concat.cc | 15 +++++++++--
..._transform_reorder_permute_dims_after_concat.py | 29 ++++++++++++++++++++++
2 files changed, 42 insertions(+), 2 deletions(-)
diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc
b/src/relax/transform/reorder_permute_dims_after_concat.cc
index bc542ccf91..01eadc37f3 100644
--- a/src/relax/transform/reorder_permute_dims_after_concat.cc
+++ b/src/relax/transform/reorder_permute_dims_after_concat.cc
@@ -151,8 +151,19 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr,
ffi::Map<DFPattern, Expr>)>>
auto concat_attrs = concat_call->attrs.as<ConcatAttrs>();
TVM_FFI_ICHECK(concat_attrs);
- auto old_concat_axis = [&]() -> size_t { return
concat_attrs->axis.value_or(0); }();
- Integer new_concat_axis =
get_permute_dims_axes(all_permute_dims[0])[old_concat_axis];
+ auto permute_dims_axes = get_permute_dims_axes(all_permute_dims[0]);
+
+ int64_t old_concat_axis = concat_attrs->axis.value_or(0);
+ int64_t ndim = static_cast<int64_t>(permute_dims_axes.size());
+ if (old_concat_axis < 0) {
+ old_concat_axis += ndim;
+ }
+ TVM_FFI_ICHECK_GE(old_concat_axis, 0)
+ << "concat axis " << old_concat_axis << " out of range for " << ndim
<< "-D input";
+ TVM_FFI_ICHECK_LT(old_concat_axis, ndim)
+ << "concat axis " << old_concat_axis << " out of range for " << ndim
<< "-D input";
+
+ Integer new_concat_axis =
permute_dims_axes[static_cast<size_t>(old_concat_axis)];
auto new_concat = concat(Tuple(args), new_concat_axis->value);
auto new_permute_dims = permute_dims(new_concat, permute_axes);
diff --git
a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py
b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py
index f93daa4c1e..2da6cfcda9 100644
--- a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py
+++ b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py
@@ -261,5 +261,34 @@ class TestCheckForRewriteBeforeIncompatibleChange(Base):
return out
+class TestNegativeConcatAxis(Base):
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([1, 4, 8, 8], "float32"),
+ y: R.Tensor([1, 4, 8, 8], "float32"),
+ ):
+ with R.dataflow():
+ xt = R.permute_dims(x, axes=[0, 2, 3, 1])
+ yt = R.permute_dims(y, axes=[0, 2, 3, 1])
+ out = R.concat([xt, yt], axis=-1)
+ R.output(out)
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([1, 4, 8, 8], "float32"),
+ y: R.Tensor([1, 4, 8, 8], "float32"),
+ ):
+ with R.dataflow():
+ merged = R.concat([x, y], axis=1)
+ out = R.permute_dims(merged, axes=[0, 2, 3, 1])
+ R.output(out)
+ return out
+
+
if __name__ == "__main__":
tvm.testing.main()