This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c921781c46 [TIR] Output DeclBuffer in SplitHostDevice (#15493)
c921781c46 is described below

commit c921781c46d017e11c57fba8ccc55d95b3c393c7
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Aug 28 19:26:57 2023 -0400

    [TIR] Output DeclBuffer in SplitHostDevice (#15493)
    
    * [TIR] Output DeclBuffer in SplitHostDevice
    
    If the generated device function uses a buffer, generate a DeclBuffer
    for the buffer at the top of the device function.
    
    This is a subset of the changes made in
    https://github.com/apache/tvm/pull/14778, broken out for ease of
    testing and review.
    
    * Updated thread sync test to account for DeclBuffer
    
    * Updated LowerWarp unit tests to find Allocate in PrimFunc
---
 src/tir/transforms/split_host_device.cc                       |  9 ++++++---
 tests/python/unittest/test_tir_transform_lower_warp_memory.py | 10 +++++++---
 tests/python/unittest/test_tir_transform_thread_sync.py       |  2 +-
 3 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/src/tir/transforms/split_host_device.cc 
b/src/tir/transforms/split_host_device.cc
index 9b1dbf1a66..b9fc056f19 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -56,7 +56,7 @@ class HostDeviceSplitter : public StmtMutator {
 
  private:
   Stmt SplitDeviceFunc(Stmt body, Target device_target) {
-    Array<Var> params = [&]() {
+    auto [params, buffers_to_declare] = [&]() -> std::tuple<Array<Var>, 
Array<Buffer>> {
       VarUseDefAnalyzer use_def(/*defined_vars=*/{}, 
/*visit_thread_extent=*/false);
       use_def(body);
 
@@ -71,7 +71,7 @@ class HostDeviceSplitter : public StmtMutator {
         };
         return sort_key(a) < sort_key(b);
       });
-      return params;
+      return {params, use_def.undefined_buffers_};
     }();
 
     // CodeGenCPU is used for some device-side targets, such as
@@ -91,12 +91,15 @@ class HostDeviceSplitter : public StmtMutator {
       kernel_ret_type = VoidType();
     }
 
-    GlobalVar kernel_symbol_global = var_supply_();
+    for (Buffer buf : buffers_to_declare) {
+      body = DeclBuffer(buf, std::move(body));
+    }
     PrimFunc device_func(params, body, kernel_ret_type);
     device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, 
device_target},
                                                      {tir::attr::kNoAlias, 
Bool(true)},
                                                      
{tir::attr::kIsGlobalFunc, Bool(true)}});
 
+    GlobalVar kernel_symbol_global = var_supply_();
     (*device_mod_)->Add(kernel_symbol_global, device_func);
     Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return 
var; });
 
diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py 
b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
index c7e90d4e7d..99ccc55565 100644
--- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
+++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
@@ -18,7 +18,7 @@ import numpy as np
 import pytest
 import tvm
 import tvm.testing
-from tvm import te
+from tvm import te, tir
 from tvm.contrib.nvcc import have_fp16
 
 
@@ -55,9 +55,13 @@ def test_lower_warp_memory_local_scope():
 
     mod = _run_passes(mod)
     fdevice = mod["f_kernel"]
-    allocate = fdevice.body.body
+
+    allocate = fdevice
+    while not isinstance(allocate, tir.Allocate):
+        allocate = allocate.body
+
     assert allocate.buffer_var.type_annotation.storage_scope == "local"
-    assert fdevice.body.body.extents[0].value == 2
+    assert allocate.extents[0].value == 2
 
 
 @tvm.testing.requires_cuda
diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py 
b/tests/python/unittest/test_tir_transform_thread_sync.py
index 571927dffe..2cfc65aae0 100644
--- a/tests/python/unittest/test_tir_transform_thread_sync.py
+++ b/tests/python/unittest/test_tir_transform_thread_sync.py
@@ -57,7 +57,7 @@ def test_thread_storage_sync():
     func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
     mod = run_passes(func)
     f = mod["test_kernel"]
-    body_list = tvm.tir.stmt_list(f.body.body.body)
+    body_list = tvm.tir.stmt_list(f.body.body.body.body.body.body)
     assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))
 
 

Reply via email to