This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9877db5ff3 [TIR] Handle callees on same target, different codegen
(#14988)
9877db5ff3 is described below
commit 9877db5ff3cd88f4320919eb0f07a3d02c2f4f6f
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Jun 4 03:46:56 2023 -0500
[TIR] Handle callees on same target, different codegen (#14988)
* [TIR] Handle callees on same target, different codegen
Prior to this commit, any caller that uses a different `Target` than
its callee is lowered to a device-kernel launch. However, if the
caller and callee are on the same device, despite using a different
target (e.g. `Target("llvm")` and `Target("c")` both use `kDLCPU`),
then the kernel launch is unnecessary.
This commit updates `LowerDeviceKernelLaunch` to produce a kernel
launch only when the callee is on another device, and to produce
`T.call_extern` for callees on the same device.
* Rename "extern_method_call_" to "extern_function_call_"
---
src/tir/transforms/lower_device_kernel_launch.cc | 45 +++++++++++++++++---
.../test_tir_transform_device_kernel_launch.py | 49 ++++++++++++++++++++++
2 files changed, 88 insertions(+), 6 deletions(-)
diff --git a/src/tir/transforms/lower_device_kernel_launch.cc
b/src/tir/transforms/lower_device_kernel_launch.cc
index 5ffbf0d7a7..52f06ea45c 100644
--- a/src/tir/transforms/lower_device_kernel_launch.cc
+++ b/src/tir/transforms/lower_device_kernel_launch.cc
@@ -170,14 +170,27 @@ class DeviceKernelMutator : public StmtExprMutator {
}
PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const {
- if (device_kernel_launch_.count(gvar.get())) {
+ bool is_kernel_launch = device_kernel_launch_.count(gvar.get());
+ bool is_call_extern = extern_function_call_.count(gvar.get());
+ CHECK(!is_kernel_launch || !is_call_extern)
+ << "Function " << gvar << " has multiple callees, "
+ << "and would need to be lowered into a call_extern at some call
sites, "
+ << "and a device kernel launch at others. "
+ << "This case is not yet supported.";
+
+ if (is_kernel_launch || is_call_extern) {
+ func = WithAttr(std::move(func), tvm::tir::attr::kIsGlobalFunc,
Bool(true));
+ }
+
+ if (is_kernel_launch) {
const auto& info = device_info_map_.at(gvar.get());
func = WithAttrs(std::move(func),
{{tvm::attr::kCallingConv,
Integer(tvm::CallingConv::kDeviceKernelLaunch)},
{tvm::tir::attr::kKernelLaunchParams,
info.launch_params},
- {tvm::attr::kGlobalSymbol, info.global_symbol},
- {tvm::tir::attr::kIsGlobalFunc, Bool(true)}});
+ {tvm::attr::kGlobalSymbol, info.global_symbol}});
+ } else if (is_call_extern &&
!func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
+ func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
}
return func;
@@ -196,12 +209,31 @@ class DeviceKernelMutator : public StmtExprMutator {
<< gvar->name_hint << " did not appear within the IRModule";
const KernelInfo& dev_info = it->second;
- auto caller_device_type = current_target_.value()->GetTargetDeviceType();
- auto callee_device_type = dev_info.target->GetTargetDeviceType();
- if (caller_device_type == callee_device_type) {
+ auto caller_target = current_target_.value();
+ auto callee_target = dev_info.target;
+
+ bool same_target = caller_target->str() == callee_target->str();
+ if (same_target) {
+ // Calls within the same target may be handled at codegen time
+ // as internal subroutine calls.
return std::move(node);
}
+ bool same_device_type =
+ caller_target->GetTargetDeviceType() ==
callee_target->GetTargetDeviceType();
+ if (same_device_type) {
+ // Calls to another target using the same device (e.g. LLVM
+ // calling a custom TIRToRuntime target) do not require a kernel
+ // launch, but need to be replaced with call_extern.
+ extern_function_call_.insert(gvar);
+ Array<PrimExpr> args;
+ args.push_back(StringImm(gvar->name_hint));
+ for (const auto& arg : node->args) {
+ args.push_back(arg);
+ }
+ return Call(node->dtype, builtin::call_extern(), args);
+ }
+
ICHECK(dev_info.launch_params.defined())
<< "CallNode attempted kernel launch to " << gvar->name_hint << " on
target "
<< dev_info.target << ", but subroutine " << gvar->name_hint
@@ -243,6 +275,7 @@ class DeviceKernelMutator : public StmtExprMutator {
Optional<Target> current_target_;
std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map_;
std::unordered_set<const GlobalVarNode*> device_kernel_launch_;
+ std::unordered_set<const GlobalVarNode*> extern_function_call_;
};
namespace transform {
diff --git a/tests/python/unittest/test_tir_transform_device_kernel_launch.py
b/tests/python/unittest/test_tir_transform_device_kernel_launch.py
index a0f77da376..34cde4e4b6 100644
--- a/tests/python/unittest/test_tir_transform_device_kernel_launch.py
+++ b/tests/python/unittest/test_tir_transform_device_kernel_launch.py
@@ -189,5 +189,54 @@ class TestCollectLaunchParameter(BaseCompare):
return mod
+class TestSameDeviceDifferentTarget(BaseCompare):
+ """Handle subroutine calls to same device, different codegen
+
+ The device kernel launch is only required when the caller and
+ callee are on different devices. However, if the caller and
+ callee use different codegen, then the call cannot be handled as
+ an internal call by a single codegen. Instead, it should be
+ lowered to a `T.call_extern`.
+ """
+
+ def before(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(A: T.Buffer(1, "float32")):
+ T.func_attr({"target": T.target("llvm")})
+ mod.kernel(A.data)
+
+ @T.prim_func
+ def kernel(A_data: T.handle("float32")):
+ T.func_attr({"target": T.target("c")})
+ A = T.decl_buffer(16, dtype="float32", data=A_data)
+ A[0] = 0.0
+
+ return mod
+
+ def expected(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(A: T.Buffer(1, "float32")):
+ T.func_attr({"target": T.target("llvm")})
+ T.call_extern("kernel", A.data, dtype="void")
+
+ @T.prim_func
+ def kernel(A_data: T.handle("float32")):
+ T.func_attr(
+ {
+ "target": T.target("c"),
+ "global_symbol": "kernel",
+ "tir.is_global_func": True,
+ }
+ )
+ A = T.decl_buffer(16, dtype="float32", data=A_data)
+ A[0] = 0.0
+
+ return mod
+
+
if __name__ == "__main__":
tvm.testing.main()