This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 5e1414275f [Cherry-Pick][BugFix][TIR] ThreadSync with shared.dyn
awareness (#15481)
5e1414275f is described below
commit 5e1414275f7cd51877cdc18334440d8cc7a580e3
Author: Junru Shao <[email protected]>
AuthorDate: Thu Aug 3 21:05:35 2023 -0700
[Cherry-Pick][BugFix][TIR] ThreadSync with shared.dyn awareness (#15481)
Cherry-picked from #15478.
This PR fixes an issue of the ThreadSync pass.
Prior to this PR, the pass is not aware of `shared.dyn` scope whose
users all share a same shared memory space. This feature is not
necessarily already revealed in the IR at the time of applying
ThreadSync. This means that when applying ThreadSync, in the IR,
each buffer of `shared.dyn` scope still uses its own data Var,
and ThreadSync is thus unable to detect the conflict properly and
insert the sync instructions properly.
This PR explicitly makes ThreadSync be aware of the `shared.dyn` scope,
and redirect all the access vars of `shared.dyn` memory to a common var,
so that ThreadSync analysis can find out the conflict and insert the
sync instructions.
Co-authored-by: Ruihang Lai <[email protected]>
---
src/tir/transforms/thread_storage_sync.cc | 18 +++++++++-
.../unittest/test_tir_transform_thread_sync.py | 42 ++++++++++++++++++++++
2 files changed, 59 insertions(+), 1 deletion(-)
diff --git a/src/tir/transforms/thread_storage_sync.cc
b/src/tir/transforms/thread_storage_sync.cc
index c21afe400c..d92986e51a 100644
--- a/src/tir/transforms/thread_storage_sync.cc
+++ b/src/tir/transforms/thread_storage_sync.cc
@@ -50,11 +50,27 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
}
// Plan the sync
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const
ForNode* loop) final {
+ // Redirect all "shared.dyn" buffer access to the same buffer var
+ // so that the accesses can be planned together.
+ Var shared_dyn_buf;
+ for (StmtEntry& entry : seq) {
+ for (AccessEntry& access : entry.access) {
+ if (access.scope.rank == StorageRank::kShared && access.scope.tag ==
".dyn" &&
+ access.buffer.defined()) {
+ if (!shared_dyn_buf.defined()) {
+ shared_dyn_buf = access.buffer;
+ } else {
+ access.buffer = shared_dyn_buf;
+ }
+ }
+ }
+ }
+
// Unsynced reads and writes
std::vector<AccessEntry> reads;
std::vector<AccessEntry> writes;
// if it is a loop, rotate two times to consider effect of loop.
- // simulation based approach to find dependenceies
+ // simulation based approach to find dependencies
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
// check if sync before statement is needed.
diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py
b/tests/python/unittest/test_tir_transform_thread_sync.py
index 57ea223cf9..571927dffe 100644
--- a/tests/python/unittest/test_tir_transform_thread_sync.py
+++ b/tests/python/unittest/test_tir_transform_thread_sync.py
@@ -119,7 +119,49 @@ def test_sync_read_thread_id_independent_location():
assert "T.tvm_storage_sync" in str(mod)
+def test_sync_shared_dyn():
+ @T.prim_func(private=True)
+ def func(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")):
+ blockIdx_x = T.launch_thread("blockIdx.x", 1)
+ B = T.allocate([24], "float32", "shared.dyn")
+ C = T.allocate([1], "float32", "local")
+ D = T.allocate([16], "float32", "shared.dyn")
+ threadIdx_x = T.launch_thread("threadIdx.x", 16)
+ B_1 = T.Buffer((24,), data=B, scope="shared.dyn")
+ A_1 = T.Buffer((16,), data=A.data)
+ B_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] = A_1[threadIdx_x]
+ C_1 = T.Buffer((1,), data=C, scope="local")
+ C_1[0] = B_1[threadIdx_x // 4 * 6 + threadIdx_x % 4]
+ D_1 = T.Buffer((16,), data=D, scope="shared.dyn")
+ D_1[threadIdx_x] = C_1[0]
+ E_1 = T.Buffer((16,), data=E.data)
+ E_1[threadIdx_x] = D_1[threadIdx_x]
+
+ @T.prim_func(private=True)
+ def expected(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4),
"float32")):
+ blockIdx_x = T.launch_thread("blockIdx.x", 1)
+ B_1 = T.allocate([24], "float32", "shared.dyn")
+ C_1 = T.allocate([1], "float32", "local")
+ D_1 = T.allocate([16], "float32", "shared.dyn")
+ threadIdx_x = T.launch_thread("threadIdx.x", 16)
+ B_1_1 = T.Buffer((24,), data=B_1, scope="shared.dyn")
+ A_1 = T.Buffer((16,), data=A.data)
+ B_1_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] = A_1[threadIdx_x]
+ C_1_1 = T.Buffer((1,), data=C_1, scope="local")
+ C_1_1[0] = B_1_1[threadIdx_x // 4 * 6 + threadIdx_x % 4]
+ T.tvm_storage_sync("shared.dyn")
+ D_1_1 = T.Buffer((16,), data=D_1, scope="shared.dyn")
+ D_1_1[threadIdx_x] = C_1_1[0]
+ E_1 = T.Buffer((16,), data=E.data)
+ E_1[threadIdx_x] = D_1_1[threadIdx_x]
+
+ mod = tvm.IRModule({"main": func})
+ mod = tvm.tir.transform.ThreadSync("shared.dyn")(mod)
+ tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
if __name__ == "__main__":
test_thread_storage_sync()
test_sync_else_branch()
test_sync_read_thread_id_independent_location()
+ test_sync_shared_dyn()