masahi commented on a change in pull request #7099:
URL: https://github.com/apache/tvm/pull/7099#discussion_r543301254



##########
File path: python/tvm/topi/cuda/sort.py
##########
@@ -94,64 +109,182 @@ def sort_ir(data, values_out, axis, is_ascend, 
indices_out=None):
             axis_mul_before *= value
         elif i > axis:
             axis_mul_after *= value
-    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+
     ib = tvm.tir.ir_builder.create()
+
     data = ib.buffer_ptr(data)
     values_out = ib.buffer_ptr(values_out)
+    values_out_swap = ib.buffer_ptr(values_out_swap)
     if indices_out is not None:
         indices_out = ib.buffer_ptr(indices_out)
-    nthread_tx = max_threads
-    nthread_bx = shape[axis] // max_threads + 1
+        assert indices_out_swap is not None
+        indices_out_swap = ib.buffer_ptr(indices_out_swap)
 
-    tx = te.thread_axis("threadIdx.x")
-    bx = te.thread_axis("blockIdx.x")
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "thread_extent", nthread_bx)
-    tid = bx * nthread_tx + tx
-    temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", 
scope="local")
-    if indices_out is not None:
-        temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", 
scope="local")
+    # Set up threading
+    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+    nthread_bx = ceil_div(shape[axis], max_threads)
+    nthread_by = axis_mul_before
+    nthread_bz = axis_mul_after
+
+    # Copy the data to initial output
+    with ib.new_scope():
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * nthread_tx + tx
+
+        by = te.thread_axis("blockIdx.y")
+        bz = te.thread_axis("blockIdx.z")
+        ib.scope_attr(by, "thread_extent", nthread_by)
+        ib.scope_attr(bz, "thread_extent", nthread_bz)
+        idx = (by * shape[axis] + tid) * axis_mul_after + bz
+        with ib.if_scope(tid < shape[axis]):
+            values_out[idx] = data[idx]
+            if indices_out is not None:
+                indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype)
+
+    ## we are looping over the array doing mergesort from the bottom up.
+    ## The outer loop runs on the host and launches a cuda kernel for each 
iteration
+    ## of the algorithm.
+    ## The basic idea is that at iteration 0, each thread does sort on 2 
elements.
+    ## On iteration 1, each thread merges 2 sorted arrays of 2 elements,
+    ## to deal with 4 total elements.
+    ## On iteration 2, each thread merges 2 sorted arrays of 4 elements,
+    ## to deal with 8 total elements. On iteration 3, each thread deals with 
16 elements, etc
+    ## On the final iteration of the algorithm, one thread will merge two 
sorted lists
+    ## to sort the entire array
+    lim = tvm.tir.generic.cast(
+        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], 
"float64"))), "int64"
+    )
+    with ib.for_range(0, lim, dtype="int64") as l2_width:
+        width = 2 << l2_width
+        # Define and launch the cuda kernel
+        with ib.new_scope():
+            i = ib.allocate("int64", (1,), name="i", scope="local")
+            j = ib.allocate("int64", (1,), name="j", scope="local")
+            start = ib.allocate("int64", (1,), name="start", scope="local")
+            middle = ib.allocate("int64", (1,), name="middle", scope="local")
+            end = ib.allocate("int64", (1,), name="end", scope="local")
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+            ib.scope_attr(tx, "thread_extent", nthread_tx)
+            # Reduce the number of blocks as the work per thread grows
+            ib.scope_attr(
+                bx,
+                "thread_extent",
+                tvm.tir.generic.cast(ceil_div(shape[axis], width * 
max_threads), "int32"),
+            )
+            tid = bx * nthread_tx + tx
+
+            by = te.thread_axis("blockIdx.y")
+            bz = te.thread_axis("blockIdx.z")
+            ib.scope_attr(by, "thread_extent", nthread_by)
+            ib.scope_attr(bz, "thread_extent", nthread_bz)
+
+            def compare(a, b):
+                """
+                Compare a and b in proper ascending or descending order
+                """
+                if is_ascend:
+                    out = a <= b
+                else:
+                    out = b <= a
+                return out
+
+            def BottomUpMerge(source, dest, source_idx, dest_idx, start, 
middle, end, even):
+                """
+                Merge the two sections of the array assigned to this thread
+                """
+                # pylint: disable=arguments-out-of-order
+                # initialize iterators
+                i[0] = start
+                j[0] = middle
+                # set up indexes
+                base_idx = by * shape[axis] * axis_mul_after + bz
+                # iterate over the output loop
+                with ib.for_range(0, end - start) as k:
+                    i_idx = base_idx + i[0] * axis_mul_after
+                    j_idx = base_idx + j[0] * axis_mul_after
+                    k_idx = base_idx + (k + start) * axis_mul_after
+
+                    def swap_values(source, dest, source_idx, dest_idx):
+                        def assign_i():
+                            """assign i value to current output"""
+                            dest[k_idx] = source[i_idx]
+                            if indices_out is not None:
+                                dest_idx[k_idx] = source_idx[i_idx]
+                            i[0] += 1
+
+                        def assign_j():
+                            """assign j value to current output"""
+                            dest[k_idx] = source[j_idx]
+                            if indices_out is not None:
+                                dest_idx[k_idx] = source_idx[j_idx]
+                            j[0] += 1
+
+                        ## if both of the iterators are in range
+                        with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < 
end)):
+                            # compare them and insert whichever is next into 
the output
+                            with ib.if_scope(compare(source[i_idx], 
source[j_idx])):
+                                assign_i()
+                            with ib.else_scope():
+                                assign_j()
+                        # otherwise, simply copy the remainder of the valid 
iterator to the output
+                        with ib.else_scope():
+                            with ib.if_scope(i[0] < middle):
+                                assign_i()
+                            with ib.else_scope():
+                                assign_j()
+
+                    # Switch which input is the source and which is the 
destination each iteration
+                    with ib.if_scope(even):
+                        swap_values(source, dest, source_idx, dest_idx)
+                    with ib.else_scope():
+                        swap_values(dest, source, dest_idx, source_idx)
+
+            def MergeSort(source, dest, source_idx, dest_idx, size, width, 
even):

Review comment:
       merge_sort




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to