This is an automated email from the ASF dual-hosted git repository.
mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 21fc3bb [TOPI] Use fixed thread block size in unique op for Vulkan
(#7718)
21fc3bb is described below
commit 21fc3bb08e2cc4928e1cd06f2280fe83431c80f0
Author: masahi <[email protected]>
AuthorDate: Tue Mar 23 00:29:21 2021 +0900
[TOPI] Use fixed thread block size in unique op for Vulkan (#7718)
* [TOPI] Use fixed thread block size in unique op for Vulkan
* forgot to add min for non vk backend
---
python/tvm/topi/cuda/unique.py | 15 ++++++++---
tests/python/unittest/test_target_codegen_spirv.py | 30 +++++++++++++++++-----
2 files changed, 35 insertions(+), 10 deletions(-)
diff --git a/python/tvm/topi/cuda/unique.py b/python/tvm/topi/cuda/unique.py
index 02a5cf3..2bca3c4 100644
--- a/python/tvm/topi/cuda/unique.py
+++ b/python/tvm/topi/cuda/unique.py
@@ -24,6 +24,15 @@ from .sort import sort, argsort
from ..utils import ceil_div
+def _get_max_threads(batch_size):
+ target = tvm.target.Target.current()
+ max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
+ if "vulkan" in str(target) and not isinstance(batch_size, tvm.tir.IntImm):
+ # SPIR-V does not support dynamic thread group size
+ return max_threads
+ return tir.min(batch_size, max_threads)
+
+
def _calc_adjacent_diff_ir(data, output, binop=tir.Sub):
"""Low level IR to calculate adjacent difference in an 1-D array.
@@ -46,7 +55,7 @@ def _calc_adjacent_diff_ir(data, output, binop=tir.Sub):
data_ptr = ib.buffer_ptr(data)
output_ptr = ib.buffer_ptr(output)
batch_size = data.shape[0]
- max_threads = tir.min(batch_size,
tvm.target.Target.current(allow_none=False).max_num_threads)
+ max_threads = _get_max_threads(batch_size)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size, max_threads)
@@ -157,7 +166,7 @@ def _calc_unique_ir(
unique_seq_indices_ptr = ib.buffer_ptr(indices)
batch_size = data.shape[0]
- max_threads = tir.min(batch_size,
tvm.target.Target.current(allow_none=False).max_num_threads)
+ max_threads = _get_max_threads(batch_size)
# if need to return counts
if isinstance(counts, tir.Buffer):
@@ -238,7 +247,7 @@ def _calc_first_occurence_ir(argsorted_indices, inc_scan,
first_occurence):
inc_scan_ptr = ib.buffer_ptr(inc_scan)
first_occurence_ptr = ib.buffer_ptr(first_occurence)
batch_size = argsorted_indices.shape[0]
- max_threads = tir.min(batch_size,
tvm.target.Target.current(allow_none=False).max_num_threads)
+ max_threads = _get_max_threads(batch_size)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size, max_threads)
diff --git a/tests/python/unittest/test_target_codegen_spirv.py
b/tests/python/unittest/test_target_codegen_spirv.py
index 68be5c4..bf47bbe 100644
--- a/tests/python/unittest/test_target_codegen_spirv.py
+++ b/tests/python/unittest/test_target_codegen_spirv.py
@@ -72,17 +72,18 @@ def test_bool_load():
tvm.testing.assert_allclose(b.asnumpy(), ref)
+def check_mod(mod, x_np, res_np):
+ target = "vulkan"
+ ctx = tvm.context(target, 0)
+ ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
+ res = ex.evaluate()(x_np).asnumpy()
+ tvm.testing.assert_allclose(res, res_np, atol=1e-5)
+
+
def test_pushconstants():
if not tvm.testing.device_enabled("vulkan"):
return
- def check_mod(mod, x_np, res_np):
- target = "vulkan"
- ctx = tvm.context(target, 0)
- ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
- res = ex.evaluate()(x_np).asnumpy()
- tvm.testing.assert_allclose(res, res_np, atol=1e-5)
-
# Three 32 bit pushconstants: any_dim, stride, stride
dtype = "float32"
x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
@@ -104,6 +105,21 @@ def test_pushconstants():
check_mod(mod, x_np, res_np)
+def test_unique():
+ if not tvm.testing.device_enabled("vulkan"):
+ return
+
+ dtype = "int32"
+ x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
+ mod = tvm.IRModule()
+ [unique, _, num_unique] = relay.unique(x, is_sorted=True)
+ mod["main"] = relay.Function([x], relay.op.strided_slice(unique,
begin=[0], end=num_unique))
+ x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype)
+ res_np = np.unique(x_np)
+ check_mod(mod, x_np, res_np)
+
+
if __name__ == "__main__":
test_bool_load()
test_pushconstants()
+ test_unique()