This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 25d3542 Add thrust support for nms (#5116)
25d3542 is described below
commit 25d354218a72c0df6b26796f3574f45b5be99ede
Author: Leyuan Wang <[email protected]>
AuthorDate: Mon Mar 23 16:52:33 2020 -0700
Add thrust support for nms (#5116)
* add argsort_nms_thrust
* consider valid count in thrust nms sort
* make thrust optional
* typo
* typo
* fix pylint
* address some of the comments
* address more comments
* fix lint
* address more comments
* address more comments
---
cmake/config.cmake | 2 +-
src/runtime/contrib/thrust/thrust.cu | 90 +++++++++++++++++++++++++-----------
topi/python/topi/cuda/nms.py | 10 ++--
topi/python/topi/cuda/sort.py | 73 +++++++++++++++++++++++++----
4 files changed, 135 insertions(+), 40 deletions(-)
diff --git a/cmake/config.cmake b/cmake/config.cmake
index fd295aa..6ab362c 100644
--- a/cmake/config.cmake
+++ b/cmake/config.cmake
@@ -148,7 +148,7 @@ set(USE_NNPACK OFF)
# Possible values:
# - ON: enable tflite with cmake's find search
# - OFF: disable tflite
-# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite
library
+# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library
set(USE_TFLITE OFF)
# /path/to/tensorflow: tensorflow root path when use tflite library
diff --git a/src/runtime/contrib/thrust/thrust.cu
b/src/runtime/contrib/thrust/thrust.cu
index fc9deac..c40235d 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -28,6 +28,7 @@
#include <dlpack/dlpack.h>
#include <algorithm>
#include <vector>
+#include <functional>
namespace tvm {
namespace contrib {
@@ -39,7 +40,8 @@ template<typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input,
DLTensor* out_values,
DLTensor* out_indices,
- bool is_ascend) {
+ bool is_ascend,
+ const std::function<int(int)> &get_sort_len) {
thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType
*>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType
*>(out_indices->data));
@@ -53,6 +55,7 @@ void thrust_sort(DLTensor* input,
thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);
for (int i = 0 ; i < n_iter; ++i) {
+ n_values = get_sort_len(i);
thrust::sequence(indices_ptr, indices_ptr + n_values);
if (is_ascend) {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
@@ -65,69 +68,100 @@ void thrust_sort(DLTensor* input,
}
}
-TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- CHECK_GE(args.num_args, 4);
- DLTensor* input = args[0];
- DLTensor* values_out = args[1];
- DLTensor* indices_out = args[2];
- bool is_ascend = args[3];
-
- auto data_dtype = DLDataType2String(input->dtype);
- auto out_dtype = DLDataType2String(indices_out->dtype);
-
+void thrust_sort_common(DLTensor* input,
+ DLTensor* values_out,
+ DLTensor* indices_out,
+ bool is_ascend,
+ const std::function<int(int)> &get_sort_len,
+ std::string data_dtype,
+ std::string out_dtype) {
if (data_dtype == "float32") {
if (out_dtype == "int32") {
- thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend);
+ thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else if (out_dtype == "int64") {
- thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend);
+ thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else if (out_dtype == "float32") {
- thrust_sort<float, float>(input, values_out, indices_out, is_ascend);
+ thrust_sort<float, float>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else if (out_dtype == "float64") {
- thrust_sort<float, double>(input, values_out, indices_out, is_ascend);
+ thrust_sort<float, double>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
- thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend);
+ thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else if (out_dtype == "int64") {
- thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend);
+ thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else if (out_dtype == "float32") {
- thrust_sort<double, float>(input, values_out, indices_out, is_ascend);
+ thrust_sort<double, float>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else if (out_dtype == "float64") {
- thrust_sort<double, double>(input, values_out, indices_out, is_ascend);
+ thrust_sort<double, double>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
- thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend);
+ thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else if (out_dtype == "int64") {
- thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend);
+ thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else if (out_dtype == "float32") {
- thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend);
+ thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else if (out_dtype == "float64") {
- thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend);
+ thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
- thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend);
+ thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else if (out_dtype == "int64") {
- thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend);
+ thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else if (out_dtype == "float32") {
- thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend);
+ thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else if (out_dtype == "float64") {
- thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend);
+ thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend,
get_sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
+}
+
+TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ CHECK_GE(args.num_args, 5);
+ DLTensor* input = args[0];
+ DLTensor* valid_count = args[1];
+ DLTensor* values_out = args[2];
+ DLTensor* indices_out = args[3];
+ bool is_ascend = args[4];
+
+ auto data_dtype = DLDataType2String(input->dtype);
+ auto out_dtype = DLDataType2String(indices_out->dtype);
+
+ thrust::device_ptr<int> valid_count_ptr(static_cast<int
*>(valid_count->data));
+ auto get_sort_len = [&valid_count_ptr](int i) { return valid_count_ptr[i]; };
+ thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
+ data_dtype, out_dtype);
});
+
+TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ CHECK_GE(args.num_args, 4);
+ DLTensor* input = args[0];
+ DLTensor* values_out = args[1];
+ DLTensor* indices_out = args[2];
+ bool is_ascend = args[3];
+
+ auto data_dtype = DLDataType2String(input->dtype);
+ auto out_dtype = DLDataType2String(indices_out->dtype);
+
+ int n_values = input->shape[input->ndim - 1];
+ auto get_sort_len = [=](int i) { return n_values; };
+ thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
+ data_dtype, out_dtype);
+});
} // namespace contrib
} // namespace tvm
diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py
index e008dcd..d295116 100644
--- a/topi/python/topi/cuda/nms.py
+++ b/topi/python/topi/cuda/nms.py
@@ -22,7 +22,7 @@ import tvm
from tvm import te
from tvm.tir import if_then_else
-from .sort import argsort
+from .sort import argsort, argsort_thrust
from .. import tag
@@ -668,8 +668,12 @@ def non_max_suppression(data, valid_count,
max_output_size=-1,
score_shape = (batch_size, num_anchors)
score_tensor = te.compute(
score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
- sort_tensor = argsort(
- score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
+ if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True):
+ sort_tensor = argsort_thrust(
+ score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
+ else:
+ sort_tensor = argsort(
+ score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
sort_tensor_buf = tvm.tir.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
"sort_tensor_buf", data_alignment=8)
diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py
index 5499683..a1c70c4 100644
--- a/topi/python/topi/cuda/sort.py
+++ b/topi/python/topi/cuda/sort.py
@@ -24,6 +24,10 @@ from ..math import identity
from ..transform import strided_slice, transpose
from .. import tag
+def swap(arr, axis):
+ """ swap arr[axis] and arr[-1] """
+ return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]
+
def _schedule_sort(outs):
"""Schedule for argsort operator.
@@ -237,6 +241,64 @@ def sort_nms_ir(data, valid_count, output, axis,
is_ascend):
return ib.get()
+def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1,
dtype="float32"):
+ """Performs sorting along the given axis and returns an array of indicies
+ having same shape as an input array that index data in sorted order.
+
+ Parameters
+ ----------
+ data: tvm.te.Tensor
+ The input array.
+
+ valid_count : tvm.te.Tensor, optional
+ The number of valid elements to be sorted.
+
+ axis : int, optional
+ Axis long which to sort the input tensor.
+
+ is_ascend : boolean, optional
+ Whether to sort in ascending or descending order.
+
+ dtype : string, optional
+ DType of the output indices.
+
+ Returns
+ -------
+ out : tvm.te.Tensor
+ The output of this function.
+ """
+ ndim = len(data.shape)
+ if axis < 0:
+ axis = ndim + axis
+ if axis != ndim - 1:
+ # Prepare for sorting along axis -1.
+ axes = swap(list(range(ndim)), axis)
+ data = transpose(data, axes)
+
+ data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf",
+ data_alignment=8)
+ valid_count_buf = tvm.tir.decl_buffer(valid_count.shape, valid_count.dtype,
+ "valid_count_buf", data_alignment=4)
+ out_bufs = [
+ tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf",
data_alignment=8),
+ tvm.tir.decl_buffer(data.shape, "int32", "indices_buf",
data_alignment=8)
+ ]
+ out = te.extern([data.shape, data.shape],
+ [data, valid_count],
+ lambda ins, outs: tvm.tir.call_packed(
+ "tvm.contrib.thrust.sort_nms", ins[0], ins[1],
outs[0], outs[1], is_ascend),
+ in_buffers=[data_buf, valid_count_buf],
+ out_buffers=out_bufs,
+ dtype=[data.dtype, "int32"],
+ name="nms_argsort_gpu",
+ tag="nms_argsort_gpu")
+
+ if axis != ndim - 1:
+ axes = swap(list(range(ndim)), axis)
+ out = [transpose(o, axes) for o in out]
+
+ return out[1]
+
def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
@@ -318,8 +380,7 @@ def argsort_thrust(data, valid_count=None, axis=-1,
is_ascend=1, dtype="float32"
The output of this function.
"""
if valid_count is not None:
- # TODO: implement argsort_nms with Thrust
- out = argsort(data, valid_count, axis, is_ascend, dtype)
+ out = argsort_nms_thrust(data, valid_count, axis, is_ascend, dtype)
else:
out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype)
return out
@@ -453,13 +514,9 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both",
is_ascend=False, dtype="int
ndim = len(data.shape)
axis = ndim + axis if axis < 0 else axis
- def swap(arr):
- """ swap arr[axis] and arr[-1] """
- return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]
-
if axis != ndim - 1:
# Prepare for sorting along axis -1.
- axes = swap(list(range(ndim)))
+ axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf",
data_alignment=8)
@@ -483,7 +540,7 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both",
is_ascend=False, dtype="int
out = [strided_slice(o, beg, end) for o in out]
if axis != ndim - 1:
- axes = swap(list(range(ndim)))
+ axes = swap(list(range(ndim)), axis)
out = [transpose(o, axes) for o in out]
if ret_type == "values":