This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1d07f1a  [THRUST] Faster multi dimensional argsort by segmented sort 
(#7195)
1d07f1a is described below

commit 1d07f1a0f4e70872c2a52531b6bd8580d64c7538
Author: masahi <[email protected]>
AuthorDate: Wed Jan 13 15:42:09 2021 +0900

    [THRUST] Faster multi dimensional argsort by segmented sort (#7195)
    
    * remove sort nms
    
    * add segmented sort by key impl
    
    * bug fix, test pass
    
    * updated fast path condition to work for all dims
---
 python/tvm/topi/cuda/nms.py          |   6 +-
 python/tvm/topi/cuda/sort.py         |  73 +---------------------
 src/runtime/contrib/thrust/thrust.cu | 117 ++++++++++++++++++++---------------
 3 files changed, 72 insertions(+), 124 deletions(-)

diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index 8946446..a4080e5 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -819,11 +819,9 @@ def non_max_suppression(
     if (
         target
         and target.kind.name == "cuda"
-        and tvm.get_global_func("tvm.contrib.thrust.sort_nms", 
allow_missing=True)
+        and tvm.get_global_func("tvm.contrib.thrust.sort", allow_missing=True)
     ):
-        sort_tensor = argsort_thrust(
-            score_tensor, valid_count=None, axis=1, is_ascend=False, 
dtype=valid_count_dtype
-        )
+        sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, 
dtype=valid_count_dtype)
     else:
         sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, 
dtype=valid_count_dtype)
 
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index 18872a2..9b6a18a 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -409,68 +409,6 @@ def sort_by_key_ir(
     )
 
 
-def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, 
dtype="float32"):
-    """Performs sorting along the given axis and returns an array of indicies
-    having same shape as an input array that index data in sorted order.
-
-    Parameters
-    ----------
-    data: tvm.te.Tensor
-        The input array.
-
-    valid_count : tvm.te.Tensor, optional
-        The number of valid elements to be sorted.
-
-    axis : int, optional
-        Axis long which to sort the input tensor.
-
-    is_ascend : boolean, optional
-        Whether to sort in ascending or descending order.
-
-    dtype : string, optional
-        DType of the output indices.
-
-    Returns
-    -------
-    out : tvm.te.Tensor
-        The output of this function.
-    """
-    ndim = len(data.shape)
-    if axis < 0:
-        axis = ndim + axis
-    if axis != ndim - 1:
-        # Prepare for sorting along axis -1.
-        axes = swap(list(range(ndim)), axis)
-        data = transpose(data, axes)
-
-    data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", 
data_alignment=8)
-    valid_count_buf = tvm.tir.decl_buffer(
-        valid_count.shape, valid_count.dtype, "valid_count_buf", 
data_alignment=4
-    )
-    out_bufs = [
-        tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", 
data_alignment=8),
-        tvm.tir.decl_buffer(data.shape, "int32", "indices_buf", 
data_alignment=8),
-    ]
-    out = te.extern(
-        [data.shape, data.shape],
-        [data, valid_count],
-        lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.thrust.sort_nms", ins[0], ins[1], outs[0], outs[1], 
is_ascend
-        ),
-        in_buffers=[data_buf, valid_count_buf],
-        out_buffers=out_bufs,
-        dtype=[data.dtype, "int32"],
-        name="nms_argsort_gpu",
-        tag="nms_argsort_gpu",
-    )
-
-    if axis != ndim - 1:
-        axes = swap(list(range(ndim)), axis)
-        out = [transpose(o, axes) for o in out]
-
-    return out[1]
-
-
 def sort(data, axis=-1, is_ascend=1):
     """Performs sorting along the given axis and returns an array of
     sorted values with the same shape as the input data.
@@ -602,7 +540,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
     return out
 
 
-def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, 
dtype="float32"):
+def argsort_thrust(data, axis=-1, is_ascend=1, dtype="float32"):
     """Performs sorting along the given axis and returns an array of indicies
     having same shape as an input array that index data in sorted order.
 
