This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e0dbc8773a [TIR] Handle DeclBuffer in InjectDoubleBuffer (#15045)
e0dbc8773a is described below

commit e0dbc8773a98bb540c90ef3fabe3645f014dd31d
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Jun 16 17:19:17 2023 -0400

    [TIR] Handle DeclBuffer in InjectDoubleBuffer (#15045)
    
    Preserve DeclBuffer node when transforming with `InjectDoubleBuffer`
    This is a subset of changes, being split out from
    https://github.com/apache/tvm/pull/14778 into independent portions.
---
 src/tir/transforms/inject_double_buffer.cc         | 24 ++++--
 .../test_tir_transform_inject_double_buffer.py     | 88 +++++++++++++++++++++-
 2 files changed, 104 insertions(+), 8 deletions(-)

diff --git a/src/tir/transforms/inject_double_buffer.cc 
b/src/tir/transforms/inject_double_buffer.cc
index c99264041e..88188425a9 100644
--- a/src/tir/transforms/inject_double_buffer.cc
+++ b/src/tir/transforms/inject_double_buffer.cc
@@ -106,21 +106,29 @@ class DoubleBufferInjector : public StmtExprMutator {
     const VarNode* buf = op->buffer_var.as<VarNode>();
     auto it = dbuffer_info_.find(buf);
     if (it != dbuffer_info_.end()) {
-      it->second.scope = GetPtrStorageScope(op->buffer_var);
+      StorageEntry& entry = it->second;
+      entry.scope = GetPtrStorageScope(op->buffer_var);
 
       ICHECK_EQ(op->extents.size(), 1) << "InjectDoubleBuffer expects flat 1-d 
buffers.  "
                                        << "Has StorageFlatten (TE-based 
schedules) or "
                                        << "FlattenBuffer (TIR-based schedules) 
been run?";
-      it->second.stride = op->extents[0];
+      entry.stride = op->extents[0];
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       op = stmt.as<AllocateNode>();
 
       Array<PrimExpr> new_extents = {op->extents[0] * 
make_const(op->extents[0].dtype(), 2)};
-      ICHECK(it->second.loop != nullptr);
-      auto& alloc_nest = loop_allocs_[it->second.loop];
+      ICHECK(entry.loop != nullptr);
+      auto& alloc_nest = loop_allocs_[entry.loop];
       alloc_nest.emplace_back(
           Allocate(op->buffer_var, op->dtype, new_extents, op->condition, 
Evaluate(0)));
-      return op->body;
+      Stmt body = op->body;
+      if (auto ptr = body.as<DeclBufferNode>()) {
+        auto new_buf = GetRemappedBuffer(ptr->buffer, entry.stride);
+        alloc_nest.emplace_back(DeclBuffer(new_buf, Evaluate(0)));
+        body = ptr->body;
+      }
+
+      return body;
     } else {
       return StmtExprMutator::VisitStmt_(op);
     }
@@ -226,8 +234,10 @@ class DoubleBufferInjector : public StmtExprMutator {
     ICHECK_EQ(buf->shape.size(), 1) << "InjectDoubleBuffer expects flat 1-d 
buffers.  "
                                     << "Has StorageFlatten (TE-based 
schedules) or "
                                     << "FlattenBuffer (TIR-based schedules) 
been run?";
-    auto writer = buf.CopyOnWrite();
-    writer->shape = {buf->shape[0] * stride};
+
+    // Stride gives the distance between the two halves of the
+    // double-buffer, not the stride of the buffer's index.
+    buf.CopyOnWrite()->shape = {buf->shape[0] + stride};
 
     buf_remap_[key] = buf;
     return buf;
diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py 
b/tests/python/unittest/test_tir_transform_inject_double_buffer.py
index 0f4cc00f07..4e7a0b5d80 100644
--- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py
+++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py
@@ -14,7 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 import tvm
+import tvm.testing
+
+from tvm.script import tir as T, ir as I
 from tvm import te
 
 
@@ -61,5 +65,87 @@ def test_double_buffer():
     assert count[0] == 4
 
 
+class TestDoubleBuffer(tvm.testing.CompareBeforeAfter):
+    transform = tvm.ir.transform.Sequential(
+        [
+            tvm.tir.transform.InjectDoubleBuffer(),
+            tvm.tir.transform.Simplify(),
+        ]
+    )
+
+    def before(A: T.Buffer([16, 32], "float32"), B: T.Buffer(16, "float32")):
+
+        for i in range(16):
+            cache_data = T.allocate([32], "float32")
+            cache = T.Buffer(32, "float32", data=cache_data)
+
+            T.attr(cache_data, "double_buffer_scope", 1)
+
+            for j in range(32):
+                cache[j] = A[i, j]
+
+            B[i] = 0.0
+            for j in range(32):
+                B[i] = B[i] + cache[j]
+
+    def expected(A: T.Buffer((16, 32), "float32"), B: T.Buffer((16,), 
"float32")):
+        cache_data = T.allocate([64], "float32", "global")
+        cache = T.Buffer(64, data=cache_data)
+        for j in range(32):
+            cache[j] = A[0, j]
+
+        B[0] = T.float32(0)
+        for j in range(32):
+            B[0] = B[0] + cache[j]
+
+        for i_outer in range(15):
+            T.attr(cache_data, "double_buffer_write", 1)
+            for j in range(32):
+                cache[(i_outer + 1) % 2 * 32 + j] = A[i_outer + 1, j]
+            B[i_outer + 1] = T.float32(0)
+            for j in range(32):
+                B[i_outer + 1] = B[i_outer + 1] + cache[(i_outer + 1) % 2 * 32 
+ j]
+
+
+class TestDoubleBufferWithDeclBuffer(tvm.testing.CompareBeforeAfter):
+    """Like TestDoubleBuffer, but with a declared buffer object"""
+
+    transform = tvm.ir.transform.Sequential(
+        [
+            tvm.tir.transform.InjectDoubleBuffer(),
+            tvm.tir.transform.Simplify(),
+        ]
+    )
+
+    def before(A: T.Buffer((16, 32), "float32"), B: T.Buffer(16, "float32")):
+        for i in range(16):
+            cache = T.decl_buffer(32, "float32")
+            T.attr(cache.data, "double_buffer_scope", 1)
+
+            for j in range(32):
+                cache[j] = A[i, j]
+
+            B[i] = 0.0
+            for j in range(32):
+                B[i] = B[i] + cache[j]
+
+    def expected(A: T.Buffer((16, 32), "float32"), B: T.Buffer(16, "float32")):
+        cache = T.decl_buffer(64, "float32")
+        for j in range(32):
+            cache[j] = A[0, j]
+
+        B[0] = T.float32(0)
+        for j in range(32):
+            B[0] = B[0] + cache[j]
+
+        for i_outer in range(15):
+            T.attr(cache.data, "double_buffer_write", 1)
+            for j in range(32):
+                cache[(i_outer + 1) % 2 * 32 + j] = A[i_outer + 1, j]
+            B[i_outer + 1] = T.float32(0)
+            for j in range(32):
+                B[i_outer + 1] = B[i_outer + 1] + cache[(i_outer + 1) % 2 * 32 
+ j]
+
+
 if __name__ == "__main__":
-    test_double_buffer()
+    tvm.testing.main()

Reply via email to