This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 980d75fd08 [TIRx] Replace vars in buffer strides and elem_offset
(#19871)
980d75fd08 is described below
commit 980d75fd0897c4bdd1328fc453ed1379e40219fa
Author: Guan-Ming Chiu <[email protected]>
AuthorDate: Fri Jun 26 05:50:02 2026 +0800
[TIRx] Replace vars in buffer strides and elem_offset (#19871)
## Why
BufferReplacer passed strides and elem_offset through unchanged, so vars
in them were never substituted
## How
- Visit and rewrite each strides `expr` and `elem_offset` when
rebuilding the buffer.
- Add test covering the fix.
Signed-off-by: Guan-Ming (Wesley) Chiu
<[email protected]>
---
python/tvm/tirx/transform/common.py | 9 ++++++---
tests/python/tirx/test_op.py | 15 +++++++++++++++
2 files changed, 21 insertions(+), 3 deletions(-)
diff --git a/python/tvm/tirx/transform/common.py
b/python/tvm/tirx/transform/common.py
index d90903daf9..d7ebd557af 100644
--- a/python/tvm/tirx/transform/common.py
+++ b/python/tvm/tirx/transform/common.py
@@ -36,7 +36,6 @@ from tvm.tirx.layout import Iter, TileLayout
from tvm.tirx.stmt_functor import StmtExprMutator, StmtMutator
-# FIXME: this pass does not replace var in the shape/layout of a buffer
class BufferReplacer(StmtExprMutator):
"""
Replace buffer with another buffer.
@@ -63,6 +62,10 @@ class BufferReplacer(StmtExprMutator):
self.buffer_attr_var_mutated = False
new_data = self.visit_expr(buffer.data)
new_shape = [self.visit_expr(expr) for expr in buffer.shape]
+ new_strides = [self.visit_expr(expr) for expr in buffer.strides]
+ new_elem_offset = (
+ self.visit_expr(buffer.elem_offset) if buffer.elem_offset is not
None else None
+ )
if isinstance(buffer.layout, TileLayout):
new_shard = []
new_replicate = []
@@ -90,8 +93,8 @@ class BufferReplacer(StmtExprMutator):
buffer.dtype,
buffer.name,
new_data,
- buffer.strides,
- buffer.elem_offset,
+ new_strides,
+ new_elem_offset,
buffer.scope(),
buffer.data_alignment,
buffer.offset_factor,
diff --git a/tests/python/tirx/test_op.py b/tests/python/tirx/test_op.py
index 480e6cd3dd..4c417033d6 100644
--- a/tests/python/tirx/test_op.py
+++ b/tests/python/tirx/test_op.py
@@ -57,6 +57,21 @@ def test_buffer_replacer_no_shared_default():
assert len(r2.buffer_map) == 0
+def test_buffer_replacer_replaces_strides_and_elem_offset():
+ """Vars in buffer strides/elem_offset must be replaced, not passed
through."""
+ from tvm.tirx import BufferStore, Var
+ from tvm.tirx.transform.common import BufferReplacer
+
+ n = Var("n", "int32")
+ m = Var("m", "int32")
+ A = decl_buffer((64,), "float32", strides=[n], elem_offset=n)
+ store = BufferStore(A, 1.0, [0])
+
+ new = BufferReplacer(var_map={n: m})(store)
+ assert new.buffer.strides[0].same_as(m)
+ assert new.buffer.elem_offset.same_as(m)
+
+
def test_gemm_async_partial_scale_factor():
"""Regression test for F7: gemm_async must reject partial scale factors."""
from tvm.tirx.script.builder.tirx import gemm_async