This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7c4c913599 [Ethos-U][TIR] Handle DeclBuffer in Ethos-U inputs (#15098)
7c4c913599 is described below
commit 7c4c913599af13250c38bfc20ba27860be985712
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Jun 16 18:17:44 2023 -0400
[Ethos-U][TIR] Handle DeclBuffer in Ethos-U inputs (#15098)
This is a subset of changes, being split out from
https://github.com/apache/tvm/pull/14778 into independent portions.
---
.../tvm/relay/backend/contrib/ethosu/tir/passes.py | 12 +++-
src/tir/contrib/ethosu/passes.cc | 67 +++++++++++++++++-----
2 files changed, 65 insertions(+), 14 deletions(-)
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
index aa80c89d38..9636f20447 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
@@ -193,13 +193,23 @@ def ReplaceOperators():
)
return None
+ def _remove_buffer_decl(stmt):
+ if isinstance(stmt, tvm.tir.DeclBuffer):
+ if stmt.buffer.data in replace_output_pointer:
+ return stmt.body
+
def _post_transform(stmt):
# Replace operators with call_externs
result = _replace_operator(stmt)
# Remove operators that don't need compiling
result = result or _remove_no_compile(stmt)
# Replace necessary pointers that were removed in the previous step
- return result or _replace_pointers(stmt)
+ result = result or _replace_pointers(stmt)
+ # Replace BufferDecl, since only the tir.Var data pointer is
+ # still used, and not the tir.Buffer
+ result = result or _remove_buffer_decl(stmt)
+
+ return result
def _ftransform(f, mod, ctx):
tvm.tir.stmt_functor.post_order_visit(f.body, _find_pointer_to_extent)
diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc
index 5968febe9a..fba506fba1 100644
--- a/src/tir/contrib/ethosu/passes.cc
+++ b/src/tir/contrib/ethosu/passes.cc
@@ -32,6 +32,8 @@
#include <unordered_map>
#include <unordered_set>
+#include "../../transforms/ir_utils.h"
+
namespace tvm {
/*!
@@ -54,8 +56,47 @@ namespace ethosu {
namespace {
+struct FlattenUnwrapResult {
+ std::vector<Stmt> seq;
+ std::vector<Stmt> rewrap_nest;
+};
+
+/*! \brief Utility function to flatten SeqStmt
+ *
+ * An AttrStmt or DeclBuffer may internally contain SeqStmt nodes that
+ * we want to flatten. Unlike SeqStmt::Flatten, this function unwraps
+ * these node types when encountered.
+ *
+ * \param stmt The tir::Stmt to be flattened.
+ * \return The flattened statements
+ */
+FlattenUnwrapResult FlattenUnwrap(const Stmt& stmt) {
+ std::vector<Stmt> seq_stmt;
+ std::vector<Stmt> rewrap_nest;
+ std::function<void(const Stmt&)> flatten_unwrap = [&](const Stmt& stmt) {
+ if (auto* ptr = stmt.as<DeclBufferNode>()) {
+ rewrap_nest.push_back(DeclBuffer(ptr->buffer, Evaluate(0)));
+ flatten_unwrap(ptr->body);
+ } else if (auto* ptr = stmt.as<SeqStmtNode>()) {
+ for (const auto& sub_stmt : ptr->seq) {
+ flatten_unwrap(sub_stmt);
+ }
+ } else if (auto* ptr = stmt.as<EvaluateNode>(); ptr &&
ptr->value.as<IntImmNode>()) {
+ // Skip
+ } else {
+ seq_stmt.push_back(stmt);
+ }
+ };
+ flatten_unwrap(stmt);
+ return FlattenUnwrapResult{seq_stmt, rewrap_nest};
+}
+
/*! Returns the arguments of the given statement */
-Array<PrimExpr> GetStmtArgs(const Stmt& stmt) {
+Array<PrimExpr> GetStmtArgs(Stmt stmt) {
+ while (auto* ptr = stmt.as<DeclBufferNode>()) {
+ stmt = ptr->body;
+ }
+
auto attr{stmt.as<AttrStmtNode>()};
Stmt eval_stmt{attr ? attr->body : stmt};
auto eval{eval_stmt.as<EvaluateNode>()};
@@ -215,13 +256,13 @@ class CopyComputeReorderingMutator : public
StmtExprMutator {
};
Stmt VisitStmt_(const SeqStmtNode* op) override {
- if (op->size() <= 1) {
+ auto [seq, rewrap_nest] = FlattenUnwrap(GetRef<Stmt>(op));
+
+ if (seq.size() <= 1) {
return StmtExprMutator::VisitStmt_(op);
}
- auto seq_stmt{GetRef<SeqStmt>(op)};
- std::vector<Stmt> new_seq(seq_stmt->size());
- std::copy(seq_stmt->seq.begin(), seq_stmt->seq.end(), new_seq.begin());
+ std::vector<Stmt> new_seq(seq.begin(), seq.end());
// Reorder the copies and computes based on the cycle count
if (_reorder_by_cycles) {
@@ -324,9 +365,7 @@ class CopyComputeReorderingMutator : public StmtExprMutator
{
}
}
- auto seq_stmt_node{CopyOnWrite(op)};
- seq_stmt_node->seq = std::move(new_seq);
- return Stmt{seq_stmt_node};
+ return MergeNest(rewrap_nest, SeqStmt::Flatten(new_seq));
}
bool stmt_is_global_copy(const Stmt& stmt) { return GetStmtType(stmt) ==
StmtType::global_copy; }
@@ -433,12 +472,13 @@ class MergeConstantsInfoExtractor : public
StmtExprVisitor {
}
void VisitStmt_(const SeqStmtNode* op) override {
- if (op->size() <= 1) {
+ std::vector<Stmt> seq_stmt = FlattenUnwrap(GetRef<Stmt>(op)).seq;
+
+ if (seq_stmt.size() <= 1) {
StmtExprVisitor::VisitStmt_(op);
return;
}
- auto seq_stmt{GetRef<SeqStmt>(op)};
for (size_t i = 0; i < seq_stmt.size(); ++i) {
Stmt stmt{seq_stmt[i]};
switch (GetStmtType(stmt)) {
@@ -593,12 +633,13 @@ class MergeConstantsMutator : public StmtExprMutator {
}
Stmt VisitStmt_(const SeqStmtNode* op) override {
- if (op->size() <= 1) {
+ std::vector<Stmt> seq_stmt = FlattenUnwrap(GetRef<Stmt>(op)).seq;
+
+ if (seq_stmt.size() <= 1) {
return StmtExprMutator::VisitStmt_(op);
}
Array<Stmt> new_seq{};
- SeqStmt seq_stmt{GetRef<SeqStmt>(op)};
for (size_t i{0}; i < seq_stmt.size(); ++i) {
Stmt stmt{seq_stmt[i]};
@@ -628,7 +669,7 @@ class MergeConstantsMutator : public StmtExprMutator {
}
}
}
- return SeqStmt(new_seq, op->span);
+ return SeqStmt::Flatten(new_seq);
}
/*! Returns the variables of the buffers written by copies */