This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6eb0779442 [TIR] SplitHostDevice, handle subroutines (#14918)
6eb0779442 is described below

commit 6eb077944295b96d70531db9b7048f2e87af1cfc
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri May 26 15:22:12 2023 -0500

    [TIR] SplitHostDevice, handle subroutines (#14918)
    
    This PR refactors SplitHostDevice into three separate transformations. 
Previously, SplitHostDevice would replace device regions with a 
builtin::tvm_call_packed() node to replace the extracted region. After this PR, 
this process is performed in three separate steps.
    
    AnnotateDeviceRegion: Annotate the regions that should be executed on 
another target.
    SplitHostDevice: Extract the annotated region into an independent PrimFunc, 
with a GlobalVar to represent the call from into the new subroutine.
    LowerDeviceKernelLaunch: For any subroutine call where the caller and 
callee are on different devices, replace with a device kernel launch.
    
    * PR#14915 [TVMScript] Allow T.target("device", host="host") in TVMScript
    
    Prior to this commit, the `TargetNode::host` could be specified in
    TVMScript as part of the config dictionary, under the key `"host"`.
    However, this required all other device parameters to be explicitly
    specified, rather than using any of the short-hand string
    representations.  This commit forwards the `host` argument from TVMScript's
    `T.target` method to `tvm.target.Target`, allowing both the device and
    host to be specified using the shorthand string representation.
    
    ```python
    @T.prim_func
    def before_this_commit():
        T.func_attr(
            {
                "target": T.target(
                    {
                        "arch": "sm_86",
                        "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""},
                        "keys": ["cuda", "gpu"],
                        "kind": "cuda",
                        "max_num_threads": 1024,
                        "tag": "",
                        "thread_warp_size": 32,
                    }
                )
            }
        )
        T.evaluate(0)
    
    @T.prim_func
    def after_this_commit():
        T.func_attr({"target": T.target("cuda", host="llvm")})
        T.evaluate(0)
    ```
    
    * [Target] Added WithoutHost method
    
    * [TIR] SplitHostDevice, handle missing kGlobalSymbol
    
    Previously, the symbol name of the extracted compute kernel was
    defined based on the `kGlobalSymbol` attribute, which was required to
    be present.  This commit updates `SplitHostDevice` to generate the
    symbol name using `kGlobalSymbol` if present, and to fall back to the
    name of the `tvm::GlobalVar` for internal functions.
    
    * [TIR] Refactor SplitHostDevice into three separate passes
    
    First pass, `AnnotateDeviceRegions`.  This pass decides which portions
    of a PrimFunc should be run on the device, and annotates them with
    `kTarget` attribute, indicating which target should be used for later
    lowering steps.
    
    Second pass, `SplitHostDevice`.  This pass extracts the annotated
    region into an independent PrimFunc.  The `kTarget` attribute of the
    extracted kernel is defined by the `kTarget` annotation inserted by
    `AnnotateDeviceRegions`.  The host function is marked by the
    `tvm::tir::attr::kIsHostFunc` attribute, allowing it to be recognized
    by later host-only lowering passes.
    
    Third pass, `LowerDeviceKernelLaunch`.  This pass identifies
    subroutine calls that call into device kernels, and rewrites them into
    `T.tvm_call_packed`.
    
    * Add unit tests specifically for SplitHostDevice behavior
    
    * Added unit test specifically for AnnotateDeviceRegions
    
    * Added unit tests for LowerDeviceKernelLaunch
    
    * Minor cleanup, moved all kernel launch collection into one spot
    
    Previously, the SplitHostDevice pass added the
    `tir::attr::kKernelLaunchParams` attribute, and the
    LowerDeviceKernelLaunch pass filled in the values for it.  This
    cleanup makes the kernel launch params be the sole responsibility of
    LowerDeviceKernelLaunch.
    
    * Updated unit tests for LowerWarpMemory
    
    * Updated unit tests for ThreadSync
    
    * Updated unit test for inject ptx async copy
    
    * [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI
    
    PRs https://github.com/apache/tvm/pull/14913 and
    https://github.com/apache/tvm/pull/14914 made analogous changes to
    `MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls.
    Both PRs introduced the same symbol,
    `tvm::tir::SubroutineCallRewriter`, a local utility to update internal
    calls to a modified function.  While each PR passed CI individually,
    and was therefore able to merge, having both changes caused a
    duplicate symbol.
    
    This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place
    their local utilities into anonymous namespaces, avoiding the
    conflict.
    
    * Maintain "tir.is_global_func" attr in device-side entry point
    
    * SplitHostDevice, update the host-side target to be the host
    
    * [TIR] Update LowerDeviceKernelLaunch to avoid kIsHostFunc
    
    Update to use the `tvm::tir::IsHostFunc` utility function, rather than
    the `kIsHostFunc` attribute.  Per discussion on
    https://github.com/apache/tvm/pull/14020, the `kIsHostFunct` attribute
    should only be used in `BindTarget`, and should not be re-introduced
    in `SplitHostDevice`.
    
    * Remove is_host_func from SplitHostDevice tests
---
 include/tvm/tir/transform.h                        |  38 +++
 python/tvm/tir/op.py                               |   2 +-
 python/tvm/tir/transform/transform.py              |  38 +++
 src/driver/driver_api.cc                           |   3 +
 src/tir/transforms/annotate_device_regions.cc      |  81 ++++++
 src/tir/transforms/lower_device_kernel_launch.cc   | 305 +++++++++++++++++++++
 src/tir/transforms/split_host_device.cc            | 272 +++++-------------
 .../test_tir_transform_annotate_device_regions.py  |  58 ++++
 .../test_tir_transform_device_kernel_launch.py     | 193 +++++++++++++
 .../test_tir_transform_inject_ptx_async_copy.py    |   2 +-
 .../test_tir_transform_lower_warp_memory.py        |  37 +--
 .../test_tir_transform_split_host_device.py        | 113 +++++++-
 .../unittest/test_tir_transform_thread_sync.py     |   5 +-
 13 files changed, 908 insertions(+), 239 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 8dee176277..d9d68e0a8b 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -263,13 +263,51 @@ TVM_DLL Pass LowerCustomDatatypes();
  */
 TVM_DLL Pass DecorateDeviceScope();
 
+/*!
+ * \brief Annotate locations that should be run on the device
+ *
+ * Insert `AttrStmt` nodes specifying a target on which regions within
+ * the PrimFunc should be executed.  Only modifies functions that have
+ * a `tvm::attr::kTarget` attribute, and where that target defines a
+ * host.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass AnnotateDeviceRegions();
+
 /*!
  * \brief Split the function into a host function and device functions.
  *
+ * The resulting host-side function will keep the same
+ * `tvm::attr::kTarget` attribute (e.g. `T.target("cuda",
+ * host=T.target("llvm"))`).  This ensures that `MakePackedAPI` knows
+ * which device type should be used for the input buffers.
+ *
+ * The resulting device-side function will
+ * have the host stripped from its target attribute
+ * (e.g. `T.target("cuda")`).
+ *
  * \return The pass.
  */
 TVM_DLL Pass SplitHostDevice();
 
+/*!
+ * \brief Lower cross-device function calls.
+ *
+ * Prior to this pass, host to device calls are represented as
+ * subroutine calls, with environment parameters (e.g. env_thread)
+ * specified internally.  The device function is an internal function,
+ * without a `tvm::attr::kGlobalSymbol` attribute.
+ *
+ * After this pass, host to device calls are represented as
+ * tvm_call_packed built-in.  The device function is an
+ * externally-exposed function, with a non-empty
+ * `tvm::attr::kGlobalSymbol` attribute.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass LowerDeviceKernelLaunch();
+
 /*!
  * \brief skip assert stmt.
  *
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 90e3db4cb9..098c13f04e 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -445,7 +445,7 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args):
         The call expression.
     """
     assert isinstance(global_var, tvm.ir.GlobalVar)
-    return Call(dtype="handle", op=global_var, args=args)
+    return Call(dtype="void", op=global_var, args=args)
 
 
 def start_profile_intrinsic(id):
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index f2ce437814..9e038f618b 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -435,6 +435,22 @@ def MakeUnpackedAPI():
     return _ffi_api.MakeUnpackedAPI()  # type: ignore
 
 
+def AnnotateDeviceRegions():
+    """Annotate locations that should be run on the device
+
+    Insert `AttrStmt` nodes specifying a target on which regions
+    within the PrimFunc should be executed.  Only modifies functions
+    that have a `tvm::attr::kTarget` attribute, and where that target
+    defines a host.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.AnnotateDeviceRegions()  # type: ignore
+
+
 def SplitHostDevice():
     """Split the function into a host function and device functions.
 
@@ -446,6 +462,28 @@ def SplitHostDevice():
     return _ffi_api.SplitHostDevice()  # type: ignore
 
 
+def LowerDeviceKernelLaunch():
+    """Lower cross-device function calls.
+
+    Prior to this pass, host to device calls are represented as
+    subroutine calls, with environment parameters (e.g. env_thread)
+    specified internally.  The device function is an internal
+    function, without a `tvm::attr::kGlobalSymbol` attribute.
+
+    After this pass, host to device calls are represented as
+    tvm_call_packed built-in.  The device function is an
+    externally-exposed function, with a non-empty
+    `tvm::attr::kGlobalSymbol` attribute.
+
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.LowerDeviceKernelLaunch()  # type: ignore
+
+
 def DecorateDeviceScope():
     """Decorate all the function's body as device function.
 
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 91bc57ccbe..e5f71c3832 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -587,7 +587,10 @@ transform::Sequential MixedModulePassManager(IRModule 
mixed_mod, Target target)
     mixed_pass_list.push_back(tir::transform::MakePackedAPI());
   }
   mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
+
+  mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
   mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+  mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
 
   return transform::Sequential(mixed_pass_list);
 }
diff --git a/src/tir/transforms/annotate_device_regions.cc 
b/src/tir/transforms/annotate_device_regions.cc
new file mode 100644
index 0000000000..a81af7d780
--- /dev/null
+++ b/src/tir/transforms/annotate_device_regions.cc
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file annotate_device_regions.cc
+ * \brief Split device function from host.
+ */
+#include <tvm/ir/transform.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/target.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+
+class DeviceRegionAnnotater : public StmtMutator {
+ public:
+  explicit DeviceRegionAnnotater(Target device_target) : 
device_target_(device_target) {}
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == tvm::attr::kTarget) {
+      // If a target attribute already exists, use it as-is.
+      return GetRef<Stmt>(op);
+    } else if (op->attr_key == attr::thread_extent || op->attr_key == 
attr::pipeline_exec_scope ||
+               op->attr_key == attr::device_scope) {
+      // These attributes are only allowed in device-side code, so
+      // they should be annotated with the function's default target.
+      Stmt body = GetRef<Stmt>(op);
+      return AttrStmt(device_target_, tvm::attr::kTarget, 0, body);
+    } else {
+      // All other annotations are ignored
+      return StmtMutator::VisitStmt_(op);
+    }
+  }
+
+ private:
+  Target device_target_;
+};
+
+namespace transform {
+
+Pass AnnotateDeviceRegions() {
+  auto pass_func = [](PrimFunc func, IRModule mod, PassContext ctx) -> 
PrimFunc {
+    auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
+    ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target 
attribute";
+    Target target = opt_target.value();
+
+    if (target->GetHost()) {
+      DeviceRegionAnnotater mutator(target.WithoutHost());
+      func.CopyOnWrite()->body = mutator(func->body);
+    }
+    return func;
+  };
+
+  return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateDeviceRegions", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.AnnotateDeviceRegions").set_body_typed(AnnotateDeviceRegions);
+
+}  // namespace transform
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/transforms/lower_device_kernel_launch.cc 
b/src/tir/transforms/lower_device_kernel_launch.cc
new file mode 100644
index 0000000000..5ffbf0d7a7
--- /dev/null
+++ b/src/tir/transforms/lower_device_kernel_launch.cc
@@ -0,0 +1,305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file lower_device_kernel_launch.cc
+ * \brief Split device function from host.
+ */
+#include <tvm/ir/transform.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/target.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+namespace {
+struct KernelInfo {
+  // The device on which the PrimFunc runs
+  Target target;
+
+  // The externally visible symbol which may refer to the PrimFunc
+  // when launching a device kernel.
+  String global_symbol;
+
+  // The parameters accepted by the PrimFunc.  Used to rewrite
+  // `launch_args` to be in terms of the calling scope.
+  Array<Var> params;
+
+  // The launch parameters that should annotate the PrimFunc, if the
+  // kernel is ever called from the host.
+  Array<String> launch_params;
+
+  // Additional arguments which must be provided to the host-side
+  // PackedFunc.  These may be in terms of the function's parameters
+  // (e.g. a function that computes the average of `N` elements, and
+  // which must be launched with `N` CUDA threads).
+  Array<PrimExpr> launch_args;
+};
+
+/*!
+ * \brief Visitor class to collect device-side program information.
+ */
+class DeviceInfoCollector : public StmtVisitor {
+ public:
+  static KernelInfo Collect(const GlobalVar& gvar, const PrimFunc& func) {
+    DeviceInfoCollector collector;
+    collector.info_.target = 
func->GetAttr<Target>(tvm::attr::kTarget).value().WithoutHost();
+    collector.info_.params = func->params;
+
+    collector(func->body);
+
+    // The dynamic shared memory is required to be the last of the
+    // kernel launch parameters
+    if (collector.dyn_shmem_size) {
+      collector.info_.launch_params.push_back(
+          tvm::runtime::launch_param::kUseDynamicSharedMemoryTag);
+    }
+
+    collector.info_.global_symbol =
+        
func->GetAttr<String>(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint);
+
+    collector.info_.launch_args = collector.info_.launch_params.Map(
+        [&](const auto& param) { return collector.GetArgument(param); });
+
+    return collector.info_;
+  }
+
+ private:
+  PrimExpr GetArgument(const String& launch_param) const {
+    if (launch_param == 
tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) {
+      CHECK(dyn_shmem_size.defined())
+          << "Compute kernel requires launch parameter \"" << launch_param
+          << "\", but PrimFunc did not contain Allocate node with shared 
dynamic scope.";
+      return dyn_shmem_size.value();
+    }
+
+    auto extent = thread_extent.Get(launch_param);
+    CHECK(extent) << "Compute kernel requires launch parameter \"" << 
launch_param
+                  << "\", but PrimFunc does not contain AttrStmt \"" << 
attr::thread_extent
+                  << "\" defining this thread extent";
+    return extent.value();
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::thread_extent) {
+      IterVar iv = Downcast<IterVar>(op->node);
+      ICHECK_NE(iv->thread_tag.length(), 0U);
+      // thread_extent can appear multiple times
+      // use the first appearance as def.
+      if (!defined_thread.count(iv.get())) {
+        defined_thread.insert(iv.get());
+        info_.launch_params.push_back(iv->thread_tag);
+        thread_extent.Set(iv->thread_tag, op->value);
+      }
+    }
+
+    StmtVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AllocateNode* op) final {
+    auto storage_scope = 
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
+    if (storage_scope.rank == runtime::StorageRank::kShared && 
storage_scope.tag == ".dyn") {
+      ICHECK(!dyn_shmem_size.defined()) << "Only one dynamic shared memory 
allocation is allowed.";
+      ICHECK_GT(op->extents.size(), 0);
+
+      PrimExpr dyn_size = Integer(1);
+      for (const auto& extent : op->extents) {
+        dyn_size *= extent;
+      }
+      dyn_size *= op->dtype.bytes();
+
+      dyn_shmem_size = dyn_size;
+    }
+    StmtVisitor::VisitStmt_(op);
+  }
+
+  // The collected results
+  KernelInfo info_;
+  // recording what thread axis have been visited.
+  std::unordered_set<const IterVarNode*> defined_thread;
+  // The extent of each thread
+  Map<String, PrimExpr> thread_extent;
+  // The amount of dynamic shared memory used
+  Optional<PrimExpr> dyn_shmem_size{NullOpt};
+};
+}  // namespace
+
+class DeviceKernelMutator : public StmtExprMutator {
+ public:
+  using Parent = StmtExprMutator;
+
+  explicit DeviceKernelMutator(std::unordered_map<const GlobalVarNode*, 
KernelInfo> device_info_map)
+      : device_info_map_(std::move(device_info_map)) {}
+
+  PrimFunc RewriteKernelLaunchSite(const GlobalVar& gvar, PrimFunc func) {
+    ICHECK(!current_target_.defined());
+    auto it = device_info_map_.find(gvar.get());
+    ICHECK(it != device_info_map_.end());
+    current_target_ = it->second.target;
+
+    auto body = VisitStmt(func->body);
+    if (!body.same_as(func->body)) {
+      func.CopyOnWrite()->body = body;
+    }
+
+    current_target_ = NullOpt;
+    return func;
+  }
+
+  PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const {
+    if (device_kernel_launch_.count(gvar.get())) {
+      const auto& info = device_info_map_.at(gvar.get());
+
+      func = WithAttrs(std::move(func),
+                       {{tvm::attr::kCallingConv, 
Integer(tvm::CallingConv::kDeviceKernelLaunch)},
+                        {tvm::tir::attr::kKernelLaunchParams, 
info.launch_params},
+                        {tvm::attr::kGlobalSymbol, info.global_symbol},
+                        {tvm::tir::attr::kIsGlobalFunc, Bool(true)}});
+    }
+
+    return func;
+  }
+
+ private:
+  PrimExpr VisitExpr_(const CallNode* op) {
+    auto node = Downcast<Call>(Parent::VisitExpr_(op));
+
+    auto* gvar = op->op.as<GlobalVarNode>();
+    if (!gvar) return std::move(node);
+
+    auto it = device_info_map_.find(gvar);
+    ICHECK(it != device_info_map_.end())
+        << "CallNode attempted subroutine call to " << gvar->name_hint << ", 
but "
+        << gvar->name_hint << " did not appear within the IRModule";
+    const KernelInfo& dev_info = it->second;
+
+    auto caller_device_type = current_target_.value()->GetTargetDeviceType();
+    auto callee_device_type = dev_info.target->GetTargetDeviceType();
+    if (caller_device_type == callee_device_type) {
+      return std::move(node);
+    }
+
+    ICHECK(dev_info.launch_params.defined())
+        << "CallNode attempted kernel launch to " << gvar->name_hint << " on 
target "
+        << dev_info.target << ", but subroutine " << gvar->name_hint
+        << " did not have the tir::attr::kKernelLaunchParams attribute "
+        << "required for cross-target kernel launch";
+
+    // Collected kernel information may be in terms of the callee's
+    // arguments, but we need expressions for them in terms of the
+    // caller's parameters.  The param_map allows substitution of
+    // parameter values into the thread extents, to generate
+    // expressions that are valid within the caller.
+    Map<Var, PrimExpr> param_map = [&]() {
+      Map<Var, PrimExpr> param_map;
+      CHECK_EQ(node->args.size(), dev_info.params.size())
+          << "Function " << gvar->name_hint << " accepts " << 
dev_info.params.size()
+          << " arguments as input, but is called using " << node->args.size() 
<< " arguments";
+      for (size_t i = 0; i < node->args.size(); i++) {
+        param_map.Set(dev_info.params[i], node->args[i]);
+      }
+      return param_map;
+    }();
+
+    device_kernel_launch_.insert(gvar);
+
+    Array<PrimExpr> call_args;
+    call_args.push_back(StringImm(dev_info.global_symbol));
+    for (PrimExpr arg : node->args) {
+      call_args.push_back(arg);
+    }
+    for (const auto& launch_arg : dev_info.launch_args) {
+      call_args.push_back(Substitute(launch_arg, param_map));
+    }
+
+    auto dtype = node->dtype.is_void() ? DataType::Int(32) : node->dtype;
+
+    return Call(dtype, builtin::tvm_call_packed(), call_args);
+  }
+
+  Optional<Target> current_target_;
+  std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map_;
+  std::unordered_set<const GlobalVarNode*> device_kernel_launch_;
+};
+
+namespace transform {
+
+Pass LowerDeviceKernelLaunch() {
+  auto pass_func = [](IRModule mod, PassContext ctx) -> IRModule {
+    auto mutator = [&mod]() {
+      std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map;
+      for (const auto& [gvar, base_func] : mod->functions) {
+        if (auto prim_func = base_func.as<PrimFunc>()) {
+          device_info_map[gvar.get()] = DeviceInfoCollector::Collect(gvar, 
prim_func.value());
+        }
+      }
+      return DeviceKernelMutator(std::move(device_info_map));
+    }();
+
+    {
+      IRModule updates;
+      for (const auto& [gvar, base_func] : mod->functions) {
+        if (auto* ptr = base_func.as<PrimFuncNode>()) {
+          auto prim_func = mutator.RewriteKernelLaunchSite(gvar, 
GetRef<PrimFunc>(ptr));
+          if (!prim_func.same_as(base_func)) {
+            updates->Add(gvar, prim_func);
+          }
+        }
+      }
+
+      if (updates->functions.size()) {
+        mod.CopyOnWrite()->Update(updates);
+      }
+    }
+
+    {
+      IRModule updates;
+      for (const auto& [gvar, base_func] : mod->functions) {
+        if (auto* ptr = base_func.as<PrimFuncNode>()) {
+          auto prim_func = mutator.UpdateKernelAttributes(gvar, 
GetRef<PrimFunc>(ptr));
+          if (!prim_func.same_as(base_func)) {
+            updates->Add(gvar, prim_func);
+          }
+        }
+      }
+
+      if (updates->functions.size()) {
+        mod.CopyOnWrite()->Update(updates);
+      }
+    }
+
+    return mod;
+  };
+
+  return tvm::transform::CreateModulePass(pass_func, 0, 
"tir.LowerDeviceKernelLaunch", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceKernelLaunch")
+    .set_body_typed(LowerDeviceKernelLaunch);
+
+}  // namespace transform
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/transforms/split_host_device.cc 
b/src/tir/transforms/split_host_device.cc
index 4f47b8ce2b..9270b356ba 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -41,246 +41,102 @@
 namespace tvm {
 namespace tir {
 
-/*!
- * \brief Visitor class to collect device-side program information.
- */
-class DeviceInfoCollector : public StmtVisitor {
- public:
-  Array<IterVar> thread_axis_;
-  Array<PrimExpr> thread_extent_;
-  PrimExpr dyn_shmem_size_{0};
-  bool use_dyn_shmem_{false};
-
-  Array<String> GetLaunchParams() const {
-    Array<String> output;
-    for (const auto& axis : thread_axis_) {
-      output.push_back(axis->thread_tag);
-    }
-    if (use_dyn_shmem_) {
-      output.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag);
-    }
-    return output;
-  }
-
- private:
-  void VisitStmt_(const AttrStmtNode* op) final {
-    if (op->attr_key == attr::thread_extent) {
-      IterVar iv = Downcast<IterVar>(op->node);
-      ICHECK_NE(iv->thread_tag.length(), 0U);
-      // thread_extent can appear multiple times
-      // use the first appearance as def.
-      if (!defined_thread.count(iv.get())) {
-        defined_thread.insert(iv.get());
-        thread_axis_.push_back(iv);
-        thread_extent_.push_back(op->value);
-      }
-
-      this->VisitExpr(op->value);
-      this->VisitStmt(op->body);
-    } else {
-      StmtVisitor::VisitStmt_(op);
-    }
-  }
-
-  void VisitStmt_(const AllocateNode* op) final {
-    auto storage_scope = 
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
-    if (storage_scope.rank == runtime::StorageRank::kShared && 
storage_scope.tag == ".dyn") {
-      ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory 
allocation is allowed.";
-      ICHECK_GT(op->extents.size(), 0);
-      dyn_shmem_size_ = op->extents[0];
-      for (size_t i = 1; i < op->extents.size(); ++i) {
-        dyn_shmem_size_ *= op->extents[i];
-      }
-      dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes());
-      use_dyn_shmem_ = true;
-    }
-    StmtVisitor::VisitStmt_(op);
-  }
-
-  // recording what thread axis have been visited.
-  std::unordered_set<const IterVarNode*> defined_thread;
-};
-
-/*!
- * \brief Mutator class to remove unrefenced let stmt/expressions.
- * \param use_count The pre-computed variable to use count map.
- */
-class UnreferencedLetRemover : public StmtExprMutator {
- public:
-  explicit UnreferencedLetRemover(const std::unordered_map<const VarNode*, 
int>& use_count)
-      : use_count_(use_count) {}
-
- private:
-  Stmt VisitStmt_(const LetStmtNode* op) final {
-    Stmt body = this->VisitStmt(op->body);
-    // eliminate unreferenced let
-    if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= 
CallEffectKind::kReadState) {
-      return body;
-    } else {
-      PrimExpr value = this->VisitExpr(op->value);
-      if (body.same_as(op->body) && value.same_as(op->value)) {
-        return GetRef<Stmt>(op);
-      } else {
-        return LetStmt(op->var, value, body);
-      }
-    }
-  }
-
-  PrimExpr VisitExpr_(const LetNode* op) final {
-    PrimExpr body = this->VisitExpr(op->body);
-    PrimExpr value = this->VisitExpr(op->value);
-    if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= 
CallEffectKind::kReadState) {
-      return body;
-    } else {
-      if (body.same_as(op->body) && value.same_as(op->value)) {
-        return GetRef<PrimExpr>(op);
-      } else {
-        return Let(op->var, value, body);
-      }
-    }
-  }
-
-  // pre-computed variable to use count map.
-  const std::unordered_map<const VarNode*, int>& use_count_;
-};
-
 class HostDeviceSplitter : public StmtMutator {
  public:
-  explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, 
std::string name_prefix)
-      : device_mod_(device_mod), device_target_(device_target), 
name_prefix_(name_prefix) {}
-
-  Stmt VisitStmt_(const AllocateNode* op) final {
-    handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
-    return StmtMutator::VisitStmt_(op);
-  }
+  explicit HostDeviceSplitter(IRModule* device_mod, std::string name_prefix)
+      : device_mod_(device_mod), name_prefix_(name_prefix) {}
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
-    if (op->attr_key == attr::thread_extent || op->attr_key == 
attr::pipeline_exec_scope ||
-        op->attr_key == attr::device_scope) {
-      return SplitDeviceFunc(GetRef<Stmt>(op));
+    if (op->attr_key == tvm::attr::kTarget) {
+      auto device_target = op->node.as<Target>().value().WithoutHost();
+      return SplitDeviceFunc(op->body, device_target);
     }
     return StmtMutator::VisitStmt_(op);
   }
 
  private:
-  Stmt SplitDeviceFunc(Stmt body) {
-    std::ostringstream os;
-    os << name_prefix_ << "_kernel" << device_func_counter_++;
-    std::string kernel_symbol = os.str();
-    // isolate the device function.
-    VarUseDefAnalyzer use_def(/*defined_vars=*/{}, 
/*visit_thread_extent=*/false);
-    use_def(body);
-    DeviceInfoCollector dev_info;
-    dev_info(body);
-    UnreferencedLetRemover let_remover(use_def.use_count_);
-    body = let_remover(std::move(body));
-
-    Array<Var> params;
-    Array<PrimExpr> arguments;
-    Map<tir::Var, tir::Var> remap_vars;
-
-    // Strictly order the arguments: Var pointers, positional arguments.
-    for (Var var : use_def.undefined_) {
-      if (var.dtype().is_handle()) {
-        // Create a new version of v.
-        auto it = handle_data_type_.find(var.get());
-        if (it != handle_data_type_.end()) {
-          String storage_scope;
-          if (auto* ptr_type = var->type_annotation.as<PointerTypeNode>()) {
-            storage_scope = ptr_type->storage_scope;
-          }
-          tir::Var new_var(var->name_hint,
-                           PointerType(PrimType((*it).second->dtype), 
storage_scope));
-          params.push_back(new_var);
-          remap_vars.Set(var, new_var);
-        } else {
-          params.push_back(var);
-        }
-        arguments.push_back(var);
-      }
-    }
-    // positional arguments
-    for (Var var : use_def.undefined_) {
-      if (!var.dtype().is_handle()) {
-        params.push_back(var);
-        arguments.push_back(var);
-      }
-    }
-    GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_);
-    GlobalVar kernel_symbol_global = 
global_var_supply->FreshGlobal(kernel_symbol, false);
-
-    PrimFunc device_func(params, Substitute(body, remap_vars));
-    device_func = WithAttr(std::move(device_func), 
tir::attr::kKernelLaunchParams,
-                           dev_info.GetLaunchParams());
-
-    device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv,
-                           Integer(CallingConv::kDeviceKernelLaunch));
-    device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol,
-                           runtime::String(kernel_symbol_global->name_hint));
-    device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, 
Integer(1));
-    device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, 
device_target_);
-    device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, 
Integer(1));
+  Stmt SplitDeviceFunc(Stmt body, Target device_target) {
+    Array<Var> params = [&]() {
+      VarUseDefAnalyzer use_def(/*defined_vars=*/{}, 
/*visit_thread_extent=*/false);
+      use_def(body);
+
+      // Sort first by variable typ, then by variable name
+      std::vector<Var> params{use_def.undefined_.begin(), 
use_def.undefined_.end()};
+      std::sort(params.begin(), params.end(), [](const Var& a, const Var& b) {
+        auto sort_key = [](const Var& var) {
+          return std::tuple{
+              !var->dtype.is_handle(),
+              var->name_hint,
+          };
+        };
+        return sort_key(a) < sort_key(b);
+      });
+      return params;
+    }();
+
+    GlobalVar kernel_symbol_global = [&]() {
+      std::stringstream name;
+      name << name_prefix_ << "_kernel";
+      GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_);
+      return global_var_supply->FreshGlobal(name.str(), false);
+    }();
+
+    PrimFunc device_func(params, body);
+    device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, 
device_target},
+                                                     {tir::attr::kNoAlias, 
Bool(true)},
+                                                     
{tir::attr::kIsGlobalFunc, Bool(true)}});
 
     (*device_mod_)->Add(kernel_symbol_global, device_func);
+    Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return 
var; });
 
-    // generate calls to the device function
-    Array<PrimExpr> call_args;
-    call_args.push_back(StringImm(kernel_symbol_global->name_hint));
-    for (PrimExpr arg : arguments) {
-      call_args.push_back(arg);
-    }
-    for (PrimExpr ext : dev_info.thread_extent_) {
-      call_args.push_back(ext);
-    }
-    if (dev_info.use_dyn_shmem_) {
-      call_args.push_back(dev_info.dyn_shmem_size_);
-    }
-    return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), 
call_args));
+    return Evaluate(Call(DataType::Void(), kernel_symbol_global, args));
   }
 
   // target ir module
   IRModule* device_mod_;
-  // Device target
-  Target device_target_;
   // function name hint
   std::string name_prefix_;
-  // Number of device functions.
-  int device_func_counter_{0};
-  std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
 };
 
-PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) {
-  auto target = func->GetAttr<Target>(tvm::attr::kTarget);
-  ICHECK(target.defined()) << "SplitHostDevice: Require the target attribute";
+PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& 
gvar) {
+  auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
+  ICHECK(opt_target) << "SplitHostDevice: Require the target attribute";
+  Target target = opt_target.value();
+
   auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-  ICHECK(global_symbol.defined())
-      << "SplitHostDevice: Expect PrimFunc to have the global_symbol 
attribute";
+  auto name_prefix = global_symbol.value_or(gvar->name_hint);
+
+  HostDeviceSplitter splitter(device_mod, name_prefix);
 
-  HostDeviceSplitter splitter(device_mod, target.value(),
-                              static_cast<std::string>(global_symbol.value()));
+  auto body = splitter(func->body);
+
+  if (!body.same_as(func->body)) {
+    func.CopyOnWrite()->body = body;
+    auto target_host = target->GetHost().value_or(Target("llvm"));
+    func = WithAttr(std::move(func), tvm::attr::kTarget, target_host);
+  }
 
-  auto* n = func.CopyOnWrite();
-  n->body = splitter(std::move(n->body));
-  // set the host target to None.
-  func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr));
-  return std::move(func);
+  return func;
 }
 
 namespace transform {
 
 Pass SplitHostDevice() {
   auto pass_func = [](IRModule mod, PassContext ctx) {
-    IRModuleNode* mod_ptr = mod.CopyOnWrite();
-    auto* func_dict = mod_ptr->functions.CopyOnWrite();
     IRModule device_mod = IRModule(Map<GlobalVar, BaseFunc>({}));
-
-    for (auto& kv : *func_dict) {
-      if (kv.second->IsInstance<PrimFuncNode>()) {
-        PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
-        ICHECK(device_mod.defined()) << "The device module must be defined.";
-        kv.second = SplitHostDevice(std::move(func), &device_mod);
+    IRModule updates = IRModule(Map<GlobalVar, BaseFunc>({}));
+
+    for (const auto& [gvar, base_func] : mod->functions) {
+      if (auto opt = base_func.as<PrimFunc>()) {
+        PrimFunc func = opt.value();
+        func = SplitHostDevice(std::move(func), &device_mod, gvar);
+        if (!func.same_as(base_func)) {
+          updates->Add(gvar, func);
+        }
       }
     }
+
+    mod->Update(updates);
     mod->Update(device_mod);
     return ConvertSSA()(mod);
   };
diff --git 
a/tests/python/unittest/test_tir_transform_annotate_device_regions.py 
b/tests/python/unittest/test_tir_transform_annotate_device_regions.py
new file mode 100644
index 0000000000..efa43027e9
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_annotate_device_regions.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import tir as T, ir as I
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+    transform = tvm.tir.transform.AnnotateDeviceRegions()
+
+
+class TestAnnotateThreadExtent(BaseCompare):
+    """Annotation inserted at the "thread_extent" attribute"""
+
+    def before(A: T.Buffer(16, "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        i = T.launch_thread("threadIdx.x", 16)
+        A[i] = 0.0
+
+    def expected(A: T.Buffer(16, "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        T.attr(T.target("cuda"), "target", 0)
+        i = T.launch_thread("threadIdx.x", 16)
+        A[i] = 0.0
+
+
+class TestAnnotateDeviceScope(BaseCompare):
+    """Annotation inserted at the "device_scope" attribute"""
+
+    def before(A: T.Buffer(1, "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        T.attr(0, "device_scope", 0)
+        A[0] = 0.0
+
+    def expected(A: T.Buffer(1, "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        T.attr(T.target("cuda"), "target", 0)
+        T.attr(0, "device_scope", 0)
+        A[0] = 0.0
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_device_kernel_launch.py 
b/tests/python/unittest/test_tir_transform_device_kernel_launch.py
new file mode 100644
index 0000000000..a0f77da376
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_device_kernel_launch.py
@@ -0,0 +1,193 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import tir as T, ir as I
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+    transform = tvm.tir.transform.LowerDeviceKernelLaunch()
+
+
+class TestLowerDeviceKernelLaunch(BaseCompare):
+    """Kernel launch parameters are added at the call site
+
+    The "tir.kernel_launch_params" determines which parameters belong
+    to the runtime, and which below to the device-side PrimFunc.
+    Parameters that are required prior to launching a kernel (e.g. the
+    number of Cuda threads to use) are stored in the
+    `"tir.kernel_launch_params"` attribute, and are used by the
+    runtime prior in order to launch the generated kernel.
+    """
+
+    def before(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(A: T.Buffer(1, "float32")):
+                T.func_attr({"target": T.target("llvm")})
+                mod.kernel(A.data)
+
+            @T.prim_func
+            def kernel(A_data: T.handle("float32")):
+                T.func_attr({"target": T.target("cuda")})
+                A = T.decl_buffer(1, dtype="float32", data=A_data)
+                A[0] = 0.0
+
+        return mod
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(A: T.Buffer(1, "float32")):
+                T.func_attr({"target": T.target("llvm")})
+                T.call_packed("kernel", A.data)
+
+            @T.prim_func
+            def kernel(A_data: T.handle("float32")):
+                T.func_attr(
+                    {
+                        "target": T.target("cuda"),
+                        "calling_conv": 2,
+                        "tir.kernel_launch_params": [],
+                        "global_symbol": "kernel",
+                        "tir.is_global_func": True,
+                    }
+                )
+                A = T.decl_buffer(1, dtype="float32", data=A_data)
+                A[0] = 0.0
+
+        return mod
+
+
+class TestExternallyVisibleKernelLaunch(BaseCompare):
+    """Like TestLowerDeviceKernelLaunch, with pre-defined global_symbol
+
+    Because the host and kernel will be handled by different code
+    generators, the device-side kernel must be externally exposed for
+    use by the host-side wrapper, even if the host-side wrapper does
+    not directly expose the kernel.  Therefore, a "global_symbol"
+    attribute must be added for the kernel if not already present.
+
+    If the kernel already has a specific name, that name should be
+    preserved.
+    """
+
+    def before(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(A: T.Buffer(1, "float32")):
+                T.func_attr({"target": T.target("llvm")})
+                mod.kernel(A.data)
+
+            @T.prim_func
+            def kernel(A_data: T.handle("float32")):
+                T.func_attr({"target": T.target("cuda"), "global_symbol": 
"kernel_by_another_name"})
+                A = T.decl_buffer(1, dtype="float32", data=A_data)
+                A[0] = 0.0
+
+        return mod
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(A: T.Buffer(1, "float32")):
+                T.func_attr({"target": T.target("llvm")})
+                T.call_packed("kernel_by_another_name", A.data)
+
+            @T.prim_func
+            def kernel(A_data: T.handle("float32")):
+                T.func_attr(
+                    {
+                        "target": T.target("cuda"),
+                        "calling_conv": 2,
+                        "tir.kernel_launch_params": [],
+                        "global_symbol": "kernel_by_another_name",
+                        "tir.is_global_func": True,
+                    }
+                )
+                A = T.decl_buffer(1, dtype="float32", data=A_data)
+                A[0] = 0.0
+
+        return mod
+
+
+class TestCollectLaunchParameter(BaseCompare):
+    """Kernel launch parameters are added at the call site
+
+    The "tir.kernel_launch_params" determines which parameters belong
+    to the runtime, and which below to the device-side PrimFunc.
+    Parameters that are required prior to launching a kernel (e.g. the
+    number of Cuda threads to use) are stored in the
+    `"tir.kernel_launch_params"` attribute, and are used by the
+    runtime prior in order to launch the generated kernel.
+    """
+
+    def before(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(A: T.Buffer(16, "float32")):
+                T.func_attr({"target": T.target("llvm")})
+                mod.kernel(A.data)
+
+            @T.prim_func
+            def kernel(A_data: T.handle("float32")):
+                T.func_attr(
+                    {
+                        "target": T.target("cuda"),
+                        "global_symbol": "kernel",
+                    }
+                )
+                A = T.decl_buffer(16, dtype="float32", data=A_data)
+                i = T.launch_thread("threadIdx.x", 16)
+                A[i] = 0.0
+
+        return mod
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(A: T.Buffer(16, "float32")):
+                T.func_attr({"target": T.target("llvm")})
+                T.call_packed("kernel", A.data, 16)
+
+            @T.prim_func
+            def kernel(A_data: T.handle("float32")):
+                T.func_attr(
+                    {
+                        "target": T.target("cuda"),
+                        "calling_conv": 2,
+                        "tir.kernel_launch_params": ["threadIdx.x"],
+                        "global_symbol": "kernel",
+                        "tir.is_global_func": True,
+                    }
+                )
+                A = T.decl_buffer(16, dtype="float32", data=A_data)
+                i = T.launch_thread("threadIdx.x", 16)
+                A[i] = 0.0
+
+        return mod
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py 
b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index 5db33a1e05..1e1ef410b4 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -201,7 +201,7 @@ expected_cuda_script = r"""
   #define int64_t long long
   #define uint64_t unsigned long long
 #endif
-extern "C" __global__ void __launch_bounds__(16) main_kernel0(float* 
__restrict__ A, float* __restrict__ B, float* __restrict__ C) {
+extern "C" __global__ void __launch_bounds__(16) main_kernel(float* 
__restrict__ A, float* __restrict__ B, float* __restrict__ C) {
   __shared__ float A_shared[64];
   __shared__ float B_shared[64];
   A_shared[((int)threadIdx.x)] = 0.000000e+00f;
diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py 
b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
index d4abc26bb2..c7e90d4e7d 100644
--- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
+++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
@@ -22,6 +22,16 @@ from tvm import te
 from tvm.contrib.nvcc import have_fp16
 
 
+def _run_passes(mod):
+    cuda_target = tvm.target.Target("cuda", host="llvm")
+    assert cuda_target.thread_warp_size == 32
+    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", 
cuda_target))(mod)
+    mod = tvm.tir.transform.AnnotateDeviceRegions()(mod)
+    mod = tvm.tir.transform.SplitHostDevice()(mod)
+    mod = tvm.tir.transform.LowerWarpMemory()(mod)
+    return mod
+
+
 @tvm.testing.requires_cuda
 def test_lower_warp_memory_local_scope():
     m = 128
@@ -39,16 +49,12 @@ def test_lower_warp_memory_local_scope():
     xo, xi = s[AA].split(s[AA].op.axis[0], 32)
     s[AA].bind(xi, tx)
 
-    cuda_target = tvm.target.Target("cuda")
-    assert cuda_target.thread_warp_size == 32
     # lowering with the CSE pass disabled as otherwise it would do some 
commoning
     with tvm.transform.PassContext(opt_level=3, 
disabled_pass=["tir.CommonSubexprElimTIR"]):
         mod = tvm.lower(s, [A, B], name="f")
 
-    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", 
cuda_target))(mod)
-    fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
-    mod = tvm.IRModule.from_expr(fdevice)
-    fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]
+    mod = _run_passes(mod)
+    fdevice = mod["f_kernel"]
     allocate = fdevice.body.body
     assert allocate.buffer_var.type_annotation.storage_scope == "local"
     assert fdevice.body.body.extents[0].value == 2
@@ -103,7 +109,7 @@ def test_lower_warp_memory_cuda_end_to_end():
         A = te.placeholder((m,), name="A", dtype=dtype)
         B = te.compute((m,), lambda i: A[i // 32 * 32 + (i + 1) % 32], 
name="B")
 
-        cuda_target = tvm.target.Target("cuda")
+        cuda_target = tvm.target.Target("cuda", host="llvm")
         assert cuda_target.thread_warp_size == 32
         with cuda_target:
             s = te.create_schedule(B.op)
@@ -168,7 +174,7 @@ def test_lower_warp_memory_cuda_half_a_warp():
             name="B",
         )
 
-        cuda_target = tvm.target.Target("cuda")
+        cuda_target = tvm.target.Target("cuda", host="llvm")
         assert cuda_target.thread_warp_size == 2 * m
         with cuda_target:
             s = te.create_schedule(B.op)
@@ -214,7 +220,7 @@ def test_lower_warp_memory_cuda_2_buffers():
         B = te.placeholder((m,), name="B", dtype=dtype)
         C = te.compute((m,), lambda i: A[(i + 1) % m] + B[(i + 1) % m], 
name="C")
 
-        cuda_target = tvm.target.Target("cuda")
+        cuda_target = tvm.target.Target("cuda", host="llvm")
         assert m <= cuda_target.thread_warp_size
         with cuda_target:
             s = te.create_schedule(C.op)
@@ -310,15 +316,12 @@ def test_lower_warp_memory_same_thread():
     xo, xi = s[BB].split(s[BB].op.axis[0], factor=32)
     s[BB].bind(xi, tx)
 
-    cuda_target = tvm.target.Target("cuda")
-    assert cuda_target.thread_warp_size == 32
     # lowering with the CSE pass disabled as otherwise it would do some 
commoning
     with tvm.transform.PassContext(opt_level=3, 
disabled_pass=["tir.CommonSubexprElimTIR"]):
         mod = tvm.lower(s, [A, B], name="f")
-    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", 
cuda_target))(mod)
-    fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
-    mod = tvm.IRModule.from_expr(fdevice)
-    fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]
+
+    mod = _run_passes(mod)
+    fdevice = mod["f_kernel"]
     assert "tvm_warp_shuffle" not in fdevice.script()
 
 
@@ -338,13 +341,11 @@ def test_lower_warp_memory_divide_by_factor():
     stmt = ib.get()
     func = tvm.tir.PrimFunc([], stmt)
     func = func.with_attr("from_legacy_te_schedule", True)
-    cuda_target = tvm.target.Target("cuda")
     # lowering with the CSE pass disabled as otherwise it would do some 
commoning
     with tvm.transform.PassContext(opt_level=3, 
disabled_pass=["tir.CommonSubexprElimTIR"]):
         mod = tvm.lower(func, name="f")
-    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", 
cuda_target))(mod)
     with pytest.raises(tvm.error.TVMError, match="Divide by zero") as cm:
-        tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]
+        _run_passes(mod)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py 
b/tests/python/unittest/test_tir_transform_split_host_device.py
index 680f23e07a..cf866ae005 100644
--- a/tests/python/unittest/test_tir_transform_split_host_device.py
+++ b/tests/python/unittest/test_tir_transform_split_host_device.py
@@ -35,17 +35,26 @@ def test_split_host_device_func_attr():
     s[A1].compute_at(s[A2], xo)
     s[A1].set_scope("shared")
 
-    mod = tvm.lower(s, [A, A2], name="f")
+    mod = tvm.lower(s, [A, A2])
 
-    cuda_target = tvm.target.Target("cuda")
+    cuda_target = tvm.target.Target("cuda", host="llvm")
     mod = tvm.tir.transform.Apply(
         lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target})
     )(mod)
-    fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
 
-    assert fdevice.attrs["global_symbol"] == "test_kernel0"
+    mod = tvm.ir.transform.Sequential(
+        [
+            tvm.tir.transform.AnnotateDeviceRegions(),
+            tvm.tir.transform.SplitHostDevice(),
+            tvm.tir.transform.LowerDeviceKernelLaunch(),
+        ]
+    )(mod)
+
+    fdevice = mod["test_kernel"]
+
+    assert fdevice.attrs["global_symbol"] == "test_kernel"
     assert fdevice.attrs["calling_conv"].value == 2
-    assert fdevice.attrs["target"] == cuda_target
+    assert str(fdevice.attrs["target"]) == str(tvm.target.Target("cuda"))
     assert fdevice.attrs["tir.is_global_func"].value
 
 
@@ -60,18 +69,104 @@ def test_ssa_across_entire_module():
     class before:
         @T.prim_func
         def main():
-            T.func_attr({"global_symbol": "main", "target": T.target("cuda")})
+            T.func_attr({"global_symbol": "main", "target": T.target("cuda", 
host="llvm")})
             for i in range(16):
                 T.attr(0, "device_scope", 0)
                 for j in range(16):
                     T.evaluate(i)
 
-    after = tvm.tir.transform.SplitHostDevice()(before)
+    after = tvm.ir.transform.Sequential(
+        [
+            tvm.tir.transform.AnnotateDeviceRegions(),
+            tvm.tir.transform.SplitHostDevice(),
+            tvm.tir.transform.LowerDeviceKernelLaunch(),
+        ]
+    )(before)
     loop_var = after["main"].body.loop_var
-    param_var = after["main_kernel0"].params[0]
+    param_var = after["main_kernel"].params[0]
 
     assert not loop_var.same_as(param_var)
 
 
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+    transform = tvm.tir.transform.SplitHostDevice()
+
+
+class TestSplitHostDevice(BaseCompare):
+    """SplitHostDevice divides a function at the "target" attribute"""
+
+    def before(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(n: T.int32):
+                T.func_attr({"target": T.target("cuda", host="llvm 
-opt-level=0")})
+                T.attr(T.target("cuda"), "target", 0)
+                T.evaluate(n)
+
+        return mod
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(n: T.int32):
+                T.func_attr({"target": T.target("llvm -opt-level=0")})
+                mod.main_kernel(n)
+
+            @T.prim_func
+            def main_kernel(n: T.int32):
+                T.func_attr(
+                    {
+                        "target": T.target("cuda"),
+                        "tir.noalias": T.bool(True),
+                        "tir.is_global_func": True,
+                    }
+                )
+                T.evaluate(n)
+
+        return mod
+
+
+class TestSplitHostDeviceWithoutFuncHostAttribute(BaseCompare):
+    """Like TestSplitHostDevice, but no host specified in the host's target
+
+    The `T.attr` specifying the device still requires splitting out
+    the kernel.
+    """
+
+    def before(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(n: T.int32):
+                T.func_attr({"target": T.target("llvm")})
+                T.attr(T.target("cuda"), "target", 0)
+                T.evaluate(n)
+
+        return mod
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(n: T.int32):
+                T.func_attr({"target": T.target("llvm")})
+                mod.main_kernel(n)
+
+            @T.prim_func
+            def main_kernel(n: T.int32):
+                T.func_attr(
+                    {
+                        "target": T.target("cuda"),
+                        "tir.noalias": T.bool(True),
+                        "tir.is_global_func": True,
+                    }
+                )
+                T.evaluate(n)
+
+        return mod
+
+
 if __name__ == "__main__":
-    test_split_host_device_func_attr()
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py 
b/tests/python/unittest/test_tir_transform_thread_sync.py
index eb578a8817..57ea223cf9 100644
--- a/tests/python/unittest/test_tir_transform_thread_sync.py
+++ b/tests/python/unittest/test_tir_transform_thread_sync.py
@@ -24,12 +24,13 @@ def run_passes(func: tvm.tir.PrimFunc):
     mod = tvm.IRModule.from_expr(func)
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
 
-    cuda_target = tvm.target.Target("cuda")
+    cuda_target = tvm.target.Target("cuda", host="llvm")
 
     mod = tvm.tir.transform.Apply(
         lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target})
     )(mod)
 
+    mod = tvm.tir.transform.AnnotateDeviceRegions()(mod)
     mod = tvm.tir.transform.SplitHostDevice()(mod)
     return tvm.tir.transform.ThreadSync("shared")(mod)
 
@@ -55,7 +56,7 @@ def test_thread_storage_sync():
 
     func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
     mod = run_passes(func)
-    f = mod["test_kernel0"]
+    f = mod["test_kernel"]
     body_list = tvm.tir.stmt_list(f.body.body.body)
     assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))
 

Reply via email to