This is an automated email from the ASF dual-hosted git repository. wuwei pushed a commit to branch vk-i64 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 1bea761d9423d1a79f05caaa9d043a106be2bfe4 Author: Masahiro Masuda <masahi...@gmail.com> AuthorDate: Wed Mar 3 08:18:24 2021 +0900 test cumsum on vulkan --- tests/python/topi/python/test_topi_cumsum.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py index a01a496..bf962d9 100644 --- a/tests/python/topi/python/test_topi_cumsum.py +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -28,6 +28,7 @@ def test_cumsum(ctx, target): "generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern), "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), + "vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule) @@ -40,8 +41,10 @@ def test_cumsum(ctx, target): check_cumsum(np.cumsum(data, dtype=np.int32), data) check_cumsum(np.cumsum(data), data, dtype="int64") - data = np.random.rand(10) > 0.5 - check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32") + if str(target.kind) != "vulkan": + # TODO(masahi): Support bool tensor in SPIRV codegen + data = np.random.rand(10) > 0.5 + check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32") for in_dtype in ["float32", "float64"]: data = np.random.randn(10, 10).astype(in_dtype) @@ -70,3 +73,4 @@ if __name__ == "__main__": test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm")) test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda")) test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx")) + test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan"))