This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b4c1c3870f [CUDA] Fixed the call of the min function in the schedule 
for cuda (#14751)
b4c1c3870f is described below

commit b4c1c3870f7901d2171297aa1a15d94b78715d83
Author: Matveenko Valery <[email protected]>
AuthorDate: Mon May 15 07:06:02 2023 +0200

    [CUDA] Fixed the call of the min function in the schedule for cuda (#14751)
    
    * fixed the call of the minimum function in the schedule for cuda
    
    * add test for scatter_nd
    
    * update test only for cuda target
    
    * fix lint
    
    * update test
    
    * fix lint
    
    * apply comments
---
 python/tvm/topi/cuda/scatter.py |  2 +-
 tests/python/relay/test_any.py  | 23 +++++++++++++++++++++++
 2 files changed, 24 insertions(+), 1 deletion(-)

diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py
index 39ef5a5a42..7f5fb8aa87 100644
--- a/python/tvm/topi/cuda/scatter.py
+++ b/python/tvm/topi/cuda/scatter.py
@@ -227,8 +227,8 @@ def scatter_nd(data, indices, updates, mode):
             fused_shape *= i
 
         max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
-        tdim = min(max_threads, fused_updates_dimension)
 
+        tdim = tvm.tir.min(max_threads, fused_updates_dimension)
         with ib.new_scope():
             bdim = ceil_div(fused_shape, tdim)
             bx = te.thread_axis("blockIdx.x")
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 443637276e..3cf4e53106 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -2148,6 +2148,29 @@ def test_scatter_nd():
     verify_scatter_nd(data, indices, updates, out)
 
 
[email protected]_gpu
+def test_scatter_nd_any_updates():
+    def verify_scatter_nd_any_updates(data_np, indices_np, updates_np, 
ref_res):
+        indices_shape = (2, relay.Any())
+        updates_shape = (2, relay.Any())
+        data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
+        indices = relay.var("indices", relay.TensorType(indices_shape, 
str(indices_np.dtype)))
+        updates = relay.var("updates", relay.TensorType(updates_shape, 
str(updates_np.dtype)))
+
+        out = relay.op.scatter_nd(data, indices, updates, "add")
+
+        mod = tvm.IRModule()
+        mod["main"] = relay.Function([data, indices, updates], out)
+
+        check_result([data_np, indices_np, updates_np], mod, [ref_res], 
only_vm=True)
+
+    data = np.zeros((3, 3)).astype("int64")
+    indices = np.array([[1, 1], [0, 1]])
+    updates = np.array([[2, 2], [1, 1]])
+    out = np.array([[0, 0, 0], [0, 0, 0], [2, 2, 1]])
+    verify_scatter_nd_any_updates(data, indices, updates, out)
+
+
 @tvm.testing.uses_gpu
 def test_gather():
     def verify_gather(data_shape, indices_shape, data_shape_np, 
indices_shape_np, axis):

Reply via email to