This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4617efac7b [Relax] Dispatch sort/scan for non-cuda gpu backends
(#16867)
4617efac7b is described below
commit 4617efac7b815f367974244870ec3ec08cda2a72
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Apr 10 18:25:14 2024 -0700
[Relax] Dispatch sort/scan for non-cuda gpu backends (#16867)
---
python/tvm/relax/backend/dispatch_sort_scan.py | 19 ++++++++++++-------
1 file changed, 12 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py
b/python/tvm/relax/backend/dispatch_sort_scan.py
index 480420c313..064d3abf25 100644
--- a/python/tvm/relax/backend/dispatch_sort_scan.py
+++ b/python/tvm/relax/backend/dispatch_sort_scan.py
@@ -29,6 +29,11 @@ from tvm.relax import PyExprMutator, expr_functor
from tvm.target import Target
+def is_gpu_target(target: Target) -> bool:
+ """Check if the target is a GPU target."""
+ return "gpu" in target.keys
+
+
@expr_functor.mutator
class SortScanDispatcher(PyExprMutator):
"""
@@ -88,7 +93,7 @@ class SortScanDispatcher(PyExprMutator):
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
te_func = topi.cuda.sort_thrust
kwargs["workspace"] = self.allocate_workspace(call)
- elif tgt.kind.name == "cuda":
+ elif is_gpu_target(tgt):
te_func = topi.cuda.sort
return self.builder_.call_te(
te_func, call.args[0], call.attrs.axis, not
call.attrs.descending, **kwargs
@@ -101,7 +106,7 @@ class SortScanDispatcher(PyExprMutator):
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
te_func = topi.cuda.argsort_thrust
kwargs["workspace"] = self.allocate_workspace(call)
- elif tgt.kind.name == "cuda":
+ elif is_gpu_target(tgt):
te_func = topi.cuda.argsort
return self.builder_.call_te(
te_func,
@@ -118,7 +123,7 @@ class SortScanDispatcher(PyExprMutator):
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
te_func = topi.cuda.topk_thrust
kwargs["workspace"] = self.allocate_workspace(call)
- elif tgt.kind.name == "cuda":
+ elif is_gpu_target(tgt):
te_func = topi.cuda.topk
tir_call = self.builder_.call_te(
te_func,
@@ -130,7 +135,7 @@ class SortScanDispatcher(PyExprMutator):
dtype=call.attrs.dtype,
**kwargs,
)
- if tgt.kind.name != "cuda":
+ if not is_gpu_target(tgt):
return tir_call
# apply dlight gpu fallback
self._apply_dlight_gpu_fallback(tgt, tir_call)
@@ -141,11 +146,11 @@ class SortScanDispatcher(PyExprMutator):
kwargs = {}
with tgt:
if call.op.name == "relax.cumsum":
- te_func = topi.cuda.cumsum if tgt.kind.name == "cuda" else
topi.cumsum
+ te_func = topi.cuda.cumsum if is_gpu_target(tgt) else
topi.cumsum
if can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan"):
kwargs["workspace"] = self.allocate_workspace(call)
elif call.op.name == "relax.cumprod":
- te_func = topi.cuda.cumprod if tgt.kind.name == "cuda"
else topi.cumprod
+ te_func = topi.cuda.cumprod if is_gpu_target(tgt) else
topi.cumprod
else:
raise ValueError(f"Unsupported op: {call.op.name}")
tir_call = self.builder_.call_te(
@@ -156,7 +161,7 @@ class SortScanDispatcher(PyExprMutator):
call.attrs.exclusive,
**kwargs,
)
- if tgt.kind.name != "cuda":
+ if not is_gpu_target(tgt):
return tir_call
# apply dlight gpu fallback
self._apply_dlight_gpu_fallback(tgt, tir_call)