This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9956b5b  [TOPI] GPU sort IR refactor to enable sort by keys (#7157)
9956b5b is described below

commit 9956b5b859a24c8f4dec1abf308723b1257ffc66
Author: masahi <[email protected]>
AuthorDate: Thu Dec 24 07:33:47 2020 +0900

    [TOPI] GPU sort IR refactor to enable sort by keys (#7157)
    
    * sort refactor initial import
    
    * sort test working
    
    * scatter 1d with positive indices working
    
    * remove negatiev indices, using extern for now
    
    * minor fix
    
    * minor fix
    
    * add sort by key test
    
    * revert scatter change
    
    * add document
    
    * fix py format
    
    Co-authored-by: masa <[email protected]>
---
 python/tvm/topi/cuda/nms.py       |   4 +-
 python/tvm/topi/cuda/scatter.py   |  12 +-
 python/tvm/topi/cuda/sort.py      | 561 +++++++++++++++++++++-----------------
 tests/python/contrib/test_sort.py |  35 ++-
 4 files changed, 343 insertions(+), 269 deletions(-)

diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index cea287e..020cf9b 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -737,9 +737,7 @@ def non_max_suppression(
             score_tensor, valid_count=None, axis=1, is_ascend=False, 
dtype=valid_count_dtype
         )
     else:
-        sort_tensor = argsort(
-            score_tensor, valid_count=None, axis=1, is_ascend=False, 
dtype=valid_count_dtype
-        )
+        sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, 
dtype=valid_count_dtype)
 
     sort_tensor_buf = tvm.tir.decl_buffer(
         sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", 
data_alignment=8
diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py
index 9916e2a..be602c8 100644
--- a/python/tvm/topi/cuda/scatter.py
+++ b/python/tvm/topi/cuda/scatter.py
@@ -424,6 +424,8 @@ def gen_scatter_1d_thrust(data, indices_sorted, 
updates_sorted, axis, out, _):
     Sorting of indices, and sorting of updates with respect to indices, can be 
done
     at the same time by thrust's sort_by_key function. It is important that 
sorting
     be done in a "stable" way via stable_sort, to guarantee deterministic 
output.
+    Negative indices are assumed to have been converted to corresponding 
positive
+    indices.
 
     Parameters
     ----------
@@ -473,12 +475,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, 
updates_sorted, axis, out, _):
 
     ni = indices_sorted.shape[0]
 
-    def do_update(ib, index, update):
-        with ib.if_scope(index < 0):
-            out_ptr[index + n] = update
-        with ib.else_scope():
-            out_ptr[index] = update
-
     with ib.new_scope():
         nthread_bx = ceil_div(ni, nthread_tx)
         tx = te.thread_axis("threadIdx.x")
@@ -491,7 +487,7 @@ def gen_scatter_1d_thrust(data, indices_sorted, 
updates_sorted, axis, out, _):
             # The last element can always update.
             index = indices_ptr[tid]
             update = updates_ptr[tid]
-            do_update(ib, index, update)
+            out_ptr[index] = update
 
         with ib.else_scope():
             with ib.if_scope(tid < ni - 1):
@@ -503,7 +499,7 @@ def gen_scatter_1d_thrust(data, indices_sorted, 
updates_sorted, axis, out, _):
                 # This thread can update the output.
                 with ib.if_scope(index != index_next):
                     update = updates_ptr[tid]
-                    do_update(ib, index, update)
+                    out_ptr[index] = update
 
     return ib.get()
 
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index 039ebe3..18872a2 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -21,7 +21,6 @@ from tvm import te
 from tvm._ffi import get_global_func
 
 from .injective import schedule_injective_from_existing
-from ..math import identity
 from ..transform import strided_slice, transpose
 from .. import tag
 
@@ -62,46 +61,14 @@ def _schedule_sort(outs):
     return s
 
 
-def sort_ir(
-    data, values_out, values_out_swap, axis, is_ascend, indices_out=None, 
indices_out_swap=None
-):
-    """Low level IR to do nms sorting on the GPU, same usage as 
tvm.contrib.sort.argsort on the CPU.
-
-    Parameters
-    ----------
-    data: Buffer
-        Buffer of input data. Data will be sorted in place.
-
-    values_out : Buffer
-        Output buffer of values of sorted tensor with same shape as data.
-
-    values_out_swap : Buffer
-        Output buffer of values with same shape as data to use as swap.
-
-    axis : Int
-        Axis long which to sort the input tensor.
-
-    is_ascend : Boolean
-        Whether to sort in ascending or descending order.
-
-    indicess_out : Buffer
-        Output buffer of indices of sorted tensor with same shape as data.
-
-    indices_out_swap : Buffer
-        Output buffer of indices with same shape as data to use as swap.
-
-    Returns
-    -------
-    stmt : Stmt
-        The result IR statement.
-    """
+def ceil_div(a, b):
+    return tvm.tir.indexdiv(a + b - 1, b)
 
-    def ceil_div(a, b):
-        return tvm.tir.indexdiv(a + b - 1, b)
 
+def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, 
value_init_func=None):
+    """Initialize the output buffers by copying from inputs"""
     axis_mul_before = 1
     axis_mul_after = 1
