This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 22ba652  fix compute inline not to over write annotated opaque 
accesses (#9509)
22ba652 is described below

commit 22ba6523cbd14fc44a1b093c482c1d02f3bc4fa5
Author: wrongtest <[email protected]>
AuthorDate: Tue Nov 16 06:03:15 2021 +0800

    fix compute inline not to over write annotated opaque accesses (#9509)
---
 src/tir/schedule/primitive/compute_inline.cc       |  6 +-
 .../unittest/test_tir_schedule_compute_inline.py   | 65 ++++++++++++++++++++++
 2 files changed, 70 insertions(+), 1 deletion(-)

diff --git a/src/tir/schedule/primitive/compute_inline.cc 
b/src/tir/schedule/primitive/compute_inline.cc
index 539a82f..12ae021 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -313,7 +313,11 @@ class BaseInliner : public StmtExprMutator {
     // Step 2. Update `BlockNode::reads` and `BlockNode::writes`
     Array<BufferRegion> reads = std::move(block->reads);
     Array<BufferRegion> writes = std::move(block->writes);
-    if (!is_scope_root) {
+    auto f_access_inline_buffer = [this](const BufferRegion& access) {
+      return access->buffer.same_as(this->inlined_buffer_);
+    };
+    if (!is_scope_root && (std::any_of(reads.begin(), reads.end(), 
f_access_inline_buffer) ||
+                           std::any_of(writes.begin(), writes.end(), 
f_access_inline_buffer))) {
       Array<Array<BufferRegion>> inspected = GetBlockReadWriteRegion(block, 
buffer_var_map_);
       reads = std::move(inspected[0]);
       writes = std::move(inspected[1]);
diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py 
b/tests/python/unittest/test_tir_schedule_compute_inline.py
index 617c75b..a078c0e 100644
--- a/tests/python/unittest/test_tir_schedule_compute_inline.py
+++ b/tests/python/unittest/test_tir_schedule_compute_inline.py
@@ -272,6 +272,63 @@ def elementwise_multi_loads_inlined(a: T.handle, c: 
T.handle) -> None:
             C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] 
* 2.0
 
 
[email protected]_func
+def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [1024])
+    B = T.match_buffer(b, [1024])
+    A_cache = T.alloc_buffer([1024])
+    BB = T.alloc_buffer([1024])
+    with T.block("opaque"):
+        # annotated opaque partial access
+        T.reads(A[0:512])
+        T.writes(A_cache[0:512])
+        T.evaluate(
+            T.tvm_access_ptr(
+                T.type_annotation(dtype="float32"), A.data, 0, 512, "r", 
dtype="handle"
+            )
+        )
+        T.evaluate(
+            T.tvm_access_ptr(
+                T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", 
dtype="handle"
+            )
+        )
+    for i in range(512):
+        with T.block("BB"):
+            vi = T.axis.remap("S", [i])
+            BB[vi] = A_cache[vi] * 2.0
+    for i in range(512):
+        with T.block("B"):
+            vi = T.axis.remap("S", [i])
+            B[vi] = BB[vi] + 1.0
+
+
[email protected]_func
+def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [1024], dtype="float32")
+    B = T.match_buffer(b, [1024], dtype="float32")
+    A_cache = T.alloc_buffer([1024], dtype="float32")
+    with T.block("opaque"):
+        # annotated opaque partial access should be kept
+        T.reads(A[0:512])
+        T.writes([A_cache[0:512]])
+        T.evaluate(
+            T.tvm_access_ptr(
+                T.type_annotation(dtype="float32"), A.data, 0, 512, "r", 
dtype="handle"
+            )
+        )
+        T.evaluate(
+            T.tvm_access_ptr(
+                T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", 
dtype="handle"
+            )
+        )
+    for i in T.serial(0, 512):
+        with T.block("B"):
+            vi = T.axis.spatial(512, i)
+            T.reads([A_cache[vi]])
+            T.writes([B[vi]])
+            B[vi] = A_cache[vi] * 2.0 + 1.0
+
+
 # pylint: enable=no-member,invalid-name,unused-variable
 
 
@@ -417,5 +474,13 @@ def test_compute_inline_multi_loads():
     verify_trace_roundtrip(sch=sch, mod=elementwise_multi_loads)
 
 
+def test_compute_inline_with_opaque_access():
+    """Test not rewrite opaque reads/writes after irrelavant compute inline"""
+    sch = tir.Schedule(access_opaque_ptr_then_elemwise, debug_mask="all")
+    BB = sch.get_block("BB")
+    sch.compute_inline(BB)
+    tvm.ir.assert_structural_equal(access_opaque_ptr_then_elemwise_inline, 
sch.mod["main"])
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to