This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a75f110c0e [TIR]Fix the crash of the pass RemoveNoOp (#13808)
a75f110c0e is described below
commit a75f110c0ea4234ff7144d18efff3d01bfe0c7d5
Author: lightzhan <[email protected]>
AuthorDate: Thu Jan 19 13:25:05 2023 +0800
[TIR]Fix the crash of the pass RemoveNoOp (#13808)
Fix the crash of the pass RemoveNoOp.
Co-authored-by: lightzhan-intellif <[email protected]>
---
src/tir/transforms/remove_no_op.cc | 5 +++++
tests/python/unittest/test_tir_transform_remove_no_op.py | 14 ++++++++++++++
2 files changed, 19 insertions(+)
diff --git a/src/tir/transforms/remove_no_op.cc
b/src/tir/transforms/remove_no_op.cc
index 430c1f41bf..d35cf8b8d6 100644
--- a/src/tir/transforms/remove_no_op.cc
+++ b/src/tir/transforms/remove_no_op.cc
@@ -119,6 +119,11 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer {
Stmt VisitStmt_(const IfThenElseNode* op) final {
Stmt stmt = Parent::VisitStmt_(op);
op = stmt.as<IfThenElseNode>();
+ // Sometimes the condition can be statically determined,
+ // in which the type of the `stmt` will not be IfThenElseNode.
+ if (!op) {
+ return stmt;
+ }
if (op->else_case) {
bool no_op_else = is_no_op(op->else_case.value());
bool no_op_then = is_no_op(op->then_case);
diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py
b/tests/python/unittest/test_tir_transform_remove_no_op.py
index ce37329b7e..06d9289aa7 100644
--- a/tests/python/unittest/test_tir_transform_remove_no_op.py
+++ b/tests/python/unittest/test_tir_transform_remove_no_op.py
@@ -603,5 +603,19 @@ class TestRemoveWriteIntoTemporary(BaseBeforeAfter):
C[0] = C[0] + B[i]
+class TestCertainConditon(BaseBeforeAfter):
+ """The conditon of the If-Else node is certain.
+ This would cause `Segmentation fault` error before."""
+
+ def before():
+ if True:
+ T.evaluate(0)
+ else:
+ T.evaluate(0)
+
+ def expected():
+ T.evaluate(0)
+
+
if __name__ == "__main__":
tvm.testing.main()