This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9956b5b [TOPI] GPU sort IR refactor to enable sort by keys (#7157)
9956b5b is described below
commit 9956b5b859a24c8f4dec1abf308723b1257ffc66
Author: masahi <[email protected]>
AuthorDate: Thu Dec 24 07:33:47 2020 +0900
[TOPI] GPU sort IR refactor to enable sort by keys (#7157)
* sort refactor initial import
* sort test working
* scatter 1d with positive indices working
* remove negatiev indices, using extern for now
* minor fix
* minor fix
* add sort by key test
* revert scatter change
* add document
* fix py format
Co-authored-by: masa <[email protected]>
---
python/tvm/topi/cuda/nms.py | 4 +-
python/tvm/topi/cuda/scatter.py | 12 +-
python/tvm/topi/cuda/sort.py | 561 +++++++++++++++++++++-----------------
tests/python/contrib/test_sort.py | 35 ++-
4 files changed, 343 insertions(+), 269 deletions(-)
diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index cea287e..020cf9b 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -737,9 +737,7 @@ def non_max_suppression(
score_tensor, valid_count=None, axis=1, is_ascend=False,
dtype=valid_count_dtype
)
else:
- sort_tensor = argsort(
- score_tensor, valid_count=None, axis=1, is_ascend=False,
dtype=valid_count_dtype
- )
+ sort_tensor = argsort(score_tensor, axis=1, is_ascend=False,
dtype=valid_count_dtype)
sort_tensor_buf = tvm.tir.decl_buffer(
sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf",
data_alignment=8
diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py
index 9916e2a..be602c8 100644
--- a/python/tvm/topi/cuda/scatter.py
+++ b/python/tvm/topi/cuda/scatter.py
@@ -424,6 +424,8 @@ def gen_scatter_1d_thrust(data, indices_sorted,
updates_sorted, axis, out, _):
Sorting of indices, and sorting of updates with respect to indices, can be
done
at the same time by thrust's sort_by_key function. It is important that
sorting
be done in a "stable" way via stable_sort, to guarantee deterministic
output.
+ Negative indices are assumed to have been converted to corresponding
positive
+ indices.
Parameters
----------
@@ -473,12 +475,6 @@ def gen_scatter_1d_thrust(data, indices_sorted,
updates_sorted, axis, out, _):
ni = indices_sorted.shape[0]
- def do_update(ib, index, update):
- with ib.if_scope(index < 0):
- out_ptr[index + n] = update
- with ib.else_scope():
- out_ptr[index] = update
-
with ib.new_scope():
nthread_bx = ceil_div(ni, nthread_tx)
tx = te.thread_axis("threadIdx.x")
@@ -491,7 +487,7 @@ def gen_scatter_1d_thrust(data, indices_sorted,
updates_sorted, axis, out, _):
# The last element can always update.
index = indices_ptr[tid]
update = updates_ptr[tid]
- do_update(ib, index, update)
+ out_ptr[index] = update
with ib.else_scope():
with ib.if_scope(tid < ni - 1):
@@ -503,7 +499,7 @@ def gen_scatter_1d_thrust(data, indices_sorted,
updates_sorted, axis, out, _):
# This thread can update the output.
with ib.if_scope(index != index_next):
update = updates_ptr[tid]
- do_update(ib, index, update)
+ out_ptr[index] = update
return ib.get()
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index 039ebe3..18872a2 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -21,7 +21,6 @@ from tvm import te
from tvm._ffi import get_global_func
from .injective import schedule_injective_from_existing
-from ..math import identity
from ..transform import strided_slice, transpose
from .. import tag
@@ -62,46 +61,14 @@ def _schedule_sort(outs):
return s
-def sort_ir(
- data, values_out, values_out_swap, axis, is_ascend, indices_out=None,
indices_out_swap=None
-):
- """Low level IR to do nms sorting on the GPU, same usage as
tvm.contrib.sort.argsort on the CPU.
-
- Parameters
- ----------
- data: Buffer
- Buffer of input data. Data will be sorted in place.
-
- values_out : Buffer
- Output buffer of values of sorted tensor with same shape as data.
-
- values_out_swap : Buffer
- Output buffer of values with same shape as data to use as swap.
-
- axis : Int
- Axis long which to sort the input tensor.
-
- is_ascend : Boolean
- Whether to sort in ascending or descending order.
-
- indicess_out : Buffer
- Output buffer of indices of sorted tensor with same shape as data.
-
- indices_out_swap : Buffer
- Output buffer of indices with same shape as data to use as swap.
-
- Returns
- -------
- stmt : Stmt
- The result IR statement.
- """
+def ceil_div(a, b):
+ return tvm.tir.indexdiv(a + b - 1, b)
- def ceil_div(a, b):
- return tvm.tir.indexdiv(a + b - 1, b)
+def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None,
value_init_func=None):
+ """Initialize the output buffers by copying from inputs"""
axis_mul_before = 1
axis_mul_after = 1
- shape = data.shape
if axis < 0:
axis = len(shape) + axis
for i, value in enumerate(shape, 0):
@@ -110,16 +77,6 @@ def sort_ir(
elif i > axis:
axis_mul_after *= value
- ib = tvm.tir.ir_builder.create()
-
- data = ib.buffer_ptr(data)
- values_out = ib.buffer_ptr(values_out)
- values_out_swap = ib.buffer_ptr(values_out_swap)
- if indices_out is not None:
- indices_out = ib.buffer_ptr(indices_out)
- assert indices_out_swap is not None
- indices_out_swap = ib.buffer_ptr(indices_out_swap)
-
# Set up threading
max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
@@ -127,7 +84,7 @@ def sort_ir(
nthread_by = axis_mul_before
nthread_bz = axis_mul_after
- # Copy the data to initial output
+ # Copy the keys_in to initial output
with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
@@ -141,9 +98,25 @@ def sort_ir(
ib.scope_attr(bz, "thread_extent", nthread_bz)
idx = (by * shape[axis] + tid) * axis_mul_after + bz
with ib.if_scope(tid < shape[axis]):
- values_out[idx] = data[idx]
- if indices_out is not None:
- indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype)
+ keys_out[idx] = keys_in[idx]
+ if values_out is not None:
+ values_out[idx] = value_init_func(idx, tid)
+
+ return axis_mul_before, axis_mul_after
+
+
+def _sort_common(
+ ib,
+ size,
+ axis_mul_before,
+ axis_mul_after,
+ is_ascend,
+ keys,
+ keys_swap,
+ values=None,
+ values_swap=None,
+):
+ """Either sort only values or sort values by keys."""
## we are looping over the array doing mergesort from the bottom up.
## The outer loop runs on the host and launches a cuda kernel for each
iteration
@@ -155,8 +128,85 @@ def sort_ir(
## to deal with 8 total elements. On iteration 3, each thread deals with
16 elements, etc
## On the final iteration of the algorithm, one thread will merge two
sorted lists
## to sort the entire array
+
+ max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ nthread_tx = max_threads
+ nthread_bx = ceil_div(size, max_threads)
+ nthread_by = axis_mul_before
+ nthread_bz = axis_mul_after
+
+ def compare(a, b):
+ """
+ Compare a and b in proper ascending or descending order
+ """
+ if is_ascend:
+ out = a <= b
+ else:
+ out = b <= a
+ return out
+
+ def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle,
end, even):
+ """
+ Merge the two sections of the array assigned to this thread
+ """
+ # pylint: disable=arguments-out-of-order
+ # initialize iterators
+ i[0] = start
+ j[0] = middle
+ # set up indexes
+ base_idx = by * size * axis_mul_after + bz
+ # iterate over the output loop
+ with ib.for_range(0, end - start) as k:
+ i_idx = base_idx + i[0] * axis_mul_after
+ j_idx = base_idx + j[0] * axis_mul_after
+ k_idx = base_idx + (k + start) * axis_mul_after
+
+ def swap_values(source, dest, source_idx, dest_idx):
+ def assign_i():
+ """assign i value to current output"""
+ dest[k_idx] = source[i_idx]
+ if values is not None:
+ dest_idx[k_idx] = source_idx[i_idx]
+ i[0] += 1
+
+ def assign_j():
+ """assign j value to current output"""
+ dest[k_idx] = source[j_idx]
+ if values is not None:
+ dest_idx[k_idx] = source_idx[j_idx]
+ j[0] += 1
+
+ ## if both of the iterators are in range
+ with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)):
+ # compare them and insert whichever is next into the output
+ with ib.if_scope(compare(source[i_idx], source[j_idx])):
+ assign_i()
+ with ib.else_scope():
+ assign_j()
+ # otherwise, simply copy the remainder of the valid iterator
to the output
+ with ib.else_scope():
+ with ib.if_scope(i[0] < middle):
+ assign_i()
+ with ib.else_scope():
+ assign_j()
+
+ # Switch which input is the source and which is the destination
each iteration
+ with ib.if_scope(even):
+ swap_values(source, dest, source_idx, dest_idx)
+ with ib.else_scope():
+ swap_values(dest, source, dest_idx, source_idx)
+
+ def mergesort(source, dest, source_idx, dest_idx, size, width, even):
+ # calculate the start, mid, and end points of this section
+ start[0] = width * tid
+ with ib.if_scope(start[0] < size):
+ middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size)
+ end[0] = tvm.te.min(start[0] + width, size)
+ ## merge the start->middle and middle->end arrays
+ bottom_up_merge(source, dest, source_idx, dest_idx, start[0],
middle[0], end[0], even)
+
lim = tvm.tir.generic.cast(
- tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis],
"float64"))), "int64"
+ tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))),
"int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width
@@ -174,7 +224,7 @@ def sort_ir(
ib.scope_attr(
bx,
"thread_extent",
- tvm.tir.generic.cast(ceil_div(shape[axis], width *
max_threads), "int32"),
+ tvm.tir.generic.cast(ceil_div(size, width * max_threads),
"int32"),
)
tid = bx * nthread_tx + tx
@@ -183,85 +233,13 @@ def sort_ir(
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)
- def compare(a, b):
- """
- Compare a and b in proper ascending or descending order
- """
- if is_ascend:
- out = a <= b
- else:
- out = b <= a
- return out
-
- def bottom_up_merge(source, dest, source_idx, dest_idx, start,
middle, end, even):
- """
- Merge the two sections of the array assigned to this thread
- """
- # pylint: disable=arguments-out-of-order
- # initialize iterators
- i[0] = start
- j[0] = middle
- # set up indexes
- base_idx = by * shape[axis] * axis_mul_after + bz
- # iterate over the output loop
- with ib.for_range(0, end - start) as k:
- i_idx = base_idx + i[0] * axis_mul_after
- j_idx = base_idx + j[0] * axis_mul_after
- k_idx = base_idx + (k + start) * axis_mul_after
-
- def swap_values(source, dest, source_idx, dest_idx):
- def assign_i():
- """assign i value to current output"""
- dest[k_idx] = source[i_idx]
- if indices_out is not None:
- dest_idx[k_idx] = source_idx[i_idx]
- i[0] += 1
-
- def assign_j():
- """assign j value to current output"""
- dest[k_idx] = source[j_idx]
- if indices_out is not None:
- dest_idx[k_idx] = source_idx[j_idx]
- j[0] += 1
-
- ## if both of the iterators are in range
- with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] <
end)):
- # compare them and insert whichever is next into
the output
- with ib.if_scope(compare(source[i_idx],
source[j_idx])):
- assign_i()
- with ib.else_scope():
- assign_j()
- # otherwise, simply copy the remainder of the valid
iterator to the output
- with ib.else_scope():
- with ib.if_scope(i[0] < middle):
- assign_i()
- with ib.else_scope():
- assign_j()
-
- # Switch which input is the source and which is the
destination each iteration
- with ib.if_scope(even):
- swap_values(source, dest, source_idx, dest_idx)
- with ib.else_scope():
- swap_values(dest, source, dest_idx, source_idx)
-
- def mergesort(source, dest, source_idx, dest_idx, size, width,
even):
- # calculate the start, mid, and end points of this section
- start[0] = width * tid
- with ib.if_scope(start[0] < size):
- middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width,
2), size)
- end[0] = tvm.te.min(start[0] + width, size)
- ## merge the start->middle and middle->end arrays
- bottom_up_merge(
- source, dest, source_idx, dest_idx, start[0],
middle[0], end[0], even
- )
-
# Call the kernel
mergesort(
- values_out,
- values_out_swap,
- indices_out,
- indices_out_swap,
- shape[axis],
+ keys,
+ keys_swap,
+ values,
+ values_swap,
+ size,
width,
tvm.tir.indexmod(l2_width, 2) == 0,
)
@@ -279,29 +257,31 @@ def sort_ir(
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)
- idx = (by * shape[axis] + tid) * axis_mul_after + bz
- with ib.if_scope(tid < shape[axis]):
- idx = (by * shape[axis] + tid) * axis_mul_after + bz
- values_out[idx] = values_out_swap[idx]
- if indices_out is not None:
- indices_out[idx] = indices_out_swap[idx]
+ idx = (by * size + tid) * axis_mul_after + bz
+ with ib.if_scope(tid < size):
+ idx = (by * size + tid) * axis_mul_after + bz
+ keys[idx] = keys_swap[idx]
+ if values is not None:
+ values[idx] = values_swap[idx]
return ib.get()
-def sort_nms_ir(data, valid_count, output, axis, is_ascend):
- """Low level IR to do nms sorting on the GPU, same usage as
tvm.contrib.sort.argsort on the CPU.
+def sort_ir(
+ data, values_out, values_out_swap, axis, is_ascend, indices_out=None,
indices_out_swap=None
+):
+ """Low level IR to do sorting on the GPU, same usage as
tvm.contrib.sort.argsort on the CPU.
Parameters
----------
data: Buffer
- Buffer of input data.
+ Buffer of input data. Data will be sorted in place.
- valid_count : Buffer
- 1D Buffer of number of valid number of boxes.
+ values_out : Buffer
+ Output buffer of values of sorted tensor with same shape as data.
- output : Buffer
- Output buffer of indicies of sorted tensor with same shape as data.
+ values_out_swap : Buffer
+ Output buffer of values with same shape as data to use as swap.
axis : Int
Axis long which to sort the input tensor.
@@ -309,82 +289,124 @@ def sort_nms_ir(data, valid_count, output, axis,
is_ascend):
is_ascend : Boolean
Whether to sort in ascending or descending order.
+ indicess_out : Buffer
+ Output buffer of indices of sorted tensor with same shape as data.
+
+ indices_out_swap : Buffer
+ Output buffer of indices with same shape as data to use as swap.
+
Returns
-------
stmt : Stmt
The result IR statement.
"""
-
- size = 1
- axis_mul_before = 1
- axis_mul_after = 1
- shape = data.shape
- if axis < 0:
- axis = len(shape) + axis
- for i, value in enumerate(shape, 0):
- size *= value
- if i < axis:
- axis_mul_before *= value
- elif i > axis:
- axis_mul_after *= value
- max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
ib = tvm.tir.ir_builder.create()
+ shape = data.shape
+
data = ib.buffer_ptr(data)
- valid_count = ib.buffer_ptr(valid_count)
- output = ib.buffer_ptr(output)
- nthread_tx = max_threads
- nthread_bx = size // max_threads + 1
- tx = te.thread_axis("threadIdx.x")
- bx = te.thread_axis("blockIdx.x")
- ib.scope_attr(tx, "thread_extent", nthread_tx)
- ib.scope_attr(bx, "thread_extent", nthread_bx)
- tid = bx * nthread_tx + tx
- temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
- temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
- is_ascend = tvm.tir.IntImm("int32", is_ascend)
-
- idxd = tvm.tir.indexdiv
- idxm = tvm.tir.indexmod
-
- with ib.for_range(0, axis_mul_before) as i:
- with ib.for_range(0, axis_mul_after) as j:
- current_sort_num = valid_count[i * axis_mul_after + j]
- base_idx = i * shape[axis] * axis_mul_after + j
- with ib.if_scope(tid < shape[axis]):
- output[base_idx + tid * axis_mul_after] = tid
- # OddEvenTransposeSort
- with ib.for_range(0, current_sort_num) as k:
- with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
- offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
- with ib.if_scope(
- tvm.tir.all(
- is_ascend == 1,
- 2 * tid + idxm(k, 2) + 1 < current_sort_num,
- data[offset] > data[offset + axis_mul_after],
- )
- ):
- temp_data[0] = data[offset]
- data[offset] = data[offset + axis_mul_after]
- data[offset + axis_mul_after] = temp_data[0]
- temp_index[0] = output[offset]
- output[offset] = output[offset + axis_mul_after]
- output[offset + axis_mul_after] = temp_index[0]
- with ib.if_scope(
- tvm.tir.all(
- is_ascend == 0,
- 2 * tid + idxm(k, 2) + 1 < current_sort_num,
- data[offset] < data[offset + axis_mul_after],
- )
- ):
- temp_data[0] = data[offset]
- data[offset] = data[offset + axis_mul_after]
- data[offset + axis_mul_after] = temp_data[0]
- temp_index[0] = output[offset]
- output[offset] = output[offset + axis_mul_after]
- output[offset + axis_mul_after] = temp_index[0]
- ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync",
tvm.runtime.convert(["shared"])))
+ values_out = ib.buffer_ptr(values_out)
+ values_out_swap = ib.buffer_ptr(values_out_swap)
+ if indices_out is not None:
+ indices_out = ib.buffer_ptr(indices_out)
+ assert indices_out_swap is not None
+ indices_out_swap = ib.buffer_ptr(indices_out_swap)
- return ib.get()
+ axis_mul_before, axis_mul_after = _sort_init(
+ ib,
+ shape,
+ axis,
+ data,
+ values_out,
+ indices_out,
+ value_init_func=lambda _, tid: tvm.tir.generic.cast(tid,
indices_out.dtype),
+ )
+
+ return _sort_common(
+ ib,
+ shape[axis],
+ axis_mul_before,
+ axis_mul_after,
+ is_ascend,
+ values_out,
+ values_out_swap,
+ values=indices_out,
+ values_swap=indices_out_swap,
+ )
+
+
+def sort_by_key_ir(
+ keys_in, values_in, keys_out, values_out, keys_out_swap, values_out_swap,
axis, is_ascend
+):
+ """Low level IR to do sort by key on the GPU.
+
+ Parameters
+ ----------
+ keys_in: Buffer
+ Buffer of input keys.
+
+ values_in: Buffer
+ Buffer of input keys.
+
+ keys_out : Buffer
+ Buffer of output sorted keys.
+
+ values_out : Buffer
+ Buffer of output sorted values.
+
+ keys_out_swap : Buffer
+ Output buffer of values with same shape as keys_in to use as swap.
+
+ values_out_swap : Buffer
+ Output buffer of values with same shape as values_in to use as swap.
+
+ axis : Int
+ Axis long which to sort the input tensor.
+
+ is_ascend : Boolean
+ Whether to sort in ascending or descending order.
+
+ indicess_out : Buffer
+ Output buffer of indices of sorted tensor with same shape as keys_in.
+
+ values_out_swap : Buffer
+ Output buffer of indices with same shape as keys_in to use as swap.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ ib = tvm.tir.ir_builder.create()
+ shape = keys_in.shape
+
+ keys_in = ib.buffer_ptr(keys_in)
+ values_in = ib.buffer_ptr(values_in)
+ keys_out = ib.buffer_ptr(keys_out)
+ keys_out_swap = ib.buffer_ptr(keys_out_swap)
+ values_out = ib.buffer_ptr(values_out)
+ values_out_swap = ib.buffer_ptr(values_out_swap)
+
+ axis_mul_before, axis_mul_after = _sort_init(
+ ib,
+ shape,
+ axis,
+ keys_in,
+ keys_out,
+ values_out,
+ value_init_func=lambda idx, _: values_in[idx],
+ )
+
+ return _sort_common(
+ ib,
+ shape[axis],
+ axis_mul_before,
+ axis_mul_after,
+ is_ascend,
+ keys_out,
+ keys_out_swap,
+ values=values_out,
+ values_swap=values_out_swap,
+ )
def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1,
dtype="float32"):
@@ -534,7 +556,7 @@ def sort_thrust(data, axis=-1, is_ascend=1):
return out
-def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
+def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
@@ -543,9 +565,6 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1,
dtype="float32"):
data: tvm.te.Tensor
The input array.
- valid_count : tvm.te.Tensor, optional
- The number of valid elements to be sorted.
-
axis : int, optional
Axis long which to sort the input tensor.
@@ -560,48 +579,26 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1,
dtype="float32"):
out : tvm.te.Tensor
The output of this function.
"""
- if valid_count is not None:
- sorted_data = identity(data)
- sorted_data_buf = tvm.tir.decl_buffer(
- data.shape, data.dtype, "sorted_data_buf", data_alignment=8
- )
- valid_count_buf = tvm.tir.decl_buffer(
- valid_count.shape, valid_count.dtype, "valid_count_buf",
data_alignment=4
- )
- out_buf = tvm.tir.decl_buffer(data.shape, "int32", "out_buf",
data_alignment=4)
- out = te.extern(
- [data.shape],
- [sorted_data, valid_count],
- lambda ins, outs: sort_nms_ir(ins[0], ins[1], outs[0], axis,
is_ascend),
- dtype="int32",
- in_buffers=[sorted_data_buf, valid_count_buf],
- out_buffers=[out_buf],
- name="argsort_nms_gpu",
- tag="argsort_nms_gpu",
- )
- else:
- value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf",
data_alignment=8)
- value_swap_buf = tvm.tir.decl_buffer(
- data.shape, data.dtype, "value_swap_buf", data_alignment=8
- )
- indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf",
data_alignment=8)
- indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype,
"out_swap_buf", data_alignment=8)
- out = te.extern(
- [data.shape, data.shape, data.shape, data.shape],
- [data],
- lambda ins, outs: sort_ir(
- ins[0],
- outs[0],
- outs[2],
- axis,
- is_ascend,
- indices_out=outs[1],
- indices_out_swap=outs[3],
- ),
- out_buffers=[value_buf, indices_buf, value_swap_buf,
indices_swap_buf],
- name="argsort_gpu",
- tag="argsort_gpu",
- )[1]
+ value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf",
data_alignment=8)
+ value_swap_buf = tvm.tir.decl_buffer(data.shape, data.dtype,
"value_swap_buf", data_alignment=8)
+ indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf",
data_alignment=8)
+ indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf",
data_alignment=8)
+ out = te.extern(
+ [data.shape, data.shape, data.shape, data.shape],
+ [data],
+ lambda ins, outs: sort_ir(
+ ins[0],
+ outs[0],
+ outs[2],
+ axis,
+ is_ascend,
+ indices_out=outs[1],
+ indices_out_swap=outs[3],
+ ),
+ out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf],
+ name="argsort_gpu",
+ tag="argsort_gpu",
+ )[1]
return out
@@ -862,6 +859,56 @@ def schedule_topk(outs):
return _schedule_sort(outs)
+def sort_by_key(keys, values, axis=-1, is_ascend=1):
+ """Sort values with respect to keys. Both keys and values will
+ be sorted and returned.
+
+ Parameters
+ ----------
+ keys: tvm.te.Tensor
+ The input keys.
+
+ values : tvm.te.Tensor,
+ The input values.
+
+ axis : int, optional
+ Axis long which to sort the input tensor.
+
+ is_ascend : boolean, optional
+ Whether to sort in ascending or descending order.
+
+ Returns
+ -------
+ keys_sorted : tvm.te.Tensor
+ The sorted keys
+
+ values_sorted : tvm.te.Tensor
+ The values sorted with respect to the keys
+ """
+ keys_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf",
data_alignment=8)
+ values_buf = tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf",
data_alignment=8)
+
+ out_bufs = [
+ tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf",
data_alignment=8),
+ tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf",
data_alignment=8),
+ tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_swap_buf",
data_alignment=8),
+ tvm.tir.decl_buffer(values.shape, values.dtype, "values_swap_buf",
data_alignment=8),
+ ]
+ out = te.extern(
+ [keys.shape, values.shape, keys.shape, values.shape],
+ [keys, values],
+ lambda ins, outs: sort_by_key_ir(
+ ins[0], ins[1], outs[0], outs[1], outs[2], outs[3], axis, is_ascend
+ ),
+ in_buffers=[keys_buf, values_buf],
+ out_buffers=out_bufs,
+ dtype=[keys.dtype, values.dtype],
+ name="sort_by_key",
+ tag="sort_by_key",
+ )
+ return out[0], out[1]
+
+
def stable_sort_by_key_thrust(keys, values, for_scatter=False):
"""Sort values with respect to keys using thrust.
Both keys and values will be sorted and returned.
diff --git a/tests/python/contrib/test_sort.py
b/tests/python/contrib/test_sort.py
index 9d6eb7c..f338276 100644
--- a/tests/python/contrib/test_sort.py
+++ b/tests/python/contrib/test_sort.py
@@ -17,7 +17,7 @@
import tvm
import tvm.testing
from tvm import te
-from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available
+from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available,
sort_by_key
import numpy as np
@@ -123,7 +123,40 @@ def test_thrust_stable_sort_by_key():
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out,
rtol=1e-5)
+def test_sort_by_key_gpu():
+ size = 6
+ keys = te.placeholder((size,), name="keys", dtype="int32")
+ values = te.placeholder((size,), name="values", dtype="int32")
+
+ for target in ["cuda", "nvptx", "opencl", "rocm"]:
+ if not tvm.testing.device_enabled(target):
+ print("Skip because %s is not enabled" % target)
+ continue
+
+ with tvm.target.Target(target):
+ keys_out, values_out = sort_by_key(keys, values)
+ ctx = tvm.context(target)
+ s = te.create_schedule([keys_out.op, values_out.op])
+ f = tvm.build(s, [keys, values, keys_out, values_out], target)
+
+ keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32)
+ values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32)
+ keys_np_out = np.zeros(keys_np.shape, np.int32)
+ values_np_out = np.zeros(values_np.shape, np.int32)
+ keys_in = tvm.nd.array(keys_np, ctx)
+ values_in = tvm.nd.array(values_np, ctx)
+ keys_out = tvm.nd.array(keys_np_out, ctx)
+ values_out = tvm.nd.array(values_np_out, ctx)
+ f(keys_in, values_in, keys_out, values_out)
+
+ ref_keys_out = np.sort(keys_np)
+ ref_values_out = np.array([values_np[i] for i in
np.argsort(keys_np)])
+ tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out,
rtol=1e-5)
+ tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out,
rtol=1e-5)
+
+
if __name__ == "__main__":
test_sort()
test_sort_np()
test_thrust_stable_sort_by_key()
+ test_sort_by_key_gpu()