This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 56ddd37d78 [TIR] Enhance loop unroll with unroll local access (#14224)
56ddd37d78 is described below
commit 56ddd37d78fa7266dc226c7071044b04701ec687
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Mar 7 15:06:19 2023 -0500
[TIR] Enhance loop unroll with unroll local access (#14224)
This PR enhances the unroller with an unroll local access option.
This option will detect loop variables that access local memories
and unroll them independent of other options.
A test case is added. This option is by default turned off and
can be useful in certain cases to improve unroller as these
local memory access have to be unrolled at some time pt to be
lifted as registers
---
src/tir/transforms/unroll_loop.cc | 64 ++++++++++++++++++++--
.../unittest/test_tir_transform_unroll_loop.py | 42 ++++++++++++++
2 files changed, 102 insertions(+), 4 deletions(-)
diff --git a/src/tir/transforms/unroll_loop.cc
b/src/tir/transforms/unroll_loop.cc
index 1e55cb22ee..dc14e4512f 100644
--- a/src/tir/transforms/unroll_loop.cc
+++ b/src/tir/transforms/unroll_loop.cc
@@ -33,6 +33,7 @@
#include <unordered_set>
#include <vector>
+#include "../../runtime/thread_storage_scope.h"
#include "ir_utils.h"
namespace tvm {
@@ -43,6 +44,7 @@ struct UnrollLoopConfigNode : public
tvm::AttrsNode<UnrollLoopConfigNode> {
int auto_max_depth;
int auto_max_extent;
int explicit_unroll;
+ int unroll_local_access;
TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") {
TVM_ATTR_FIELD(auto_max_step)
@@ -57,6 +59,9 @@ struct UnrollLoopConfigNode : public
tvm::AttrsNode<UnrollLoopConfigNode> {
TVM_ATTR_FIELD(explicit_unroll)
.describe("Whether to explicitly unroll the loop instead of setting a
pragma")
.set_default(true);
+ TVM_ATTR_FIELD(unroll_local_access)
+ .describe("Whether to always unroll local access")
+ .set_default(false);
}
};
@@ -68,14 +73,30 @@ class UnrollLoopConfig : public Attrs {
TVM_REGISTER_NODE_TYPE(UnrollLoopConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);
+class VarLocalAccessMarker : public ExprVisitor {
+ public:
+ explicit VarLocalAccessMarker(
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>*
var_touched_local)
+ : var_touched_local_(var_touched_local) {}
+
+ void VisitExpr_(const VarNode* op) final {
var_touched_local_->insert(GetRef<Var>(op)); }
+
+ private:
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>* var_touched_local_;
+};
+
+// The Visitor is used to check whether var is used as write index in a local
memory
+// If a loop var is used as indices to a local memory, it must be unrolled so
+// the local memory access can be turned into register access.
class LoopUnroller : public StmtExprMutator {
public:
explicit LoopUnroller(int auto_max_step, int auto_max_depth, int
auto_max_extent,
- bool explicit_unroll)
+ bool explicit_unroll, bool unroll_local_access)
: auto_max_step_(auto_max_step),
auto_max_depth_(auto_max_depth),
auto_max_extent_(auto_max_extent),
- explicit_unroll_(explicit_unroll) {}
+ explicit_unroll_(explicit_unroll),
+ unroll_local_access_(unroll_local_access) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
@@ -96,6 +117,7 @@ class LoopUnroller : public StmtExprMutator {
}
Stmt VisitStmt_(const ForNode* op) {
+ // Post order so we can collect more information
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
int value = GetExtent(op);
@@ -111,6 +133,12 @@ class LoopUnroller : public StmtExprMutator {
auto_unroll = true;
}
+ // If a loop var is used as indices to a local memory, it must be unrolled
so
+ // the local memory access can be turned into register access.
+ if (this->var_touched_local_.count(op->loop_var) && value > 0 &&
unroll_local_access_) {
+ auto_unroll = true;
+ }
+
if (auto_unroll) {
step_count_ *= value;
unroll_depth_ += 1;
@@ -137,8 +165,32 @@ class LoopUnroller : public StmtExprMutator {
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use
BufferStoreNode instead.";
}
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ if (unroll_local_access_) {
+ auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
+ if (storage_scope.rank == runtime::StorageRank::kLocal ||
+ storage_scope.rank == runtime::StorageRank::kWarp) {
+ VarLocalAccessMarker marker(&var_touched_local_);
+ for (PrimExpr e : op->indices) {
+ marker(e);
+ }
+ }
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
Stmt VisitStmt_(const BufferStoreNode* op) final {
++step_count_;
+ if (unroll_local_access_) {
+ auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
+ if (storage_scope.rank == runtime::StorageRank::kLocal ||
+ storage_scope.rank == runtime::StorageRank::kWarp) {
+ VarLocalAccessMarker marker(&var_touched_local_);
+ for (PrimExpr e : op->indices) {
+ marker(e);
+ }
+ }
+ }
return StmtExprMutator::VisitStmt_(op);
}
@@ -161,7 +213,7 @@ class LoopUnroller : public StmtExprMutator {
unroll_depth_ = std::max(unroll_depth_, unroll_depth);
return ret;
};
- return StmtMutator::VisitSeqStmt_(op, false, fmutate);
+ return StmtExprMutator::VisitSeqStmt_(op, false, fmutate);
}
Stmt Unroll(const ForNode* op) {
@@ -202,19 +254,23 @@ class LoopUnroller : public StmtExprMutator {
// this not not count the total steps, only count the number of loops
int auto_max_extent_;
bool explicit_unroll_;
+ // Wether to unroll loops to local access.
+ bool unroll_local_access_{false};
// Number of normal loops in scope
int normal_loop_depth_{0};
// number of unrolled cases in current scope.
int unroll_depth_{0};
// Number of total steps unrolled
int step_count_{0};
+ // set of indices touched during visit local memory
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_touched_local_;
// analyzer
arith::Analyzer analyzer_;
};
Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) {
Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth,
cfg->auto_max_extent,
- cfg->explicit_unroll)(stmt);
+ cfg->explicit_unroll,
cfg->unroll_local_access)(stmt);
if (!ret.same_as(stmt)) {
return ConvertSSA(ret);
} else {
diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py
b/tests/python/unittest/test_tir_transform_unroll_loop.py
index a76e6135b3..a05a085eeb 100644
--- a/tests/python/unittest/test_tir_transform_unroll_loop.py
+++ b/tests/python/unittest/test_tir_transform_unroll_loop.py
@@ -134,7 +134,49 @@ def test_unroll_allocations():
tvm.ir.assert_structural_equal(after, expected)
+def test_unroll_local_access():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def main(B: T.Buffer((64,), "float32")):
+ for bx in T.thread_binding(4, thread="blockIdx.x"):
+ for tx in T.thread_binding(4, thread="threadIdx.x"):
+ A_local_data = T.allocate([4], dtype="float32",
scope="local")
+ A_local = T.Buffer([4], dtype="float32", data=A_local_data)
+ for i in T.serial(4):
+ A_local[i] = T.float32(i)
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def main(B: T.Buffer((64,), "float32")):
+ for bx in T.thread_binding(4, thread="blockIdx.x"):
+ for tx in T.thread_binding(4, thread="threadIdx.x"):
+ A_local_data = T.allocate([4], dtype="float32",
scope="local")
+ A_local = T.Buffer([4], dtype="float32", data=A_local_data)
+ A_local[0] = T.float32(0)
+ A_local[1] = T.float32(1)
+ A_local[2] = T.float32(2)
+ A_local[3] = T.float32(3)
+
+ with tvm.transform.PassContext(
+ config={
+ "tir.UnrollLoop": {
+ "auto_max_depth": 0,
+ "auto_max_extent": 1,
+ "explicit_unroll": True,
+ "unroll_local_access": True,
+ }
+ }
+ ):
+ after = tvm.tir.transform.UnrollLoop()(Before)
+ after = tvm.tir.transform.Simplify()(after)
+
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
if __name__ == "__main__":
+ test_unroll_local_access()
test_unroll_loop()
test_unroll_fake_loop()
test_unroll_single_count_loops()