This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bca7ebf1b5 [TIR] Fix RenewDef for symbolic input shapes (#15163)
bca7ebf1b5 is described below

commit bca7ebf1b5321e4d76905164e4c6a468fbaca2a6
Author: Junru Shao <[email protected]>
AuthorDate: Tue Jun 27 07:09:42 2023 -0700

    [TIR] Fix RenewDef for symbolic input shapes (#15163)
    
    There are cases where the shapes of input buffers are symbolic, but the
    first symbol is a composite PrimExpr rather than a TIR Var, which the
    original implementation does not take this into account.
    
    Example:
    
    ```python
    @T.prim_func
    def main(a: T.handle, b: T.handle):
        m = T.int64()
        A = T.match_buffer(a, (m * 2,))  // `m` first appears as composite
        B = T.match_buffer(b, (m, 2))
    ```
---
 src/tir/transforms/renew_defs.cc             | 12 ++++++++++++
 tests/python/unittest/test_tir_renew_defs.py | 20 +++++++++++++++++++-
 2 files changed, 31 insertions(+), 1 deletion(-)

diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc
index 8cb01dfe6d..fd2c27dcd1 100644
--- a/src/tir/transforms/renew_defs.cc
+++ b/src/tir/transforms/renew_defs.cc
@@ -50,6 +50,18 @@ class RenewDefMutator : public StmtExprMutator {
     for (const auto& param : func->params) {
       params.push_back(generator.ReDefineVar(param));
     }
+    for (const auto& param : func->params) {
+      if (param->dtype.is_handle()) {
+        const Buffer& buffer = func->buffer_map.at(param);
+        for (const PrimExpr& e : buffer->shape) {
+          if (const auto* v = e.as<VarNode>()) {
+            if (generator.remap_.count(GetRef<Var>(v)) == 0) {
+              generator.ReDefineVar(GetRef<Var>(v));
+            }
+          }
+        }
+      }
+    }
     // Redefine buffers in order
     // TODO(Siyuan Feng): checking var is used after define
     Map<tir::Var, Buffer> buffer_map;
diff --git a/tests/python/unittest/test_tir_renew_defs.py 
b/tests/python/unittest/test_tir_renew_defs.py
index e01f5ecb12..3f286a241c 100644
--- a/tests/python/unittest/test_tir_renew_defs.py
+++ b/tests/python/unittest/test_tir_renew_defs.py
@@ -15,9 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import pytest
 import sys
 
+import pytest
 import tvm
 import tvm.testing
 from tvm.script import tir as T
@@ -76,6 +76,7 @@ def test_simple():
     assert f1.body.block.body.loop_var != f2.body.block.body.loop_var
     # check remap of j
     assert f1.body.block.body.body.loop_var != f2.body.block.body.body.loop_var
+
     # check inner block
     def _get_block(f):
         return f.body.block.body.body.body.block
@@ -169,5 +170,22 @@ def test_symbolic_func():
     tvm.ir.assert_structural_equal(f1, f2)
 
 
+def test_buffer_map():
+    @T.prim_func
+    def main(a: T.handle, b: T.handle):
+        m = T.int64()
+        A = T.match_buffer(a, (m * 2,))
+        B = T.match_buffer(b, (m, 2))
+        for i, j in T.grid(m, 2):
+            with T.block("B"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                B[vi, vj] = A[vi * 2 + vj]
+
+    f1 = main
+    f2 = tvm.tir.stmt_functor.renew_defs(main)
+    tvm.ir.assert_structural_equal(f1, f2)
+    assert f1.buffer_map[f1.params[1]].shape[0] != 
f2.buffer_map[f2.params[1]].shape[0]
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to