This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 825dc1ffb5 [TOPI] Remove `blockIdx.z` in topi sort (#16977)
825dc1ffb5 is described below

commit 825dc1ffb51c25506600136d2ec8fb336f476c84
Author: Siyuan Feng <[email protected]>
AuthorDate: Fri May 10 21:08:17 2024 +0800

    [TOPI] Remove `blockIdx.z` in topi sort (#16977)
    
    As `blockIdx.z` is not allowed in WebGPU, this PR split `blockIdx.z`
    into `blockIdx.y` to support WebGPU
---
 python/tvm/topi/cuda/sort.py | 31 ++++++++++++++-----------------
 1 file changed, 14 insertions(+), 17 deletions(-)

diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index dc72aa8cc1..9151744b69 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -57,18 +57,16 @@ def _schedule_sort(outs):
     return s
 
 
-def _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz):
+def _get_threads(ib, nthread_tx, nthread_bx, nthread_by):
     tx = te.thread_axis("threadIdx.x")
     bx = te.thread_axis("blockIdx.x")
     ib.scope_attr(tx, "thread_extent", nthread_tx)
     ib.scope_attr(bx, "thread_extent", nthread_bx)
 
     by = te.thread_axis("blockIdx.y")
-    bz = te.thread_axis("blockIdx.z")
     ib.scope_attr(by, "thread_extent", nthread_by)
-    ib.scope_attr(bz, "thread_extent", nthread_bz)
 
-    return tx, bx, by, bz
+    return tx, bx, by
 
 
 def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, 
value_init_func=None):
@@ -87,13 +85,13 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, 
values_out=None, value_init_f
     max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
     nthread_bx = ceil_div(shape[axis], max_threads)
-    nthread_by = axis_mul_before
-    nthread_bz = axis_mul_after
+    nthread_by = axis_mul_before * axis_mul_after
 
     # Copy the keys_in to initial output
     with ib.new_scope():
-        tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, 
nthread_bz)
+        tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
         tid = bx * nthread_tx + tx
+        by, bz = by % axis_mul_before, by // axis_mul_before
         idx = (by * shape[axis] + tid) * axis_mul_after + bz
         with ib.if_scope(tid < shape[axis]):
             keys_out[idx] = keys_in[idx]
@@ -122,11 +120,11 @@ def _odd_even_sort(
 ):
     nthread_tx = block_size // 2
     nthread_bx = ceil_div(size, block_size)
-    nthread_by = axis_mul_before
-    nthread_bz = axis_mul_after
+    nthread_by = axis_mul_before * axis_mul_after
     with ib.new_scope():
         ib.scope_attr(tvm.tir.const(0), "hand_threaded", 0)
-        tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, 
nthread_bz)
+        tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
+        by, bz = by % axis_mul_before, by // axis_mul_before
         tid = 2 * tx
         start = bx * block_size
 
@@ -222,7 +220,6 @@ def _sort_common(
 
     max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_by = axis_mul_before * axis_mul_after
-    nthread_bz = 1
     nthread_tx = max_threads
     nthread_bx = ceil_div(size, nthread_tx)
 
@@ -334,12 +331,13 @@ def _sort_common(
                 ntx = max_threads
                 nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * 
thread_work), "int32")
                 nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
-                tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
+                tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz)
             else:
                 ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width), 
"int32")
                 nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * 
thread_work), "int32")
                 nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
-                tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
+                tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz)
+            by, bz = by % nthread_by, by // nthread_by
 
             def mergepath(
                 source,
@@ -471,8 +469,7 @@ def _sort_common(
                 width,
                 tvm.tir.indexmod(l2_width, 2) == 0,
             )
-    nthread_by = axis_mul_before
-    nthread_bz = axis_mul_after
+    nthread_by = axis_mul_before * axis_mul_after
     nthread_tx = max_threads
     nthread_bx = ceil_div(size, nthread_tx)
     ## if the final sorted data ended up in the swap, copy it to the real 
output
@@ -480,9 +477,9 @@ def _sort_common(
         tvm.tir.all(upper_lim > lower_lim, tvm.tir.indexmod(upper_lim - 
lower_lim, 2) == 1)
     ):
         with ib.new_scope():
-            tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, 
nthread_by, nthread_bz)
+            tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
             tid = bx * nthread_tx + tx
-            idx = (by * axis_mul_after + bz) * size + tid
+            idx = by * size + tid
             with ib.if_scope(tid < size):
                 keys[idx] = keys_swap[idx]
                 if values is not None:

Reply via email to