mbrookhart commented on a change in pull request #7611:
URL: https://github.com/apache/tvm/pull/7611#discussion_r590506792



##########
File path: python/tvm/topi/cuda/sort.py
##########
@@ -136,93 +236,223 @@ def compare(a, b):
             out = b <= a
         return out
 
-    def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, 
end, even):
-        """
-        Merge the two sections of the array assigned to this thread
-        """
-        # pylint: disable=arguments-out-of-order
-        # initialize iterators
-        i[0] = start
-        j[0] = middle
-        # set up indexes
-        base_idx = by * size * axis_mul_after + bz
-        # iterate over the output loop
-        with ib.for_range(0, end - start) as k:
-            i_idx = base_idx + i[0] * axis_mul_after
-            j_idx = base_idx + j[0] * axis_mul_after
-            k_idx = base_idx + (k + start) * axis_mul_after
-
-            def swap_values(source, dest, source_idx, dest_idx):
-                def assign_i():
-                    """assign i value to current output"""
-                    dest[k_idx] = source[i_idx]
-                    if values is not None:
-                        dest_idx[k_idx] = source_idx[i_idx]
-                    i[0] += 1
-
-                def assign_j():
-                    """assign j value to current output"""
-                    dest[k_idx] = source[j_idx]
-                    if values is not None:
-                        dest_idx[k_idx] = source_idx[j_idx]
-                    j[0] += 1
-
-                ## if both of the iterators are in range
-                with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)):
-                    # compare them and insert whichever is next into the output
-                    with ib.if_scope(compare(source[i_idx], source[j_idx])):
-                        assign_i()
-                    with ib.else_scope():
-                        assign_j()
-                # otherwise, simply copy the remainder of the valid iterator 
to the output
-                with ib.else_scope():
-                    with ib.if_scope(i[0] < middle):
-                        assign_i()
-                    with ib.else_scope():
-                        assign_j()
+    # Sort the lower levels of the merge using odd-even sort, it's fast for 
small inputs
+    lower_lim = tvm.tir.generic.cast(
+        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, 
"float64"))), "int64"
+    )
 
-            # Switch which input is the source and which is the destination 
each iteration
-            with ib.if_scope(even):
-                swap_values(source, dest, source_idx, dest_idx)
-            with ib.else_scope():
-                swap_values(dest, source, dest_idx, source_idx)
-
-    def mergesort(source, dest, source_idx, dest_idx, size, width, even):
-        # calculate the start, mid, and end points of this section
-        start[0] = width * tid
-        with ib.if_scope(start[0] < size):
-            middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size)
-            end[0] = tvm.te.min(start[0] + width, size)
-            ## merge the start->middle and middle->end arrays
-            bottom_up_merge(source, dest, source_idx, dest_idx, start[0], 
middle[0], end[0], even)
-
-    lim = tvm.tir.generic.cast(
+    _odd_even_sort(
+        ib,
+        size,
+        axis_mul_before * axis_mul_after,
+        1,
+        is_ascend,
+        keys,
+        keys_swap,
+        values,
+        values_swap,
+    )
+
+    upper_lim = tvm.tir.generic.cast(
         tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), 
"int64"
     )
-    with ib.for_range(0, lim, dtype="int64") as l2_width:
-        width = 2 << l2_width
+
+    def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, 
diag, step_count):
+        first = ib.allocate("int64", (1,), name="first", scope="local")
+        mid = ib.allocate("int64", (1,), name="mid", scope="local")
+        last = ib.allocate("int64", (1,), name="last", scope="local")
+        first[0] = tvm.te.max(0, diag - bCount)
+        last[0] = tvm.te.min(diag, aCount)
+        with ib.while_loop(first[0] < last[0]):
+            mid[0] = (first[0] + last[0]) >> 1
+            a = source[base_idx + (aStart + mid[0])]
+            b = source[base_idx + (bStart + diag - 1 - mid[0])]
+            with ib.if_scope(compare(a, b)):
+                first[0] = mid[0] + 1
+            with ib.else_scope():
+                last[0] = mid[0]
+        return first, last

Review comment:
       I'm not sure it matters because this is IR builder, everything in this 
function ends up getting inlined, but I'll take another look at the style. :+1: 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to