This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c6b766a [Relay][Op] Remove reverse attribute from reshape and
reverse_reshape operators. (#7086)
c6b766a is described below
commit c6b766a4cea4e59384c2606deecdc5321ac3d41c
Author: Josh Fromm <[email protected]>
AuthorDate: Wed Dec 30 17:29:06 2020 -0800
[Relay][Op] Remove reverse attribute from reshape and reverse_reshape
operators. (#7086)
---
include/tvm/relay/attrs/transform.h | 4 --
src/relay/op/dyn/tensor/transform.cc | 1 -
src/relay/op/tensor/transform.cc | 76 +++++++++++++++++-----
src/relay/op/tensor/transform.h | 2 +-
.../contrib/test_arm_compute_lib/test_reshape.py | 1 -
5 files changed, 59 insertions(+), 25 deletions(-)
diff --git a/include/tvm/relay/attrs/transform.h
b/include/tvm/relay/attrs/transform.h
index cbe989f..efa44e0 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -83,13 +83,9 @@ struct TransposeAttrs : public
tvm::AttrsNode<TransposeAttrs> {
/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Array<Integer> newshape;
- bool reverse;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape).describe(
"The new shape. Should be compatible with the original shape.");
- TVM_ATTR_FIELD(reverse)
- .describe("Infer the special values from right to left if true")
- .set_default(false);
}
}; // struct ReshapeAttrs
diff --git a/src/relay/op/dyn/tensor/transform.cc
b/src/relay/op/dyn/tensor/transform.cc
index 815f24b..e4e81e3 100644
--- a/src/relay/op/dyn/tensor/transform.cc
+++ b/src/relay/op/dyn/tensor/transform.cc
@@ -90,7 +90,6 @@ Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const
Array<te::Tensor>& in
Expr MakeReshape(Expr data, Expr newshape) {
auto attrs = make_object<ReshapeAttrs>();
- attrs->reverse = false;
static const Op& op = Op::Get("dyn.reshape");
return Call(op, {data, newshape}, Attrs(attrs), {});
}
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 6819ea9..19ca612 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -455,13 +455,14 @@ RELAY_REGISTER_OP("transpose")
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs);
-Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const
Attrs& attrs) {
+Array<IndexExpr> InferNewShape(const Array<IndexExpr>& data_shape, const
Attrs& attrs,
+ bool reverse) {
const auto* param = attrs.as<ReshapeAttrs>();
Array<IndexExpr> oshape;
Array<IndexExpr> ishape;
Array<Integer> newshape;
- if (param->reverse) {
+ if (reverse) {
ishape.Assign(data_shape.rbegin(), data_shape.rend());
newshape.Assign(param->newshape.rbegin(), param->newshape.rend());
} else {
@@ -584,7 +585,6 @@ Array<IndexExpr> infer_newshape(const Array<IndexExpr>&
data_shape, const Attrs&
bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
- const auto* param = attrs.as<ReshapeAttrs>();
// types: [data, result]
ICHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
@@ -594,16 +594,12 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
return false;
}
- const auto& oshape = infer_newshape(data->shape, attrs);
+ const auto& oshape = InferNewShape(data->shape, attrs, false);
// Verify that the sum of dimensions in the output shape is the sum of
// dimensions in the input shape
Array<IndexExpr> data_shape;
- if (param->reverse) {
- data_shape.Assign(data->shape.rbegin(), data->shape.rend());
- } else {
- data_shape = data->shape;
- }
+ data_shape = data->shape;
bool found_dynamic = false;
int64_t oshape_sum = 1;
@@ -633,12 +629,58 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
<< "Input tensor shape and reshaped shape are not compatible";
}
- if (param->reverse) {
- reporter->Assign(types[1],
- TensorType(Array<IndexExpr>(oshape.rbegin(),
oshape.rend()), data->dtype));
- } else {
- reporter->Assign(types[1], TensorType(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
+ return true;
+}
+
+bool ReverseReshapeRel(const Array<Type>& types, int num_inputs, const Attrs&
attrs,
+ const TypeReporter& reporter) {
+ // types: [data, result]
+ ICHECK_EQ(types.size(), 2);
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) {
+ ICHECK(types[0].as<IncompleteTypeNode>())
+ << "reshape: expect input type to be TensorType but get " << types[0];
+ return false;
+ }
+
+ const auto& oshape = InferNewShape(data->shape, attrs, true);
+
+ // Verify that the sum of dimensions in the output shape is the sum of
+ // dimensions in the input shape
+ Array<IndexExpr> data_shape;
+ data_shape.Assign(data->shape.rbegin(), data->shape.rend());
+
+ bool found_dynamic = false;
+ int64_t oshape_sum = 1;
+ for (auto& x : oshape) {
+ // Check if we have a dynamic shape. If we do, we can't verify if the
+ // reshape is valid. Dynamic shapes are marker by using Any, but can also
+ // occur from SizeVar's. In the case of SizeVar, the shape expression can
+ // be an AST. We can't easily check if we have an AST because of a ShapeVar
+ // or some other reason, so our check for dynamic shape is just if we can
+ // convert the shape to in integer or not.
+ if (!x->IsInstance<tvm::Integer::ContainerType>()) {
+ found_dynamic = true;
+ break;
+ }
+ oshape_sum *= Downcast<tvm::Integer>(x)->value;
}
+ int64_t data_shape_sum = 1;
+ for (auto& x : data_shape) {
+ if (!x->IsInstance<tvm::Integer::ContainerType>()) {
+ found_dynamic = true;
+ break;
+ }
+ data_shape_sum *= Downcast<tvm::Integer>(x)->value;
+ }
+ if (!found_dynamic) {
+ ICHECK_EQ(oshape_sum, data_shape_sum)
+ << "Input tensor shape and reshaped shape are not compatible";
+ }
+
+ reporter->Assign(types[1],
+ TensorType(Array<IndexExpr>(oshape.rbegin(),
oshape.rend()), data->dtype));
return true;
}
@@ -701,7 +743,7 @@ Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const
Array<te::Tensor>& in
}
if (newshape_has_any) {
- newshape = infer_newshape(inputs[0]->shape, attrs);
+ newshape = InferNewShape(inputs[0]->shape, attrs, false);
}
return {topi::reshape(inputs[0], newshape)};
}
@@ -709,7 +751,6 @@ Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const
Array<te::Tensor>& in
Expr MakeReshape(Expr data, Array<Integer> newshape) {
auto attrs = make_object<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
- attrs->reverse = false;
static const Op& op = Op::Get("reshape");
return Call(op, {data}, Attrs(attrs), {});
}
@@ -2871,7 +2912,6 @@ RELAY_REGISTER_OP("auto_scheduler_layout_transform")
Expr MakeReverseReshape(Expr data, Array<Integer> newshape) {
auto attrs = make_object<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
- attrs->reverse = true;
static const Op& op = Op::Get("contrib_reverse_reshape");
return Call(op, {data}, Attrs(attrs), {});
}
@@ -2896,7 +2936,7 @@ example below::
.set_attrs_type<ReshapeAttrs>()
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(10)
- .add_type_rel("Reshape", ReshapeRel)
+ .add_type_rel("ReverseReshape", ReverseReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h
index 34aaf46..a3770ff 100644
--- a/src/relay/op/tensor/transform.h
+++ b/src/relay/op/tensor/transform.h
@@ -195,7 +195,7 @@ static inline Array<Array<Layout>> ConcatenateLayout(const
Attrs& attrs,
* \param attrs The attributes.
* \return Output shape.
*/
-Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const
Attrs& attrs);
+Array<IndexExpr> InferNewShape(const Array<IndexExpr>& data_shape, const
Attrs& attrs);
} // namespace relay
} // namespace tvm
diff --git a/tests/python/contrib/test_arm_compute_lib/test_reshape.py
b/tests/python/contrib/test_arm_compute_lib/test_reshape.py
index 9364c6b..9494272 100644
--- a/tests/python/contrib/test_arm_compute_lib/test_reshape.py
+++ b/tests/python/contrib/test_arm_compute_lib/test_reshape.py
@@ -50,7 +50,6 @@ def _get_expected_codegen(input_shape, output_shape, dtype):
"newshape": [[str(s) for s in output_shape]],
"shape": [[list(output_shape)]],
"dtype": [[dtype]],
- "reverse": [["0"]],
},
}