This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dcf1b1a90e [TIR] Handle DeclBuffer in CacheReadWrite schedule
primitive (#15037)
dcf1b1a90e is described below
commit dcf1b1a90e3d96d00368b19a99adefe7c0ef5f68
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Jun 16 18:44:04 2023 -0400
[TIR] Handle DeclBuffer in CacheReadWrite schedule primitive (#15037)
* [Util] Handle AllocateConst in MergeNest
* [TIR] Handle DeclBuffer in CacheReadWrite schedule primitive
Part of changes being split out from
https://github.com/apache/tvm/pull/14778 into independent portions.
This commit allows TIR `cache_read` and `cache_write` schedule
primitives to preserve `DeclBuffer` nodes.
---
src/tir/schedule/primitive/cache_read_write.cc | 59 ++++++---
.../unittest/test_tir_schedule_cache_read_write.py | 134 +++++++++++----------
2 files changed, 113 insertions(+), 80 deletions(-)
diff --git a/src/tir/schedule/primitive/cache_read_write.cc
b/src/tir/schedule/primitive/cache_read_write.cc
index 74a960eefb..0a4cf2329e 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -20,6 +20,7 @@
#include <unordered_set>
#include "../../analysis/var_use_def_analysis.h"
+#include "../../transforms/ir_utils.h"
#include "../utils.h"
namespace tvm {
@@ -425,21 +426,43 @@ bool CalculateAffineFlag(const ScheduleState& self, const
StmtSRef& block_sref)
* \return A SeqStmt, the result after insertion
*/
Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) {
- if (const auto* alloc = stmt.as<AllocateConstNode>()) {
- auto seq_stmt = InsertCacheStage(alloc->body, pos, stage);
- return AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents,
alloc->data, seq_stmt,
- alloc->annotations, alloc->span);
- }
- if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) {
- ObjectPtr<SeqStmtNode> result = make_object<SeqStmtNode>(*seq_stmt);
- result->seq.insert(result->seq.begin() + pos, stage);
- return SeqStmt(result);
+ std::vector<Stmt> nest;
+ Stmt body = stmt;
+ while (true) {
+ if (auto opt = body.as<AllocateConst>()) {
+ auto alloc = opt.value();
+ body = alloc->body;
+ alloc.CopyOnWrite()->body = Evaluate(0);
+ nest.push_back(alloc);
+ } else if (auto opt = body.as<DeclBuffer>()) {
+ auto decl_buffer = opt.value();
+ body = decl_buffer->body;
+ decl_buffer.CopyOnWrite()->body = Evaluate(0);
+ nest.push_back(decl_buffer);
+ } else {
+ break;
+ }
}
- if (pos == 0) {
- return SeqStmt({stage, stmt});
+
+ if (const auto* seq_stmt = body.as<SeqStmtNode>()) {
+ Array<Stmt> seq = seq_stmt->seq;
+ ICHECK_LE(pos, seq.size()) << "Cannot insert at position " << pos << "
into sequence of length "
+ << seq.size();
+ seq.insert(seq.begin() + pos, stage);
+ body = SeqStmt(seq);
+ } else if (pos == 0) {
+ body = SeqStmt({stage, stmt});
+ } else if (pos == 1) {
+ body = SeqStmt({stmt, stage});
+ } else {
+ LOG(FATAL) << "Cannot insert at position " << pos
+ << ". When inserting adjacent to non-SeqStmt, "
+ << "only positions 0 and 1 are valid.";
}
- ICHECK_EQ(pos, 1);
- return SeqStmt({stmt, stage});
+
+ body = MergeNest(nest, body);
+
+ return body;
}
/*!
@@ -550,8 +573,14 @@ class CacheLocDetector : public StmtVisitor {
auto block_body = scope_sref->StmtAs<BlockNode>()->body;
// Find the SeqStmtNode within (potentially nested) AllocateConstNodes
- while (block_body->IsInstance<AllocateConstNode>()) {
- block_body = block_body.as<AllocateConstNode>()->body;
+ while (true) {
+ if (auto* ptr = block_body.as<AllocateConstNode>()) {
+ block_body = ptr->body;
+ } else if (auto* ptr = block_body.as<DeclBufferNode>()) {
+ block_body = ptr->body;
+ } else {
+ break;
+ }
}
const auto* body = block_body.as<SeqStmtNode>();
info->loc_pos = body == nullptr ? 1 : body->size();
diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py
b/tests/python/unittest/test_tir_schedule_cache_read_write.py
index 454557a2bd..95955646c6 100644
--- a/tests/python/unittest/test_tir_schedule_cache_read_write.py
+++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py
@@ -1172,67 +1172,6 @@ def block_predicate_cache_write_output_buf() -> None:
use_block_name = tvm.testing.parameter(by_dict={"block_obj": False,
"block_name": True})
[email protected]_func
-def cache_write_allocate_const(
- A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")
-):
- B = T.alloc_buffer([128, 128], dtype="float32")
- const = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
"float32", [8])
- const_1 = T.Buffer([8], dtype="float32", data=const)
- const2 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
"float32", [8])
- const_2 = T.Buffer([8], dtype="float32", data=const)
- for i, j in T.grid(128, 128):
- for x in range(8):
- with T.block("B"):
- vi, vj, vx = T.axis.remap("SSS", [i, j, x])
- T.reads(A[vi, vj], const_1[vx], const_2[vx])
- T.writes(B[vi, vj])
- B[vi, vj] = A[vi, vj] * const_1[vx] + const_2[vx]
- for i, j in T.grid(128, 128):
- with T.block("C"):
- vi, vj = T.axis.remap("SS", [i, j])
- T.reads(B[vi, vj])
- T.writes(C[vi, vj])
- C[vi, vj] = B[vi, vj] + 1.0
-
-
[email protected]_func
-def cache_write_allocate_const_output(
- A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")
-):
- B = T.alloc_buffer([128, 128], dtype="float32")
- A_global = T.alloc_buffer([128, 128], dtype="float32")
- C_global = T.alloc_buffer([128, 128], dtype="float16")
- const_2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
"float32", [8])
- const_1 = T.Buffer([8], dtype="float32", data=const_2)
- const_2_1 = T.Buffer([8], dtype="float32", data=const_2)
- const2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
"float32", [8])
- for ax0, ax1 in T.grid(128, 128):
- with T.block("A_global"):
- v0, v1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(A[v0, v1])
- T.writes(A_global[v0, v1])
- A_global[v0, v1] = A[v0, v1]
- for i, j, x in T.grid(128, 128, 8):
- with T.block("B"):
- vi, vj, vx = T.axis.remap("SSS", [i, j, x])
- T.reads(A_global[vi, vj], const_1[vx], const_2_1[vx])
- T.writes(B[vi, vj])
- B[vi, vj] = A_global[vi, vj] * const_1[vx] + const_2_1[vx]
- for i, j in T.grid(128, 128):
- with T.block("C"):
- vi, vj = T.axis.remap("SS", [i, j])
- T.reads(B[vi, vj])
- T.writes(C_global[vi, vj])
- C_global[vi, vj] = B[vi, vj] + T.float32(1)
- for ax0, ax1 in T.grid(128, 128):
- with T.block("C_global"):
- v0, v1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(C_global[v0, v1])
- T.writes(C[v0, v1])
- C[v0, v1] = C_global[v0, v1]
-
-
def test_cache_read_elementwise(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
@@ -1493,14 +1432,79 @@ def
test_cache_write_fail_invalid_storage_scope(use_block_name):
sch.cache_write(block_b, 0, "test_scope")
-def test_cache_write_allocate_const():
- sch = tir.Schedule(cache_write_allocate_const)
[email protected]("use_decl_buffer", [True, False])
+def test_cache_write_allocate_const(use_decl_buffer):
+ def apply_decl_buffer(*args, **kwargs):
+ if use_decl_buffer:
+ return T.decl_buffer(*args, **kwargs)
+ else:
+ return T.Buffer(*args, **kwargs)
+
+ @T.prim_func
+ def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128),
"float16")):
+ B = T.alloc_buffer([128, 128], dtype="float32")
+ const1 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
"float32", [8])
+ const1_buf = apply_decl_buffer([8], dtype="float32", data=const1)
+ const2 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
"float32", [8])
+ const2_buf = apply_decl_buffer([8], dtype="float32", data=const2)
+ for i, j in T.grid(128, 128):
+ for x in range(8):
+ with T.block("B"):
+ vi, vj, vx = T.axis.remap("SSS", [i, j, x])
+ T.reads(A[vi, vj], const1_buf[vx], const2_buf[vx])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * const1_buf[vx] + const2_buf[vx]
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(B[vi, vj])
+ T.writes(C[vi, vj])
+ C[vi, vj] = B[vi, vj] + 1.0
+
+ @T.prim_func
+ def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128),
"float16")):
+ B = T.alloc_buffer([128, 128], dtype="float32")
+ A_global = T.alloc_buffer([128, 128], dtype="float32")
+ C_global = T.alloc_buffer([128, 128], dtype="float16")
+ const1 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
"float32", [8])
+ const1_buf = apply_decl_buffer([8], dtype="float32", data=const1)
+ const2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
"float32", [8])
+ const2_buf = apply_decl_buffer([8], dtype="float32", data=const2)
+ for ax0, ax1 in T.grid(128, 128):
+ with T.block("A_global"):
+ v0, v1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v0, v1])
+ T.writes(A_global[v0, v1])
+ A_global[v0, v1] = A[v0, v1]
+ for i, j, x in T.grid(128, 128, 8):
+ with T.block("B"):
+ vi, vj, vx = T.axis.remap("SSS", [i, j, x])
+ T.reads(A_global[vi, vj], const1_buf[vx], const2_buf[vx])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A_global[vi, vj] * const1_buf[vx] + const2_buf[vx]
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(B[vi, vj])
+ T.writes(C_global[vi, vj])
+ C_global[vi, vj] = B[vi, vj] + T.float32(1)
+ for ax0, ax1 in T.grid(128, 128):
+ with T.block("C_global"):
+ v0, v1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(C_global[v0, v1])
+ T.writes(C[v0, v1])
+ C[v0, v1] = C_global[v0, v1]
+
+ sch = tir.Schedule(before)
block_b = sch.get_block("B")
block_c = sch.get_block("C")
sch.cache_read(block_b, 0, "global")
sch.cache_write(block_c, 0, "global")
- tvm.ir.assert_structural_equal(cache_write_allocate_const_output,
sch.mod["main"])
- verify_trace_roundtrip(sch=sch, mod=cache_write_allocate_const)
+
+ after = sch.mod["main"]
+
+ tvm.ir.assert_structural_equal(expected, after)
+ verify_trace_roundtrip(sch=sch, mod=before)
def test_reindex_cache_read():