This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b4c1c3870f [CUDA] Fixed the call of the min function in the schedule
for cuda (#14751)
b4c1c3870f is described below
commit b4c1c3870f7901d2171297aa1a15d94b78715d83
Author: Matveenko Valery <[email protected]>
AuthorDate: Mon May 15 07:06:02 2023 +0200
[CUDA] Fixed the call of the min function in the schedule for cuda (#14751)
* fixed the call of the minimum function in the schedule for cuda
* add test for scatter_nd
* update test only for cuda target
* fix lint
* update test
* fix lint
* apply comments
---
python/tvm/topi/cuda/scatter.py | 2 +-
tests/python/relay/test_any.py | 23 +++++++++++++++++++++++
2 files changed, 24 insertions(+), 1 deletion(-)
diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py
index 39ef5a5a42..7f5fb8aa87 100644
--- a/python/tvm/topi/cuda/scatter.py
+++ b/python/tvm/topi/cuda/scatter.py
@@ -227,8 +227,8 @@ def scatter_nd(data, indices, updates, mode):
fused_shape *= i
max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
- tdim = min(max_threads, fused_updates_dimension)
+ tdim = tvm.tir.min(max_threads, fused_updates_dimension)
with ib.new_scope():
bdim = ceil_div(fused_shape, tdim)
bx = te.thread_axis("blockIdx.x")
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 443637276e..3cf4e53106 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -2148,6 +2148,29 @@ def test_scatter_nd():
verify_scatter_nd(data, indices, updates, out)
[email protected]_gpu
+def test_scatter_nd_any_updates():
+ def verify_scatter_nd_any_updates(data_np, indices_np, updates_np,
ref_res):
+ indices_shape = (2, relay.Any())
+ updates_shape = (2, relay.Any())
+ data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
+ indices = relay.var("indices", relay.TensorType(indices_shape,
str(indices_np.dtype)))
+ updates = relay.var("updates", relay.TensorType(updates_shape,
str(updates_np.dtype)))
+
+ out = relay.op.scatter_nd(data, indices, updates, "add")
+
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function([data, indices, updates], out)
+
+ check_result([data_np, indices_np, updates_np], mod, [ref_res],
only_vm=True)
+
+ data = np.zeros((3, 3)).astype("int64")
+ indices = np.array([[1, 1], [0, 1]])
+ updates = np.array([[2, 2], [1, 1]])
+ out = np.array([[0, 0, 0], [0, 0, 0], [2, 2, 1]])
+ verify_scatter_nd_any_updates(data, indices, updates, out)
+
+
@tvm.testing.uses_gpu
def test_gather():
def verify_gather(data_shape, indices_shape, data_shape_np,
indices_shape_np, axis):