This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4617efac7b [Relax] Dispatch sort/scan for non-cuda gpu backends 
(#16867)
4617efac7b is described below

commit 4617efac7b815f367974244870ec3ec08cda2a72
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Apr 10 18:25:14 2024 -0700

    [Relax] Dispatch sort/scan for non-cuda gpu backends (#16867)
---
 python/tvm/relax/backend/dispatch_sort_scan.py | 19 ++++++++++++-------
 1 file changed, 12 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py 
b/python/tvm/relax/backend/dispatch_sort_scan.py
index 480420c313..064d3abf25 100644
--- a/python/tvm/relax/backend/dispatch_sort_scan.py
+++ b/python/tvm/relax/backend/dispatch_sort_scan.py
@@ -29,6 +29,11 @@ from tvm.relax import PyExprMutator, expr_functor
 from tvm.target import Target
 
 
+def is_gpu_target(target: Target) -> bool:
+    """Check if the target is a GPU target."""
+    return "gpu" in target.keys
+
+
 @expr_functor.mutator
 class SortScanDispatcher(PyExprMutator):
     """
@@ -88,7 +93,7 @@ class SortScanDispatcher(PyExprMutator):
                 if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
                     te_func = topi.cuda.sort_thrust
                     kwargs["workspace"] = self.allocate_workspace(call)
-                elif tgt.kind.name == "cuda":
+                elif is_gpu_target(tgt):
                     te_func = topi.cuda.sort
             return self.builder_.call_te(
                 te_func, call.args[0], call.attrs.axis, not 
call.attrs.descending, **kwargs
@@ -101,7 +106,7 @@ class SortScanDispatcher(PyExprMutator):
                 if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
                     te_func = topi.cuda.argsort_thrust
                     kwargs["workspace"] = self.allocate_workspace(call)
-                elif tgt.kind.name == "cuda":
+                elif is_gpu_target(tgt):
                     te_func = topi.cuda.argsort
             return self.builder_.call_te(
                 te_func,
@@ -118,7 +123,7 @@ class SortScanDispatcher(PyExprMutator):
             if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
                 te_func = topi.cuda.topk_thrust
                 kwargs["workspace"] = self.allocate_workspace(call)
-            elif tgt.kind.name == "cuda":
+            elif is_gpu_target(tgt):
                 te_func = topi.cuda.topk
             tir_call = self.builder_.call_te(
                 te_func,
@@ -130,7 +135,7 @@ class SortScanDispatcher(PyExprMutator):
                 dtype=call.attrs.dtype,
                 **kwargs,
             )
-            if tgt.kind.name != "cuda":
+            if not is_gpu_target(tgt):
                 return tir_call
             # apply dlight gpu fallback
             self._apply_dlight_gpu_fallback(tgt, tir_call)
@@ -141,11 +146,11 @@ class SortScanDispatcher(PyExprMutator):
             kwargs = {}
             with tgt:
                 if call.op.name == "relax.cumsum":
-                    te_func = topi.cuda.cumsum if tgt.kind.name == "cuda" else 
topi.cumsum
+                    te_func = topi.cuda.cumsum if is_gpu_target(tgt) else 
topi.cumsum
                     if can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan"):
                         kwargs["workspace"] = self.allocate_workspace(call)
                 elif call.op.name == "relax.cumprod":
-                    te_func = topi.cuda.cumprod if tgt.kind.name == "cuda" 
else topi.cumprod
+                    te_func = topi.cuda.cumprod if is_gpu_target(tgt) else 
topi.cumprod
                 else:
                     raise ValueError(f"Unsupported op: {call.op.name}")
                 tir_call = self.builder_.call_te(
@@ -156,7 +161,7 @@ class SortScanDispatcher(PyExprMutator):
                     call.attrs.exclusive,
                     **kwargs,
                 )
-            if tgt.kind.name != "cuda":
+            if not is_gpu_target(tgt):
                 return tir_call
             # apply dlight gpu fallback
             self._apply_dlight_gpu_fallback(tgt, tir_call)

Reply via email to