This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1fac10b Fix GatherND attribute registration (#8269)
1fac10b is described below
commit 1fac10b359dec1bd6ad45ce36541a882aaba586b
Author: masahi <[email protected]>
AuthorDate: Thu Jun 17 10:16:39 2021 +0900
Fix GatherND attribute registration (#8269)
---
include/tvm/relay/attrs/transform.h | 2 +-
src/relay/op/tensor/transform.cc | 3 +++
2 files changed, 4 insertions(+), 1 deletion(-)
diff --git a/include/tvm/relay/attrs/transform.h
b/include/tvm/relay/attrs/transform.h
index 69a9c64..a8317e1 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -148,7 +148,7 @@ struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs>
{
Integer batch_dims;
Optional<Integer> index_rank;
- TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") {
+ TVM_DECLARE_ATTRS(GatherNDAttrs, "relay.attrs.GatherNDAttrs") {
TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of
batch dimensions.");
TVM_ATTR_FIELD(index_rank)
.set_default(NullValue<Integer>())
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 9361e19..7d40bf2 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3290,6 +3290,8 @@ which must just be not null. Output will have same shape
as ``indices``.
.set_attr<FTVMCompute>("FTVMCompute", GatherCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
+TVM_REGISTER_NODE_TYPE(GatherNDAttrs);
+
// gather_nd operator
bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
@@ -3367,6 +3369,7 @@ When B == 0 (the default case), the output shape will be
(Y_0, ..., Y_{K-1}, X_M
In both cases, if M + B == N, the output shape will simply be (Y_0, ...,
Y_{K-1}).
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
+ .set_attrs_type<GatherNDAttrs>()
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices of values to gather.")
.set_support_level(3)