This is an automated email from the ASF dual-hosted git repository.
laurawly pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b305204 Only use thrust for cuda target (#6722)
b305204 is described below
commit b30520434e8b0f22720d9ee1ef0fe15f85581b8b
Author: Trevor Morris <[email protected]>
AuthorDate: Thu Oct 29 10:46:12 2020 -0700
Only use thrust for cuda target (#6722)
---
python/tvm/relay/op/strategy/cuda.py | 8 ++++++--
python/tvm/topi/cuda/nms.py | 7 ++++++-
2 files changed, 12 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/op/strategy/cuda.py
b/python/tvm/relay/op/strategy/cuda.py
index d77361d..187ea01 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -673,7 +673,9 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort.cuda",
)
- if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
+ if target.kind.name == "cuda" and get_global_func(
+ "tvm.contrib.thrust.sort", allow_missing=True
+ ):
strategy.add_implementation(
wrap_compute_argsort(topi.cuda.argsort_thrust),
wrap_topi_schedule(topi.cuda.schedule_argsort),
@@ -692,7 +694,9 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk.cuda",
)
- if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
+ if target.kind.name == "cuda" and get_global_func(
+ "tvm.contrib.thrust.sort", allow_missing=True
+ ):
strategy.add_implementation(
wrap_compute_topk(topi.cuda.topk_thrust),
wrap_topi_schedule(topi.cuda.schedule_topk),
diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index 2041f4c..ed6e8f0 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -483,7 +483,12 @@ def non_max_suppression(
score_axis = score_index
score_shape = (batch_size, num_anchors)
score_tensor = te.compute(score_shape, lambda i, j: data[i, j,
score_axis], tag=tag.ELEMWISE)
- if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True):
+ target = tvm.target.Target.current()
+ if (
+ target
+ and target.kind.name == "cuda"
+ and tvm.get_global_func("tvm.contrib.thrust.sort_nms",
allow_missing=True)
+ ):
sort_tensor = argsort_thrust(
score_tensor, valid_count=None, axis=1, is_ascend=False,
dtype=valid_count_dtype
)