-    shape = data.shape
     if axis < 0:
         axis = len(shape) + axis
     for i, value in enumerate(shape, 0):
@@ -110,16 +77,6 @@ def sort_ir(
         elif i > axis:
             axis_mul_after *= value
 
-    ib = tvm.tir.ir_builder.create()
-
-    data = ib.buffer_ptr(data)
-    values_out = ib.buffer_ptr(values_out)
-    values_out_swap = ib.buffer_ptr(values_out_swap)
-    if indices_out is not None:
-        indices_out = ib.buffer_ptr(indices_out)
-        assert indices_out_swap is not None
-        indices_out_swap = ib.buffer_ptr(indices_out_swap)
-
     # Set up threading
     max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
@@ -127,7 +84,7 @@ def sort_ir(
     nthread_by = axis_mul_before
     nthread_bz = axis_mul_after
 
-    # Copy the data to initial output
+    # Copy the keys_in to initial output
     with ib.new_scope():
         tx = te.thread_axis("threadIdx.x")
         bx = te.thread_axis("blockIdx.x")
@@ -141,9 +98,25 @@ def sort_ir(
         ib.scope_attr(bz, "thread_extent", nthread_bz)
         idx = (by * shape[axis] + tid) * axis_mul_after + bz
         with ib.if_scope(tid < shape[axis]):
-            values_out[idx] = data[idx]
-            if indices_out is not None:
-                indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype)
+            keys_out[idx] = keys_in[idx]
+            if values_out is not None:
+                values_out[idx] = value_init_func(idx, tid)
+
+    return axis_mul_before, axis_mul_after
+
+
+def _sort_common(
+    ib,
+    size,
+    axis_mul_before,
+    axis_mul_after,
+    is_ascend,
+    keys,
+    keys_swap,
+    values=None,
+    values_swap=None,
+):
+    """Either sort only values or sort values by keys."""
 
     ## we are looping over the array doing mergesort from the bottom up.
     ## The outer loop runs on the host and launches a cuda kernel for each 
iteration
@@ -155,8 +128,85 @@ def sort_ir(
     ## to deal with 8 total elements. On iteration 3, each thread deals with 
16 elements, etc
     ## On the final iteration of the algorithm, one thread will merge two 
sorted lists
     ## to sort the entire array
+
+    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+    nthread_bx = ceil_div(size, max_threads)
+    nthread_by = axis_mul_before
+    nthread_bz = axis_mul_after
+
+    def compare(a, b):
+        """
+        Compare a and b in proper ascending or descending order
+        """
+        if is_ascend:
+            out = a <= b
+        else:
+            out = b <= a
+        return out
+
+    def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, 
end, even):
+        """
+        Merge the two sections of the array assigned to this thread
+        """
+        # pylint: disable=arguments-out-of-order
+        # initialize iterators
+        i[0] = start
+        j[0] = middle
+        # set up indexes
+        base_idx = by * size * axis_mul_after + bz
+        # iterate over the output loop
+        with ib.for_range(0, end - start) as k:
+            i_idx = base_idx + i[0] * axis_mul_after
+            j_idx = base_idx + j[0] * axis_mul_after
+            k_idx = base_idx + (k + start) * axis_mul_after
+
+            def swap_values(source, dest, source_idx, dest_idx):
+                def assign_i():
+                    """assign i value to current output"""
+                    dest[k_idx] = source[i_idx]
+                    if values is not None:
+                        dest_idx[k_idx] = source_idx[i_idx]
+                    i[0] += 1
+
+                def assign_j():
+                    """assign j value to current output"""
+                    dest[k_idx] = source[j_idx]
+                    if values is not None:
+                        dest_idx[k_idx] = source_idx[j_idx]
+                    j[0] += 1
+
+                ## if both of the iterators are in range
+                with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)):
+                    # compare them and insert whichever is next into the output
+                    with ib.if_scope(compare(source[i_idx], source[j_idx])):
+                        assign_i()
+                    with ib.else_scope():
+                        assign_j()
+                # otherwise, simply copy the remainder of the valid iterator 
to the output
+                with ib.else_scope():
+                    with ib.if_scope(i[0] < middle):
+                        assign_i()
+                    with ib.else_scope():
+                        assign_j()
+
+            # Switch which input is the source and which is the destination 
each iteration
+            with ib.if_scope(even):
+                swap_values(source, dest, source_idx, dest_idx)
+            with ib.else_scope():
+                swap_values(dest, source, dest_idx, source_idx)
+
+    def mergesort(source, dest, source_idx, dest_idx, size, width, even):
+        # calculate the start, mid, and end points of this section
+        start[0] = width * tid
+        with ib.if_scope(start[0] < size):
+            middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size)
+            end[0] = tvm.te.min(start[0] + width, size)
+            ## merge the start->middle and middle->end arrays
+            bottom_up_merge(source, dest, source_idx, dest_idx, start[0], 
middle[0], end[0], even)
+
     lim = tvm.tir.generic.cast(
-        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], 
"float64"))), "int64"
+        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), 
"int64"
     )
     with ib.for_range(0, lim, dtype="int64") as l2_width:
         width = 2 << l2_width
