This is an automated email from the ASF dual-hosted git repository.
sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 799e81036d [ARITH] Simplify nested if_then_else when constant is
appearing in then_expr (#16227)
799e81036d is described below
commit 799e81036d43df552efdaecf75c680532ea85144
Author: rutkoor <[email protected]>
AuthorDate: Sun Dec 17 21:38:11 2023 +0530
[ARITH] Simplify nested if_then_else when constant is appearing in
then_expr (#16227)
Simplify nested if_then_else when constant is appearing in then_expr
---
src/arith/ir_mutator_with_analyzer.cc | 5 +++--
tests/python/tir-transform/test_tir_transform_simplify.py | 12 ++++++++++++
2 files changed, 15 insertions(+), 2 deletions(-)
diff --git a/src/arith/ir_mutator_with_analyzer.cc
b/src/arith/ir_mutator_with_analyzer.cc
index 2ee427beb8..d26ac36676 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -173,8 +173,9 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode*
op) {
WithRecordIterPredicate(cond, [&] { true_value =
this->VisitExpr(op->args[1]); });
}
{
- With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not(cond)));
- false_value = this->VisitExpr(op->args[2]);
+ PrimExpr not_cond = Not(cond);
+ With<ConstraintContext> constraint(analyzer_, not_cond);
+ WithRecordIterPredicate(not_cond, [&] { false_value =
this->VisitExpr(op->args[2]); });
}
if (is_zero(cond)) {
return false_value;
diff --git a/tests/python/tir-transform/test_tir_transform_simplify.py
b/tests/python/tir-transform/test_tir_transform_simplify.py
index c779d92f9c..6bad817c49 100644
--- a/tests/python/tir-transform/test_tir_transform_simplify.py
+++ b/tests/python/tir-transform/test_tir_transform_simplify.py
@@ -1757,5 +1757,17 @@ class
TestBufferShapeConstraintWithOffset(BaseBeforeAfter):
A[T.int64(1)] = T.float32(0)
+class TestNestedIfElimination(BaseBeforeAfter):
+ def before(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")):
+ for i0, j0 in T.grid(2, 8):
+ b[i0, j0] = T.if_then_else(
+ i0 == 1 and 6 <= j0, 0, T.max(0, T.if_then_else(i0 == 1 and 6
<= j0, 0, a[i0, j0]))
+ )
+
+ def expected(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")):
+ for i0, j0 in T.grid(2, 8):
+ b[i0, j0] = T.if_then_else(i0 == 1 and 6 <= j0, 0, T.max(0, a[i0,
j0]))
+
+
if __name__ == "__main__":
tvm.testing.main()