This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 799e81036d [ARITH] Simplify nested if_then_else when constant is 
appearing in then_expr (#16227)
799e81036d is described below

commit 799e81036d43df552efdaecf75c680532ea85144
Author: rutkoor <[email protected]>
AuthorDate: Sun Dec 17 21:38:11 2023 +0530

    [ARITH] Simplify nested if_then_else when constant is appearing in 
then_expr (#16227)
    
    Simplify nested if_then_else when constant is appearing in then_expr
---
 src/arith/ir_mutator_with_analyzer.cc                     |  5 +++--
 tests/python/tir-transform/test_tir_transform_simplify.py | 12 ++++++++++++
 2 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/src/arith/ir_mutator_with_analyzer.cc 
b/src/arith/ir_mutator_with_analyzer.cc
index 2ee427beb8..d26ac36676 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -173,8 +173,9 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* 
op) {
       WithRecordIterPredicate(cond, [&] { true_value = 
this->VisitExpr(op->args[1]); });
     }
     {
-      With<ConstraintContext> constraint(analyzer_, 
analyzer_->rewrite_simplify(Not(cond)));
-      false_value = this->VisitExpr(op->args[2]);
+      PrimExpr not_cond = Not(cond);
+      With<ConstraintContext> constraint(analyzer_, not_cond);
+      WithRecordIterPredicate(not_cond, [&] { false_value = 
this->VisitExpr(op->args[2]); });
     }
     if (is_zero(cond)) {
       return false_value;
diff --git a/tests/python/tir-transform/test_tir_transform_simplify.py 
b/tests/python/tir-transform/test_tir_transform_simplify.py
index c779d92f9c..6bad817c49 100644
--- a/tests/python/tir-transform/test_tir_transform_simplify.py
+++ b/tests/python/tir-transform/test_tir_transform_simplify.py
@@ -1757,5 +1757,17 @@ class 
TestBufferShapeConstraintWithOffset(BaseBeforeAfter):
         A[T.int64(1)] = T.float32(0)
 
 
+class TestNestedIfElimination(BaseBeforeAfter):
+    def before(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")):
+        for i0, j0 in T.grid(2, 8):
+            b[i0, j0] = T.if_then_else(
+                i0 == 1 and 6 <= j0, 0, T.max(0, T.if_then_else(i0 == 1 and 6 
<= j0, 0, a[i0, j0]))
+            )
+
+    def expected(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")):
+        for i0, j0 in T.grid(2, 8):
+            b[i0, j0] = T.if_then_else(i0 == 1 and 6 <= j0, 0, T.max(0, a[i0, 
j0]))
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to