This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f232f91be [TIR] [Bugfix] Pass the correct block_sref_reuse to Replace 
(#14023)
6f232f91be is described below

commit 6f232f91bee4907d2de5820bd301db5b4c002d2d
Author: Anirudh Sundar Subramaniam <[email protected]>
AuthorDate: Sat Feb 18 09:49:20 2023 +0530

    [TIR] [Bugfix] Pass the correct block_sref_reuse to Replace (#14023)
    
    * [TIR] [Bugfix] Pass the correct block_sref_reuse to Replace
    
    A mismatch between the blocks present in the `result` vs the blocks
    passed in `block_sref_to_reuse` caused the bug mentioned in #13974.
    
    This patch tries to fix that bug by collecting only the blocks that are
    part of result and also present in the block replacement map
    `new_block_to_old_`. Since the scope block is `result`, only that block
    and its child blocks would be replaced, and any replaced block would be
    present in `rewriter.new_block_to_old_`. Thus, collecting the replaced
    blocks from among child blocks of `result` guarantees that the
    `block_sref_reuse` would contain all the replaced blocks and that
    they'll point to the correct block in `result` thus avoiding the missing
    SRef error.
---
 .../schedule/primitive/layout_transformation.cc    | 49 ++++++++++++++++-----
 .../unittest/test_tir_schedule_transform_layout.py | 51 ++++++++++++++++++++++
 2 files changed, 89 insertions(+), 11 deletions(-)

diff --git a/src/tir/schedule/primitive/layout_transformation.cc 
b/src/tir/schedule/primitive/layout_transformation.cc
index 742384fc79..0e993d06dc 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -704,6 +704,42 @@ class TransformLayoutPlanner : private StmtExprVisitor {
   Buffer old_buffer_;
 };
 
+/*!
+ * \brief Collect blocks that are part of root block to be passed to 
ScheduleState::Replace for SRef
+ * reuse
+ */
+class ReuseBlocksCollector : public tir::StmtVisitor {
+ public:
+  static Map<Block, Block> Collect(Block result, Map<Block, Block> 
new_block_to_old) {
+    return ReuseBlocksCollector(new_block_to_old).Run(result);
+  }
+
+ private:
+  /*! \brief Entry point */
+  Map<Block, Block> Run(const Block result) {
+    VisitStmt(result);
+    return block_sref_reuse_;
+  }
+  /*! \brief Constructor */
+  explicit ReuseBlocksCollector(Map<Block, Block> new_block_to_old)
+      : new_block_to_old_(new_block_to_old) {}
+
+  /*! \brief Override the Stmt visiting behaviour */
+  void VisitStmt_(const tir::BlockNode* block) override {
+    Block block_ref = GetRef<Block>(block);
+    auto it = new_block_to_old_.find(block_ref);
+    if (it != new_block_to_old_.end()) {
+      block_sref_reuse_.Set((*it).second, (*it).first);
+    }
+    StmtVisitor::VisitStmt_(block);
+  }
+
+  /*! \brief New map to be filled with just blocks from scope block */
+  Map<Block, Block> block_sref_reuse_;
+  /*! \brief All block replacements collected so far */
+  Map<Block, Block> new_block_to_old_;
+};
+
 class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
  public:
   /*!
@@ -730,17 +766,8 @@ class TransformLayoutRewriter : private 
arith::IRMutatorWithAnalyzer {
       write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body});
     }
 
-    Map<Block, Block> block_sref_reuse;
-    for (auto [after, before] : rewriter.new_block_to_old_) {
-      while (auto opt = rewriter.new_block_to_old_.Get(before)) {
-        before = opt.value();
-      }
-      while (auto opt = block_sref_reuse.Get(after)) {
-        after = opt.value();
-      }
-
-      block_sref_reuse.Set(before, after);
-    }
+    Map<Block, Block> block_sref_reuse =
+        ReuseBlocksCollector::Collect(result, rewriter.new_block_to_old_);
 
     return {result, block_sref_reuse};
   }
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py 
b/tests/python/unittest/test_tir_schedule_transform_layout.py
index ace2b58acb..d866de33f1 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -173,6 +173,57 @@ def two_elementwise_unit_dim(A: T.Buffer((1, 128), 
"float32"), C: T.Buffer((1, 1
             vi, vj = T.axis.remap("SS", [i, j])
             C[vi, vj] = B[vi, vj] + 1.0
 
+class 
TestTransformLayoutWithCacheWriteAndAxisSeparators(tvm.testing.CompareBeforeAfter):
+    """
+    transform_layout with axis_separator on a buffer from cache_write should 
work as expected
+    """
+
+    @pytest.fixture
+    def transform(self):
+        def transform(mod):
+
+            def transform_fn(x, y):
+                return [x // 32, y, tvm.te.AXIS_SEPARATOR, x % 32]
+
+            sch = tvm.tir.Schedule(mod, debug_mask="all")
+            block_rv = sch.get_block("T_add")
+            sch.cache_write(block_rv, 0, "global")
+            sch.transform_layout(block_rv, ("write", 0), transform_fn, 
pad_value=0.0)
+            return sch.mod
+
+        return transform
+
+    def before(
+        p0: T.Buffer((T.int64(33), T.int64(128)), "float32"),
+        p1: T.Buffer((T.int64(33), T.int64(128)), "float32"),
+        T_add: T.Buffer((T.int64(33), T.int64(128)), "float32"),
+    ):
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # with T.block("root"):
+        for ax0, ax1 in T.grid(T.int64(33), T.int64(128)):
+            with T.block("T_add"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(p0[v_ax0, v_ax1], p1[v_ax0, v_ax1])
+                T.writes(T_add[v_ax0, v_ax1])
+                T_add[v_ax0, v_ax1] = p0[v_ax0, v_ax1] + p1[v_ax0, v_ax1]
+
+    def expected(p0: T.Buffer((T.int64(33), T.int64(128)), "float32"), p1: 
T.Buffer((T.int64(33), T.int64(128)), "float32"), T_add: T.Buffer((T.int64(33), 
T.int64(128)), "float32")):
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # with T.block("root"):
+        T_add_global = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(32)), 
axis_separators=[2])
+        for axis0, axis1, axis2 in T.grid(T.int64(2), T.int64(128), 
T.int64(32)):
+            with T.block("T_add"):
+                v_axis0, v_axis1, v_axis2 = T.axis.remap("SSS", [axis0, axis1, 
axis2])
+                T.reads(p0[v_axis0 * T.int64(32) + v_axis2, v_axis1], 
p1[v_axis0 * T.int64(32) + v_axis2, v_axis1])
+                T.writes(T_add_global[v_axis0, v_axis1, v_axis2])
+                T_add_global[v_axis0, v_axis1, v_axis2] = 
T.if_then_else(v_axis0 == T.int64(1) and T.int64(1) <= v_axis2, T.float32(0), 
p0[v_axis0 * T.int64(32) + v_axis2, v_axis1] + p1[v_axis0 * T.int64(32) + 
v_axis2, v_axis1])
+        for ax0, ax1 in T.grid(T.int64(33), T.int64(128)):
+            with T.block("T_add_global"):
+                v0, v1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(T_add_global[v0 // T.int64(32), v1, v0 % T.int64(32)])
+                T.writes(T_add[v0, v1])
+                T_add[v0, v1] = T_add_global[v0 // T.int64(32), v1, v0 % 
T.int64(32)]
+
 # pylint: 
enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
 # fmt: on
 

Reply via email to