This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e0dbc8773a [TIR] Handle DeclBuffer in InjectDoubleBuffer (#15045)
e0dbc8773a is described below
commit e0dbc8773a98bb540c90ef3fabe3645f014dd31d
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Jun 16 17:19:17 2023 -0400
[TIR] Handle DeclBuffer in InjectDoubleBuffer (#15045)
Preserve DeclBuffer node when transforming with `InjectDoubleBuffer`
This is a subset of changes, being split out from
https://github.com/apache/tvm/pull/14778 into independent portions.
---
src/tir/transforms/inject_double_buffer.cc | 24 ++++--
.../test_tir_transform_inject_double_buffer.py | 88 +++++++++++++++++++++-
2 files changed, 104 insertions(+), 8 deletions(-)
diff --git a/src/tir/transforms/inject_double_buffer.cc
b/src/tir/transforms/inject_double_buffer.cc
index c99264041e..88188425a9 100644
--- a/src/tir/transforms/inject_double_buffer.cc
+++ b/src/tir/transforms/inject_double_buffer.cc
@@ -106,21 +106,29 @@ class DoubleBufferInjector : public StmtExprMutator {
const VarNode* buf = op->buffer_var.as<VarNode>();
auto it = dbuffer_info_.find(buf);
if (it != dbuffer_info_.end()) {
- it->second.scope = GetPtrStorageScope(op->buffer_var);
+ StorageEntry& entry = it->second;
+ entry.scope = GetPtrStorageScope(op->buffer_var);
ICHECK_EQ(op->extents.size(), 1) << "InjectDoubleBuffer expects flat 1-d
buffers. "
<< "Has StorageFlatten (TE-based
schedules) or "
<< "FlattenBuffer (TIR-based schedules)
been run?";
- it->second.stride = op->extents[0];
+ entry.stride = op->extents[0];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
Array<PrimExpr> new_extents = {op->extents[0] *
make_const(op->extents[0].dtype(), 2)};
- ICHECK(it->second.loop != nullptr);
- auto& alloc_nest = loop_allocs_[it->second.loop];
+ ICHECK(entry.loop != nullptr);
+ auto& alloc_nest = loop_allocs_[entry.loop];
alloc_nest.emplace_back(
Allocate(op->buffer_var, op->dtype, new_extents, op->condition,
Evaluate(0)));
- return op->body;
+ Stmt body = op->body;
+ if (auto ptr = body.as<DeclBufferNode>()) {
+ auto new_buf = GetRemappedBuffer(ptr->buffer, entry.stride);
+ alloc_nest.emplace_back(DeclBuffer(new_buf, Evaluate(0)));
+ body = ptr->body;
+ }
+
+ return body;
} else {
return StmtExprMutator::VisitStmt_(op);
}
@@ -226,8 +234,10 @@ class DoubleBufferInjector : public StmtExprMutator {
ICHECK_EQ(buf->shape.size(), 1) << "InjectDoubleBuffer expects flat 1-d
buffers. "
<< "Has StorageFlatten (TE-based
schedules) or "
<< "FlattenBuffer (TIR-based schedules)
been run?";
- auto writer = buf.CopyOnWrite();
- writer->shape = {buf->shape[0] * stride};
+
+ // Stride gives the distance between the two halves of the
+ // double-buffer, not the stride of the buffer's index.
+ buf.CopyOnWrite()->shape = {buf->shape[0] + stride};
buf_remap_[key] = buf;
return buf;
diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py
b/tests/python/unittest/test_tir_transform_inject_double_buffer.py
index 0f4cc00f07..4e7a0b5d80 100644
--- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py
+++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py
@@ -14,7 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
import tvm
+import tvm.testing
+
+from tvm.script import tir as T, ir as I
from tvm import te
@@ -61,5 +65,87 @@ def test_double_buffer():
assert count[0] == 4
+class TestDoubleBuffer(tvm.testing.CompareBeforeAfter):
+ transform = tvm.ir.transform.Sequential(
+ [
+ tvm.tir.transform.InjectDoubleBuffer(),
+ tvm.tir.transform.Simplify(),
+ ]
+ )
+
+ def before(A: T.Buffer([16, 32], "float32"), B: T.Buffer(16, "float32")):
+
+ for i in range(16):
+ cache_data = T.allocate([32], "float32")
+ cache = T.Buffer(32, "float32", data=cache_data)
+
+ T.attr(cache_data, "double_buffer_scope", 1)
+
+ for j in range(32):
+ cache[j] = A[i, j]
+
+ B[i] = 0.0
+ for j in range(32):
+ B[i] = B[i] + cache[j]
+
+ def expected(A: T.Buffer((16, 32), "float32"), B: T.Buffer((16,),
"float32")):
+ cache_data = T.allocate([64], "float32", "global")
+ cache = T.Buffer(64, data=cache_data)
+ for j in range(32):
+ cache[j] = A[0, j]
+
+ B[0] = T.float32(0)
+ for j in range(32):
+ B[0] = B[0] + cache[j]
+
+ for i_outer in range(15):
+ T.attr(cache_data, "double_buffer_write", 1)
+ for j in range(32):
+ cache[(i_outer + 1) % 2 * 32 + j] = A[i_outer + 1, j]
+ B[i_outer + 1] = T.float32(0)
+ for j in range(32):
+ B[i_outer + 1] = B[i_outer + 1] + cache[(i_outer + 1) % 2 * 32
+ j]
+
+
+class TestDoubleBufferWithDeclBuffer(tvm.testing.CompareBeforeAfter):
+ """Like TestDoubleBuffer, but with a declared buffer object"""
+
+ transform = tvm.ir.transform.Sequential(
+ [
+ tvm.tir.transform.InjectDoubleBuffer(),
+ tvm.tir.transform.Simplify(),
+ ]
+ )
+
+ def before(A: T.Buffer((16, 32), "float32"), B: T.Buffer(16, "float32")):
+ for i in range(16):
+ cache = T.decl_buffer(32, "float32")
+ T.attr(cache.data, "double_buffer_scope", 1)
+
+ for j in range(32):
+ cache[j] = A[i, j]
+
+ B[i] = 0.0
+ for j in range(32):
+ B[i] = B[i] + cache[j]
+
+ def expected(A: T.Buffer((16, 32), "float32"), B: T.Buffer(16, "float32")):
+ cache = T.decl_buffer(64, "float32")
+ for j in range(32):
+ cache[j] = A[0, j]
+
+ B[0] = T.float32(0)
+ for j in range(32):
+ B[0] = B[0] + cache[j]
+
+ for i_outer in range(15):
+ T.attr(cache.data, "double_buffer_write", 1)
+ for j in range(32):
+ cache[(i_outer + 1) % 2 * 32 + j] = A[i_outer + 1, j]
+ B[i_outer + 1] = T.float32(0)
+ for j in range(32):
+ B[i_outer + 1] = B[i_outer + 1] + cache[(i_outer + 1) % 2 * 32
+ j]
+
+
if __name__ == "__main__":
- test_double_buffer()
+ tvm.testing.main()