This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1fac10b  Fix GatherND attribute registration (#8269)
1fac10b is described below

commit 1fac10b359dec1bd6ad45ce36541a882aaba586b
Author: masahi <[email protected]>
AuthorDate: Thu Jun 17 10:16:39 2021 +0900

    Fix GatherND attribute registration (#8269)
---
 include/tvm/relay/attrs/transform.h | 2 +-
 src/relay/op/tensor/transform.cc    | 3 +++
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/include/tvm/relay/attrs/transform.h 
b/include/tvm/relay/attrs/transform.h
index 69a9c64..a8317e1 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -148,7 +148,7 @@ struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> 
{
   Integer batch_dims;
   Optional<Integer> index_rank;
 
-  TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") {
+  TVM_DECLARE_ATTRS(GatherNDAttrs, "relay.attrs.GatherNDAttrs") {
     TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of 
batch dimensions.");
     TVM_ATTR_FIELD(index_rank)
         .set_default(NullValue<Integer>())
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 9361e19..7d40bf2 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3290,6 +3290,8 @@ which must just be not null. Output will have same shape 
as ``indices``.
     .set_attr<FTVMCompute>("FTVMCompute", GatherCompute)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
 
+TVM_REGISTER_NODE_TYPE(GatherNDAttrs);
+
 // gather_nd operator
 bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                  const TypeReporter& reporter) {
@@ -3367,6 +3369,7 @@ When B == 0 (the default case), the output shape will be 
(Y_0, ..., Y_{K-1}, X_M
 In both cases, if M + B == N, the output shape will simply be (Y_0, ..., 
Y_{K-1}).
 )code" TVM_ADD_FILELINE)
     .set_num_inputs(2)
+    .set_attrs_type<GatherNDAttrs>()
     .add_argument("data", "Tensor", "The input tensor.")
     .add_argument("indices", "Tensor", "The indices of values to gather.")
     .set_support_level(3)

Reply via email to