This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 22ba652 fix compute inline not to over write annotated opaque
accesses (#9509)
22ba652 is described below
commit 22ba6523cbd14fc44a1b093c482c1d02f3bc4fa5
Author: wrongtest <[email protected]>
AuthorDate: Tue Nov 16 06:03:15 2021 +0800
fix compute inline not to over write annotated opaque accesses (#9509)
---
src/tir/schedule/primitive/compute_inline.cc | 6 +-
.../unittest/test_tir_schedule_compute_inline.py | 65 ++++++++++++++++++++++
2 files changed, 70 insertions(+), 1 deletion(-)
diff --git a/src/tir/schedule/primitive/compute_inline.cc
b/src/tir/schedule/primitive/compute_inline.cc
index 539a82f..12ae021 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -313,7 +313,11 @@ class BaseInliner : public StmtExprMutator {
// Step 2. Update `BlockNode::reads` and `BlockNode::writes`
Array<BufferRegion> reads = std::move(block->reads);
Array<BufferRegion> writes = std::move(block->writes);
- if (!is_scope_root) {
+ auto f_access_inline_buffer = [this](const BufferRegion& access) {
+ return access->buffer.same_as(this->inlined_buffer_);
+ };
+ if (!is_scope_root && (std::any_of(reads.begin(), reads.end(),
f_access_inline_buffer) ||
+ std::any_of(writes.begin(), writes.end(),
f_access_inline_buffer))) {
Array<Array<BufferRegion>> inspected = GetBlockReadWriteRegion(block,
buffer_var_map_);
reads = std::move(inspected[0]);
writes = std::move(inspected[1]);
diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py
b/tests/python/unittest/test_tir_schedule_compute_inline.py
index 617c75b..a078c0e 100644
--- a/tests/python/unittest/test_tir_schedule_compute_inline.py
+++ b/tests/python/unittest/test_tir_schedule_compute_inline.py
@@ -272,6 +272,63 @@ def elementwise_multi_loads_inlined(a: T.handle, c:
T.handle) -> None:
C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2]
* 2.0
[email protected]_func
+def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [1024])
+ B = T.match_buffer(b, [1024])
+ A_cache = T.alloc_buffer([1024])
+ BB = T.alloc_buffer([1024])
+ with T.block("opaque"):
+ # annotated opaque partial access
+ T.reads(A[0:512])
+ T.writes(A_cache[0:512])
+ T.evaluate(
+ T.tvm_access_ptr(
+ T.type_annotation(dtype="float32"), A.data, 0, 512, "r",
dtype="handle"
+ )
+ )
+ T.evaluate(
+ T.tvm_access_ptr(
+ T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w",
dtype="handle"
+ )
+ )
+ for i in range(512):
+ with T.block("BB"):
+ vi = T.axis.remap("S", [i])
+ BB[vi] = A_cache[vi] * 2.0
+ for i in range(512):
+ with T.block("B"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = BB[vi] + 1.0
+
+
[email protected]_func
+def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [1024], dtype="float32")
+ B = T.match_buffer(b, [1024], dtype="float32")
+ A_cache = T.alloc_buffer([1024], dtype="float32")
+ with T.block("opaque"):
+ # annotated opaque partial access should be kept
+ T.reads(A[0:512])
+ T.writes([A_cache[0:512]])
+ T.evaluate(
+ T.tvm_access_ptr(
+ T.type_annotation(dtype="float32"), A.data, 0, 512, "r",
dtype="handle"
+ )
+ )
+ T.evaluate(
+ T.tvm_access_ptr(
+ T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w",
dtype="handle"
+ )
+ )
+ for i in T.serial(0, 512):
+ with T.block("B"):
+ vi = T.axis.spatial(512, i)
+ T.reads([A_cache[vi]])
+ T.writes([B[vi]])
+ B[vi] = A_cache[vi] * 2.0 + 1.0
+
+
# pylint: enable=no-member,invalid-name,unused-variable
@@ -417,5 +474,13 @@ def test_compute_inline_multi_loads():
verify_trace_roundtrip(sch=sch, mod=elementwise_multi_loads)
+def test_compute_inline_with_opaque_access():
+ """Test not rewrite opaque reads/writes after irrelavant compute inline"""
+ sch = tir.Schedule(access_opaque_ptr_then_elemwise, debug_mask="all")
+ BB = sch.get_block("BB")
+ sch.compute_inline(BB)
+ tvm.ir.assert_structural_equal(access_opaque_ptr_then_elemwise_inline,
sch.mod["main"])
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))