This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 571eff9223 [BugFix][Relay] fix `scatter_nd` type relation (#14773)
571eff9223 is described below
commit 571eff9223ce4f344728d5cead89700c7608d6e0
Author: Jack River <[email protected]>
AuthorDate: Sun May 7 04:38:04 2023 +0800
[BugFix][Relay] fix `scatter_nd` type relation (#14773)
* [BugFix][Relay] fix scatter_nd type relation
ScatterND requires updates.shape[K:] == output.shape[M:],
not data.shape[K:] == output.shape[M:]
* [BugFix][Relay] fix scatter_nd type relation
add testcase for scatter_nd with m != k
---------
Co-authored-by: Jiang.Zhongzhou <[email protected]>
---
src/relay/op/tensor/transform.cc | 6 +++---
tests/python/relay/test_op_level3.py | 20 ++++++++++++++++++++
2 files changed, 23 insertions(+), 3 deletions(-)
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index cd8240d57d..a0111ff7cd 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1185,7 +1185,7 @@ bool ScatterNDRel(const Array<Type>& types, int
num_inputs, const Attrs& attrs,
const size_t kdim = indices->shape.size() - 1;
const size_t ndim = out_shape.size();
ICHECK_LE(size_t(mdim->value), ndim)
- << "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ...,
X_{N-1}), and indices "
+ << "ScatterND: Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ...,
X_{N-1}), and indices "
"with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to
N.";
// Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's.
for (size_t i = 0; i < kdim; i++) {
@@ -1197,9 +1197,9 @@ bool ScatterNDRel(const Array<Type>& types, int
num_inputs, const Attrs& attrs,
oshape.push_back(x);
}
- // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify
X_M to X_{N-1}
+ // updates: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}),
verify X_M to X_{N-1}
for (size_t i = mdim->value; i < ndim; i++) {
- reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]);
+ reporter->AssertEQ(updates->shape[i - mdim->value + kdim], oshape[i]);
}
reporter->Assign(types[3], TensorType(data->shape, data->dtype));
diff --git a/tests/python/relay/test_op_level3.py
b/tests/python/relay/test_op_level3.py
index 493bf00fc6..5e86ab8da7 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1942,6 +1942,26 @@ def test_scatter_nd(target, dev, executor_kind):
test_scatter_nd_large_shape()
+ def test_scatter_nd_inequal_m_k():
+ def before():
+ data = relay.const(np.zeros((1, 1, 10), dtype="float32"),
dtype="float32")
+ indices = relay.const(np.zeros((2, 1, 1, 1), dtype="float32"),
dtype="int64")
+ update = relay.const(np.ones((1, 1, 1, 10), dtype="float32"),
dtype="float32")
+ b = relay.op.scatter_nd(data, indices, update)
+ return relay.Function(relay.analysis.free_vars(b), b)
+
+ passes = tvm.transform.Sequential(
+ [
+ relay.transform.InferType(),
+ relay.transform.FoldConstant(),
+ ]
+ )
+ before_mod = tvm.IRModule.from_expr(before())
+ with tvm.transform.PassContext(opt_level=3):
+ after_mod = passes(before_mod)
+
+ test_scatter_nd_inequal_m_k()
+
def verify_scatter_nd(
data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5,
atol=1e-5
):