masahi commented on a change in pull request #6978:
URL: https://github.com/apache/tvm/pull/6978#discussion_r544676738
##########
File path: python/tvm/topi/cuda/sort.py
##########
@@ -316,6 +316,89 @@ def argsort_nms_thrust(data, valid_count, axis=-1,
is_ascend=1, dtype="float32")
return out[1]
+def sort(data, axis=-1, is_ascend=1):
+ """Performs sorting along the given axis and returns an array of
+ sorted values with teh same shape as the input data.
+
+ Parameters
+ ----------
+ data: tvm.te.Tensor
+ The input array.
+
+ axis : int, optional
+ Axis long which to sort the input tensor.
+
+ is_ascend : boolean, optional
+ Whether to sort in ascending or descending order.
+
+ Returns
+ -------
+ out : tvm.te.Tensor
+ The output of this function.
+ """
+ dtype = "float32"
+ value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf",
data_alignment=8)
+ indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf",
data_alignment=8)
+ out = te.extern(
+ [data.shape, data.shape],
+ [data],
+ lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend,
indices_out=outs[1]),
+ out_buffers=[value_buf, indices_buf],
+ name="sort_gpu",
+ tag="sort_gpu",
+ )[0]
+ return out
+
+
+def sort_thrust(data, axis=-1, is_ascend=1):
+ """Performs sorting along the given axis and returns an array of indicies
+ having same shape as an input array that index data in sorted order.
+
+ Parameters
+ ----------
+ data: tvm.te.Tensor
+ The input array.
+
+ axis : int, optional
+ Axis long which to sort the input tensor.
+
+ is_ascend : boolean, optional
+ Whether to sort in ascending or descending order.
+
+ Returns
+ -------
+ out : tvm.te.Tensor
+ The output of this function.
+ """
+ dtype = "float32"
+
+ ndim = len(data.shape)
+ axis = ndim + axis if axis < 0 else axis
+
+ if axis != ndim - 1:
+ # Prepare for sorting along axis -1.
+ axes = swap(list(range(ndim)), axis)
+ data = transpose(data, axes)
+
+ value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf",
data_alignment=8)
+ indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf",
data_alignment=8)
+ out = te.extern(
+ [data.shape, data.shape],
+ [data],
+ lambda ins, outs: tvm.tir.call_packed(
+ "tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend
+ ),
+ out_buffers=[value_buf, indices_buf],
Review comment:
Yeah that would be nice
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]