This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a75f110c0e [TIR]Fix the crash of the pass RemoveNoOp (#13808)
a75f110c0e is described below

commit a75f110c0ea4234ff7144d18efff3d01bfe0c7d5
Author: lightzhan <[email protected]>
AuthorDate: Thu Jan 19 13:25:05 2023 +0800

    [TIR]Fix the crash of the pass RemoveNoOp (#13808)
    
    Fix the crash of the pass RemoveNoOp.
    
    Co-authored-by: lightzhan-intellif <[email protected]>
---
 src/tir/transforms/remove_no_op.cc                       |  5 +++++
 tests/python/unittest/test_tir_transform_remove_no_op.py | 14 ++++++++++++++
 2 files changed, 19 insertions(+)

diff --git a/src/tir/transforms/remove_no_op.cc 
b/src/tir/transforms/remove_no_op.cc
index 430c1f41bf..d35cf8b8d6 100644
--- a/src/tir/transforms/remove_no_op.cc
+++ b/src/tir/transforms/remove_no_op.cc
@@ -119,6 +119,11 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer {
   Stmt VisitStmt_(const IfThenElseNode* op) final {
     Stmt stmt = Parent::VisitStmt_(op);
     op = stmt.as<IfThenElseNode>();
+    // Sometimes the condition can be statically determined,
+    // in which the type of the `stmt` will not be IfThenElseNode.
+    if (!op) {
+      return stmt;
+    }
     if (op->else_case) {
       bool no_op_else = is_no_op(op->else_case.value());
       bool no_op_then = is_no_op(op->then_case);
diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py 
b/tests/python/unittest/test_tir_transform_remove_no_op.py
index ce37329b7e..06d9289aa7 100644
--- a/tests/python/unittest/test_tir_transform_remove_no_op.py
+++ b/tests/python/unittest/test_tir_transform_remove_no_op.py
@@ -603,5 +603,19 @@ class TestRemoveWriteIntoTemporary(BaseBeforeAfter):
             C[0] = C[0] + B[i]
 
 
+class TestCertainConditon(BaseBeforeAfter):
+    """The conditon of the If-Else node is certain.
+    This would cause `Segmentation fault` error before."""
+
+    def before():
+        if True:
+            T.evaluate(0)
+        else:
+            T.evaluate(0)
+
+    def expected():
+        T.evaluate(0)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to