This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c921781c46 [TIR] Output DeclBuffer in SplitHostDevice (#15493)
c921781c46 is described below
commit c921781c46d017e11c57fba8ccc55d95b3c393c7
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Aug 28 19:26:57 2023 -0400
[TIR] Output DeclBuffer in SplitHostDevice (#15493)
* [TIR] Output DeclBuffer in SplitHostDevice
If the generated device function uses a buffer, generate a DeclBuffer
for the buffer at the top of the device function.
This is a subset of the changes made in
https://github.com/apache/tvm/pull/14778, broken out for ease of
testing and review.
* Updated thread sync test to account for DeclBuffer
* Updated LowerWarp unit tests to find Allocate in PrimFunc
---
src/tir/transforms/split_host_device.cc | 9 ++++++---
tests/python/unittest/test_tir_transform_lower_warp_memory.py | 10 +++++++---
tests/python/unittest/test_tir_transform_thread_sync.py | 2 +-
3 files changed, 14 insertions(+), 7 deletions(-)
diff --git a/src/tir/transforms/split_host_device.cc
b/src/tir/transforms/split_host_device.cc
index 9b1dbf1a66..b9fc056f19 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -56,7 +56,7 @@ class HostDeviceSplitter : public StmtMutator {
private:
Stmt SplitDeviceFunc(Stmt body, Target device_target) {
- Array<Var> params = [&]() {
+ auto [params, buffers_to_declare] = [&]() -> std::tuple<Array<Var>,
Array<Buffer>> {
VarUseDefAnalyzer use_def(/*defined_vars=*/{},
/*visit_thread_extent=*/false);
use_def(body);
@@ -71,7 +71,7 @@ class HostDeviceSplitter : public StmtMutator {
};
return sort_key(a) < sort_key(b);
});
- return params;
+ return {params, use_def.undefined_buffers_};
}();
// CodeGenCPU is used for some device-side targets, such as
@@ -91,12 +91,15 @@ class HostDeviceSplitter : public StmtMutator {
kernel_ret_type = VoidType();
}
- GlobalVar kernel_symbol_global = var_supply_();
+ for (Buffer buf : buffers_to_declare) {
+ body = DeclBuffer(buf, std::move(body));
+ }
PrimFunc device_func(params, body, kernel_ret_type);
device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget,
device_target},
{tir::attr::kNoAlias,
Bool(true)},
{tir::attr::kIsGlobalFunc, Bool(true)}});
+ GlobalVar kernel_symbol_global = var_supply_();
(*device_mod_)->Add(kernel_symbol_global, device_func);
Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return
var; });
diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
index c7e90d4e7d..99ccc55565 100644
--- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
+++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
@@ -18,7 +18,7 @@ import numpy as np
import pytest
import tvm
import tvm.testing
-from tvm import te
+from tvm import te, tir
from tvm.contrib.nvcc import have_fp16
@@ -55,9 +55,13 @@ def test_lower_warp_memory_local_scope():
mod = _run_passes(mod)
fdevice = mod["f_kernel"]
- allocate = fdevice.body.body
+
+ allocate = fdevice
+ while not isinstance(allocate, tir.Allocate):
+ allocate = allocate.body
+
assert allocate.buffer_var.type_annotation.storage_scope == "local"
- assert fdevice.body.body.extents[0].value == 2
+ assert allocate.extents[0].value == 2
@tvm.testing.requires_cuda
diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py
b/tests/python/unittest/test_tir_transform_thread_sync.py
index 571927dffe..2cfc65aae0 100644
--- a/tests/python/unittest/test_tir_transform_thread_sync.py
+++ b/tests/python/unittest/test_tir_transform_thread_sync.py
@@ -57,7 +57,7 @@ def test_thread_storage_sync():
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
mod = run_passes(func)
f = mod["test_kernel"]
- body_list = tvm.tir.stmt_list(f.body.body.body)
+ body_list = tvm.tir.stmt_list(f.body.body.body.body.body.body)
assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))