@@ -174,7 +224,7 @@ def sort_ir(
             ib.scope_attr(
                 bx,
                 "thread_extent",
-                tvm.tir.generic.cast(ceil_div(shape[axis], width * 
max_threads), "int32"),
+                tvm.tir.generic.cast(ceil_div(size, width * max_threads), 
"int32"),
             )
             tid = bx * nthread_tx + tx
 
@@ -183,85 +233,13 @@ def sort_ir(
             ib.scope_attr(by, "thread_extent", nthread_by)
             ib.scope_attr(bz, "thread_extent", nthread_bz)
 
-            def compare(a, b):
-                """
-                Compare a and b in proper ascending or descending order
-                """
-                if is_ascend:
-                    out = a <= b
-                else:
-                    out = b <= a
-                return out
-
-            def bottom_up_merge(source, dest, source_idx, dest_idx, start, 
middle, end, even):
-                """
-                Merge the two sections of the array assigned to this thread
-                """
-                # pylint: disable=arguments-out-of-order
-                # initialize iterators
-                i[0] = start
-                j[0] = middle
-                # set up indexes
-                base_idx = by * shape[axis] * axis_mul_after + bz
-                # iterate over the output loop
-                with ib.for_range(0, end - start) as k:
-                    i_idx = base_idx + i[0] * axis_mul_after
-                    j_idx = base_idx + j[0] * axis_mul_after
-                    k_idx = base_idx + (k + start) * axis_mul_after
-
-                    def swap_values(source, dest, source_idx, dest_idx):
-                        def assign_i():
-                            """assign i value to current output"""
-                            dest[k_idx] = source[i_idx]
-                            if indices_out is not None:
-                                dest_idx[k_idx] = source_idx[i_idx]
-                            i[0] += 1
-
-                        def assign_j():
-                            """assign j value to current output"""
-                            dest[k_idx] = source[j_idx]
-                            if indices_out is not None:
-                                dest_idx[k_idx] = source_idx[j_idx]
-                            j[0] += 1
-
-                        ## if both of the iterators are in range
-                        with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < 
end)):
-                            # compare them and insert whichever is next into 
the output
-                            with ib.if_scope(compare(source[i_idx], 
source[j_idx])):
-                                assign_i()
-                            with ib.else_scope():
-                                assign_j()
-                        # otherwise, simply copy the remainder of the valid 
iterator to the output
-                        with ib.else_scope():
-                            with ib.if_scope(i[0] < middle):
-                                assign_i()
-                            with ib.else_scope():
-                                assign_j()
-
-                    # Switch which input is the source and which is the 
destination each iteration
-                    with ib.if_scope(even):
-                        swap_values(source, dest, source_idx, dest_idx)
-                    with ib.else_scope():
-                        swap_values(dest, source, dest_idx, source_idx)
-
-            def mergesort(source, dest, source_idx, dest_idx, size, width, 
even):
-                # calculate the start, mid, and end points of this section
-                start[0] = width * tid
-                with ib.if_scope(start[0] < size):
-                    middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 
2), size)
-                    end[0] = tvm.te.min(start[0] + width, size)
-                    ## merge the start->middle and middle->end arrays
-                    bottom_up_merge(
-                        source, dest, source_idx, dest_idx, start[0], 
middle[0], end[0], even
-                    )
-
             # Call the kernel
             mergesort(
-                values_out,
-                values_out_swap,
-                indices_out,
-                indices_out_swap,
-                shape[axis],
+                keys,
+                keys_swap,
+                values,
+                values_swap,
+                size,
                 width,
                 tvm.tir.indexmod(l2_width, 2) == 0,
             )
