This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6eb0779442 [TIR] SplitHostDevice, handle subroutines (#14918)
6eb0779442 is described below
commit 6eb077944295b96d70531db9b7048f2e87af1cfc
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri May 26 15:22:12 2023 -0500
[TIR] SplitHostDevice, handle subroutines (#14918)
This PR refactors SplitHostDevice into three separate transformations.
Previously, SplitHostDevice would replace device regions with a
builtin::tvm_call_packed() node to replace the extracted region. After this PR,
this process is performed in three separate steps.
AnnotateDeviceRegion: Annotate the regions that should be executed on
another target.
SplitHostDevice: Extract the annotated region into an independent PrimFunc,
with a GlobalVar to represent the call from into the new subroutine.
LowerDeviceKernelLaunch: For any subroutine call where the caller and
callee are on different devices, replace with a device kernel launch.
* PR#14915 [TVMScript] Allow T.target("device", host="host") in TVMScript
Prior to this commit, the `TargetNode::host` could be specified in
TVMScript as part of the config dictionary, under the key `"host"`.
However, this required all other device parameters to be explicitly
specified, rather than using any of the short-hand string
representations. This commit forwards the `host` argument from TVMScript's
`T.target` method to `tvm.target.Target`, allowing both the device and
host to be specified using the shorthand string representation.
```python
@T.prim_func
def before_this_commit():
T.func_attr(
{
"target": T.target(
{
"arch": "sm_86",
"host": {"keys": ["cpu"], "kind": "llvm", "tag": ""},
"keys": ["cuda", "gpu"],
"kind": "cuda",
"max_num_threads": 1024,
"tag": "",
"thread_warp_size": 32,
}
)
}
)
T.evaluate(0)
@T.prim_func
def after_this_commit():
T.func_attr({"target": T.target("cuda", host="llvm")})
T.evaluate(0)
```
* [Target] Added WithoutHost method
* [TIR] SplitHostDevice, handle missing kGlobalSymbol
Previously, the symbol name of the extracted compute kernel was
defined based on the `kGlobalSymbol` attribute, which was required to
be present. This commit updates `SplitHostDevice` to generate the
symbol name using `kGlobalSymbol` if present, and to fall back to the
name of the `tvm::GlobalVar` for internal functions.
* [TIR] Refactor SplitHostDevice into three separate passes
First pass, `AnnotateDeviceRegions`. This pass decides which portions
of a PrimFunc should be run on the device, and annotates them with
`kTarget` attribute, indicating which target should be used for later
lowering steps.
Second pass, `SplitHostDevice`. This pass extracts the annotated
region into an independent PrimFunc. The `kTarget` attribute of the
extracted kernel is defined by the `kTarget` annotation inserted by
`AnnotateDeviceRegions`. The host function is marked by the
`tvm::tir::attr::kIsHostFunc` attribute, allowing it to be recognized
by later host-only lowering passes.
Third pass, `LowerDeviceKernelLaunch`. This pass identifies
subroutine calls that call into device kernels, and rewrites them into
`T.tvm_call_packed`.
* Add unit tests specifically for SplitHostDevice behavior
* Added unit test specifically for AnnotateDeviceRegions
* Added unit tests for LowerDeviceKernelLaunch
* Minor cleanup, moved all kernel launch collection into one spot
Previously, the SplitHostDevice pass added the
`tir::attr::kKernelLaunchParams` attribute, and the
LowerDeviceKernelLaunch pass filled in the values for it. This
cleanup makes the kernel launch params be the sole responsibility of
LowerDeviceKernelLaunch.
* Updated unit tests for LowerWarpMemory
* Updated unit tests for ThreadSync
* Updated unit test for inject ptx async copy
* [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI
PRs https://github.com/apache/tvm/pull/14913 and
https://github.com/apache/tvm/pull/14914 made analogous changes to
`MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls.
Both PRs introduced the same symbol,
`tvm::tir::SubroutineCallRewriter`, a local utility to update internal
calls to a modified function. While each PR passed CI individually,
and was therefore able to merge, having both changes caused a
duplicate symbol.
This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place
their local utilities into anonymous namespaces, avoiding the
conflict.
* Maintain "tir.is_global_func" attr in device-side entry point
* SplitHostDevice, update the host-side target to be the host
* [TIR] Update LowerDeviceKernelLaunch to avoid kIsHostFunc
Update to use the `tvm::tir::IsHostFunc` utility function, rather than
the `kIsHostFunc` attribute. Per discussion on
https://github.com/apache/tvm/pull/14020, the `kIsHostFunct` attribute
should only be used in `BindTarget`, and should not be re-introduced
in `SplitHostDevice`.
* Remove is_host_func from SplitHostDevice tests
---
include/tvm/tir/transform.h | 38 +++
python/tvm/tir/op.py | 2 +-
python/tvm/tir/transform/transform.py | 38 +++
src/driver/driver_api.cc | 3 +
src/tir/transforms/annotate_device_regions.cc | 81 ++++++
src/tir/transforms/lower_device_kernel_launch.cc | 305 +++++++++++++++++++++
src/tir/transforms/split_host_device.cc | 272 +++++-------------
.../test_tir_transform_annotate_device_regions.py | 58 ++++
.../test_tir_transform_device_kernel_launch.py | 193 +++++++++++++
.../test_tir_transform_inject_ptx_async_copy.py | 2 +-
.../test_tir_transform_lower_warp_memory.py | 37 +--
.../test_tir_transform_split_host_device.py | 113 +++++++-
.../unittest/test_tir_transform_thread_sync.py | 5 +-
13 files changed, 908 insertions(+), 239 deletions(-)
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 8dee176277..d9d68e0a8b 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -263,13 +263,51 @@ TVM_DLL Pass LowerCustomDatatypes();
*/
TVM_DLL Pass DecorateDeviceScope();
+/*!
+ * \brief Annotate locations that should be run on the device
+ *
+ * Insert `AttrStmt` nodes specifying a target on which regions within
+ * the PrimFunc should be executed. Only modifies functions that have
+ * a `tvm::attr::kTarget` attribute, and where that target defines a
+ * host.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass AnnotateDeviceRegions();
+
/*!
* \brief Split the function into a host function and device functions.
*
+ * The resulting host-side function will keep the same
+ * `tvm::attr::kTarget` attribute (e.g. `T.target("cuda",
+ * host=T.target("llvm"))`). This ensures that `MakePackedAPI` knows
+ * which device type should be used for the input buffers.
+ *
+ * The resulting device-side function will
+ * have the host stripped from its target attribute
+ * (e.g. `T.target("cuda")`).
+ *
* \return The pass.
*/
TVM_DLL Pass SplitHostDevice();
+/*!
+ * \brief Lower cross-device function calls.
+ *
+ * Prior to this pass, host to device calls are represented as
+ * subroutine calls, with environment parameters (e.g. env_thread)
+ * specified internally. The device function is an internal function,
+ * without a `tvm::attr::kGlobalSymbol` attribute.
+ *
+ * After this pass, host to device calls are represented as
+ * tvm_call_packed built-in. The device function is an
+ * externally-exposed function, with a non-empty
+ * `tvm::attr::kGlobalSymbol` attribute.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass LowerDeviceKernelLaunch();
+
/*!
* \brief skip assert stmt.
*
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 90e3db4cb9..098c13f04e 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -445,7 +445,7 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args):
The call expression.
"""
assert isinstance(global_var, tvm.ir.GlobalVar)
- return Call(dtype="handle", op=global_var, args=args)
+ return Call(dtype="void", op=global_var, args=args)
def start_profile_intrinsic(id):
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index f2ce437814..9e038f618b 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -435,6 +435,22 @@ def MakeUnpackedAPI():
return _ffi_api.MakeUnpackedAPI() # type: ignore
+def AnnotateDeviceRegions():
+ """Annotate locations that should be run on the device
+
+ Insert `AttrStmt` nodes specifying a target on which regions
+ within the PrimFunc should be executed. Only modifies functions
+ that have a `tvm::attr::kTarget` attribute, and where that target
+ defines a host.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.AnnotateDeviceRegions() # type: ignore
+
+
def SplitHostDevice():
"""Split the function into a host function and device functions.
@@ -446,6 +462,28 @@ def SplitHostDevice():
return _ffi_api.SplitHostDevice() # type: ignore
+def LowerDeviceKernelLaunch():
+ """Lower cross-device function calls.
+
+ Prior to this pass, host to device calls are represented as
+ subroutine calls, with environment parameters (e.g. env_thread)
+ specified internally. The device function is an internal
+ function, without a `tvm::attr::kGlobalSymbol` attribute.
+
+ After this pass, host to device calls are represented as
+ tvm_call_packed built-in. The device function is an
+ externally-exposed function, with a non-empty
+ `tvm::attr::kGlobalSymbol` attribute.
+
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerDeviceKernelLaunch() # type: ignore
+
+
def DecorateDeviceScope():
"""Decorate all the function's body as device function.
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 91bc57ccbe..e5f71c3832 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -587,7 +587,10 @@ transform::Sequential MixedModulePassManager(IRModule
mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
+
+ mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+ mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
return transform::Sequential(mixed_pass_list);
}
diff --git a/src/tir/transforms/annotate_device_regions.cc
b/src/tir/transforms/annotate_device_regions.cc
new file mode 100644
index 0000000000..a81af7d780
--- /dev/null
+++ b/src/tir/transforms/annotate_device_regions.cc
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file annotate_device_regions.cc
+ * \brief Split device function from host.
+ */
+#include <tvm/ir/transform.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/target.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+
+class DeviceRegionAnnotater : public StmtMutator {
+ public:
+ explicit DeviceRegionAnnotater(Target device_target) :
device_target_(device_target) {}
+
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ if (op->attr_key == tvm::attr::kTarget) {
+ // If a target attribute already exists, use it as-is.
+ return GetRef<Stmt>(op);
+ } else if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::pipeline_exec_scope ||
+ op->attr_key == attr::device_scope) {
+ // These attributes are only allowed in device-side code, so
+ // they should be annotated with the function's default target.
+ Stmt body = GetRef<Stmt>(op);
+ return AttrStmt(device_target_, tvm::attr::kTarget, 0, body);
+ } else {
+ // All other annotations are ignored
+ return StmtMutator::VisitStmt_(op);
+ }
+ }
+
+ private:
+ Target device_target_;
+};
+
+namespace transform {
+
+Pass AnnotateDeviceRegions() {
+ auto pass_func = [](PrimFunc func, IRModule mod, PassContext ctx) ->
PrimFunc {
+ auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
+ ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target
attribute";
+ Target target = opt_target.value();
+
+ if (target->GetHost()) {
+ DeviceRegionAnnotater mutator(target.WithoutHost());
+ func.CopyOnWrite()->body = mutator(func->body);
+ }
+ return func;
+ };
+
+ return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateDeviceRegions", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.AnnotateDeviceRegions").set_body_typed(AnnotateDeviceRegions);
+
+} // namespace transform
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/transforms/lower_device_kernel_launch.cc
b/src/tir/transforms/lower_device_kernel_launch.cc
new file mode 100644
index 0000000000..5ffbf0d7a7
--- /dev/null
+++ b/src/tir/transforms/lower_device_kernel_launch.cc
@@ -0,0 +1,305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file lower_device_kernel_launch.cc
+ * \brief Split device function from host.
+ */
+#include <tvm/ir/transform.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/target.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+namespace {
+struct KernelInfo {
+ // The device on which the PrimFunc runs
+ Target target;
+
+ // The externally visible symbol which may refer to the PrimFunc
+ // when launching a device kernel.
+ String global_symbol;
+
+ // The parameters accepted by the PrimFunc. Used to rewrite
+ // `launch_args` to be in terms of the calling scope.
+ Array<Var> params;
+
+ // The launch parameters that should annotate the PrimFunc, if the
+ // kernel is ever called from the host.
+ Array<String> launch_params;
+
+ // Additional arguments which must be provided to the host-side
+ // PackedFunc. These may be in terms of the function's parameters
+ // (e.g. a function that computes the average of `N` elements, and
+ // which must be launched with `N` CUDA threads).
+ Array<PrimExpr> launch_args;
+};
+
+/*!
+ * \brief Visitor class to collect device-side program information.
+ */
+class DeviceInfoCollector : public StmtVisitor {
+ public:
+ static KernelInfo Collect(const GlobalVar& gvar, const PrimFunc& func) {
+ DeviceInfoCollector collector;
+ collector.info_.target =
func->GetAttr<Target>(tvm::attr::kTarget).value().WithoutHost();
+ collector.info_.params = func->params;
+
+ collector(func->body);
+
+ // The dynamic shared memory is required to be the last of the
+ // kernel launch parameters
+ if (collector.dyn_shmem_size) {
+ collector.info_.launch_params.push_back(
+ tvm::runtime::launch_param::kUseDynamicSharedMemoryTag);
+ }
+
+ collector.info_.global_symbol =
+
func->GetAttr<String>(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint);
+
+ collector.info_.launch_args = collector.info_.launch_params.Map(
+ [&](const auto& param) { return collector.GetArgument(param); });
+
+ return collector.info_;
+ }
+
+ private:
+ PrimExpr GetArgument(const String& launch_param) const {
+ if (launch_param ==
tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) {
+ CHECK(dyn_shmem_size.defined())
+ << "Compute kernel requires launch parameter \"" << launch_param
+ << "\", but PrimFunc did not contain Allocate node with shared
dynamic scope.";
+ return dyn_shmem_size.value();
+ }
+
+ auto extent = thread_extent.Get(launch_param);
+ CHECK(extent) << "Compute kernel requires launch parameter \"" <<
launch_param
+ << "\", but PrimFunc does not contain AttrStmt \"" <<
attr::thread_extent
+ << "\" defining this thread extent";
+ return extent.value();
+ }
+
+ void VisitStmt_(const AttrStmtNode* op) final {
+ if (op->attr_key == attr::thread_extent) {
+ IterVar iv = Downcast<IterVar>(op->node);
+ ICHECK_NE(iv->thread_tag.length(), 0U);
+ // thread_extent can appear multiple times
+ // use the first appearance as def.
+ if (!defined_thread.count(iv.get())) {
+ defined_thread.insert(iv.get());
+ info_.launch_params.push_back(iv->thread_tag);
+ thread_extent.Set(iv->thread_tag, op->value);
+ }
+ }
+
+ StmtVisitor::VisitStmt_(op);
+ }
+
+ void VisitStmt_(const AllocateNode* op) final {
+ auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
+ if (storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn") {
+ ICHECK(!dyn_shmem_size.defined()) << "Only one dynamic shared memory
allocation is allowed.";
+ ICHECK_GT(op->extents.size(), 0);
+
+ PrimExpr dyn_size = Integer(1);
+ for (const auto& extent : op->extents) {
+ dyn_size *= extent;
+ }
+ dyn_size *= op->dtype.bytes();
+
+ dyn_shmem_size = dyn_size;
+ }
+ StmtVisitor::VisitStmt_(op);
+ }
+
+ // The collected results
+ KernelInfo info_;
+ // recording what thread axis have been visited.
+ std::unordered_set<const IterVarNode*> defined_thread;
+ // The extent of each thread
+ Map<String, PrimExpr> thread_extent;
+ // The amount of dynamic shared memory used
+ Optional<PrimExpr> dyn_shmem_size{NullOpt};
+};
+} // namespace
+
+class DeviceKernelMutator : public StmtExprMutator {
+ public:
+ using Parent = StmtExprMutator;
+
+ explicit DeviceKernelMutator(std::unordered_map<const GlobalVarNode*,
KernelInfo> device_info_map)
+ : device_info_map_(std::move(device_info_map)) {}
+
+ PrimFunc RewriteKernelLaunchSite(const GlobalVar& gvar, PrimFunc func) {
+ ICHECK(!current_target_.defined());
+ auto it = device_info_map_.find(gvar.get());
+ ICHECK(it != device_info_map_.end());
+ current_target_ = it->second.target;
+
+ auto body = VisitStmt(func->body);
+ if (!body.same_as(func->body)) {
+ func.CopyOnWrite()->body = body;
+ }
+
+ current_target_ = NullOpt;
+ return func;
+ }
+
+ PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const {
+ if (device_kernel_launch_.count(gvar.get())) {
+ const auto& info = device_info_map_.at(gvar.get());
+
+ func = WithAttrs(std::move(func),
+ {{tvm::attr::kCallingConv,
Integer(tvm::CallingConv::kDeviceKernelLaunch)},
+ {tvm::tir::attr::kKernelLaunchParams,
info.launch_params},
+ {tvm::attr::kGlobalSymbol, info.global_symbol},
+ {tvm::tir::attr::kIsGlobalFunc, Bool(true)}});
+ }
+
+ return func;
+ }
+
+ private:
+ PrimExpr VisitExpr_(const CallNode* op) {
+ auto node = Downcast<Call>(Parent::VisitExpr_(op));
+
+ auto* gvar = op->op.as<GlobalVarNode>();
+ if (!gvar) return std::move(node);
+
+ auto it = device_info_map_.find(gvar);
+ ICHECK(it != device_info_map_.end())
+ << "CallNode attempted subroutine call to " << gvar->name_hint << ",
but "
+ << gvar->name_hint << " did not appear within the IRModule";
+ const KernelInfo& dev_info = it->second;
+
+ auto caller_device_type = current_target_.value()->GetTargetDeviceType();
+ auto callee_device_type = dev_info.target->GetTargetDeviceType();
+ if (caller_device_type == callee_device_type) {
+ return std::move(node);
+ }
+
+ ICHECK(dev_info.launch_params.defined())
+ << "CallNode attempted kernel launch to " << gvar->name_hint << " on
target "
+ << dev_info.target << ", but subroutine " << gvar->name_hint
+ << " did not have the tir::attr::kKernelLaunchParams attribute "
+ << "required for cross-target kernel launch";
+
+ // Collected kernel information may be in terms of the callee's
+ // arguments, but we need expressions for them in terms of the
+ // caller's parameters. The param_map allows substitution of
+ // parameter values into the thread extents, to generate
+ // expressions that are valid within the caller.
+ Map<Var, PrimExpr> param_map = [&]() {
+ Map<Var, PrimExpr> param_map;
+ CHECK_EQ(node->args.size(), dev_info.params.size())
+ << "Function " << gvar->name_hint << " accepts " <<
dev_info.params.size()
+ << " arguments as input, but is called using " << node->args.size()
<< " arguments";
+ for (size_t i = 0; i < node->args.size(); i++) {
+ param_map.Set(dev_info.params[i], node->args[i]);
+ }
+ return param_map;
+ }();
+
+ device_kernel_launch_.insert(gvar);
+
+ Array<PrimExpr> call_args;
+ call_args.push_back(StringImm(dev_info.global_symbol));
+ for (PrimExpr arg : node->args) {
+ call_args.push_back(arg);
+ }
+ for (const auto& launch_arg : dev_info.launch_args) {
+ call_args.push_back(Substitute(launch_arg, param_map));
+ }
+
+ auto dtype = node->dtype.is_void() ? DataType::Int(32) : node->dtype;
+
+ return Call(dtype, builtin::tvm_call_packed(), call_args);
+ }
+
+ Optional<Target> current_target_;
+ std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map_;
+ std::unordered_set<const GlobalVarNode*> device_kernel_launch_;
+};
+
+namespace transform {
+
+Pass LowerDeviceKernelLaunch() {
+ auto pass_func = [](IRModule mod, PassContext ctx) -> IRModule {
+ auto mutator = [&mod]() {
+ std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map;
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto prim_func = base_func.as<PrimFunc>()) {
+ device_info_map[gvar.get()] = DeviceInfoCollector::Collect(gvar,
prim_func.value());
+ }
+ }
+ return DeviceKernelMutator(std::move(device_info_map));
+ }();
+
+ {
+ IRModule updates;
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto* ptr = base_func.as<PrimFuncNode>()) {
+ auto prim_func = mutator.RewriteKernelLaunchSite(gvar,
GetRef<PrimFunc>(ptr));
+ if (!prim_func.same_as(base_func)) {
+ updates->Add(gvar, prim_func);
+ }
+ }
+ }
+
+ if (updates->functions.size()) {
+ mod.CopyOnWrite()->Update(updates);
+ }
+ }
+
+ {
+ IRModule updates;
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto* ptr = base_func.as<PrimFuncNode>()) {
+ auto prim_func = mutator.UpdateKernelAttributes(gvar,
GetRef<PrimFunc>(ptr));
+ if (!prim_func.same_as(base_func)) {
+ updates->Add(gvar, prim_func);
+ }
+ }
+ }
+
+ if (updates->functions.size()) {
+ mod.CopyOnWrite()->Update(updates);
+ }
+ }
+
+ return mod;
+ };
+
+ return tvm::transform::CreateModulePass(pass_func, 0,
"tir.LowerDeviceKernelLaunch", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceKernelLaunch")
+ .set_body_typed(LowerDeviceKernelLaunch);
+
+} // namespace transform
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/transforms/split_host_device.cc
b/src/tir/transforms/split_host_device.cc
index 4f47b8ce2b..9270b356ba 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -41,246 +41,102 @@
namespace tvm {
namespace tir {
-/*!
- * \brief Visitor class to collect device-side program information.
- */
-class DeviceInfoCollector : public StmtVisitor {
- public:
- Array<IterVar> thread_axis_;
- Array<PrimExpr> thread_extent_;
- PrimExpr dyn_shmem_size_{0};
- bool use_dyn_shmem_{false};
-
- Array<String> GetLaunchParams() const {
- Array<String> output;
- for (const auto& axis : thread_axis_) {
- output.push_back(axis->thread_tag);
- }
- if (use_dyn_shmem_) {
- output.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag);
- }
- return output;
- }
-
- private:
- void VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::thread_extent) {
- IterVar iv = Downcast<IterVar>(op->node);
- ICHECK_NE(iv->thread_tag.length(), 0U);
- // thread_extent can appear multiple times
- // use the first appearance as def.
- if (!defined_thread.count(iv.get())) {
- defined_thread.insert(iv.get());
- thread_axis_.push_back(iv);
- thread_extent_.push_back(op->value);
- }
-
- this->VisitExpr(op->value);
- this->VisitStmt(op->body);
- } else {
- StmtVisitor::VisitStmt_(op);
- }
- }
-
- void VisitStmt_(const AllocateNode* op) final {
- auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
- if (storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn") {
- ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory
allocation is allowed.";
- ICHECK_GT(op->extents.size(), 0);
- dyn_shmem_size_ = op->extents[0];
- for (size_t i = 1; i < op->extents.size(); ++i) {
- dyn_shmem_size_ *= op->extents[i];
- }
- dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes());
- use_dyn_shmem_ = true;
- }
- StmtVisitor::VisitStmt_(op);
- }
-
- // recording what thread axis have been visited.
- std::unordered_set<const IterVarNode*> defined_thread;
-};
-
-/*!
- * \brief Mutator class to remove unrefenced let stmt/expressions.
- * \param use_count The pre-computed variable to use count map.
- */
-class UnreferencedLetRemover : public StmtExprMutator {
- public:
- explicit UnreferencedLetRemover(const std::unordered_map<const VarNode*,
int>& use_count)
- : use_count_(use_count) {}
-
- private:
- Stmt VisitStmt_(const LetStmtNode* op) final {
- Stmt body = this->VisitStmt(op->body);
- // eliminate unreferenced let
- if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <=
CallEffectKind::kReadState) {
- return body;
- } else {
- PrimExpr value = this->VisitExpr(op->value);
- if (body.same_as(op->body) && value.same_as(op->value)) {
- return GetRef<Stmt>(op);
- } else {
- return LetStmt(op->var, value, body);
- }
- }
- }
-
- PrimExpr VisitExpr_(const LetNode* op) final {
- PrimExpr body = this->VisitExpr(op->body);
- PrimExpr value = this->VisitExpr(op->value);
- if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <=
CallEffectKind::kReadState) {
- return body;
- } else {
- if (body.same_as(op->body) && value.same_as(op->value)) {
- return GetRef<PrimExpr>(op);
- } else {
- return Let(op->var, value, body);
- }
- }
- }
-
- // pre-computed variable to use count map.
- const std::unordered_map<const VarNode*, int>& use_count_;
-};
-
class HostDeviceSplitter : public StmtMutator {
public:
- explicit HostDeviceSplitter(IRModule* device_mod, Target device_target,
std::string name_prefix)
- : device_mod_(device_mod), device_target_(device_target),
name_prefix_(name_prefix) {}
-
- Stmt VisitStmt_(const AllocateNode* op) final {
- handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
- return StmtMutator::VisitStmt_(op);
- }
+ explicit HostDeviceSplitter(IRModule* device_mod, std::string name_prefix)
+ : device_mod_(device_mod), name_prefix_(name_prefix) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::pipeline_exec_scope ||
- op->attr_key == attr::device_scope) {
- return SplitDeviceFunc(GetRef<Stmt>(op));
+ if (op->attr_key == tvm::attr::kTarget) {
+ auto device_target = op->node.as<Target>().value().WithoutHost();
+ return SplitDeviceFunc(op->body, device_target);
}
return StmtMutator::VisitStmt_(op);
}
private:
- Stmt SplitDeviceFunc(Stmt body) {
- std::ostringstream os;
- os << name_prefix_ << "_kernel" << device_func_counter_++;
- std::string kernel_symbol = os.str();
- // isolate the device function.
- VarUseDefAnalyzer use_def(/*defined_vars=*/{},
/*visit_thread_extent=*/false);
- use_def(body);
- DeviceInfoCollector dev_info;
- dev_info(body);
- UnreferencedLetRemover let_remover(use_def.use_count_);
- body = let_remover(std::move(body));
-
- Array<Var> params;
- Array<PrimExpr> arguments;
- Map<tir::Var, tir::Var> remap_vars;
-
- // Strictly order the arguments: Var pointers, positional arguments.
- for (Var var : use_def.undefined_) {
- if (var.dtype().is_handle()) {
- // Create a new version of v.
- auto it = handle_data_type_.find(var.get());
- if (it != handle_data_type_.end()) {
- String storage_scope;
- if (auto* ptr_type = var->type_annotation.as<PointerTypeNode>()) {
- storage_scope = ptr_type->storage_scope;
- }
- tir::Var new_var(var->name_hint,
- PointerType(PrimType((*it).second->dtype),
storage_scope));
- params.push_back(new_var);
- remap_vars.Set(var, new_var);
- } else {
- params.push_back(var);
- }
- arguments.push_back(var);
- }
- }
- // positional arguments
- for (Var var : use_def.undefined_) {
- if (!var.dtype().is_handle()) {
- params.push_back(var);
- arguments.push_back(var);
- }
- }
- GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_);
- GlobalVar kernel_symbol_global =
global_var_supply->FreshGlobal(kernel_symbol, false);
-
- PrimFunc device_func(params, Substitute(body, remap_vars));
- device_func = WithAttr(std::move(device_func),
tir::attr::kKernelLaunchParams,
- dev_info.GetLaunchParams());
-
- device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv,
- Integer(CallingConv::kDeviceKernelLaunch));
- device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol,
- runtime::String(kernel_symbol_global->name_hint));
- device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias,
Integer(1));
- device_func = WithAttr(std::move(device_func), tvm::attr::kTarget,
device_target_);
- device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc,
Integer(1));
+ Stmt SplitDeviceFunc(Stmt body, Target device_target) {
+ Array<Var> params = [&]() {
+ VarUseDefAnalyzer use_def(/*defined_vars=*/{},
/*visit_thread_extent=*/false);
+ use_def(body);
+
+ // Sort first by variable typ, then by variable name
+ std::vector<Var> params{use_def.undefined_.begin(),
use_def.undefined_.end()};
+ std::sort(params.begin(), params.end(), [](const Var& a, const Var& b) {
+ auto sort_key = [](const Var& var) {
+ return std::tuple{
+ !var->dtype.is_handle(),
+ var->name_hint,
+ };
+ };
+ return sort_key(a) < sort_key(b);
+ });
+ return params;
+ }();
+
+ GlobalVar kernel_symbol_global = [&]() {
+ std::stringstream name;
+ name << name_prefix_ << "_kernel";
+ GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_);
+ return global_var_supply->FreshGlobal(name.str(), false);
+ }();
+
+ PrimFunc device_func(params, body);
+ device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget,
device_target},
+ {tir::attr::kNoAlias,
Bool(true)},
+
{tir::attr::kIsGlobalFunc, Bool(true)}});
(*device_mod_)->Add(kernel_symbol_global, device_func);
+ Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return
var; });
- // generate calls to the device function
- Array<PrimExpr> call_args;
- call_args.push_back(StringImm(kernel_symbol_global->name_hint));
- for (PrimExpr arg : arguments) {
- call_args.push_back(arg);
- }
- for (PrimExpr ext : dev_info.thread_extent_) {
- call_args.push_back(ext);
- }
- if (dev_info.use_dyn_shmem_) {
- call_args.push_back(dev_info.dyn_shmem_size_);
- }
- return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
call_args));
+ return Evaluate(Call(DataType::Void(), kernel_symbol_global, args));
}
// target ir module
IRModule* device_mod_;
- // Device target
- Target device_target_;
// function name hint
std::string name_prefix_;
- // Number of device functions.
- int device_func_counter_{0};
- std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
};
-PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) {
- auto target = func->GetAttr<Target>(tvm::attr::kTarget);
- ICHECK(target.defined()) << "SplitHostDevice: Require the target attribute";
+PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar&
gvar) {
+ auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
+ ICHECK(opt_target) << "SplitHostDevice: Require the target attribute";
+ Target target = opt_target.value();
+
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(global_symbol.defined())
- << "SplitHostDevice: Expect PrimFunc to have the global_symbol
attribute";
+ auto name_prefix = global_symbol.value_or(gvar->name_hint);
+
+ HostDeviceSplitter splitter(device_mod, name_prefix);
- HostDeviceSplitter splitter(device_mod, target.value(),
- static_cast<std::string>(global_symbol.value()));
+ auto body = splitter(func->body);
+
+ if (!body.same_as(func->body)) {
+ func.CopyOnWrite()->body = body;
+ auto target_host = target->GetHost().value_or(Target("llvm"));
+ func = WithAttr(std::move(func), tvm::attr::kTarget, target_host);
+ }
- auto* n = func.CopyOnWrite();
- n->body = splitter(std::move(n->body));
- // set the host target to None.
- func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr));
- return std::move(func);
+ return func;
}
namespace transform {
Pass SplitHostDevice() {
auto pass_func = [](IRModule mod, PassContext ctx) {
- IRModuleNode* mod_ptr = mod.CopyOnWrite();
- auto* func_dict = mod_ptr->functions.CopyOnWrite();
IRModule device_mod = IRModule(Map<GlobalVar, BaseFunc>({}));
-
- for (auto& kv : *func_dict) {
- if (kv.second->IsInstance<PrimFuncNode>()) {
- PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
- ICHECK(device_mod.defined()) << "The device module must be defined.";
- kv.second = SplitHostDevice(std::move(func), &device_mod);
+ IRModule updates = IRModule(Map<GlobalVar, BaseFunc>({}));
+
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto opt = base_func.as<PrimFunc>()) {
+ PrimFunc func = opt.value();
+ func = SplitHostDevice(std::move(func), &device_mod, gvar);
+ if (!func.same_as(base_func)) {
+ updates->Add(gvar, func);
+ }
}
}
+
+ mod->Update(updates);
mod->Update(device_mod);
return ConvertSSA()(mod);
};
diff --git
a/tests/python/unittest/test_tir_transform_annotate_device_regions.py
b/tests/python/unittest/test_tir_transform_annotate_device_regions.py
new file mode 100644
index 0000000000..efa43027e9
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_annotate_device_regions.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import tir as T, ir as I
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+ transform = tvm.tir.transform.AnnotateDeviceRegions()
+
+
+class TestAnnotateThreadExtent(BaseCompare):
+ """Annotation inserted at the "thread_extent" attribute"""
+
+ def before(A: T.Buffer(16, "float32")):
+ T.func_attr({"target": T.target("cuda", host="llvm")})
+ i = T.launch_thread("threadIdx.x", 16)
+ A[i] = 0.0
+
+ def expected(A: T.Buffer(16, "float32")):
+ T.func_attr({"target": T.target("cuda", host="llvm")})
+ T.attr(T.target("cuda"), "target", 0)
+ i = T.launch_thread("threadIdx.x", 16)
+ A[i] = 0.0
+
+
+class TestAnnotateDeviceScope(BaseCompare):
+ """Annotation inserted at the "device_scope" attribute"""
+
+ def before(A: T.Buffer(1, "float32")):
+ T.func_attr({"target": T.target("cuda", host="llvm")})
+ T.attr(0, "device_scope", 0)
+ A[0] = 0.0
+
+ def expected(A: T.Buffer(1, "float32")):
+ T.func_attr({"target": T.target("cuda", host="llvm")})
+ T.attr(T.target("cuda"), "target", 0)
+ T.attr(0, "device_scope", 0)
+ A[0] = 0.0
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_device_kernel_launch.py
b/tests/python/unittest/test_tir_transform_device_kernel_launch.py
new file mode 100644
index 0000000000..a0f77da376
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_device_kernel_launch.py
@@ -0,0 +1,193 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import tir as T, ir as I
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+ transform = tvm.tir.transform.LowerDeviceKernelLaunch()
+
+
+class TestLowerDeviceKernelLaunch(BaseCompare):
+ """Kernel launch parameters are added at the call site
+
+ The "tir.kernel_launch_params" determines which parameters belong
+ to the runtime, and which below to the device-side PrimFunc.
+ Parameters that are required prior to launching a kernel (e.g. the
+ number of Cuda threads to use) are stored in the
+ `"tir.kernel_launch_params"` attribute, and are used by the
+ runtime prior in order to launch the generated kernel.
+ """
+
+ def before(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(A: T.Buffer(1, "float32")):
+ T.func_attr({"target": T.target("llvm")})
+ mod.kernel(A.data)
+
+ @T.prim_func
+ def kernel(A_data: T.handle("float32")):
+ T.func_attr({"target": T.target("cuda")})
+ A = T.decl_buffer(1, dtype="float32", data=A_data)
+ A[0] = 0.0
+
+ return mod
+
+ def expected(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(A: T.Buffer(1, "float32")):
+ T.func_attr({"target": T.target("llvm")})
+ T.call_packed("kernel", A.data)
+
+ @T.prim_func
+ def kernel(A_data: T.handle("float32")):
+ T.func_attr(
+ {
+ "target": T.target("cuda"),
+ "calling_conv": 2,
+ "tir.kernel_launch_params": [],
+ "global_symbol": "kernel",
+ "tir.is_global_func": True,
+ }
+ )
+ A = T.decl_buffer(1, dtype="float32", data=A_data)
+ A[0] = 0.0
+
+ return mod
+
+
+class TestExternallyVisibleKernelLaunch(BaseCompare):
+ """Like TestLowerDeviceKernelLaunch, with pre-defined global_symbol
+
+ Because the host and kernel will be handled by different code
+ generators, the device-side kernel must be externally exposed for
+ use by the host-side wrapper, even if the host-side wrapper does
+ not directly expose the kernel. Therefore, a "global_symbol"
+ attribute must be added for the kernel if not already present.
+
+ If the kernel already has a specific name, that name should be
+ preserved.
+ """
+
+ def before(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(A: T.Buffer(1, "float32")):
+ T.func_attr({"target": T.target("llvm")})
+ mod.kernel(A.data)
+
+ @T.prim_func
+ def kernel(A_data: T.handle("float32")):
+ T.func_attr({"target": T.target("cuda"), "global_symbol":
"kernel_by_another_name"})
+ A = T.decl_buffer(1, dtype="float32", data=A_data)
+ A[0] = 0.0
+
+ return mod
+
+ def expected(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(A: T.Buffer(1, "float32")):
+ T.func_attr({"target": T.target("llvm")})
+ T.call_packed("kernel_by_another_name", A.data)
+
+ @T.prim_func
+ def kernel(A_data: T.handle("float32")):
+ T.func_attr(
+ {
+ "target": T.target("cuda"),
+ "calling_conv": 2,
+ "tir.kernel_launch_params": [],
+ "global_symbol": "kernel_by_another_name",
+ "tir.is_global_func": True,
+ }
+ )
+ A = T.decl_buffer(1, dtype="float32", data=A_data)
+ A[0] = 0.0
+
+ return mod
+
+
+class TestCollectLaunchParameter(BaseCompare):
+ """Kernel launch parameters are added at the call site
+
+ The "tir.kernel_launch_params" determines which parameters belong
+ to the runtime, and which below to the device-side PrimFunc.
+ Parameters that are required prior to launching a kernel (e.g. the
+ number of Cuda threads to use) are stored in the
+ `"tir.kernel_launch_params"` attribute, and are used by the
+ runtime prior in order to launch the generated kernel.
+ """
+
+ def before(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(A: T.Buffer(16, "float32")):
+ T.func_attr({"target": T.target("llvm")})
+ mod.kernel(A.data)
+
+ @T.prim_func
+ def kernel(A_data: T.handle("float32")):
+ T.func_attr(
+ {
+ "target": T.target("cuda"),
+ "global_symbol": "kernel",
+ }
+ )
+ A = T.decl_buffer(16, dtype="float32", data=A_data)
+ i = T.launch_thread("threadIdx.x", 16)
+ A[i] = 0.0
+
+ return mod
+
+ def expected(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(A: T.Buffer(16, "float32")):
+ T.func_attr({"target": T.target("llvm")})
+ T.call_packed("kernel", A.data, 16)
+
+ @T.prim_func
+ def kernel(A_data: T.handle("float32")):
+ T.func_attr(
+ {
+ "target": T.target("cuda"),
+ "calling_conv": 2,
+ "tir.kernel_launch_params": ["threadIdx.x"],
+ "global_symbol": "kernel",
+ "tir.is_global_func": True,
+ }
+ )
+ A = T.decl_buffer(16, dtype="float32", data=A_data)
+ i = T.launch_thread("threadIdx.x", 16)
+ A[i] = 0.0
+
+ return mod
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index 5db33a1e05..1e1ef410b4 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -201,7 +201,7 @@ expected_cuda_script = r"""
#define int64_t long long
#define uint64_t unsigned long long
#endif
-extern "C" __global__ void __launch_bounds__(16) main_kernel0(float*
__restrict__ A, float* __restrict__ B, float* __restrict__ C) {
+extern "C" __global__ void __launch_bounds__(16) main_kernel(float*
__restrict__ A, float* __restrict__ B, float* __restrict__ C) {
__shared__ float A_shared[64];
__shared__ float B_shared[64];
A_shared[((int)threadIdx.x)] = 0.000000e+00f;
diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
index d4abc26bb2..c7e90d4e7d 100644
--- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
+++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
@@ -22,6 +22,16 @@ from tvm import te
from tvm.contrib.nvcc import have_fp16
+def _run_passes(mod):
+ cuda_target = tvm.target.Target("cuda", host="llvm")
+ assert cuda_target.thread_warp_size == 32
+ mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
cuda_target))(mod)
+ mod = tvm.tir.transform.AnnotateDeviceRegions()(mod)
+ mod = tvm.tir.transform.SplitHostDevice()(mod)
+ mod = tvm.tir.transform.LowerWarpMemory()(mod)
+ return mod
+
+
@tvm.testing.requires_cuda
def test_lower_warp_memory_local_scope():
m = 128
@@ -39,16 +49,12 @@ def test_lower_warp_memory_local_scope():
xo, xi = s[AA].split(s[AA].op.axis[0], 32)
s[AA].bind(xi, tx)
- cuda_target = tvm.target.Target("cuda")
- assert cuda_target.thread_warp_size == 32
# lowering with the CSE pass disabled as otherwise it would do some
commoning
with tvm.transform.PassContext(opt_level=3,
disabled_pass=["tir.CommonSubexprElimTIR"]):
mod = tvm.lower(s, [A, B], name="f")
- mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
cuda_target))(mod)
- fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
- mod = tvm.IRModule.from_expr(fdevice)
- fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]
+ mod = _run_passes(mod)
+ fdevice = mod["f_kernel"]
allocate = fdevice.body.body
assert allocate.buffer_var.type_annotation.storage_scope == "local"
assert fdevice.body.body.extents[0].value == 2
@@ -103,7 +109,7 @@ def test_lower_warp_memory_cuda_end_to_end():
A = te.placeholder((m,), name="A", dtype=dtype)
B = te.compute((m,), lambda i: A[i // 32 * 32 + (i + 1) % 32],
name="B")
- cuda_target = tvm.target.Target("cuda")
+ cuda_target = tvm.target.Target("cuda", host="llvm")
assert cuda_target.thread_warp_size == 32
with cuda_target:
s = te.create_schedule(B.op)
@@ -168,7 +174,7 @@ def test_lower_warp_memory_cuda_half_a_warp():
name="B",
)
- cuda_target = tvm.target.Target("cuda")
+ cuda_target = tvm.target.Target("cuda", host="llvm")
assert cuda_target.thread_warp_size == 2 * m
with cuda_target:
s = te.create_schedule(B.op)
@@ -214,7 +220,7 @@ def test_lower_warp_memory_cuda_2_buffers():
B = te.placeholder((m,), name="B", dtype=dtype)
C = te.compute((m,), lambda i: A[(i + 1) % m] + B[(i + 1) % m],
name="C")
- cuda_target = tvm.target.Target("cuda")
+ cuda_target = tvm.target.Target("cuda", host="llvm")
assert m <= cuda_target.thread_warp_size
with cuda_target:
s = te.create_schedule(C.op)
@@ -310,15 +316,12 @@ def test_lower_warp_memory_same_thread():
xo, xi = s[BB].split(s[BB].op.axis[0], factor=32)
s[BB].bind(xi, tx)
- cuda_target = tvm.target.Target("cuda")
- assert cuda_target.thread_warp_size == 32
# lowering with the CSE pass disabled as otherwise it would do some
commoning
with tvm.transform.PassContext(opt_level=3,
disabled_pass=["tir.CommonSubexprElimTIR"]):
mod = tvm.lower(s, [A, B], name="f")
- mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
cuda_target))(mod)
- fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
- mod = tvm.IRModule.from_expr(fdevice)
- fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]
+
+ mod = _run_passes(mod)
+ fdevice = mod["f_kernel"]
assert "tvm_warp_shuffle" not in fdevice.script()
@@ -338,13 +341,11 @@ def test_lower_warp_memory_divide_by_factor():
stmt = ib.get()
func = tvm.tir.PrimFunc([], stmt)
func = func.with_attr("from_legacy_te_schedule", True)
- cuda_target = tvm.target.Target("cuda")
# lowering with the CSE pass disabled as otherwise it would do some
commoning
with tvm.transform.PassContext(opt_level=3,
disabled_pass=["tir.CommonSubexprElimTIR"]):
mod = tvm.lower(func, name="f")
- mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
cuda_target))(mod)
with pytest.raises(tvm.error.TVMError, match="Divide by zero") as cm:
- tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]
+ _run_passes(mod)
if __name__ == "__main__":
diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py
b/tests/python/unittest/test_tir_transform_split_host_device.py
index 680f23e07a..cf866ae005 100644
--- a/tests/python/unittest/test_tir_transform_split_host_device.py
+++ b/tests/python/unittest/test_tir_transform_split_host_device.py
@@ -35,17 +35,26 @@ def test_split_host_device_func_attr():
s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared")
- mod = tvm.lower(s, [A, A2], name="f")
+ mod = tvm.lower(s, [A, A2])
- cuda_target = tvm.target.Target("cuda")
+ cuda_target = tvm.target.Target("cuda", host="llvm")
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target})
)(mod)
- fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
- assert fdevice.attrs["global_symbol"] == "test_kernel0"
+ mod = tvm.ir.transform.Sequential(
+ [
+ tvm.tir.transform.AnnotateDeviceRegions(),
+ tvm.tir.transform.SplitHostDevice(),
+ tvm.tir.transform.LowerDeviceKernelLaunch(),
+ ]
+ )(mod)
+
+ fdevice = mod["test_kernel"]
+
+ assert fdevice.attrs["global_symbol"] == "test_kernel"
assert fdevice.attrs["calling_conv"].value == 2
- assert fdevice.attrs["target"] == cuda_target
+ assert str(fdevice.attrs["target"]) == str(tvm.target.Target("cuda"))
assert fdevice.attrs["tir.is_global_func"].value
@@ -60,18 +69,104 @@ def test_ssa_across_entire_module():
class before:
@T.prim_func
def main():
- T.func_attr({"global_symbol": "main", "target": T.target("cuda")})
+ T.func_attr({"global_symbol": "main", "target": T.target("cuda",
host="llvm")})
for i in range(16):
T.attr(0, "device_scope", 0)
for j in range(16):
T.evaluate(i)
- after = tvm.tir.transform.SplitHostDevice()(before)
+ after = tvm.ir.transform.Sequential(
+ [
+ tvm.tir.transform.AnnotateDeviceRegions(),
+ tvm.tir.transform.SplitHostDevice(),
+ tvm.tir.transform.LowerDeviceKernelLaunch(),
+ ]
+ )(before)
loop_var = after["main"].body.loop_var
- param_var = after["main_kernel0"].params[0]
+ param_var = after["main_kernel"].params[0]
assert not loop_var.same_as(param_var)
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+ transform = tvm.tir.transform.SplitHostDevice()
+
+
+class TestSplitHostDevice(BaseCompare):
+ """SplitHostDevice divides a function at the "target" attribute"""
+
+ def before(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(n: T.int32):
+ T.func_attr({"target": T.target("cuda", host="llvm
-opt-level=0")})
+ T.attr(T.target("cuda"), "target", 0)
+ T.evaluate(n)
+
+ return mod
+
+ def expected(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(n: T.int32):
+ T.func_attr({"target": T.target("llvm -opt-level=0")})
+ mod.main_kernel(n)
+
+ @T.prim_func
+ def main_kernel(n: T.int32):
+ T.func_attr(
+ {
+ "target": T.target("cuda"),
+ "tir.noalias": T.bool(True),
+ "tir.is_global_func": True,
+ }
+ )
+ T.evaluate(n)
+
+ return mod
+
+
+class TestSplitHostDeviceWithoutFuncHostAttribute(BaseCompare):
+ """Like TestSplitHostDevice, but no host specified in the host's target
+
+ The `T.attr` specifying the device still requires splitting out
+ the kernel.
+ """
+
+ def before(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(n: T.int32):
+ T.func_attr({"target": T.target("llvm")})
+ T.attr(T.target("cuda"), "target", 0)
+ T.evaluate(n)
+
+ return mod
+
+ def expected(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(n: T.int32):
+ T.func_attr({"target": T.target("llvm")})
+ mod.main_kernel(n)
+
+ @T.prim_func
+ def main_kernel(n: T.int32):
+ T.func_attr(
+ {
+ "target": T.target("cuda"),
+ "tir.noalias": T.bool(True),
+ "tir.is_global_func": True,
+ }
+ )
+ T.evaluate(n)
+
+ return mod
+
+
if __name__ == "__main__":
- test_split_host_device_func_attr()
+ tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py
b/tests/python/unittest/test_tir_transform_thread_sync.py
index eb578a8817..57ea223cf9 100644
--- a/tests/python/unittest/test_tir_transform_thread_sync.py
+++ b/tests/python/unittest/test_tir_transform_thread_sync.py
@@ -24,12 +24,13 @@ def run_passes(func: tvm.tir.PrimFunc):
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
- cuda_target = tvm.target.Target("cuda")
+ cuda_target = tvm.target.Target("cuda", host="llvm")
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target})
)(mod)
+ mod = tvm.tir.transform.AnnotateDeviceRegions()(mod)
mod = tvm.tir.transform.SplitHostDevice()(mod)
return tvm.tir.transform.ThreadSync("shared")(mod)
@@ -55,7 +56,7 @@ def test_thread_storage_sync():
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
mod = run_passes(func)
- f = mod["test_kernel0"]
+ f = mod["test_kernel"]
body_list = tvm.tir.stmt_list(f.body.body.body)
assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))