masahi commented on a change in pull request #7099:
URL: https://github.com/apache/tvm/pull/7099#discussion_r543301010



##########
File path: python/tvm/topi/cuda/sort.py
##########
@@ -94,64 +109,182 @@ def sort_ir(data, values_out, axis, is_ascend, 
indices_out=None):
             axis_mul_before *= value
         elif i > axis:
             axis_mul_after *= value
-    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+
     ib = tvm.tir.ir_builder.create()
+
     data = ib.buffer_ptr(data)
     values_out = ib.buffer_ptr(values_out)
+    values_out_swap = ib.buffer_ptr(values_out_swap)
     if indices_out is not None:
         indices_out = ib.buffer_ptr(indices_out)
-    nthread_tx = max_threads
-    nthread_bx = shape[axis] // max_threads + 1
+        assert indices_out_swap is not None
+        indices_out_swap = ib.buffer_ptr(indices_out_swap)
 
-    tx = te.thread_axis("threadIdx.x")
-    bx = te.thread_axis("blockIdx.x")
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "thread_extent", nthread_bx)
-    tid = bx * nthread_tx + tx
-    temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", 
scope="local")
-    if indices_out is not None:
-        temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", 
scope="local")
+    # Set up threading
+    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+    nthread_bx = ceil_div(shape[axis], max_threads)
+    nthread_by = axis_mul_before
+    nthread_bz = axis_mul_after
+
+    # Copy the data to initial output
+    with ib.new_scope():
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * nthread_tx + tx
+
+        by = te.thread_axis("blockIdx.y")
+        bz = te.thread_axis("blockIdx.z")
+        ib.scope_attr(by, "thread_extent", nthread_by)
+        ib.scope_attr(bz, "thread_extent", nthread_bz)
+        idx = (by * shape[axis] + tid) * axis_mul_after + bz
+        with ib.if_scope(tid < shape[axis]):
+            values_out[idx] = data[idx]
+            if indices_out is not None:
+                indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype)
+
+    ## we are looping over the array doing mergesort from the bottom up.
+    ## The outer loop runs on the host and launches a cuda kernel for each 
iteration
+    ## of the algorithm.
+    ## The basic idea is that at iteration 0, each thread does sort on 2 
elements.
+    ## On iteration 1, each thread merges 2 sorted arrays of 2 elements,
+    ## to deal with 4 total elements.
+    ## On iteration 2, each thread merges 2 sorted arrays of 4 elements,
+    ## to deal with 8 total elements. On iteration 3, each thread deals with 
16 elements, etc
+    ## On the final iteration of the algorithm, one thread will merge two 
sorted lists
+    ## to sort the entire array
+    lim = tvm.tir.generic.cast(
+        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], 
"float64"))), "int64"
+    )
+    with ib.for_range(0, lim, dtype="int64") as l2_width:
+        width = 2 << l2_width
+        # Define and launch the cuda kernel
+        with ib.new_scope():
+            i = ib.allocate("int64", (1,), name="i", scope="local")
+            j = ib.allocate("int64", (1,), name="j", scope="local")
+            start = ib.allocate("int64", (1,), name="start", scope="local")
+            middle = ib.allocate("int64", (1,), name="middle", scope="local")
+            end = ib.allocate("int64", (1,), name="end", scope="local")
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+            ib.scope_attr(tx, "thread_extent", nthread_tx)
+            # Reduce the number of blocks as the work per thread grows
+            ib.scope_attr(
+                bx,
+                "thread_extent",
+                tvm.tir.generic.cast(ceil_div(shape[axis], width * 
max_threads), "int32"),
+            )
+            tid = bx * nthread_tx + tx
+
+            by = te.thread_axis("blockIdx.y")
+            bz = te.thread_axis("blockIdx.z")
+            ib.scope_attr(by, "thread_extent", nthread_by)
+            ib.scope_attr(bz, "thread_extent", nthread_bz)
+
+            def compare(a, b):
+                """
+                Compare a and b in proper ascending or descending order
+                """
+                if is_ascend:
+                    out = a <= b
+                else:
+                    out = b <= a
+                return out
+
+            def BottomUpMerge(source, dest, source_idx, dest_idx, start, 
middle, end, even):

Review comment:
       bottom_up_merge




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to