This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 4c976a6 handle likely in IRMutatorWithAnalyzer (#5665)
4c976a6 is described below
commit 4c976a6acc8516067d340878f3905552d465f613
Author: Shizhi Tang <[email protected]>
AuthorDate: Tue May 26 00:39:33 2020 +0800
handle likely in IRMutatorWithAnalyzer (#5665)
---
src/arith/ir_mutator_with_analyzer.cc | 16 ++++++++++++----
src/arith/rewrite_simplify.cc | 4 +++-
.../python/unittest/test_tir_transform_simplify.py | 21 +++++++++++++++++++++
3 files changed, 36 insertions(+), 5 deletions(-)
diff --git a/src/arith/ir_mutator_with_analyzer.cc
b/src/arith/ir_mutator_with_analyzer.cc
index e09ff1d..e6f37f4 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -55,17 +55,25 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode*
op) {
Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
+ PrimExpr real_condition = condition;
+ if (auto call = condition.as<CallNode>()) {
+ if (call->is_intrinsic(CallNode::likely)) {
+ real_condition = call->args[0];
+ }
+ }
+
Stmt then_case, else_case;
{
- With<ConstraintContext> ctx(analyzer_, condition);
+ With<ConstraintContext> ctx(analyzer_, real_condition);
then_case = this->VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
- With<ConstraintContext> ctx(analyzer_,
analyzer_->rewrite_simplify(NotNode::make(condition)));
+ With<ConstraintContext> ctx(analyzer_,
+
analyzer_->rewrite_simplify(NotNode::make(real_condition)));
else_case = this->VisitStmt(op->else_case);
}
- if (is_one(condition)) return then_case;
- if (is_zero(condition)) {
+ if (is_one(real_condition)) return then_case;
+ if (is_zero(real_condition)) {
if (else_case.defined()) {
return else_case;
}
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 3b8ccfb..223b2e6 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -211,7 +211,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode*
op) {
std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr&
constraint) {
size_t old_literal_size = literal_constraints_.size();
- literal_constraints_.push_back(constraint);
+ // we will compare the already simplified result with the constraint,
+ // so simplify the constarint as well
+ literal_constraints_.push_back(operator()(constraint));
size_t new_literal_size = literal_constraints_.size();
auto frecover = [old_literal_size, new_literal_size, this]() {
CHECK_EQ(literal_constraints_.size(), new_literal_size);
diff --git a/tests/python/unittest/test_tir_transform_simplify.py
b/tests/python/unittest/test_tir_transform_simplify.py
index bf53982..48d0849 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -52,6 +52,26 @@ def test_thread_extent_simplify():
assert isinstance(body.body.body.body, tvm.tir.Store)
+def test_if_likely():
+ ib = tvm.tir.ir_builder.create()
+ A = ib.pointer("float32", name="A")
+ C = ib.pointer("float32", name="C")
+ n = te.size_var("n")
+ tx = te.thread_axis("threadIdx.x")
+ ty = te.thread_axis("threadIdx.y")
+ ib.scope_attr(tx, "thread_extent", 32)
+ ib.scope_attr(ty, "thread_extent", 32)
+ with ib.if_scope(ib.likely(tx * 32 + ty < n)):
+ with ib.if_scope(ib.likely(tx * 32 + ty < n)):
+ A[tx] = C[tx * 32 + ty]
+ body = ib.get()
+ mod = tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([A, C, n], body))
+ body = tvm.tir.transform.Simplify()(mod)["main"].body
+ assert isinstance(body.body.body, tvm.tir.IfThenElse)
+ assert not isinstance(body.body.body.then_case, tvm.tir.IfThenElse)
+
+
def test_basic_likely_elimination():
n = te.size_var('n')
X = te.placeholder(shape=(n,), name="x")
@@ -110,5 +130,6 @@ def test_complex_likely_elimination():
if __name__ == "__main__":
test_stmt_simplify()
test_thread_extent_simplify()
+ test_if_likely()
test_basic_likely_elimination()
test_complex_likely_elimination()