@@ -279,29 +257,31 @@ def sort_ir(
             bz = te.thread_axis("blockIdx.z")
             ib.scope_attr(by, "thread_extent", nthread_by)
             ib.scope_attr(bz, "thread_extent", nthread_bz)
-            idx = (by * shape[axis] + tid) * axis_mul_after + bz
-            with ib.if_scope(tid < shape[axis]):
-                idx = (by * shape[axis] + tid) * axis_mul_after + bz
-                values_out[idx] = values_out_swap[idx]
-                if indices_out is not None:
-                    indices_out[idx] = indices_out_swap[idx]
+            idx = (by * size + tid) * axis_mul_after + bz
+            with ib.if_scope(tid < size):
+                idx = (by * size + tid) * axis_mul_after + bz
+                keys[idx] = keys_swap[idx]
+                if values is not None:
+                    values[idx] = values_swap[idx]
 
     return ib.get()
 
 
-def sort_nms_ir(data, valid_count, output, axis, is_ascend):
-    """Low level IR to do nms sorting on the GPU, same usage as 
tvm.contrib.sort.argsort on the CPU.
+def sort_ir(
+    data, values_out, values_out_swap, axis, is_ascend, indices_out=None, 
indices_out_swap=None
+):
+    """Low level IR to do sorting on the GPU, same usage as 
tvm.contrib.sort.argsort on the CPU.
 
     Parameters
     ----------
     data: Buffer
-        Buffer of input data.
+        Buffer of input data. Data will be sorted in place.
 
-    valid_count : Buffer
-        1D Buffer of number of valid number of boxes.
+    values_out : Buffer
+        Output buffer of values of sorted tensor with same shape as data.
 
-    output : Buffer
-        Output buffer of indicies of sorted tensor with same shape as data.
+    values_out_swap : Buffer
+        Output buffer of values with same shape as data to use as swap.
 
     axis : Int
         Axis long which to sort the input tensor.
@@ -309,82 +289,124 @@ def sort_nms_ir(data, valid_count, output, axis, 
is_ascend):
     is_ascend : Boolean
         Whether to sort in ascending or descending order.
 
+    indicess_out : Buffer
+        Output buffer of indices of sorted tensor with same shape as data.
+
+    indices_out_swap : Buffer
+        Output buffer of indices with same shape as data to use as swap.
+
     Returns
     -------
     stmt : Stmt
         The result IR statement.
     """
-
-    size = 1
-    axis_mul_before = 1
-    axis_mul_after = 1
-    shape = data.shape
-    if axis < 0:
-        axis = len(shape) + axis
-    for i, value in enumerate(shape, 0):
-        size *= value
-        if i < axis:
-            axis_mul_before *= value
-        elif i > axis:
-            axis_mul_after *= value
-    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
     ib = tvm.tir.ir_builder.create()
+    shape = data.shape
+
     data = ib.buffer_ptr(data)
-    valid_count = ib.buffer_ptr(valid_count)
-    output = ib.buffer_ptr(output)
-    nthread_tx = max_threads
-    nthread_bx = size // max_threads + 1
-    tx = te.thread_axis("threadIdx.x")
-    bx = te.thread_axis("blockIdx.x")
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "thread_extent", nthread_bx)
-    tid = bx * nthread_tx + tx
-    temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
-    temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
-    is_ascend = tvm.tir.IntImm("int32", is_ascend)
-
-    idxd = tvm.tir.indexdiv
-    idxm = tvm.tir.indexmod
-
-    with ib.for_range(0, axis_mul_before) as i:
-        with ib.for_range(0, axis_mul_after) as j:
-            current_sort_num = valid_count[i * axis_mul_after + j]
-            base_idx = i * shape[axis] * axis_mul_after + j
-            with ib.if_scope(tid < shape[axis]):
-                output[base_idx + tid * axis_mul_after] = tid
-            # OddEvenTransposeSort
-            with ib.for_range(0, current_sort_num) as k:
-                with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
-                    offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
-                    with ib.if_scope(
-                        tvm.tir.all(
-                            is_ascend == 1,
-                            2 * tid + idxm(k, 2) + 1 < current_sort_num,
-                            data[offset] > data[offset + axis_mul_after],
-                        )
-                    ):
-                        temp_data[0] = data[offset]
-                        data[offset] = data[offset + axis_mul_after]
-                        data[offset + axis_mul_after] = temp_data[0]
-                        temp_index[0] = output[offset]
-                        output[offset] = output[offset + axis_mul_after]
-                        output[offset + axis_mul_after] = temp_index[0]
-                    with ib.if_scope(
-                        tvm.tir.all(
-                            is_ascend == 0,
-                            2 * tid + idxm(k, 2) + 1 < current_sort_num,
-                            data[offset] < data[offset + axis_mul_after],
-                        )
-                    ):
-                        temp_data[0] = data[offset]
-                        data[offset] = data[offset + axis_mul_after]
-                        data[offset + axis_mul_after] = temp_data[0]
-                        temp_index[0] = output[offset]
-                        output[offset] = output[offset + axis_mul_after]
-                        output[offset + axis_mul_after] = temp_index[0]
-                ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", 
tvm.runtime.convert(["shared"])))
+    values_out = ib.buffer_ptr(values_out)
+    values_out_swap = ib.buffer_ptr(values_out_swap)
+    if indices_out is not None:
+        indices_out = ib.buffer_ptr(indices_out)
+        assert indices_out_swap is not None
+        indices_out_swap = ib.buffer_ptr(indices_out_swap)
 
-    return ib.get()
+    axis_mul_before, axis_mul_after = _sort_init(
+        ib,
+        shape,
+        axis,
+        data,
+        values_out,
+        indices_out,
+        value_init_func=lambda _, tid: tvm.tir.generic.cast(tid, 
indices_out.dtype),
+    )
+
+    return _sort_common(
+        ib,
+        shape[axis],
+        axis_mul_before,
+        axis_mul_after,
+        is_ascend,
+        values_out,
+        values_out_swap,
+        values=indices_out,
+        values_swap=indices_out_swap,
+    )
+
+
+def sort_by_key_ir(
+    keys_in, values_in, keys_out, values_out, keys_out_swap, values_out_swap, 
axis, is_ascend
+):
+    """Low level IR to do sort by key on the GPU.
+
+    Parameters
+    ----------
+    keys_in: Buffer
+        Buffer of input keys.
+
+    values_in: Buffer
+        Buffer of input keys.
+
+    keys_out : Buffer
+        Buffer of output sorted keys.
+
+    values_out : Buffer
+        Buffer of output sorted values.
+
+    keys_out_swap : Buffer
+        Output buffer of values with same shape as keys_in to use as swap.
+
+    values_out_swap : Buffer
+        Output buffer of values with same shape as values_in to use as swap.
+
+    axis : Int
+        Axis long which to sort the input tensor.
+
+    is_ascend : Boolean
+        Whether to sort in ascending or descending order.
+
+    indicess_out : Buffer
+        Output buffer of indices of sorted tensor with same shape as keys_in.
+
+    values_out_swap : Buffer
+        Output buffer of indices with same shape as keys_in to use as swap.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    ib = tvm.tir.ir_builder.create()
+    shape = keys_in.shape
+
+    keys_in = ib.buffer_ptr(keys_in)
+    values_in = ib.buffer_ptr(values_in)
+    keys_out = ib.buffer_ptr(keys_out)
+    keys_out_swap = ib.buffer_ptr(keys_out_swap)
+    values_out = ib.buffer_ptr(values_out)
+    values_out_swap = ib.buffer_ptr(values_out_swap)
+
+    axis_mul_before, axis_mul_after = _sort_init(
+        ib,
+        shape,
+        axis,
+        keys_in,
+        keys_out,
+        values_out,
+        value_init_func=lambda idx, _: values_in[idx],
+    )
+
+    return _sort_common(
+        ib,
+        shape[axis],
+        axis_mul_before,
+        axis_mul_after,
+        is_ascend,
+        keys_out,
+        keys_out_swap,
+        values=values_out,
+        values_swap=values_out_swap,
+    )
 
 
 def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, 
dtype="float32"):
@@ -534,7 +556,7 @@ def sort_thrust(data, axis=-1, is_ascend=1):
     return out
 
 
