This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 571eff9223 [BugFix][Relay] fix `scatter_nd` type relation (#14773)
571eff9223 is described below

commit 571eff9223ce4f344728d5cead89700c7608d6e0
Author: Jack River <[email protected]>
AuthorDate: Sun May 7 04:38:04 2023 +0800

    [BugFix][Relay] fix `scatter_nd` type relation (#14773)
    
    * [BugFix][Relay] fix scatter_nd type relation
    
    ScatterND requires updates.shape[K:] == output.shape[M:],
    not data.shape[K:] == output.shape[M:]
    
    * [BugFix][Relay] fix scatter_nd type relation
    add testcase for scatter_nd with m != k
    
    ---------
    
    Co-authored-by: Jiang.Zhongzhou <[email protected]>
---
 src/relay/op/tensor/transform.cc     |  6 +++---
 tests/python/relay/test_op_level3.py | 20 ++++++++++++++++++++
 2 files changed, 23 insertions(+), 3 deletions(-)

diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index cd8240d57d..a0111ff7cd 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1185,7 +1185,7 @@ bool ScatterNDRel(const Array<Type>& types, int 
num_inputs, const Attrs& attrs,
   const size_t kdim = indices->shape.size() - 1;
   const size_t ndim = out_shape.size();
   ICHECK_LE(size_t(mdim->value), ndim)
-      << "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., 
X_{N-1}), and indices "
+      << "ScatterND: Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., 
X_{N-1}), and indices "
          "with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to 
N.";
   // Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's.
   for (size_t i = 0; i < kdim; i++) {
@@ -1197,9 +1197,9 @@ bool ScatterNDRel(const Array<Type>& types, int 
num_inputs, const Attrs& attrs,
     oshape.push_back(x);
   }
 
-  // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify 
X_M to X_{N-1}
+  // updates: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), 
verify X_M to X_{N-1}
   for (size_t i = mdim->value; i < ndim; i++) {
-    reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]);
+    reporter->AssertEQ(updates->shape[i - mdim->value + kdim], oshape[i]);
   }
 
   reporter->Assign(types[3], TensorType(data->shape, data->dtype));
diff --git a/tests/python/relay/test_op_level3.py 
b/tests/python/relay/test_op_level3.py
index 493bf00fc6..5e86ab8da7 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1942,6 +1942,26 @@ def test_scatter_nd(target, dev, executor_kind):
 
     test_scatter_nd_large_shape()
 
+    def test_scatter_nd_inequal_m_k():
+        def before():
+            data = relay.const(np.zeros((1, 1, 10), dtype="float32"), 
dtype="float32")
+            indices = relay.const(np.zeros((2, 1, 1, 1), dtype="float32"), 
dtype="int64")
+            update = relay.const(np.ones((1, 1, 1, 10), dtype="float32"), 
dtype="float32")
+            b = relay.op.scatter_nd(data, indices, update)
+            return relay.Function(relay.analysis.free_vars(b), b)
+
+        passes = tvm.transform.Sequential(
+            [
+                relay.transform.InferType(),
+                relay.transform.FoldConstant(),
+            ]
+        )
+        before_mod = tvm.IRModule.from_expr(before())
+        with tvm.transform.PassContext(opt_level=3):
+            after_mod = passes(before_mod)
+
+    test_scatter_nd_inequal_m_k()
+
     def verify_scatter_nd(
         data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5, 
atol=1e-5
     ):

Reply via email to