@@ -611,9 +549,6 @@ def argsort_thrust(data, valid_count=None, axis=-1, 
is_ascend=1, dtype="float32"
     data: tvm.te.Tensor
         The input array.
 
-    valid_count : tvm.te.Tensor, optional
-        The number of valid elements to be sorted.
-
     axis : int, optional
         Axis long which to sort the input tensor.
 
@@ -628,11 +563,7 @@ def argsort_thrust(data, valid_count=None, axis=-1, 
is_ascend=1, dtype="float32"
     out : tvm.te.Tensor
         The output of this function.
     """
-    if valid_count is not None:
-        out = argsort_nms_thrust(data, valid_count, axis, is_ascend, dtype)
-    else:
-        out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype)
-    return out
+    return topk_thrust(data, 0, axis, "indices", is_ascend, dtype)
 
 
 def schedule_sort(outs):
diff --git a/src/runtime/contrib/thrust/thrust.cu 
b/src/runtime/contrib/thrust/thrust.cu
index dddbb04..6a48f1a 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -22,7 +22,9 @@
  */
 
 #include <thrust/device_ptr.h>
+#include <thrust/device_vector.h>
 #include <thrust/sort.h>
+#include <thrust/gather.h>
 
 #include <tvm/runtime/registry.h>
 #include <dlpack/dlpack.h>
@@ -41,21 +43,19 @@ void thrust_sort(DLTensor* input,
                  DLTensor* out_values,
                  DLTensor* out_indices,
                  bool is_ascend,
-                 const std::function<int(int)> &get_sort_len) {
+                 int n_values) {
   thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
   thrust::device_ptr<DataType> values_ptr(static_cast<DataType 
*>(out_values->data));
   thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType 
*>(out_indices->data));
 
-  int n_values = input->shape[input->ndim - 1];
-  int n_iter = 1;
-  for (int i = 0; i < input->ndim - 1; ++i) {
-    n_iter *= input->shape[i];
+  size_t size = 1;
+  for (int i = 0; i < input->ndim; ++i) {
+    size *= input->shape[i];
   }
+  thrust::copy(data_ptr, data_ptr + size, values_ptr);
 
-  thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);
-
-  for (int i = 0 ; i < n_iter; ++i) {
-    n_values = get_sort_len(i);
+  if (size == static_cast<size_t>(input->shape[input->ndim - 1])) {
+    // A fast path for single segment case
     thrust::sequence(indices_ptr, indices_ptr + n_values);
     if (is_ascend) {
       thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
@@ -63,8 +63,47 @@ void thrust_sort(DLTensor* input,
       thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr,
                           thrust::greater<DataType>());
     }
-    values_ptr += n_values;
-    indices_ptr += n_values;
+  } else {
+    // segmented sort by key
+    // Follow the back-to-back stable_sort_by_key strategy explained below
+    // https://groups.google.com/g/thrust-users/c/BoLsxO6b4FY
+    thrust::device_vector<int64_t> argsort_order(size);
+    thrust::sequence(argsort_order.begin(), argsort_order.end());
+
+    // First, sort values and store the sorted order in argsort_order.
+    if (is_ascend) {
+      thrust::stable_sort_by_key(values_ptr, values_ptr + size, 
argsort_order.begin());
+    } else {
+      thrust::stable_sort_by_key(values_ptr, values_ptr + size, 
argsort_order.begin(),
+                                 thrust::greater<DataType>());
+    }
+
+    // The following is to create the indices array 0, 1, 2, 0, 1, 2 ... 0, 1, 
2
+    // without materializing it
+    auto counting_iter = thrust::counting_iterator<int64_t>(0);
+    auto linear_index_to_sort_axis_index = [n_values] __host__ 
__device__(int64_t i) {
+      return i % n_values;
+    }; // NOLINT(*)
+    auto init_indices_iter = thrust::make_transform_iterator(counting_iter,
+                                                             
linear_index_to_sort_axis_index);
+
+    // This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr
+    thrust::gather(argsort_order.begin(), argsort_order.end(), 
init_indices_iter, indices_ptr);
+
+    thrust::device_vector<int> segment_ids(size);
+    auto linear_index_to_segment_id = [n_values] __host__ __device__(int64_t 
i) {
+      return i / n_values;
+    }; // NOLINT(*)
+    // We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of 
values_ptr
+    thrust::transform(argsort_order.begin(), argsort_order.end(), 
segment_ids.begin(),
+                      linear_index_to_segment_id);
+
+    // The second sort key-ed by segment_ids would bring segment_ids back to 
0, 0, 0, 1, 1, 1 ...
+    // values_ptr and indices_ptr will also be sorted in the order of 
segmend_ids above
+    // Since sorting has been done in a stable way, relative orderings of 
values and indices
+    // in the segment do not change and hence they remain sorted.
+    auto key_val_zip = 
thrust::make_zip_iterator(thrust::make_tuple(values_ptr, indices_ptr));
+    thrust::stable_sort_by_key(segment_ids.begin(), segment_ids.end(), 
key_val_zip);
   }
 }
 
@@ -72,54 +111,54 @@ void thrust_sort_common(DLTensor* input,
                         DLTensor* values_out,
                         DLTensor* indices_out,
                         bool is_ascend,
-                        const std::function<int(int)> &get_sort_len,
+                        int sort_len,
                         std::string data_dtype,
                         std::string out_dtype) {
   if (data_dtype == "float32") {
     if (out_dtype == "int32") {
-      thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else if (out_dtype == "int64") {
-      thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else if (out_dtype == "float32") {
-      thrust_sort<float, float>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<float, float>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else if (out_dtype == "float64") {
-      thrust_sort<float, double>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<float, double>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
     }
   } else if (data_dtype == "float64") {
     if (out_dtype == "int32") {
-      thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else if (out_dtype == "int64") {
-      thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else if (out_dtype == "float32") {
-      thrust_sort<double, float>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<double, float>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else if (out_dtype == "float64") {
-      thrust_sort<double, double>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<double, double>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
     }
   } else if (data_dtype == "int32") {
     if (out_dtype == "int32") {
-      thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else if (out_dtype == "int64") {
-      thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else if (out_dtype == "float32") {
-      thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else if (out_dtype == "float64") {
-      thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
     }
   }  else if (data_dtype == "int64") {
     if (out_dtype == "int32") {
-      thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else if (out_dtype == "int64") {
-      thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else if (out_dtype == "float32") {
-      thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else if (out_dtype == "float64") {
-      thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, 
get_sort_len);
+      thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, 
sort_len);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
     }
@@ -128,25 +167,6 @@ void thrust_sort_common(DLTensor* input,
   }
 }
 
-TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  ICHECK_GE(args.num_args, 5);
-  DLTensor* input = args[0];
-  DLTensor* valid_count = args[1];
-  DLTensor* values_out = args[2];
-  DLTensor* indices_out = args[3];
-  bool is_ascend = args[4];
-
-  auto data_dtype = DLDataType2String(input->dtype);
-  auto out_dtype = DLDataType2String(indices_out->dtype);
-
-  thrust::device_ptr<int> valid_count_ptr(static_cast<int 
*>(valid_count->data));
-  auto get_sort_len = [&valid_count_ptr](int i) { return valid_count_ptr[i]; };
-  thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
-                     data_dtype, out_dtype);
-});
-
-
 TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   ICHECK_GE(args.num_args, 4);
@@ -159,8 +179,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
   auto out_dtype = DLDataType2String(indices_out->dtype);
 
   int n_values = input->shape[input->ndim - 1];
-  auto get_sort_len = [=](int i) { return n_values; };
-  thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
+  thrust_sort_common(input, values_out, indices_out, is_ascend, n_values,
                      data_dtype, out_dtype);
 });
 

Reply via email to