-def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
+def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
     """Performs sorting along the given axis and returns an array of indicies
     having same shape as an input array that index data in sorted order.
 
@@ -543,9 +565,6 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, 
dtype="float32"):
     data: tvm.te.Tensor
         The input array.
 
-    valid_count : tvm.te.Tensor, optional
-        The number of valid elements to be sorted.
-
     axis : int, optional
         Axis long which to sort the input tensor.
 
@@ -560,48 +579,26 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, 
dtype="float32"):
     out : tvm.te.Tensor
         The output of this function.
     """
-    if valid_count is not None:
-        sorted_data = identity(data)
-        sorted_data_buf = tvm.tir.decl_buffer(
-            data.shape, data.dtype, "sorted_data_buf", data_alignment=8
-        )
-        valid_count_buf = tvm.tir.decl_buffer(
-            valid_count.shape, valid_count.dtype, "valid_count_buf", 
data_alignment=4
-        )
-        out_buf = tvm.tir.decl_buffer(data.shape, "int32", "out_buf", 
data_alignment=4)
-        out = te.extern(
-            [data.shape],
-            [sorted_data, valid_count],
-            lambda ins, outs: sort_nms_ir(ins[0], ins[1], outs[0], axis, 
is_ascend),
-            dtype="int32",
-            in_buffers=[sorted_data_buf, valid_count_buf],
-            out_buffers=[out_buf],
-            name="argsort_nms_gpu",
-            tag="argsort_nms_gpu",
-        )
-    else:
-        value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", 
data_alignment=8)
-        value_swap_buf = tvm.tir.decl_buffer(
-            data.shape, data.dtype, "value_swap_buf", data_alignment=8
-        )
-        indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", 
data_alignment=8)
-        indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, 
"out_swap_buf", data_alignment=8)
-        out = te.extern(
-            [data.shape, data.shape, data.shape, data.shape],
-            [data],
-            lambda ins, outs: sort_ir(
-                ins[0],
-                outs[0],
-                outs[2],
-                axis,
-                is_ascend,
-                indices_out=outs[1],
-                indices_out_swap=outs[3],
-            ),
-            out_buffers=[value_buf, indices_buf, value_swap_buf, 
indices_swap_buf],
-            name="argsort_gpu",
-            tag="argsort_gpu",
-        )[1]
+    value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", 
data_alignment=8)
+    value_swap_buf = tvm.tir.decl_buffer(data.shape, data.dtype, 
"value_swap_buf", data_alignment=8)
+    indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", 
data_alignment=8)
+    indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", 
data_alignment=8)
+    out = te.extern(
+        [data.shape, data.shape, data.shape, data.shape],
+        [data],
+        lambda ins, outs: sort_ir(
+            ins[0],
+            outs[0],
+            outs[2],
+            axis,
+            is_ascend,
+            indices_out=outs[1],
+            indices_out_swap=outs[3],
+        ),
+        out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf],
+        name="argsort_gpu",
+        tag="argsort_gpu",
+    )[1]
     return out
 
 
