This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7a6281e [BugFix] Generate unique names for reduction blocks (#10726)
7a6281e is described below
commit 7a6281e600e9796ec41e89b59f3736d455ee8255
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Mar 24 17:08:48 2022 +0800
[BugFix] Generate unique names for reduction blocks (#10726)
---
src/te/operation/create_primfunc.cc | 2 +-
tests/python/unittest/test_te_create_primfunc.py | 17 +++++++++++++++--
2 files changed, 16 insertions(+), 3 deletions(-)
diff --git a/src/te/operation/create_primfunc.cc
b/src/te/operation/create_primfunc.cc
index d7503b8..5cf6e5c 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -136,7 +136,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp&
compute_op,
Stmt body;
if (const auto* reduce = expr_body.as<ReduceNode>()) {
// Case 1. Reduce compute
- block_name = compute_op->name;
+ block_name = info->GetUniqueName(compute_op->name);
int n_buffers = buffers.size();
Array<PrimExpr> lhs;
diff --git a/tests/python/unittest/test_te_create_primfunc.py
b/tests/python/unittest/test_te_create_primfunc.py
index a65c5d8..23d264d 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -22,7 +22,7 @@ import numpy as np
import tvm.testing
-def test_unique_name():
+def test_unique_name_complete_block():
A = te.placeholder((16, 16), name="A")
B = te.compute((16, 16), lambda x, y: A[x, y] * 2, name="main")
C = te.compute((16, 16), lambda x, y: B[x, y] + 1, name="main")
@@ -32,6 +32,18 @@ def test_unique_name():
assert isinstance(s.get_sref(s.get_block("main_1")), tir.schedule.StmtSRef)
+def test_unique_name_reduction_block():
+ k1 = te.reduce_axis((0, 16), "k1")
+ k2 = te.reduce_axis((0, 16), "k2")
+ A = te.placeholder((16, 16), name="A")
+ B = te.compute((16,), lambda i: te.sum(A[i, k1], axis=k1), name="sum")
+ C = te.compute((), lambda: te.sum(B[k2], axis=k2), name="sum")
+ func = te.create_prim_func([A, C])
+ s = tir.Schedule(func, debug_mask="all")
+ assert isinstance(s.get_sref(s.get_block("sum")), tir.schedule.StmtSRef)
+ assert isinstance(s.get_sref(s.get_block("sum_1")), tir.schedule.StmtSRef)
+
+
def _check_workload(te_workload, tir_workload):
func = te.create_prim_func(te_workload())
tvm.ir.assert_structural_equal(func, tir_workload)
@@ -462,7 +474,8 @@ def test_argmax_val_idx():
if __name__ == "__main__":
- test_unique_name()
+ test_unique_name_complete_block()
+ test_unique_name_reduction_block()
test_matmul()
test_element_wise()
test_conv2d()