This is an automated email from the ASF dual-hosted git repository.

laurawly pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b305204  Only use thrust for cuda target (#6722)
b305204 is described below

commit b30520434e8b0f22720d9ee1ef0fe15f85581b8b
Author: Trevor Morris <[email protected]>
AuthorDate: Thu Oct 29 10:46:12 2020 -0700

    Only use thrust for cuda target (#6722)
---
 python/tvm/relay/op/strategy/cuda.py | 8 ++++++--
 python/tvm/topi/cuda/nms.py          | 7 ++++++-
 2 files changed, 12 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/op/strategy/cuda.py 
b/python/tvm/relay/op/strategy/cuda.py
index d77361d..187ea01 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -673,7 +673,9 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target):
         wrap_topi_schedule(topi.cuda.schedule_argsort),
         name="argsort.cuda",
     )
-    if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
+    if target.kind.name == "cuda" and get_global_func(
+        "tvm.contrib.thrust.sort", allow_missing=True
+    ):
         strategy.add_implementation(
             wrap_compute_argsort(topi.cuda.argsort_thrust),
             wrap_topi_schedule(topi.cuda.schedule_argsort),
@@ -692,7 +694,9 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
         wrap_topi_schedule(topi.cuda.schedule_topk),
         name="topk.cuda",
     )
-    if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
+    if target.kind.name == "cuda" and get_global_func(
+        "tvm.contrib.thrust.sort", allow_missing=True
+    ):
         strategy.add_implementation(
             wrap_compute_topk(topi.cuda.topk_thrust),
             wrap_topi_schedule(topi.cuda.schedule_topk),
diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index 2041f4c..ed6e8f0 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -483,7 +483,12 @@ def non_max_suppression(
     score_axis = score_index
     score_shape = (batch_size, num_anchors)
     score_tensor = te.compute(score_shape, lambda i, j: data[i, j, 
score_axis], tag=tag.ELEMWISE)
-    if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True):
+    target = tvm.target.Target.current()
+    if (
+        target
+        and target.kind.name == "cuda"
+        and tvm.get_global_func("tvm.contrib.thrust.sort_nms", 
allow_missing=True)
+    ):
         sort_tensor = argsort_thrust(
             score_tensor, valid_count=None, axis=1, is_ascend=False, 
dtype=valid_count_dtype
         )

Reply via email to