@@ -862,6 +859,56 @@ def schedule_topk(outs):
     return _schedule_sort(outs)
 
 
+def sort_by_key(keys, values, axis=-1, is_ascend=1):
+    """Sort values with respect to keys. Both keys and values will
+     be sorted and returned.
+
+    Parameters
+    ----------
+    keys: tvm.te.Tensor
+        The input keys.
+
+    values : tvm.te.Tensor,
+        The input values.
+
+    axis : int, optional
+        Axis long which to sort the input tensor.
+
+    is_ascend : boolean, optional
+        Whether to sort in ascending or descending order.
+
+    Returns
+    -------
+    keys_sorted : tvm.te.Tensor
+        The sorted keys
+
+    values_sorted : tvm.te.Tensor
+        The values sorted with respect to the keys
+    """
+    keys_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", 
data_alignment=8)
+    values_buf = tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", 
data_alignment=8)
+
+    out_bufs = [
+        tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", 
data_alignment=8),
+        tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", 
data_alignment=8),
+        tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_swap_buf", 
data_alignment=8),
+        tvm.tir.decl_buffer(values.shape, values.dtype, "values_swap_buf", 
data_alignment=8),
+    ]
+    out = te.extern(
+        [keys.shape, values.shape, keys.shape, values.shape],
+        [keys, values],
+        lambda ins, outs: sort_by_key_ir(
+            ins[0], ins[1], outs[0], outs[1], outs[2], outs[3], axis, is_ascend
+        ),
+        in_buffers=[keys_buf, values_buf],
+        out_buffers=out_bufs,
+        dtype=[keys.dtype, values.dtype],
+        name="sort_by_key",
+        tag="sort_by_key",
+    )
+    return out[0], out[1]
+
+
 def stable_sort_by_key_thrust(keys, values, for_scatter=False):
     """Sort values with respect to keys using thrust.
     Both keys and values will be sorted and returned.
diff --git a/tests/python/contrib/test_sort.py 
b/tests/python/contrib/test_sort.py
index 9d6eb7c..f338276 100644
--- a/tests/python/contrib/test_sort.py
+++ b/tests/python/contrib/test_sort.py
@@ -17,7 +17,7 @@
 import tvm
 import tvm.testing
 from tvm import te
-from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available
+from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available, 
sort_by_key
 import numpy as np
 
 
@@ -123,7 +123,40 @@ def test_thrust_stable_sort_by_key():
     tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, 
rtol=1e-5)
 
 
+def test_sort_by_key_gpu():
+    size = 6
+    keys = te.placeholder((size,), name="keys", dtype="int32")
+    values = te.placeholder((size,), name="values", dtype="int32")
+
+    for target in ["cuda", "nvptx", "opencl", "rocm"]:
+        if not tvm.testing.device_enabled(target):
+            print("Skip because %s is not enabled" % target)
+            continue
+
+        with tvm.target.Target(target):
+            keys_out, values_out = sort_by_key(keys, values)
+            ctx = tvm.context(target)
+            s = te.create_schedule([keys_out.op, values_out.op])
+            f = tvm.build(s, [keys, values, keys_out, values_out], target)
+
+            keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32)
+            values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32)
+            keys_np_out = np.zeros(keys_np.shape, np.int32)
+            values_np_out = np.zeros(values_np.shape, np.int32)
+            keys_in = tvm.nd.array(keys_np, ctx)
+            values_in = tvm.nd.array(values_np, ctx)
+            keys_out = tvm.nd.array(keys_np_out, ctx)
+            values_out = tvm.nd.array(values_np_out, ctx)
+            f(keys_in, values_in, keys_out, values_out)
+
+            ref_keys_out = np.sort(keys_np)
+            ref_values_out = np.array([values_np[i] for i in 
np.argsort(keys_np)])
+            tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, 
rtol=1e-5)
+            tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, 
rtol=1e-5)
+
+
 if __name__ == "__main__":
     test_sort()
     test_sort_np()
     test_thrust_stable_sort_by_key()
+    test_sort_by_key_gpu()

Reply via email to