This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bca7ebf1b5 [TIR] Fix RenewDef for symbolic input shapes (#15163)
bca7ebf1b5 is described below
commit bca7ebf1b5321e4d76905164e4c6a468fbaca2a6
Author: Junru Shao <[email protected]>
AuthorDate: Tue Jun 27 07:09:42 2023 -0700
[TIR] Fix RenewDef for symbolic input shapes (#15163)
There are cases where the shapes of input buffers are symbolic, but the
first symbol is a composite PrimExpr rather than a TIR Var, which the
original implementation does not take this into account.
Example:
```python
@T.prim_func
def main(a: T.handle, b: T.handle):
m = T.int64()
A = T.match_buffer(a, (m * 2,)) // `m` first appears as composite
B = T.match_buffer(b, (m, 2))
```
---
src/tir/transforms/renew_defs.cc | 12 ++++++++++++
tests/python/unittest/test_tir_renew_defs.py | 20 +++++++++++++++++++-
2 files changed, 31 insertions(+), 1 deletion(-)
diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc
index 8cb01dfe6d..fd2c27dcd1 100644
--- a/src/tir/transforms/renew_defs.cc
+++ b/src/tir/transforms/renew_defs.cc
@@ -50,6 +50,18 @@ class RenewDefMutator : public StmtExprMutator {
for (const auto& param : func->params) {
params.push_back(generator.ReDefineVar(param));
}
+ for (const auto& param : func->params) {
+ if (param->dtype.is_handle()) {
+ const Buffer& buffer = func->buffer_map.at(param);
+ for (const PrimExpr& e : buffer->shape) {
+ if (const auto* v = e.as<VarNode>()) {
+ if (generator.remap_.count(GetRef<Var>(v)) == 0) {
+ generator.ReDefineVar(GetRef<Var>(v));
+ }
+ }
+ }
+ }
+ }
// Redefine buffers in order
// TODO(Siyuan Feng): checking var is used after define
Map<tir::Var, Buffer> buffer_map;
diff --git a/tests/python/unittest/test_tir_renew_defs.py
b/tests/python/unittest/test_tir_renew_defs.py
index e01f5ecb12..3f286a241c 100644
--- a/tests/python/unittest/test_tir_renew_defs.py
+++ b/tests/python/unittest/test_tir_renew_defs.py
@@ -15,9 +15,9 @@
# specific language governing permissions and limitations
# under the License.
-import pytest
import sys
+import pytest
import tvm
import tvm.testing
from tvm.script import tir as T
@@ -76,6 +76,7 @@ def test_simple():
assert f1.body.block.body.loop_var != f2.body.block.body.loop_var
# check remap of j
assert f1.body.block.body.body.loop_var != f2.body.block.body.body.loop_var
+
# check inner block
def _get_block(f):
return f.body.block.body.body.body.block
@@ -169,5 +170,22 @@ def test_symbolic_func():
tvm.ir.assert_structural_equal(f1, f2)
+def test_buffer_map():
+ @T.prim_func
+ def main(a: T.handle, b: T.handle):
+ m = T.int64()
+ A = T.match_buffer(a, (m * 2,))
+ B = T.match_buffer(b, (m, 2))
+ for i, j in T.grid(m, 2):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi * 2 + vj]
+
+ f1 = main
+ f2 = tvm.tir.stmt_functor.renew_defs(main)
+ tvm.ir.assert_structural_equal(f1, f2)
+ assert f1.buffer_map[f1.params[1]].shape[0] !=
f2.buffer_map[f2.params[1]].shape[0]
+
+
if __name__ == "__main__":
tvm.testing.main()