This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7a6281e  [BugFix] Generate unique names for reduction blocks (#10726)
7a6281e is described below

commit 7a6281e600e9796ec41e89b59f3736d455ee8255
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Mar 24 17:08:48 2022 +0800

    [BugFix] Generate unique names for reduction blocks (#10726)
---
 src/te/operation/create_primfunc.cc              |  2 +-
 tests/python/unittest/test_te_create_primfunc.py | 17 +++++++++++++++--
 2 files changed, 16 insertions(+), 3 deletions(-)

diff --git a/src/te/operation/create_primfunc.cc 
b/src/te/operation/create_primfunc.cc
index d7503b8..5cf6e5c 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -136,7 +136,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& 
compute_op,
   Stmt body;
   if (const auto* reduce = expr_body.as<ReduceNode>()) {
     // Case 1. Reduce compute
-    block_name = compute_op->name;
+    block_name = info->GetUniqueName(compute_op->name);
     int n_buffers = buffers.size();
 
     Array<PrimExpr> lhs;
diff --git a/tests/python/unittest/test_te_create_primfunc.py 
b/tests/python/unittest/test_te_create_primfunc.py
index a65c5d8..23d264d 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -22,7 +22,7 @@ import numpy as np
 import tvm.testing
 
 
-def test_unique_name():
+def test_unique_name_complete_block():
     A = te.placeholder((16, 16), name="A")
     B = te.compute((16, 16), lambda x, y: A[x, y] * 2, name="main")
     C = te.compute((16, 16), lambda x, y: B[x, y] + 1, name="main")
@@ -32,6 +32,18 @@ def test_unique_name():
     assert isinstance(s.get_sref(s.get_block("main_1")), tir.schedule.StmtSRef)
 
 
+def test_unique_name_reduction_block():
+    k1 = te.reduce_axis((0, 16), "k1")
+    k2 = te.reduce_axis((0, 16), "k2")
+    A = te.placeholder((16, 16), name="A")
+    B = te.compute((16,), lambda i: te.sum(A[i, k1], axis=k1), name="sum")
+    C = te.compute((), lambda: te.sum(B[k2], axis=k2), name="sum")
+    func = te.create_prim_func([A, C])
+    s = tir.Schedule(func, debug_mask="all")
+    assert isinstance(s.get_sref(s.get_block("sum")), tir.schedule.StmtSRef)
+    assert isinstance(s.get_sref(s.get_block("sum_1")), tir.schedule.StmtSRef)
+
+
 def _check_workload(te_workload, tir_workload):
     func = te.create_prim_func(te_workload())
     tvm.ir.assert_structural_equal(func, tir_workload)
@@ -462,7 +474,8 @@ def test_argmax_val_idx():
 
 
 if __name__ == "__main__":
-    test_unique_name()
+    test_unique_name_complete_block()
+    test_unique_name_reduction_block()
     test_matmul()
     test_element_wise()
     test_conv2d()

Reply via email to