This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0ed6fd6e8c [TIR] Handle DeclBuffer in RemoveNoOp (#15096)
0ed6fd6e8c is described below
commit 0ed6fd6e8cc6b476e301ccbf0e9f5d7b1873690e
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Jun 16 16:56:49 2023 -0400
[TIR] Handle DeclBuffer in RemoveNoOp (#15096)
When `RemoveNoOp` removes an unused allocation, it should also remove
all `DeclBuffer` nodes that refer to that allocation. This is a
subset of changes, being split out from
https://github.com/apache/tvm/pull/14778 into independent portions.
---
src/tir/analysis/var_use_def_analysis.cc | 63 +++++++++++++++++-----
src/tir/analysis/var_use_def_analysis.h | 13 +++--
src/tir/transforms/remove_no_op.cc | 14 +++++
.../unittest/test_tir_transform_remove_no_op.py | 16 ++++++
4 files changed, 89 insertions(+), 17 deletions(-)
diff --git a/src/tir/analysis/var_use_def_analysis.cc
b/src/tir/analysis/var_use_def_analysis.cc
index 9d1105cd15..0d5a4be8ed 100644
--- a/src/tir/analysis/var_use_def_analysis.cc
+++ b/src/tir/analysis/var_use_def_analysis.cc
@@ -39,7 +39,7 @@ void VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) {
// thread_extent can appear multiple times
// use the first appearance as def.
if (!use_count_.count(iv->var.get())) {
- this->HandleDef(iv->var.get());
+ this->HandleDef(iv->var);
}
if (visit_thread_extent_) {
@@ -53,27 +53,32 @@ void VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) {
}
void VarUseDefAnalyzer::VisitStmt_(const LetStmtNode* op) {
- this->HandleDef(op->var.get());
+ this->HandleDef(op->var);
StmtExprVisitor::VisitStmt_(op);
}
void VarUseDefAnalyzer::VisitStmt_(const ForNode* op) {
- this->HandleDef(op->loop_var.get());
+ this->HandleDef(op->loop_var);
+ StmtExprVisitor::VisitStmt_(op);
+}
+
+void VarUseDefAnalyzer::VisitStmt_(const DeclBufferNode* op) {
+ this->HandleDef(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}
void VarUseDefAnalyzer::VisitStmt_(const AllocateNode* op) {
- this->HandleDef(op->buffer_var.get());
+ this->HandleDef(op->buffer_var);
StmtExprVisitor::VisitStmt_(op);
}
void VarUseDefAnalyzer::VisitStmt_(const AllocateConstNode* op) {
- this->HandleDef(op->buffer_var.get());
+ this->HandleDef(op->buffer_var);
StmtExprVisitor::VisitStmt_(op);
}
void VarUseDefAnalyzer::VisitStmt_(const BufferStoreNode* op) {
- VisitBuffer(op->buffer);
+ HandleUse(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}
@@ -90,31 +95,32 @@ void VarUseDefAnalyzer::VisitExpr_(const LetNode* op) {
ICHECK(deep_equal_(it->second->value, op->value))
<< "Let cannot bind the same var to two different values";
} else {
- this->HandleDef(op->var.get());
+ this->HandleDef(op->var);
let_binding_[op->var.get()] = op;
}
this->VisitExpr(op->body);
}
void VarUseDefAnalyzer::VisitExpr_(const VarNode* op) {
- this->HandleUse(op);
+ this->HandleUse(GetRef<Var>(op));
StmtExprVisitor::VisitExpr_(op);
}
void VarUseDefAnalyzer::VisitExpr_(const ReduceNode* op) {
for (const auto& iv : op->axis) {
- this->HandleDef(iv->var.get());
+ this->HandleDef(iv->var);
}
StmtExprVisitor::VisitExpr_(op);
}
void VarUseDefAnalyzer::VisitExpr_(const BufferLoadNode* op) {
- VisitBuffer(op->buffer);
+ HandleUse(op->buffer);
StmtExprVisitor::VisitExpr_(op);
}
-void VarUseDefAnalyzer::VisitBuffer(Buffer buffer) {
- this->HandleUse(buffer->data.get());
+void VarUseDefAnalyzer::VisitBuffer(const Buffer& buffer) {
+ this->HandleUse(buffer->data);
+
auto visit_arr = [&](Array<PrimExpr> arr) {
for (const auto& element : arr) {
this->VisitExpr(element);
@@ -125,7 +131,8 @@ void VarUseDefAnalyzer::VisitBuffer(Buffer buffer) {
visit_arr(buffer->strides);
}
-void VarUseDefAnalyzer::HandleDef(const VarNode* v) {
+void VarUseDefAnalyzer::HandleDef(const Var& var) {
+ auto v = var.get();
ICHECK(!def_count_.count(v)) << "variable " << v->name_hint
<< " has already been defined, the Stmt is not
SSA";
ICHECK(!use_count_.count(v)) << "variable " << v->name_hint
@@ -134,7 +141,8 @@ void VarUseDefAnalyzer::HandleDef(const VarNode* v) {
def_count_[v] = 1;
}
-void VarUseDefAnalyzer::HandleUse(const VarNode* v) {
+void VarUseDefAnalyzer::HandleUse(const Var& var) {
+ auto v = var.get();
auto it = use_count_.find(v);
if (it != use_count_.end()) {
if (it->second >= 0) {
@@ -146,6 +154,33 @@ void VarUseDefAnalyzer::HandleUse(const VarNode* v) {
}
}
+void VarUseDefAnalyzer::HandleDef(const Buffer& buf) {
+ auto ptr = buf.get();
+ ICHECK(!buffer_def_count_.count(ptr))
+ << "buffer " << ptr->name << " has already been defined, the Stmt is not
SSA";
+ ICHECK(!buffer_use_count_.count(ptr))
+ << "buffer " << ptr->name << " has been used before definition!";
+ buffer_use_count_[ptr] = 0;
+ buffer_def_count_[ptr] = 1;
+
+ VisitBuffer(buf);
+}
+
+void VarUseDefAnalyzer::HandleUse(const Buffer& buf) {
+ auto ptr = buf.get();
+ auto it = buffer_use_count_.find(ptr);
+ if (it != buffer_use_count_.end()) {
+ if (it->second >= 0) {
+ ++it->second;
+ }
+ } else {
+ undefined_buffers_.push_back(GetRef<Buffer>(ptr));
+ buffer_use_count_[ptr] = -1;
+ }
+
+ VisitBuffer(buf);
+}
+
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
VarUseDefAnalyzer m(args);
m(stmt);
diff --git a/src/tir/analysis/var_use_def_analysis.h
b/src/tir/analysis/var_use_def_analysis.h
index 5d0ceed13c..64985b11a9 100644
--- a/src/tir/analysis/var_use_def_analysis.h
+++ b/src/tir/analysis/var_use_def_analysis.h
@@ -45,9 +45,12 @@ class VarUseDefAnalyzer : public StmtExprVisitor {
// be accessible to the users.
bool visit_thread_extent_{true};
Array<Var> undefined_;
+ Array<Buffer> undefined_buffers_;
std::unordered_map<const VarNode*, int> use_count_;
std::unordered_map<const VarNode*, int> def_count_;
+ std::unordered_map<const BufferNode*, int> buffer_use_count_;
+ std::unordered_map<const BufferNode*, int> buffer_def_count_;
private:
ExprDeepEqual deep_equal_;
@@ -58,6 +61,8 @@ class VarUseDefAnalyzer : public StmtExprVisitor {
void VisitStmt_(const ForNode* op) final;
+ void VisitStmt_(const DeclBufferNode* op) final;
+
void VisitStmt_(const AllocateNode* op) final;
void VisitStmt_(const AllocateConstNode* op) final;
@@ -72,11 +77,13 @@ class VarUseDefAnalyzer : public StmtExprVisitor {
void VisitExpr_(const BufferLoadNode* op) final;
- void HandleDef(const VarNode* v);
+ void HandleDef(const Var& v);
+ void HandleUse(const Var& v);
- void HandleUse(const VarNode* v);
+ void HandleDef(const Buffer& buf);
+ void HandleUse(const Buffer& buf);
- void VisitBuffer(Buffer buffer);
+ void VisitBuffer(const Buffer& buffer);
};
} // namespace tir
diff --git a/src/tir/transforms/remove_no_op.cc
b/src/tir/transforms/remove_no_op.cc
index 7951a2befa..bc606aa0b7 100644
--- a/src/tir/transforms/remove_no_op.cc
+++ b/src/tir/transforms/remove_no_op.cc
@@ -35,6 +35,7 @@
#include "../../arith/const_fold.h"
#include "../../arith/ir_mutator_with_analyzer.h"
#include "../analysis/control_flow_graph.h"
+#include "../analysis/var_use_def_analysis.h"
#include "ir_utils.h"
namespace tvm {
@@ -239,6 +240,19 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer {
return std::move(store);
}
+ Stmt VisitStmt_(const DeclBufferNode* op) final {
+ auto node = Downcast<DeclBuffer>(Parent::VisitStmt_(op));
+
+ VarUseDefAnalyzer var_use({});
+ var_use(node->body);
+
+ if (var_use.buffer_use_count_.count(node->buffer.get())) {
+ return std::move(node);
+ } else {
+ return node->body;
+ }
+ }
+
private:
bool ArrayValueEqual(const Array<PrimExpr>& a, const Array<PrimExpr>& b) {
if (a.size() != b.size()) {
diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py
b/tests/python/unittest/test_tir_transform_remove_no_op.py
index 00452bb5bd..4cf746a7c2 100644
--- a/tests/python/unittest/test_tir_transform_remove_no_op.py
+++ b/tests/python/unittest/test_tir_transform_remove_no_op.py
@@ -552,6 +552,22 @@ class TestRemoveEmptyTemporary(BaseBeforeAfter):
T.evaluate(0)
+class TestRemoveEmptyTemporaryWithDeclBuffer(BaseBeforeAfter):
+ """Remove DeclBuffer alongside Allocate
+
+ If an unused allocation is removed, any DeclBuffer instances that
+ refer to it should also be removed.
+ """
+
+ def before():
+ A = T.decl_buffer([4, 4], "int32", scope="local")
+ A_flat = T.decl_buffer(16, "int32", scope="local", data=A.data)
+ T.evaluate(0)
+
+ def expected():
+ T.evaluate(0)
+
+
@pytest.mark.xfail(reason="Not implemented yet")
class TestRemoveUnusedTemporary(BaseBeforeAfter):
"""An unused allocation is a no-op."""