This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 825dc1ffb5 [TOPI] Remove `blockIdx.z` in topi sort (#16977)
825dc1ffb5 is described below
commit 825dc1ffb51c25506600136d2ec8fb336f476c84
Author: Siyuan Feng <[email protected]>
AuthorDate: Fri May 10 21:08:17 2024 +0800
[TOPI] Remove `blockIdx.z` in topi sort (#16977)
As `blockIdx.z` is not allowed in WebGPU, this PR split `blockIdx.z`
into `blockIdx.y` to support WebGPU
---
python/tvm/topi/cuda/sort.py | 31 ++++++++++++++-----------------
1 file changed, 14 insertions(+), 17 deletions(-)
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index dc72aa8cc1..9151744b69 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -57,18 +57,16 @@ def _schedule_sort(outs):
return s
-def _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz):
+def _get_threads(ib, nthread_tx, nthread_bx, nthread_by):
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
by = te.thread_axis("blockIdx.y")
- bz = te.thread_axis("blockIdx.z")
ib.scope_attr(by, "thread_extent", nthread_by)
- ib.scope_attr(bz, "thread_extent", nthread_bz)
- return tx, bx, by, bz
+ return tx, bx, by
def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None,
value_init_func=None):
@@ -87,13 +85,13 @@ def _sort_init(ib, shape, axis, keys_in, keys_out,
values_out=None, value_init_f
max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = ceil_div(shape[axis], max_threads)
- nthread_by = axis_mul_before
- nthread_bz = axis_mul_after
+ nthread_by = axis_mul_before * axis_mul_after
# Copy the keys_in to initial output
with ib.new_scope():
- tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by,
nthread_bz)
+ tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
tid = bx * nthread_tx + tx
+ by, bz = by % axis_mul_before, by // axis_mul_before
idx = (by * shape[axis] + tid) * axis_mul_after + bz
with ib.if_scope(tid < shape[axis]):
keys_out[idx] = keys_in[idx]
@@ -122,11 +120,11 @@ def _odd_even_sort(
):
nthread_tx = block_size // 2
nthread_bx = ceil_div(size, block_size)
- nthread_by = axis_mul_before
- nthread_bz = axis_mul_after
+ nthread_by = axis_mul_before * axis_mul_after
with ib.new_scope():
ib.scope_attr(tvm.tir.const(0), "hand_threaded", 0)
- tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by,
nthread_bz)
+ tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
+ by, bz = by % axis_mul_before, by // axis_mul_before
tid = 2 * tx
start = bx * block_size
@@ -222,7 +220,6 @@ def _sort_common(
max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_by = axis_mul_before * axis_mul_after
- nthread_bz = 1
nthread_tx = max_threads
nthread_bx = ceil_div(size, nthread_tx)
@@ -334,12 +331,13 @@ def _sort_common(
ntx = max_threads
nbx = tvm.tir.generic.cast(ceil_div(width, max_threads *
thread_work), "int32")
nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
- tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
+ tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz)
else:
ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width),
"int32")
nbx = tvm.tir.generic.cast(ceil_div(width, max_threads *
thread_work), "int32")
nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
- tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
+ tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz)
+ by, bz = by % nthread_by, by // nthread_by
def mergepath(
source,
@@ -471,8 +469,7 @@ def _sort_common(
width,
tvm.tir.indexmod(l2_width, 2) == 0,
)
- nthread_by = axis_mul_before
- nthread_bz = axis_mul_after
+ nthread_by = axis_mul_before * axis_mul_after
nthread_tx = max_threads
nthread_bx = ceil_div(size, nthread_tx)
## if the final sorted data ended up in the swap, copy it to the real
output
@@ -480,9 +477,9 @@ def _sort_common(
tvm.tir.all(upper_lim > lower_lim, tvm.tir.indexmod(upper_lim -
lower_lim, 2) == 1)
):
with ib.new_scope():
- tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx,
nthread_by, nthread_bz)
+ tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
tid = bx * nthread_tx + tx
- idx = (by * axis_mul_after + bz) * size + tid
+ idx = by * size + tid
with ib.if_scope(tid < size):
keys[idx] = keys_swap[idx]
if values is not None: