This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c6b766a  [Relay][Op] Remove reverse attribute from reshape and 
reverse_reshape operators. (#7086)
c6b766a is described below

commit c6b766a4cea4e59384c2606deecdc5321ac3d41c
Author: Josh Fromm <[email protected]>
AuthorDate: Wed Dec 30 17:29:06 2020 -0800

    [Relay][Op] Remove reverse attribute from reshape and reverse_reshape 
operators. (#7086)
---
 include/tvm/relay/attrs/transform.h                |  4 --
 src/relay/op/dyn/tensor/transform.cc               |  1 -
 src/relay/op/tensor/transform.cc                   | 76 +++++++++++++++++-----
 src/relay/op/tensor/transform.h                    |  2 +-
 .../contrib/test_arm_compute_lib/test_reshape.py   |  1 -
 5 files changed, 59 insertions(+), 25 deletions(-)

diff --git a/include/tvm/relay/attrs/transform.h 
b/include/tvm/relay/attrs/transform.h
index cbe989f..efa44e0 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -83,13 +83,9 @@ struct TransposeAttrs : public 
tvm::AttrsNode<TransposeAttrs> {
 /*! \brief Attributes used in reshape operators */
 struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
   Array<Integer> newshape;
-  bool reverse;
   TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
     TVM_ATTR_FIELD(newshape).describe(
         "The new shape. Should be compatible with the original shape.");
-    TVM_ATTR_FIELD(reverse)
-        .describe("Infer the special values from right to left if true")
-        .set_default(false);
   }
 };  // struct ReshapeAttrs
 
diff --git a/src/relay/op/dyn/tensor/transform.cc 
b/src/relay/op/dyn/tensor/transform.cc
index 815f24b..e4e81e3 100644
--- a/src/relay/op/dyn/tensor/transform.cc
+++ b/src/relay/op/dyn/tensor/transform.cc
@@ -90,7 +90,6 @@ Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const 
Array<te::Tensor>& in
 
 Expr MakeReshape(Expr data, Expr newshape) {
   auto attrs = make_object<ReshapeAttrs>();
-  attrs->reverse = false;
   static const Op& op = Op::Get("dyn.reshape");
   return Call(op, {data, newshape}, Attrs(attrs), {});
 }
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 6819ea9..19ca612 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -455,13 +455,14 @@ RELAY_REGISTER_OP("transpose")
 TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
 TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs);
 
-Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const 
Attrs& attrs) {
+Array<IndexExpr> InferNewShape(const Array<IndexExpr>& data_shape, const 
Attrs& attrs,
+                               bool reverse) {
   const auto* param = attrs.as<ReshapeAttrs>();
   Array<IndexExpr> oshape;
   Array<IndexExpr> ishape;
   Array<Integer> newshape;
 
-  if (param->reverse) {
+  if (reverse) {
     ishape.Assign(data_shape.rbegin(), data_shape.rend());
     newshape.Assign(param->newshape.rbegin(), param->newshape.rend());
   } else {
@@ -584,7 +585,6 @@ Array<IndexExpr> infer_newshape(const Array<IndexExpr>& 
data_shape, const Attrs&
 
 bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                 const TypeReporter& reporter) {
-  const auto* param = attrs.as<ReshapeAttrs>();
   // types: [data, result]
   ICHECK_EQ(types.size(), 2);
   const auto* data = types[0].as<TensorTypeNode>();
@@ -594,16 +594,12 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
     return false;
   }
 
-  const auto& oshape = infer_newshape(data->shape, attrs);
+  const auto& oshape = InferNewShape(data->shape, attrs, false);
 
   // Verify that the sum of dimensions in the output shape is the sum of
   // dimensions in the input shape
   Array<IndexExpr> data_shape;
-  if (param->reverse) {
-    data_shape.Assign(data->shape.rbegin(), data->shape.rend());
-  } else {
-    data_shape = data->shape;
-  }
+  data_shape = data->shape;
 
   bool found_dynamic = false;
   int64_t oshape_sum = 1;
@@ -633,12 +629,58 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
         << "Input tensor shape and reshaped shape are not compatible";
   }
 
-  if (param->reverse) {
-    reporter->Assign(types[1],
-                     TensorType(Array<IndexExpr>(oshape.rbegin(), 
oshape.rend()), data->dtype));
-  } else {
-    reporter->Assign(types[1], TensorType(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
+  return true;
+}
+
+bool ReverseReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                       const TypeReporter& reporter) {
+  // types: [data, result]
+  ICHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    ICHECK(types[0].as<IncompleteTypeNode>())
+        << "reshape: expect input type to be TensorType but get " << types[0];
+    return false;
+  }
+
+  const auto& oshape = InferNewShape(data->shape, attrs, true);
+
+  // Verify that the sum of dimensions in the output shape is the sum of
+  // dimensions in the input shape
+  Array<IndexExpr> data_shape;
+  data_shape.Assign(data->shape.rbegin(), data->shape.rend());
+
+  bool found_dynamic = false;
+  int64_t oshape_sum = 1;
+  for (auto& x : oshape) {
+    // Check if we have a dynamic shape. If we do, we can't verify if the
+    // reshape is valid. Dynamic shapes are marker by using Any, but can also
+    // occur from SizeVar's. In the case of SizeVar, the shape expression can
+    // be an AST. We can't easily check if we have an AST because of a ShapeVar
+    // or some other reason, so our check for dynamic shape is just if we can
+    // convert the shape to in integer or not.
+    if (!x->IsInstance<tvm::Integer::ContainerType>()) {
+      found_dynamic = true;
+      break;
+    }
+    oshape_sum *= Downcast<tvm::Integer>(x)->value;
   }
+  int64_t data_shape_sum = 1;
+  for (auto& x : data_shape) {
+    if (!x->IsInstance<tvm::Integer::ContainerType>()) {
+      found_dynamic = true;
+      break;
+    }
+    data_shape_sum *= Downcast<tvm::Integer>(x)->value;
+  }
+  if (!found_dynamic) {
+    ICHECK_EQ(oshape_sum, data_shape_sum)
+        << "Input tensor shape and reshaped shape are not compatible";
+  }
+
+  reporter->Assign(types[1],
+                   TensorType(Array<IndexExpr>(oshape.rbegin(), 
oshape.rend()), data->dtype));
   return true;
 }
 
@@ -701,7 +743,7 @@ Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const 
Array<te::Tensor>& in
   }
 
   if (newshape_has_any) {
-    newshape = infer_newshape(inputs[0]->shape, attrs);
+    newshape = InferNewShape(inputs[0]->shape, attrs, false);
   }
   return {topi::reshape(inputs[0], newshape)};
 }
@@ -709,7 +751,6 @@ Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const 
Array<te::Tensor>& in
 Expr MakeReshape(Expr data, Array<Integer> newshape) {
   auto attrs = make_object<ReshapeAttrs>();
   attrs->newshape = std::move(newshape);
-  attrs->reverse = false;
   static const Op& op = Op::Get("reshape");
   return Call(op, {data}, Attrs(attrs), {});
 }
@@ -2871,7 +2912,6 @@ RELAY_REGISTER_OP("auto_scheduler_layout_transform")
 Expr MakeReverseReshape(Expr data, Array<Integer> newshape) {
   auto attrs = make_object<ReshapeAttrs>();
   attrs->newshape = std::move(newshape);
-  attrs->reverse = true;
   static const Op& op = Op::Get("contrib_reverse_reshape");
   return Call(op, {data}, Attrs(attrs), {});
 }
@@ -2896,7 +2936,7 @@ example below::
     .set_attrs_type<ReshapeAttrs>()
     .add_argument("data", "Tensor", "The input tensor.")
     .set_support_level(10)
-    .add_type_rel("Reshape", ReshapeRel)
+    .add_type_rel("ReverseReshape", ReverseReshapeRel)
     .set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
 
diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h
index 34aaf46..a3770ff 100644
--- a/src/relay/op/tensor/transform.h
+++ b/src/relay/op/tensor/transform.h
@@ -195,7 +195,7 @@ static inline Array<Array<Layout>> ConcatenateLayout(const 
Attrs& attrs,
  * \param attrs The attributes.
  * \return Output shape.
  */
-Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const 
Attrs& attrs);
+Array<IndexExpr> InferNewShape(const Array<IndexExpr>& data_shape, const 
Attrs& attrs);
 
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/contrib/test_arm_compute_lib/test_reshape.py 
b/tests/python/contrib/test_arm_compute_lib/test_reshape.py
index 9364c6b..9494272 100644
--- a/tests/python/contrib/test_arm_compute_lib/test_reshape.py
+++ b/tests/python/contrib/test_arm_compute_lib/test_reshape.py
@@ -50,7 +50,6 @@ def _get_expected_codegen(input_shape, output_shape, dtype):
             "newshape": [[str(s) for s in output_shape]],
             "shape": [[list(output_shape)]],
             "dtype": [[dtype]],
-            "reverse": [["0"]],
         },
     }
 

Reply via email to