This is an automated email from the ASF dual-hosted git repository. wuwei pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 52df2e8414 [TIR] Additional Stmt/Expr simplication rules (#11373) 52df2e8414 is described below commit 52df2e84141b34cda2b1e723c22d38b22796d6a7 Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Thu May 26 11:26:29 2022 -0500 [TIR] Additional Stmt/Expr simplication rules (#11373) * [TIR] Additional Stmt/Expr simplication rules - Enabled simplification of `A[i] = A[i] + 0` into no-op. This was a bug introduced in https://github.com/apache/tvm/pull/9727, which applied this rewrite only to `A[i] = A[i]`, and not to statements which simplify to `A[i] = A[i]`. Regression test added to prevent reoccurrence of this bug. - Enabled simplification of `x - x` to zero for floating point types. Previously, this simplification was applied only for data types that could be used as buffer indices. * Updated to maintain separate int/float simplification paths * Updated to use tvm.testing.main * Remove duplicate rewrite rules --- src/arith/rewrite_simplify.cc | 9 +++++ src/tir/transforms/simplify.cc | 12 +++--- .../python/unittest/test_arith_rewrite_simplify.py | 8 ++++ .../python/unittest/test_tir_transform_simplify.py | 45 +++++++++++++++++++--- 4 files changed, 63 insertions(+), 11 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 4d8b6ff769..dab78c77a0 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -411,6 +411,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1); TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y); TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1)); + } else if (op->dtype.is_float()) { + // Cancellation rules. Deliberately off of the integer path, to + // avoid introducing checks on the side effects for the fast path. + TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x), + SideEffect(x.Eval()) <= CallEffectKind::kReadState); + TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState); + TVM_TRY_REWRITE_IF((x + y) - x, y, SideEffect(x.Eval()) <= CallEffectKind::kReadState); + TVM_TRY_REWRITE_IF(x - (y + x), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState); + TVM_TRY_REWRITE_IF(x - (x + y), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState); } // condition rules. diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 7d4fac8d7b..85f405be44 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -90,12 +90,12 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { // eliminate useless stores Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op)); - if (const BufferLoadNode* load = op->value.as<BufferLoadNode>()) { - if (load->buffer->data.same_as(op->buffer->data) && - ArrayDeepEqual(load->indices, op->indices) && - tir::ExprDeepEqual()(load->buffer->elem_offset, op->buffer->elem_offset) && - ArrayDeepEqual(load->buffer->shape, op->buffer->shape) && - ArrayDeepEqual(load->buffer->strides, op->buffer->strides)) { + if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) { + if (load->buffer->data.same_as(store->buffer->data) && + ArrayDeepEqual(load->indices, store->indices) && + tir::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) && + ArrayDeepEqual(load->buffer->shape, store->buffer->shape) && + ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) { return Evaluate(0); } } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 855635b3f9..8d26710f40 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -972,5 +972,13 @@ def test_div_zero_simplify(): assert "division by zero" in str(cm.execption) +def test_sub_bufferload(): + ck = RewriteChecker() + buf = tvm.tir.decl_buffer([1], dtype="float32") + load = tvm.tir.BufferLoad(buf, [0]) + expr = load - load + ck.verify(expr, 0.0) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 824bef4f32..01cc41c7ce 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -15,7 +15,10 @@ # specific language governing permissions and limitations # under the License. import tvm +import tvm.testing + from tvm import te +from tvm.script import tir as T def test_stmt_simplify(): @@ -133,9 +136,41 @@ def test_complex_likely_elimination(): assert "if" not in str(stmt) +def test_load_store_noop(): + """Store of a value that was just read from the same location is a no-op.""" + + @T.prim_func + def before(A: T.Buffer[(1,), "float32"]): + A[0] = A[0] + + @T.prim_func + def expected(A: T.Buffer[(1,), "float32"]): + T.evaluate(0) + + after = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(before))["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_load_store_noop_after_simplify(): + """As test_load_store_noop, but requiring simplification to identify. + + Previously, a bug caused the self-assignment of a buffer to + checked based on the pre-simplification assignment, not the + post-simplification. This test is to identify any similar + regression. + """ + + @T.prim_func + def before(A: T.Buffer[(1,), "float32"]): + A[0] = A[0] + (5.0 - 5.0) + + @T.prim_func + def expected(A: T.Buffer[(1,), "float32"]): + T.evaluate(0) + + after = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(before))["main"] + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": - test_stmt_simplify() - test_thread_extent_simplify() - test_if_likely() - test_basic_likely_elimination() - test_complex_likely_elimination() + tvm.testing.main()