GoodNight-bye commented on issue #8374:
URL: https://github.com/apache/tvm/issues/8374#issuecomment-913451485
```diff
diff --git a/include/tvm/relay/attrs/transform.h
b/include/tvm/relay/attrs/transform.h
index a8317e1e5..449e0f55f 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -489,6 +489,14 @@ struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs>
{
}
}; // struct UniqueAttrs
+/*! \brief Attributes used in adv_index operator */
+struct AdvIndexAttrs : public tvm::AttrsNode<AdvIndexAttrs> {
+ Array<Integer> iter_axes;
+ TVM_DECLARE_ATTRS(AdvIndexAttrs, "relay.attrs.AdvIndexAttrs") {
+ TVM_ATTR_FIELD(iter_axes).describe("The iterative axes, iterate all the
data of the corresponding axis.");
+ }
+}; // struct AdvIndexAttrs
+
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
diff --git a/python/tvm/relay/op/op_attrs.py
b/python/tvm/relay/op/op_attrs.py
index 2d185bcee..4922242d9 100644
--- a/python/tvm/relay/op/op_attrs.py
+++ b/python/tvm/relay/op/op_attrs.py
@@ -592,3 +592,7 @@ class NLLLossAttrs(Attrs):
@tvm._ffi.register_object("relay.attrs.FixedPointMultiplyAttrs")
class FixedPointMultiplyAttrs(Attrs):
"""Attributes used in fixed_point_multiply operators"""
+
+@tvm._ffi.register_object("relay.attrs.AdvIndexAttrs")
+class AdvIndexAttrs(Attrs):
+ """Attributes for transform.adv_index"""
diff --git a/src/relay/op/tensor/transform.cc
b/src/relay/op/tensor/transform.cc
index 9f9ed1c07..a79becba0 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3787,12 +3787,15 @@ RELAY_REGISTER_OP("matrix_set_diag")
.set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
+TVM_REGISTER_NODE_TYPE(AdvIndexAttrs);
+
// adv_index
bool AdvIndexRel(const Array<Type>& types, int num_inputs, const Attrs&
attrs,
const TypeReporter& reporter) {
ICHECK_EQ(num_inputs, 1);
auto inputs = types[0].as<TupleTypeNode>();
auto data = inputs->fields[0].as<TensorTypeNode>();
+ const auto param = attrs.as<AdvIndexAttrs>();
if (inputs == nullptr || data == nullptr) {
return false;
@@ -3832,10 +3835,22 @@ bool AdvIndexRel(const Array<Type>& types, int
num_inputs, const Attrs& attrs,
}
}
- for (const auto& dim : broadcast_shape) {
- oshape.push_back(dim);
+ for (auto axes : param->iter_axes) {
+ oshape.push_back(data->shape[axes]);
+ }
+
+ for (size_t i = 0; i < param->iter_axes.size(); ++i) {
+ if (param->iter_axes[i] != static_cast<int>(i)) {
+ oshape.insert(oshape.begin() + i, broadcast_shape.begin(),
broadcast_shape.end());
+ break;
+ }
+ }
+
+ if (param->iter_axes.size() == oshape.size()) {
+ oshape.insert(oshape.end(), broadcast_shape.begin(),
broadcast_shape.end());
}
- for (size_t i = inputs->fields.size() - 1; i < data->shape.size(); ++i) {
+
+ for (size_t i = inputs->fields.size() - 1 + param->iter_axes.size(); i <
data->shape.size(); ++i) {
oshape.push_back(data->shape[i]);
}
reporter->Assign(types[1], TensorType(oshape, data->dtype));
@@ -3852,8 +3867,20 @@ Array<te::Tensor> AdvIndexCompute(const Attrs& attrs,
const Array<te::Tensor>& i
}
Expr MakeAdvIndex(Expr inputs) {
+ auto tuple_node = inputs.as<TupleNode>();
+ auto attrs = make_object<AdvIndexAttrs>();
+ std::vector<Expr> new_inputs;
+
+ for (size_t i = 0; i < tuple_node->fields.size(); ++i) {
+ if (tuple_node->fields[i].as<ExprNode>() != nullptr) {
+ new_inputs.push_back(tuple_node->fields[i]);
+ } else {
+ attrs->iter_axes.push_back(i - 1);
+ }
+ }
+
static const Op& op = Op::Get("adv_index");
- return Call(op, {inputs}, Attrs(), {});
+ return Call(op, {Tuple(new_inputs)}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op._make.adv_index").set_body_typed(MakeAdvIndex);
@@ -3861,6 +3888,7 @@
TVM_REGISTER_GLOBAL("relay.op._make.adv_index").set_body_typed(MakeAdvIndex);
RELAY_REGISTER_OP("adv_index")
.describe(R"code(Numpy style advanced indexing. Index with a list of
tensors.
)code" TVM_ADD_FILELINE)
+ .set_attrs_type<AdvIndexAttrs>()
.set_num_inputs(1)
.set_support_level(3)
.add_argument("inputs", "Tuple of Tensors", "Input tensor and indices.")
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]