This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6f232f91be [TIR] [Bugfix] Pass the correct block_sref_reuse to Replace
(#14023)
6f232f91be is described below
commit 6f232f91bee4907d2de5820bd301db5b4c002d2d
Author: Anirudh Sundar Subramaniam <[email protected]>
AuthorDate: Sat Feb 18 09:49:20 2023 +0530
[TIR] [Bugfix] Pass the correct block_sref_reuse to Replace (#14023)
* [TIR] [Bugfix] Pass the correct block_sref_reuse to Replace
A mismatch between the blocks present in the `result` vs the blocks
passed in `block_sref_to_reuse` caused the bug mentioned in #13974.
This patch tries to fix that bug by collecting only the blocks that are
part of result and also present in the block replacement map
`new_block_to_old_`. Since the scope block is `result`, only that block
and its child blocks would be replaced, and any replaced block would be
present in `rewriter.new_block_to_old_`. Thus, collecting the replaced
blocks from among child blocks of `result` guarantees that the
`block_sref_reuse` would contain all the replaced blocks and that
they'll point to the correct block in `result` thus avoiding the missing
SRef error.
---
.../schedule/primitive/layout_transformation.cc | 49 ++++++++++++++++-----
.../unittest/test_tir_schedule_transform_layout.py | 51 ++++++++++++++++++++++
2 files changed, 89 insertions(+), 11 deletions(-)
diff --git a/src/tir/schedule/primitive/layout_transformation.cc
b/src/tir/schedule/primitive/layout_transformation.cc
index 742384fc79..0e993d06dc 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -704,6 +704,42 @@ class TransformLayoutPlanner : private StmtExprVisitor {
Buffer old_buffer_;
};
+/*!
+ * \brief Collect blocks that are part of root block to be passed to
ScheduleState::Replace for SRef
+ * reuse
+ */
+class ReuseBlocksCollector : public tir::StmtVisitor {
+ public:
+ static Map<Block, Block> Collect(Block result, Map<Block, Block>
new_block_to_old) {
+ return ReuseBlocksCollector(new_block_to_old).Run(result);
+ }
+
+ private:
+ /*! \brief Entry point */
+ Map<Block, Block> Run(const Block result) {
+ VisitStmt(result);
+ return block_sref_reuse_;
+ }
+ /*! \brief Constructor */
+ explicit ReuseBlocksCollector(Map<Block, Block> new_block_to_old)
+ : new_block_to_old_(new_block_to_old) {}
+
+ /*! \brief Override the Stmt visiting behaviour */
+ void VisitStmt_(const tir::BlockNode* block) override {
+ Block block_ref = GetRef<Block>(block);
+ auto it = new_block_to_old_.find(block_ref);
+ if (it != new_block_to_old_.end()) {
+ block_sref_reuse_.Set((*it).second, (*it).first);
+ }
+ StmtVisitor::VisitStmt_(block);
+ }
+
+ /*! \brief New map to be filled with just blocks from scope block */
+ Map<Block, Block> block_sref_reuse_;
+ /*! \brief All block replacements collected so far */
+ Map<Block, Block> new_block_to_old_;
+};
+
class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
public:
/*!
@@ -730,17 +766,8 @@ class TransformLayoutRewriter : private
arith::IRMutatorWithAnalyzer {
write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body});
}
- Map<Block, Block> block_sref_reuse;
- for (auto [after, before] : rewriter.new_block_to_old_) {
- while (auto opt = rewriter.new_block_to_old_.Get(before)) {
- before = opt.value();
- }
- while (auto opt = block_sref_reuse.Get(after)) {
- after = opt.value();
- }
-
- block_sref_reuse.Set(before, after);
- }
+ Map<Block, Block> block_sref_reuse =
+ ReuseBlocksCollector::Collect(result, rewriter.new_block_to_old_);
return {result, block_sref_reuse};
}
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py
b/tests/python/unittest/test_tir_schedule_transform_layout.py
index ace2b58acb..d866de33f1 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -173,6 +173,57 @@ def two_elementwise_unit_dim(A: T.Buffer((1, 128),
"float32"), C: T.Buffer((1, 1
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
+class
TestTransformLayoutWithCacheWriteAndAxisSeparators(tvm.testing.CompareBeforeAfter):
+ """
+ transform_layout with axis_separator on a buffer from cache_write should
work as expected
+ """
+
+ @pytest.fixture
+ def transform(self):
+ def transform(mod):
+
+ def transform_fn(x, y):
+ return [x // 32, y, tvm.te.AXIS_SEPARATOR, x % 32]
+
+ sch = tvm.tir.Schedule(mod, debug_mask="all")
+ block_rv = sch.get_block("T_add")
+ sch.cache_write(block_rv, 0, "global")
+ sch.transform_layout(block_rv, ("write", 0), transform_fn,
pad_value=0.0)
+ return sch.mod
+
+ return transform
+
+ def before(
+ p0: T.Buffer((T.int64(33), T.int64(128)), "float32"),
+ p1: T.Buffer((T.int64(33), T.int64(128)), "float32"),
+ T_add: T.Buffer((T.int64(33), T.int64(128)), "float32"),
+ ):
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(T.int64(33), T.int64(128)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(p0[v_ax0, v_ax1], p1[v_ax0, v_ax1])
+ T.writes(T_add[v_ax0, v_ax1])
+ T_add[v_ax0, v_ax1] = p0[v_ax0, v_ax1] + p1[v_ax0, v_ax1]
+
+ def expected(p0: T.Buffer((T.int64(33), T.int64(128)), "float32"), p1:
T.Buffer((T.int64(33), T.int64(128)), "float32"), T_add: T.Buffer((T.int64(33),
T.int64(128)), "float32")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # with T.block("root"):
+ T_add_global = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(32)),
axis_separators=[2])
+ for axis0, axis1, axis2 in T.grid(T.int64(2), T.int64(128),
T.int64(32)):
+ with T.block("T_add"):
+ v_axis0, v_axis1, v_axis2 = T.axis.remap("SSS", [axis0, axis1,
axis2])
+ T.reads(p0[v_axis0 * T.int64(32) + v_axis2, v_axis1],
p1[v_axis0 * T.int64(32) + v_axis2, v_axis1])
+ T.writes(T_add_global[v_axis0, v_axis1, v_axis2])
+ T_add_global[v_axis0, v_axis1, v_axis2] =
T.if_then_else(v_axis0 == T.int64(1) and T.int64(1) <= v_axis2, T.float32(0),
p0[v_axis0 * T.int64(32) + v_axis2, v_axis1] + p1[v_axis0 * T.int64(32) +
v_axis2, v_axis1])
+ for ax0, ax1 in T.grid(T.int64(33), T.int64(128)):
+ with T.block("T_add_global"):
+ v0, v1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(T_add_global[v0 // T.int64(32), v1, v0 % T.int64(32)])
+ T.writes(T_add[v0, v1])
+ T_add[v0, v1] = T_add_global[v0 // T.int64(32), v1, v0 %
T.int64(32)]
+
# pylint:
enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
# fmt: on