This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 4c976a6  handle likely in IRMutatorWithAnalyzer (#5665)
4c976a6 is described below

commit 4c976a6acc8516067d340878f3905552d465f613
Author: Shizhi Tang <[email protected]>
AuthorDate: Tue May 26 00:39:33 2020 +0800

    handle likely in IRMutatorWithAnalyzer (#5665)
---
 src/arith/ir_mutator_with_analyzer.cc               | 16 ++++++++++++----
 src/arith/rewrite_simplify.cc                       |  4 +++-
 .../python/unittest/test_tir_transform_simplify.py  | 21 +++++++++++++++++++++
 3 files changed, 36 insertions(+), 5 deletions(-)

diff --git a/src/arith/ir_mutator_with_analyzer.cc 
b/src/arith/ir_mutator_with_analyzer.cc
index e09ff1d..e6f37f4 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -55,17 +55,25 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* 
op) {
 
 Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
   PrimExpr condition = this->VisitExpr(op->condition);
+  PrimExpr real_condition = condition;
+  if (auto call = condition.as<CallNode>()) {
+    if (call->is_intrinsic(CallNode::likely)) {
+      real_condition = call->args[0];
+    }
+  }
+
   Stmt then_case, else_case;
   {
-    With<ConstraintContext> ctx(analyzer_, condition);
+    With<ConstraintContext> ctx(analyzer_, real_condition);
     then_case = this->VisitStmt(op->then_case);
   }
   if (op->else_case.defined()) {
-    With<ConstraintContext> ctx(analyzer_, 
analyzer_->rewrite_simplify(NotNode::make(condition)));
+    With<ConstraintContext> ctx(analyzer_,
+                                
analyzer_->rewrite_simplify(NotNode::make(real_condition)));
     else_case = this->VisitStmt(op->else_case);
   }
-  if (is_one(condition)) return then_case;
-  if (is_zero(condition)) {
+  if (is_one(real_condition)) return then_case;
+  if (is_zero(real_condition)) {
     if (else_case.defined()) {
       return else_case;
     }
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 3b8ccfb..223b2e6 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -211,7 +211,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* 
op) {
 
 std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& 
constraint) {
   size_t old_literal_size = literal_constraints_.size();
-  literal_constraints_.push_back(constraint);
+  // we will compare the already simplified result with the constraint,
+  // so simplify the constarint as well
+  literal_constraints_.push_back(operator()(constraint));
   size_t new_literal_size = literal_constraints_.size();
   auto frecover = [old_literal_size, new_literal_size, this]() {
     CHECK_EQ(literal_constraints_.size(), new_literal_size);
diff --git a/tests/python/unittest/test_tir_transform_simplify.py 
b/tests/python/unittest/test_tir_transform_simplify.py
index bf53982..48d0849 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -52,6 +52,26 @@ def test_thread_extent_simplify():
     assert isinstance(body.body.body.body, tvm.tir.Store)
 
 
+def test_if_likely():
+    ib = tvm.tir.ir_builder.create()
+    A = ib.pointer("float32", name="A")
+    C = ib.pointer("float32", name="C")
+    n = te.size_var("n")
+    tx = te.thread_axis("threadIdx.x")
+    ty = te.thread_axis("threadIdx.y")
+    ib.scope_attr(tx, "thread_extent", 32)
+    ib.scope_attr(ty, "thread_extent", 32)
+    with ib.if_scope(ib.likely(tx * 32 + ty < n)):
+        with ib.if_scope(ib.likely(tx * 32 + ty < n)):
+            A[tx] = C[tx * 32 + ty]
+    body = ib.get()
+    mod = tvm.IRModule.from_expr(
+        tvm.tir.PrimFunc([A, C, n], body))
+    body = tvm.tir.transform.Simplify()(mod)["main"].body
+    assert isinstance(body.body.body, tvm.tir.IfThenElse)
+    assert not isinstance(body.body.body.then_case, tvm.tir.IfThenElse)
+
+
 def test_basic_likely_elimination():
     n = te.size_var('n')
     X = te.placeholder(shape=(n,), name="x")
@@ -110,5 +130,6 @@ def test_complex_likely_elimination():
 if __name__ == "__main__":
     test_stmt_simplify()
     test_thread_extent_simplify()
+    test_if_likely()
     test_basic_likely_elimination()
     test_complex_likely_elimination()

Reply via email to