This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0ed6fd6e8c [TIR] Handle DeclBuffer in RemoveNoOp (#15096)
0ed6fd6e8c is described below

commit 0ed6fd6e8cc6b476e301ccbf0e9f5d7b1873690e
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Jun 16 16:56:49 2023 -0400

    [TIR] Handle DeclBuffer in RemoveNoOp (#15096)
    
    When `RemoveNoOp` removes an unused allocation, it should also remove
    all `DeclBuffer` nodes that refer to that allocation.  This is a
    subset of changes, being split out from
    https://github.com/apache/tvm/pull/14778 into independent portions.
---
 src/tir/analysis/var_use_def_analysis.cc           | 63 +++++++++++++++++-----
 src/tir/analysis/var_use_def_analysis.h            | 13 +++--
 src/tir/transforms/remove_no_op.cc                 | 14 +++++
 .../unittest/test_tir_transform_remove_no_op.py    | 16 ++++++
 4 files changed, 89 insertions(+), 17 deletions(-)

diff --git a/src/tir/analysis/var_use_def_analysis.cc 
b/src/tir/analysis/var_use_def_analysis.cc
index 9d1105cd15..0d5a4be8ed 100644
--- a/src/tir/analysis/var_use_def_analysis.cc
+++ b/src/tir/analysis/var_use_def_analysis.cc
@@ -39,7 +39,7 @@ void VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) {
     // thread_extent can appear multiple times
     // use the first appearance as def.
     if (!use_count_.count(iv->var.get())) {
-      this->HandleDef(iv->var.get());
+      this->HandleDef(iv->var);
     }
 
     if (visit_thread_extent_) {
@@ -53,27 +53,32 @@ void VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) {
 }
 
 void VarUseDefAnalyzer::VisitStmt_(const LetStmtNode* op) {
-  this->HandleDef(op->var.get());
+  this->HandleDef(op->var);
   StmtExprVisitor::VisitStmt_(op);
 }
 
 void VarUseDefAnalyzer::VisitStmt_(const ForNode* op) {
-  this->HandleDef(op->loop_var.get());
+  this->HandleDef(op->loop_var);
+  StmtExprVisitor::VisitStmt_(op);
+}
+
+void VarUseDefAnalyzer::VisitStmt_(const DeclBufferNode* op) {
+  this->HandleDef(op->buffer);
   StmtExprVisitor::VisitStmt_(op);
 }
 
 void VarUseDefAnalyzer::VisitStmt_(const AllocateNode* op) {
-  this->HandleDef(op->buffer_var.get());
+  this->HandleDef(op->buffer_var);
   StmtExprVisitor::VisitStmt_(op);
 }
 
 void VarUseDefAnalyzer::VisitStmt_(const AllocateConstNode* op) {
-  this->HandleDef(op->buffer_var.get());
+  this->HandleDef(op->buffer_var);
   StmtExprVisitor::VisitStmt_(op);
 }
 
 void VarUseDefAnalyzer::VisitStmt_(const BufferStoreNode* op) {
-  VisitBuffer(op->buffer);
+  HandleUse(op->buffer);
   StmtExprVisitor::VisitStmt_(op);
 }
 
@@ -90,31 +95,32 @@ void VarUseDefAnalyzer::VisitExpr_(const LetNode* op) {
     ICHECK(deep_equal_(it->second->value, op->value))
         << "Let cannot bind the same var to two different values";
   } else {
-    this->HandleDef(op->var.get());
+    this->HandleDef(op->var);
     let_binding_[op->var.get()] = op;
   }
   this->VisitExpr(op->body);
 }
 
 void VarUseDefAnalyzer::VisitExpr_(const VarNode* op) {
-  this->HandleUse(op);
+  this->HandleUse(GetRef<Var>(op));
   StmtExprVisitor::VisitExpr_(op);
 }
 
 void VarUseDefAnalyzer::VisitExpr_(const ReduceNode* op) {
   for (const auto& iv : op->axis) {
-    this->HandleDef(iv->var.get());
+    this->HandleDef(iv->var);
   }
   StmtExprVisitor::VisitExpr_(op);
 }
 
 void VarUseDefAnalyzer::VisitExpr_(const BufferLoadNode* op) {
-  VisitBuffer(op->buffer);
+  HandleUse(op->buffer);
   StmtExprVisitor::VisitExpr_(op);
 }
 
-void VarUseDefAnalyzer::VisitBuffer(Buffer buffer) {
-  this->HandleUse(buffer->data.get());
+void VarUseDefAnalyzer::VisitBuffer(const Buffer& buffer) {
+  this->HandleUse(buffer->data);
+
   auto visit_arr = [&](Array<PrimExpr> arr) {
     for (const auto& element : arr) {
       this->VisitExpr(element);
@@ -125,7 +131,8 @@ void VarUseDefAnalyzer::VisitBuffer(Buffer buffer) {
   visit_arr(buffer->strides);
 }
 
-void VarUseDefAnalyzer::HandleDef(const VarNode* v) {
+void VarUseDefAnalyzer::HandleDef(const Var& var) {
+  auto v = var.get();
   ICHECK(!def_count_.count(v)) << "variable " << v->name_hint
                                << " has already been defined, the Stmt is not 
SSA";
   ICHECK(!use_count_.count(v)) << "variable " << v->name_hint
@@ -134,7 +141,8 @@ void VarUseDefAnalyzer::HandleDef(const VarNode* v) {
   def_count_[v] = 1;
 }
 
-void VarUseDefAnalyzer::HandleUse(const VarNode* v) {
+void VarUseDefAnalyzer::HandleUse(const Var& var) {
+  auto v = var.get();
   auto it = use_count_.find(v);
   if (it != use_count_.end()) {
     if (it->second >= 0) {
@@ -146,6 +154,33 @@ void VarUseDefAnalyzer::HandleUse(const VarNode* v) {
   }
 }
 
+void VarUseDefAnalyzer::HandleDef(const Buffer& buf) {
+  auto ptr = buf.get();
+  ICHECK(!buffer_def_count_.count(ptr))
+      << "buffer " << ptr->name << " has already been defined, the Stmt is not 
SSA";
+  ICHECK(!buffer_use_count_.count(ptr))
+      << "buffer " << ptr->name << " has been used before definition!";
+  buffer_use_count_[ptr] = 0;
+  buffer_def_count_[ptr] = 1;
+
+  VisitBuffer(buf);
+}
+
+void VarUseDefAnalyzer::HandleUse(const Buffer& buf) {
+  auto ptr = buf.get();
+  auto it = buffer_use_count_.find(ptr);
+  if (it != buffer_use_count_.end()) {
+    if (it->second >= 0) {
+      ++it->second;
+    }
+  } else {
+    undefined_buffers_.push_back(GetRef<Buffer>(ptr));
+    buffer_use_count_[ptr] = -1;
+  }
+
+  VisitBuffer(buf);
+}
+
 Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
   VarUseDefAnalyzer m(args);
   m(stmt);
diff --git a/src/tir/analysis/var_use_def_analysis.h 
b/src/tir/analysis/var_use_def_analysis.h
index 5d0ceed13c..64985b11a9 100644
--- a/src/tir/analysis/var_use_def_analysis.h
+++ b/src/tir/analysis/var_use_def_analysis.h
@@ -45,9 +45,12 @@ class VarUseDefAnalyzer : public StmtExprVisitor {
   // be accessible to the users.
   bool visit_thread_extent_{true};
   Array<Var> undefined_;
+  Array<Buffer> undefined_buffers_;
 
   std::unordered_map<const VarNode*, int> use_count_;
   std::unordered_map<const VarNode*, int> def_count_;
+  std::unordered_map<const BufferNode*, int> buffer_use_count_;
+  std::unordered_map<const BufferNode*, int> buffer_def_count_;
 
  private:
   ExprDeepEqual deep_equal_;
@@ -58,6 +61,8 @@ class VarUseDefAnalyzer : public StmtExprVisitor {
 
   void VisitStmt_(const ForNode* op) final;
 
+  void VisitStmt_(const DeclBufferNode* op) final;
+
   void VisitStmt_(const AllocateNode* op) final;
 
   void VisitStmt_(const AllocateConstNode* op) final;
@@ -72,11 +77,13 @@ class VarUseDefAnalyzer : public StmtExprVisitor {
 
   void VisitExpr_(const BufferLoadNode* op) final;
 
-  void HandleDef(const VarNode* v);
+  void HandleDef(const Var& v);
+  void HandleUse(const Var& v);
 
-  void HandleUse(const VarNode* v);
+  void HandleDef(const Buffer& buf);
+  void HandleUse(const Buffer& buf);
 
-  void VisitBuffer(Buffer buffer);
+  void VisitBuffer(const Buffer& buffer);
 };
 
 }  // namespace tir
diff --git a/src/tir/transforms/remove_no_op.cc 
b/src/tir/transforms/remove_no_op.cc
index 7951a2befa..bc606aa0b7 100644
--- a/src/tir/transforms/remove_no_op.cc
+++ b/src/tir/transforms/remove_no_op.cc
@@ -35,6 +35,7 @@
 #include "../../arith/const_fold.h"
 #include "../../arith/ir_mutator_with_analyzer.h"
 #include "../analysis/control_flow_graph.h"
+#include "../analysis/var_use_def_analysis.h"
 #include "ir_utils.h"
 
 namespace tvm {
@@ -239,6 +240,19 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer {
     return std::move(store);
   }
 
+  Stmt VisitStmt_(const DeclBufferNode* op) final {
+    auto node = Downcast<DeclBuffer>(Parent::VisitStmt_(op));
+
+    VarUseDefAnalyzer var_use({});
+    var_use(node->body);
+
+    if (var_use.buffer_use_count_.count(node->buffer.get())) {
+      return std::move(node);
+    } else {
+      return node->body;
+    }
+  }
+
  private:
   bool ArrayValueEqual(const Array<PrimExpr>& a, const Array<PrimExpr>& b) {
     if (a.size() != b.size()) {
diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py 
b/tests/python/unittest/test_tir_transform_remove_no_op.py
index 00452bb5bd..4cf746a7c2 100644
--- a/tests/python/unittest/test_tir_transform_remove_no_op.py
+++ b/tests/python/unittest/test_tir_transform_remove_no_op.py
@@ -552,6 +552,22 @@ class TestRemoveEmptyTemporary(BaseBeforeAfter):
         T.evaluate(0)
 
 
+class TestRemoveEmptyTemporaryWithDeclBuffer(BaseBeforeAfter):
+    """Remove DeclBuffer alongside Allocate
+
+    If an unused allocation is removed, any DeclBuffer instances that
+    refer to it should also be removed.
+    """
+
+    def before():
+        A = T.decl_buffer([4, 4], "int32", scope="local")
+        A_flat = T.decl_buffer(16, "int32", scope="local", data=A.data)
+        T.evaluate(0)
+
+    def expected():
+        T.evaluate(0)
+
+
 @pytest.mark.xfail(reason="Not implemented yet")
 class TestRemoveUnusedTemporary(BaseBeforeAfter):
     """An unused allocation is a no-op."""

Reply via email to