This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 980d75fd08 [TIRx] Replace vars in buffer strides and elem_offset 
(#19871)
980d75fd08 is described below

commit 980d75fd0897c4bdd1328fc453ed1379e40219fa
Author: Guan-Ming Chiu <[email protected]>
AuthorDate: Fri Jun 26 05:50:02 2026 +0800

    [TIRx] Replace vars in buffer strides and elem_offset (#19871)
    
    ## Why
    
    BufferReplacer passed strides and elem_offset through unchanged, so vars
    in them were never substituted
    
    ## How
    
    - Visit and rewrite each strides `expr` and `elem_offset` when
    rebuilding the buffer.
    - Add test covering the fix.
    
    Signed-off-by: Guan-Ming (Wesley) Chiu 
<[email protected]>
---
 python/tvm/tirx/transform/common.py |  9 ++++++---
 tests/python/tirx/test_op.py        | 15 +++++++++++++++
 2 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/python/tvm/tirx/transform/common.py 
b/python/tvm/tirx/transform/common.py
index d90903daf9..d7ebd557af 100644
--- a/python/tvm/tirx/transform/common.py
+++ b/python/tvm/tirx/transform/common.py
@@ -36,7 +36,6 @@ from tvm.tirx.layout import Iter, TileLayout
 from tvm.tirx.stmt_functor import StmtExprMutator, StmtMutator
 
 
-# FIXME: this pass does not replace var in the shape/layout of a buffer
 class BufferReplacer(StmtExprMutator):
     """
     Replace buffer with another buffer.
@@ -63,6 +62,10 @@ class BufferReplacer(StmtExprMutator):
         self.buffer_attr_var_mutated = False
         new_data = self.visit_expr(buffer.data)
         new_shape = [self.visit_expr(expr) for expr in buffer.shape]
+        new_strides = [self.visit_expr(expr) for expr in buffer.strides]
+        new_elem_offset = (
+            self.visit_expr(buffer.elem_offset) if buffer.elem_offset is not 
None else None
+        )
         if isinstance(buffer.layout, TileLayout):
             new_shard = []
             new_replicate = []
@@ -90,8 +93,8 @@ class BufferReplacer(StmtExprMutator):
             buffer.dtype,
             buffer.name,
             new_data,
-            buffer.strides,
-            buffer.elem_offset,
+            new_strides,
+            new_elem_offset,
             buffer.scope(),
             buffer.data_alignment,
             buffer.offset_factor,
diff --git a/tests/python/tirx/test_op.py b/tests/python/tirx/test_op.py
index 480e6cd3dd..4c417033d6 100644
--- a/tests/python/tirx/test_op.py
+++ b/tests/python/tirx/test_op.py
@@ -57,6 +57,21 @@ def test_buffer_replacer_no_shared_default():
     assert len(r2.buffer_map) == 0
 
 
+def test_buffer_replacer_replaces_strides_and_elem_offset():
+    """Vars in buffer strides/elem_offset must be replaced, not passed 
through."""
+    from tvm.tirx import BufferStore, Var
+    from tvm.tirx.transform.common import BufferReplacer
+
+    n = Var("n", "int32")
+    m = Var("m", "int32")
+    A = decl_buffer((64,), "float32", strides=[n], elem_offset=n)
+    store = BufferStore(A, 1.0, [0])
+
+    new = BufferReplacer(var_map={n: m})(store)
+    assert new.buffer.strides[0].same_as(m)
+    assert new.buffer.elem_offset.same_as(m)
+
+
 def test_gemm_async_partial_scale_factor():
     """Regression test for F7: gemm_async must reject partial scale factors."""
     from tvm.tirx.script.builder.tirx import gemm_async

Reply via email to