This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9877db5ff3 [TIR] Handle callees on same target, different codegen 
(#14988)
9877db5ff3 is described below

commit 9877db5ff3cd88f4320919eb0f07a3d02c2f4f6f
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Jun 4 03:46:56 2023 -0500

    [TIR] Handle callees on same target, different codegen (#14988)
    
    * [TIR] Handle callees on same target, different codegen
    
    Prior to this commit, any caller that uses a different `Target` than
    its callee is lowered to a device-kernel launch.  However, if the
    caller and callee are on the same device, despite using a different
    target (e.g. `Target("llvm")` and `Target("c")` both use `kDLCPU`),
    then the kernel launch is unnecessary.
    
    This commit updates `LowerDeviceKernelLaunch` to produce a kernel
    launch only when the callee is on another device, and to produce
    `T.call_extern` for callees on the same device.
    
    * Rename "extern_method_call_" to "extern_function_call_"
---
 src/tir/transforms/lower_device_kernel_launch.cc   | 45 +++++++++++++++++---
 .../test_tir_transform_device_kernel_launch.py     | 49 ++++++++++++++++++++++
 2 files changed, 88 insertions(+), 6 deletions(-)

diff --git a/src/tir/transforms/lower_device_kernel_launch.cc 
b/src/tir/transforms/lower_device_kernel_launch.cc
index 5ffbf0d7a7..52f06ea45c 100644
--- a/src/tir/transforms/lower_device_kernel_launch.cc
+++ b/src/tir/transforms/lower_device_kernel_launch.cc
@@ -170,14 +170,27 @@ class DeviceKernelMutator : public StmtExprMutator {
   }
 
   PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const {
-    if (device_kernel_launch_.count(gvar.get())) {
+    bool is_kernel_launch = device_kernel_launch_.count(gvar.get());
+    bool is_call_extern = extern_function_call_.count(gvar.get());
+    CHECK(!is_kernel_launch || !is_call_extern)
+        << "Function " << gvar << " has multiple callees, "
+        << "and would need to be lowered into a call_extern at some call 
sites, "
+        << "and a device kernel launch at others.  "
+        << "This case is not yet supported.";
+
+    if (is_kernel_launch || is_call_extern) {
+      func = WithAttr(std::move(func), tvm::tir::attr::kIsGlobalFunc, 
Bool(true));
+    }
+
+    if (is_kernel_launch) {
       const auto& info = device_info_map_.at(gvar.get());
 
       func = WithAttrs(std::move(func),
                        {{tvm::attr::kCallingConv, 
Integer(tvm::CallingConv::kDeviceKernelLaunch)},
                         {tvm::tir::attr::kKernelLaunchParams, 
info.launch_params},
-                        {tvm::attr::kGlobalSymbol, info.global_symbol},
-                        {tvm::tir::attr::kIsGlobalFunc, Bool(true)}});
+                        {tvm::attr::kGlobalSymbol, info.global_symbol}});
+    } else if (is_call_extern && 
!func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
+      func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
     }
 
     return func;
@@ -196,12 +209,31 @@ class DeviceKernelMutator : public StmtExprMutator {
         << gvar->name_hint << " did not appear within the IRModule";
     const KernelInfo& dev_info = it->second;
 
-    auto caller_device_type = current_target_.value()->GetTargetDeviceType();
-    auto callee_device_type = dev_info.target->GetTargetDeviceType();
-    if (caller_device_type == callee_device_type) {
+    auto caller_target = current_target_.value();
+    auto callee_target = dev_info.target;
+
+    bool same_target = caller_target->str() == callee_target->str();
+    if (same_target) {
+      // Calls within the same target may be handled at codegen time
+      // as internal subroutine calls.
       return std::move(node);
     }
 
+    bool same_device_type =
+        caller_target->GetTargetDeviceType() == 
callee_target->GetTargetDeviceType();
+    if (same_device_type) {
+      // Calls to another target using the same device (e.g. LLVM
+      // calling a custom TIRToRuntime target) do not require a kernel
+      // launch, but need to be replaced with call_extern.
+      extern_function_call_.insert(gvar);
+      Array<PrimExpr> args;
+      args.push_back(StringImm(gvar->name_hint));
+      for (const auto& arg : node->args) {
+        args.push_back(arg);
+      }
+      return Call(node->dtype, builtin::call_extern(), args);
+    }
+
     ICHECK(dev_info.launch_params.defined())
         << "CallNode attempted kernel launch to " << gvar->name_hint << " on 
target "
         << dev_info.target << ", but subroutine " << gvar->name_hint
@@ -243,6 +275,7 @@ class DeviceKernelMutator : public StmtExprMutator {
   Optional<Target> current_target_;
   std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map_;
   std::unordered_set<const GlobalVarNode*> device_kernel_launch_;
+  std::unordered_set<const GlobalVarNode*> extern_function_call_;
 };
 
 namespace transform {
diff --git a/tests/python/unittest/test_tir_transform_device_kernel_launch.py 
b/tests/python/unittest/test_tir_transform_device_kernel_launch.py
index a0f77da376..34cde4e4b6 100644
--- a/tests/python/unittest/test_tir_transform_device_kernel_launch.py
+++ b/tests/python/unittest/test_tir_transform_device_kernel_launch.py
@@ -189,5 +189,54 @@ class TestCollectLaunchParameter(BaseCompare):
         return mod
 
 
+class TestSameDeviceDifferentTarget(BaseCompare):
+    """Handle subroutine calls to same device, different codegen
+
+    The device kernel launch is only required when the caller and
+    callee are on different devices.  However, if the caller and
+    callee use different codegen, then the call cannot be handled as
+    an internal call by a single codegen.  Instead, it should be
+    lowered to a `T.call_extern`.
+    """
+
+    def before(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(A: T.Buffer(1, "float32")):
+                T.func_attr({"target": T.target("llvm")})
+                mod.kernel(A.data)
+
+            @T.prim_func
+            def kernel(A_data: T.handle("float32")):
+                T.func_attr({"target": T.target("c")})
+                A = T.decl_buffer(16, dtype="float32", data=A_data)
+                A[0] = 0.0
+
+        return mod
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(A: T.Buffer(1, "float32")):
+                T.func_attr({"target": T.target("llvm")})
+                T.call_extern("kernel", A.data, dtype="void")
+
+            @T.prim_func
+            def kernel(A_data: T.handle("float32")):
+                T.func_attr(
+                    {
+                        "target": T.target("c"),
+                        "global_symbol": "kernel",
+                        "tir.is_global_func": True,
+                    }
+                )
+                A = T.decl_buffer(16, dtype="float32", data=A_data)
+                A[0] = 0.0
+
+        return mod
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to