This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7c4c913599 [Ethos-U][TIR] Handle DeclBuffer in Ethos-U inputs (#15098)
7c4c913599 is described below

commit 7c4c913599af13250c38bfc20ba27860be985712
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Jun 16 18:17:44 2023 -0400

    [Ethos-U][TIR] Handle DeclBuffer in Ethos-U inputs (#15098)
    
    This is a subset of changes, being split out from
    https://github.com/apache/tvm/pull/14778 into independent portions.
---
 .../tvm/relay/backend/contrib/ethosu/tir/passes.py | 12 +++-
 src/tir/contrib/ethosu/passes.cc                   | 67 +++++++++++++++++-----
 2 files changed, 65 insertions(+), 14 deletions(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py 
b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
index aa80c89d38..9636f20447 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
@@ -193,13 +193,23 @@ def ReplaceOperators():
                 )
         return None
 
+    def _remove_buffer_decl(stmt):
+        if isinstance(stmt, tvm.tir.DeclBuffer):
+            if stmt.buffer.data in replace_output_pointer:
+                return stmt.body
+
     def _post_transform(stmt):
         # Replace operators with call_externs
         result = _replace_operator(stmt)
         # Remove operators that don't need compiling
         result = result or _remove_no_compile(stmt)
         # Replace necessary pointers that were removed in the previous step
-        return result or _replace_pointers(stmt)
+        result = result or _replace_pointers(stmt)
+        # Replace BufferDecl, since only the tir.Var data pointer is
+        # still used, and not the tir.Buffer
+        result = result or _remove_buffer_decl(stmt)
+
+        return result
 
     def _ftransform(f, mod, ctx):
         tvm.tir.stmt_functor.post_order_visit(f.body, _find_pointer_to_extent)
diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc
index 5968febe9a..fba506fba1 100644
--- a/src/tir/contrib/ethosu/passes.cc
+++ b/src/tir/contrib/ethosu/passes.cc
@@ -32,6 +32,8 @@
 #include <unordered_map>
 #include <unordered_set>
 
+#include "../../transforms/ir_utils.h"
+
 namespace tvm {
 
 /*!
@@ -54,8 +56,47 @@ namespace ethosu {
 
 namespace {
 
+struct FlattenUnwrapResult {
+  std::vector<Stmt> seq;
+  std::vector<Stmt> rewrap_nest;
+};
+
+/*! \brief Utility function to flatten SeqStmt
+ *
+ * An AttrStmt or DeclBuffer may internally contain SeqStmt nodes that
+ * we want to flatten.  Unlike SeqStmt::Flatten, this function unwraps
+ * these node types when encountered.
+ *
+ * \param stmt The tir::Stmt to be flattened.
+ * \return The flattened statements
+ */
+FlattenUnwrapResult FlattenUnwrap(const Stmt& stmt) {
+  std::vector<Stmt> seq_stmt;
+  std::vector<Stmt> rewrap_nest;
+  std::function<void(const Stmt&)> flatten_unwrap = [&](const Stmt& stmt) {
+    if (auto* ptr = stmt.as<DeclBufferNode>()) {
+      rewrap_nest.push_back(DeclBuffer(ptr->buffer, Evaluate(0)));
+      flatten_unwrap(ptr->body);
+    } else if (auto* ptr = stmt.as<SeqStmtNode>()) {
+      for (const auto& sub_stmt : ptr->seq) {
+        flatten_unwrap(sub_stmt);
+      }
+    } else if (auto* ptr = stmt.as<EvaluateNode>(); ptr && 
ptr->value.as<IntImmNode>()) {
+      // Skip
+    } else {
+      seq_stmt.push_back(stmt);
+    }
+  };
+  flatten_unwrap(stmt);
+  return FlattenUnwrapResult{seq_stmt, rewrap_nest};
+}
+
 /*! Returns the arguments of the given statement */
-Array<PrimExpr> GetStmtArgs(const Stmt& stmt) {
+Array<PrimExpr> GetStmtArgs(Stmt stmt) {
+  while (auto* ptr = stmt.as<DeclBufferNode>()) {
+    stmt = ptr->body;
+  }
+
   auto attr{stmt.as<AttrStmtNode>()};
   Stmt eval_stmt{attr ? attr->body : stmt};
   auto eval{eval_stmt.as<EvaluateNode>()};
@@ -215,13 +256,13 @@ class CopyComputeReorderingMutator : public 
StmtExprMutator {
   };
 
   Stmt VisitStmt_(const SeqStmtNode* op) override {
-    if (op->size() <= 1) {
+    auto [seq, rewrap_nest] = FlattenUnwrap(GetRef<Stmt>(op));
+
+    if (seq.size() <= 1) {
       return StmtExprMutator::VisitStmt_(op);
     }
 
-    auto seq_stmt{GetRef<SeqStmt>(op)};
-    std::vector<Stmt> new_seq(seq_stmt->size());
-    std::copy(seq_stmt->seq.begin(), seq_stmt->seq.end(), new_seq.begin());
+    std::vector<Stmt> new_seq(seq.begin(), seq.end());
 
     // Reorder the copies and computes based on the cycle count
     if (_reorder_by_cycles) {
@@ -324,9 +365,7 @@ class CopyComputeReorderingMutator : public StmtExprMutator 
{
       }
     }
 
-    auto seq_stmt_node{CopyOnWrite(op)};
-    seq_stmt_node->seq = std::move(new_seq);
-    return Stmt{seq_stmt_node};
+    return MergeNest(rewrap_nest, SeqStmt::Flatten(new_seq));
   }
 
   bool stmt_is_global_copy(const Stmt& stmt) { return GetStmtType(stmt) == 
StmtType::global_copy; }
@@ -433,12 +472,13 @@ class MergeConstantsInfoExtractor : public 
StmtExprVisitor {
   }
 
   void VisitStmt_(const SeqStmtNode* op) override {
-    if (op->size() <= 1) {
+    std::vector<Stmt> seq_stmt = FlattenUnwrap(GetRef<Stmt>(op)).seq;
+
+    if (seq_stmt.size() <= 1) {
       StmtExprVisitor::VisitStmt_(op);
       return;
     }
 
-    auto seq_stmt{GetRef<SeqStmt>(op)};
     for (size_t i = 0; i < seq_stmt.size(); ++i) {
       Stmt stmt{seq_stmt[i]};
       switch (GetStmtType(stmt)) {
@@ -593,12 +633,13 @@ class MergeConstantsMutator : public StmtExprMutator {
   }
 
   Stmt VisitStmt_(const SeqStmtNode* op) override {
-    if (op->size() <= 1) {
+    std::vector<Stmt> seq_stmt = FlattenUnwrap(GetRef<Stmt>(op)).seq;
+
+    if (seq_stmt.size() <= 1) {
       return StmtExprMutator::VisitStmt_(op);
     }
 
     Array<Stmt> new_seq{};
-    SeqStmt seq_stmt{GetRef<SeqStmt>(op)};
     for (size_t i{0}; i < seq_stmt.size(); ++i) {
       Stmt stmt{seq_stmt[i]};
 
@@ -628,7 +669,7 @@ class MergeConstantsMutator : public StmtExprMutator {
         }
       }
     }
-    return SeqStmt(new_seq, op->span);
+    return SeqStmt::Flatten(new_seq);
   }
 
   /*! Returns the variables of the buffers written by copies */

Reply via email to