This is an automated email from the ASF dual-hosted git repository.

kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new aa42f97  [RELAY][TF] Support symbolic newshape for Reshape (#5429)
aa42f97 is described below

commit aa42f978b89cc643b40c4f2bfdf2fc42dc1aeadd
Author: lixiaoquan <[email protected]>
AuthorDate: Wed May 13 14:37:33 2020 +0800

    [RELAY][TF] Support symbolic newshape for Reshape (#5429)
    
    * [RELAY][TF] Support symbolic newshape for Reshape
    
    * Only need to pass data
    
    * Use MakeReshape() in Reshape()
    
    * Change newshape to Expr
    
    * Create a template for Array<T>
    
    * Fuse reshape when newshape is constant
    
    * Make newshape Optional
    
    * Use bool() of Optional
    
    Co-authored-by: Li Xiaoquan <[email protected]>
---
 include/tvm/relay/attrs/transform.h              |   2 +-
 python/tvm/relay/_parser.py                      |   5 +-
 python/tvm/relay/frontend/tensorflow.py          |  13 +--
 python/tvm/relay/op/_tensor_grad.py              |   2 +-
 python/tvm/relay/op/_transform.py                |  81 +++++++++++++-
 python/tvm/relay/op/transform.py                 |   8 +-
 src/relay/analysis/util.cc                       |  40 +++++++
 src/relay/backend/compile_engine.cc              |  24 +----
 src/relay/op/tensor/transform.cc                 | 132 ++++++++++++++---------
 src/relay/op/tensor/transform.h                  |   2 +-
 src/relay/transforms/fold_scale_axis.cc          |   3 +-
 src/relay/transforms/fuse_ops.cc                 |   9 +-
 src/relay/transforms/pass_util.h                 |  14 +++
 src/relay/transforms/pattern_util.h              |  38 ++++++-
 tests/cpp/relay_build_module_test.cc             |   1 +
 tests/python/frontend/tensorflow/test_forward.py |  15 +++
 tests/python/relay/test_any.py                   |  29 +++--
 vta/python/vta/top/graphpack.py                  |   4 +-
 18 files changed, 312 insertions(+), 110 deletions(-)

diff --git a/include/tvm/relay/attrs/transform.h 
b/include/tvm/relay/attrs/transform.h
index 84dda6f..c0e2272 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -82,7 +82,7 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> 
{
 
 /*! \brief Attributes used in reshape operators */
 struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
-  Array<Integer> newshape;
+  Optional<Array<Integer>> newshape;
   bool reverse;
   TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
     TVM_ATTR_FIELD(newshape).describe(
diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py
index 7731efe..1d97b55 100644
--- a/python/tvm/relay/_parser.py
+++ b/python/tvm/relay/_parser.py
@@ -114,7 +114,10 @@ class FuncOp(OpWrapper):
     def __call__(self, args, attrs, type_args):
         if attrs is None:
             attrs = {}
-        x = self.operator(*args, **{k: self.convert(v) for k, v in 
attrs.items()})
+        if self.operator is op.reshape:
+            x = self.operator(*args)
+        else:
+            x = self.operator(*args, **{k: self.convert(v) for k, v in 
attrs.items()})
         if isinstance(x, expr.TupleWrapper):
             x = x.astuple()
         return x
diff --git a/python/tvm/relay/frontend/tensorflow.py 
b/python/tvm/relay/frontend/tensorflow.py
index 913a016..ab9e9e6 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1155,14 +1155,11 @@ def _reshape():
                 shape_arg = 
tuple(params_new.asnumpy().astype('int64').flatten())
             except Exception:
                 # Deal with symbolic shape case.
-                # Currently only shape_of can be the direct ancestor.
-                if not isinstance(pop_node, tvm.relay.expr.Call) or \
-                        "shape_of" not in str(pop_node.op):
-                    raise RuntimeError("If shape operator is used in reshape 
to "
-                                       "express reshape_like, shape_of must be 
"
-                                       "the direct ancestor of reshape when 
input "
-                                       "shape is symbolic.")
-                return _op.reshape_like(inputs[0], pop_node.args[0])
+                if isinstance(pop_node, _expr.Call) and \
+                        "shape_of" in str(pop_node.op):
+                    # shape_of is the direct ancestor.
+                    return _op.reshape_like(inputs[0], pop_node.args[0])
+                shape_arg = pop_node
         return AttrCvt(
             op_name="reshape",
             extras={'newshape': shape_arg},
diff --git a/python/tvm/relay/op/_tensor_grad.py 
b/python/tvm/relay/op/_tensor_grad.py
index c034bcc..8be3358 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -479,7 +479,7 @@ def dense_grad(orig, grad):
 @register_gradient("reshape")
 def reshape_grad(orig, grad):
     """Gradient of reshape"""
-    return [reshape_like(grad, orig.args[0])]
+    return [reshape_like(grad, orig.args[0]), orig.args[1]]
 
 
 @register_gradient("cast")
diff --git a/python/tvm/relay/op/_transform.py 
b/python/tvm/relay/op/_transform.py
index ee23fce..43d8d62 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -123,7 +123,7 @@ def concatenate_shape_func(attrs, inputs, _):
     return [_concatenate_shape_func(inputs, convert(axis))]
 
 @script
-def _reshape_shape_func(data_shape, newshape, ndim):
+def _reshape_shape_func_input_shape(data_shape, newshape, ndim):
     out = output_tensor((ndim,), "int64")
     src_idx = 0
     dst_idx = 0
@@ -189,10 +189,83 @@ def _reshape_shape_func(data_shape, newshape, ndim):
             out[infer_idx] = old_size // new_size
     return out
 
-@_reg.register_shape_func("reshape", False)
+@script
+def _reshape_shape_func_input_data(data, newshape, ndim):
+    out = output_tensor((ndim,), "int64")
+    data_shape = allocate((len(data.shape),), "int64")
+    for x in const_range(len(data.shape)):
+        data_shape[x] = int64(data.shape[x])
+    src_idx = 0
+    dst_idx = 0
+    infer_idx = -1
+    copy = False
+    skip = 0
+    for i in const_range(len(newshape)):
+        if skip > 0:
+            skip -= 1
+        elif newshape[i] > 0:
+            out[dst_idx] = int64(newshape[i])
+            src_idx += 1
+            dst_idx += 1
+        elif newshape[i] == 0:
+            out[dst_idx] = data_shape[src_idx]
+            src_idx += 1
+            dst_idx += 1
+        elif newshape[i] == -1:
+            assert infer_idx < 0, "One and only one dim can be inferred"
+            out[dst_idx] = int64(1)
+            infer_idx = i
+            dst_idx += 1
+        elif newshape[i] == -2:
+            copy = True
+        elif newshape[i] == -3:
+            assert data_shape.shape[0] - src_idx > 1, \
+                "Not enough dims in input shape for -3"
+            out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
+            src_idx += 2
+            dst_idx += 1
+        elif newshape[i] == -4:
+            assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
+            if newshape[i+1] == -1:
+                assert newshape[i+2] != -1, "Split dims cannot both be -1."
+                out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2])
+                out[dst_idx+1] = int64(newshape[i+2])
+            else:
+                out[dst_idx] = int64(newshape[i+1])
+                if newshape[i+2] == -1:
+                    out[dst_idx+1] = data_shape[src_idx] // 
int64(newshape[i+1])
+                else:
+                    out[dst_idx+1] = int64(newshape[i+2])
+            assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
+                "Product of split dims doesn't match to input dim"
+            src_idx += 1
+            dst_idx += 2
+            skip = 2
+        else:
+            assert False, "Invalid special values in new shape"
+    if len(data_shape.shape) > 0:
+        # if data is not constant, we can then handle -1 and -2
+        if copy:
+            for i in range(src_idx, data_shape.shape[0]):
+                out[dst_idx] = data_shape[i]
+                dst_idx += 1
+        if infer_idx >= 0:
+            old_size = int64(1)
+            for i in const_range(data_shape.shape[0]):
+                old_size *= data_shape[i]
+            new_size = int64(1)
+            for i in const_range(out.shape[0]):
+                new_size *= out[i]
+            out[infer_idx] = old_size // new_size
+    return out
+
+@_reg.register_shape_func("reshape", True)
 def reshape_shape_func(attrs, inputs, out_ndims):
-    newshape = get_const_tuple(attrs.newshape)
-    return [_reshape_shape_func(inputs[0], convert(newshape), out_ndims[0])]
+    if attrs.newshape is None:
+        return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]
+    return [_reshape_shape_func_input_shape(inputs[0],
+                                            convert(attrs.newshape),
+                                            out_ndims[0])]
 
 @script
 def _take_no_axis_shape_func(indices_shape, out_ndim):
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 4e9bb45..2d9e4ba 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -201,7 +201,7 @@ def reshape(data, newshape):
     data : relay.Expr
         The input data to the operator.
 
-    newshape : Union[int, Tuple[int], List[int]]
+    newshape : Union[int, Tuple[int], List[int]] or relay.Expr
         The new shape. Should be compatible with the original shape.
 
     Returns
@@ -210,8 +210,10 @@ def reshape(data, newshape):
         The reshaped result.
     """
     if isinstance(newshape, int):
-        newshape = [newshape]
-    return _make.reshape(data, list(newshape))
+        newshape = const([newshape])
+    if isinstance(newshape, (tuple, list)):
+        newshape = const(list(newshape))
+    return _make.reshape(data, newshape)
 
 def argwhere(condition):
     """Find the indices of elements of a tensor that are
diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc
index 1d84016..af23836 100644
--- a/src/relay/analysis/util.cc
+++ b/src/relay/analysis/util.cc
@@ -27,6 +27,7 @@
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/pattern_functor.h>
 
 #include "../transforms/pass_util.h"
@@ -414,5 +415,44 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, 
Type>& subst_map) {
   return ret;
 }
 
+struct IsDynamicVisitor : public TypeVisitor {
+  bool is_dyn{false};
+  void VisitType_(const TensorTypeNode* tt) {
+    for (auto dim : tt->shape) {
+      if (dim.as<Any>()) {
+        is_dyn = true;
+        break;
+      }
+    }
+  }
+};
+
+bool IsDynamic(const Type& ty) {
+  IsDynamicVisitor v;
+  v.VisitType(ty);
+  return v.is_dyn;
+}
+
+TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);
+
+bool IsDataDependant(const CallNode* call) {
+  static auto tshape_data_dependant = 
Op::GetAttr<TShapeDataDependant>("TShapeDataDependant");
+  Op op = Downcast<Op>(call->op);
+
+  if (!tshape_data_dependant.count(op)) {
+    return false;
+  }
+
+  if (op->name == "reshape") {
+    if (const auto* attrs = call->attrs.as<ReshapeAttrs>()) {
+      if (attrs->newshape) {
+        // If newshape attribute exists, it isn't data dependant.
+        return false;
+      }
+    }
+  }
+
+  return tshape_data_dependant[op];
+}
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/backend/compile_engine.cc 
b/src/relay/backend/compile_engine.cc
index 3851de1..12a5add 100644
--- a/src/relay/backend/compile_engine.cc
+++ b/src/relay/backend/compile_engine.cc
@@ -45,6 +45,7 @@
 #include <utility>
 #include <vector>
 
+#include "../transforms/pass_util.h"
 #include "utils.h"
 
 namespace tvm {
@@ -70,27 +71,6 @@ CCacheKey::CCacheKey(Function source_func, Target target) {
   data_ = std::move(n);
 }
 
-struct IsDynamicVisitor : public TypeVisitor {
-  bool is_dyn{false};
-  void VisitType_(const TensorTypeNode* tt) {
-    for (auto dim : tt->shape) {
-      if (dim.as<Any>()) {
-        is_dyn = true;
-        break;
-      }
-    }
-  }
-};
-
-bool IsDynamic(const Type& ty) {
-  IsDynamicVisitor v;
-  v.VisitType(ty);
-  return v.is_dyn;
-}
-
-// TODO(@jroesch): MOVE ME
-TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);
-
 Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
   // for now, we always use int32 shape when possible
   // even if the result of shape inference becomes int64.
@@ -485,7 +465,7 @@ class MakeShapeFunc : public 
backend::MemoizedExprTranslator<Array<te::Tensor>>
     CHECK_GT(tshape_data_dependant.count(op), 0)
         << "Internal error, cannot find TShapeDataDependant for " << op->name;
 
-    data_dependants_.push_back(tshape_data_dependant[op]);
+    data_dependants_.push_back(IsDataDependant(call_node));
     // Visit all inputs
     Array<te::Tensor> inputs;
     int count_tuple = 0;
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 15761f6..8b58946 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -447,10 +447,54 @@ RELAY_REGISTER_OP("transpose")
 /* relay.reshape */
 TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
 
+double ToScalar(const runtime::NDArray& array, int i = 0) {
+  if (array->dtype.code == kDLInt) {
+    if (array->dtype.bits == 8) {
+      return reinterpret_cast<int8_t*>(array->data)[i];
+    } else if (array->dtype.bits == 16) {
+      return reinterpret_cast<int16_t*>(array->data)[i];
+    } else if (array->dtype.bits == 32) {
+      return reinterpret_cast<int32_t*>(array->data)[i];
+    } else if (array->dtype.bits == 64) {
+      return reinterpret_cast<int64_t*>(array->data)[i];
+    }
+  } else if (array->dtype.code == kDLUInt) {
+    if (array->dtype.bits == 8) {
+      return reinterpret_cast<uint8_t*>(array->data)[i];
+    } else if (array->dtype.bits == 16) {
+      return reinterpret_cast<uint16_t*>(array->data)[i];
+    } else if (array->dtype.bits == 32) {
+      return reinterpret_cast<uint32_t*>(array->data)[i];
+    } else if (array->dtype.bits == 64) {
+      return reinterpret_cast<uint64_t*>(array->data)[i];
+    }
+  } else if (array->dtype.code == kDLFloat) {
+#if (__ARM_FP16_FORMAT_IEEE == 1)
+    if (array->dtype.bits == 16) {
+      return reinterpret_cast<__fp16*>(array->data)[i];
+    }
+#endif
+    if (array->dtype.bits == 32) {
+      return reinterpret_cast<float*>(array->data)[i];
+    } else if (array->dtype.bits == 64) {
+      return reinterpret_cast<double*>(array->data)[i];
+    }
+  }
+  LOG(FATAL) << "Unknown data type: " << 
tvm::runtime::DLDataType2String(array->dtype);
+  // make compiler happy
+  return -std::numeric_limits<double>::infinity();
+}
+
 bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                 const TypeReporter& reporter) {
-  // types: [data, result]
-  CHECK_EQ(types.size(), 2);
+  const auto* param = attrs.as<ReshapeAttrs>();
+  if (param->reverse) {
+    // types: [data, result]
+    CHECK_EQ(types.size(), 2);
+  } else {
+    // types: [data, newshape, result]
+    CHECK_EQ(types.size(), 3);
+  }
   const auto* data = types[0].as<TensorTypeNode>();
   if (data == nullptr) {
     CHECK(types[0].as<IncompleteTypeNode>())
@@ -458,17 +502,31 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
     return false;
   }
 
-  const auto* param = attrs.as<ReshapeAttrs>();
+  Array<IndexExpr> oshape;
   Array<IndexExpr> data_shape;
   Array<Integer> newshape;
-  if (param->reverse) {
-    data_shape.assign(data->shape.rbegin(), data->shape.rend());
-    newshape.assign(param->newshape.rbegin(), param->newshape.rend());
+
+  if (param->newshape) {
+    auto temp = param->newshape.value();
+    if (param->reverse) {
+      data_shape.assign(data->shape.rbegin(), data->shape.rend());
+      newshape.assign(temp.rbegin(), temp.rend());
+    } else {
+      data_shape = data->shape;
+      newshape = temp;
+    }
   } else {
-    data_shape = data->shape;
-    newshape = param->newshape;
+    const auto* newshape = types[1].as<TensorTypeNode>();
+
+    // Doesn't support dynamic output rank
+    for (int i = 0; i < newshape->shape[0].as<IntImmNode>()->value; i++) {
+      oshape.push_back(Any::make());
+    }
+
+    reporter->Assign(types[2], TensorType(oshape, data->dtype));
+    return true;
   }
-  Array<IndexExpr> oshape;
+
   std::unordered_set<size_t> used_input_dims;
   std::unordered_set<size_t> used_output_dims;
   size_t src_idx = 0;
@@ -581,7 +639,7 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
     reporter->Assign(types[1],
                      TensorType(Array<IndexExpr>(oshape.rbegin(), 
oshape.rend()), data->dtype));
   } else {
-    reporter->Assign(types[1], TensorType(oshape, data->dtype));
+    reporter->Assign(types[2], TensorType(oshape, data->dtype));
   }
   return true;
 }
@@ -601,12 +659,19 @@ Array<te::Tensor> ReshapeCompute(const Attrs& attrs, 
const Array<te::Tensor>& in
   return {topi::reshape(inputs[0], newshape)};
 }
 
-Expr MakeReshape(Expr data, Array<Integer> newshape) {
+Expr MakeReshape(Expr data, Expr newshape) {
   auto attrs = make_object<ReshapeAttrs>();
-  attrs->newshape = std::move(newshape);
+  if (const ConstantNode* c = newshape.as<ConstantNode>()) {
+    CHECK_EQ(c->data->ndim, 1);
+    Array<Integer> newshape;
+    for (int i = 0; i < c->data->shape[0]; i++) {
+      newshape.push_back(Integer(static_cast<int>(ToScalar(c->data, i))));
+    }
+    attrs->newshape = newshape;
+  }
   attrs->reverse = false;
   static const Op& op = Op::Get("reshape");
-  return Call(op, {data}, Attrs(attrs), {});
+  return Call(op, {data, newshape}, Attrs(attrs), {});
 }
 
 TVM_REGISTER_GLOBAL("relay.op._make.reshape").set_body_typed(MakeReshape);
@@ -662,9 +727,10 @@ Example::
 - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4)
 
 )code" TVM_ADD_FILELINE)
-    .set_num_inputs(1)
+    .set_num_inputs(2)
     .set_attrs_type<ReshapeAttrs>()
     .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("newshape", "Tensor", "The shape of output tensor.")
     .set_support_level(3)
     .add_type_rel("Reshape", ReshapeRel)
     .set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
@@ -1005,44 +1071,6 @@ and type as the input array.
 // arange operator
 TVM_REGISTER_NODE_TYPE(ArangeAttrs);
 
-double ToScalar(const runtime::NDArray& array) {
-  if (array->dtype.code == kDLInt) {
-    if (array->dtype.bits == 8) {
-      return reinterpret_cast<int8_t*>(array->data)[0];
-    } else if (array->dtype.bits == 16) {
-      return reinterpret_cast<int16_t*>(array->data)[0];
-    } else if (array->dtype.bits == 32) {
-      return reinterpret_cast<int32_t*>(array->data)[0];
-    } else if (array->dtype.bits == 64) {
-      return reinterpret_cast<int64_t*>(array->data)[0];
-    }
-  } else if (array->dtype.code == kDLUInt) {
-    if (array->dtype.bits == 8) {
-      return reinterpret_cast<uint8_t*>(array->data)[0];
-    } else if (array->dtype.bits == 16) {
-      return reinterpret_cast<uint16_t*>(array->data)[0];
-    } else if (array->dtype.bits == 32) {
-      return reinterpret_cast<uint32_t*>(array->data)[0];
-    } else if (array->dtype.bits == 64) {
-      return reinterpret_cast<uint64_t*>(array->data)[0];
-    }
-  } else if (array->dtype.code == kDLFloat) {
-#if (__ARM_FP16_FORMAT_IEEE == 1)
-    if (array->dtype.bits == 16) {
-      return reinterpret_cast<__fp16*>(array->data)[0];
-    }
-#endif
-    if (array->dtype.bits == 32) {
-      return reinterpret_cast<float*>(array->data)[0];
-    } else if (array->dtype.bits == 64) {
-      return reinterpret_cast<double*>(array->data)[0];
-    }
-  }
-  LOG(FATAL) << "Unknown data type: " << 
tvm::runtime::DLDataType2String(array->dtype);
-  // make compiler happy
-  return -std::numeric_limits<double>::infinity();
-}
-
 bool ArangeRel(const Array<Type>& types, int num_inputs, const Attrs& 
raw_attrs,
                const TypeReporter& reporter) {
   CHECK_EQ(types.size(), 4);
diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h
index 1d1f9c0..bc35ed6 100644
--- a/src/relay/op/tensor/transform.h
+++ b/src/relay/op/tensor/transform.h
@@ -38,7 +38,7 @@
 namespace tvm {
 namespace relay {
 
-extern Expr MakeReshape(Expr data, Array<Integer> newshape);
+extern Expr MakeReshape(Expr data, Expr newshape);
 
 template <typename AttrType>
 bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
diff --git a/src/relay/transforms/fold_scale_axis.cc 
b/src/relay/transforms/fold_scale_axis.cc
index 4c8025a..4083d08 100644
--- a/src/relay/transforms/fold_scale_axis.cc
+++ b/src/relay/transforms/fold_scale_axis.cc
@@ -329,7 +329,8 @@ static Expr ReshapeToMatchAxis(Expr scale, const 
Array<PrimExpr>& shape,
       arr.push_back(1);
     }
   }
-  return MakeReshape(scale, std::move(arr));
+  return MakeReshape(
+      scale, MakeConstantTensor(DataType::Int(32), 
{static_cast<int64_t>(arr.size())}, arr));
 }
 
 // if only one axis, use expand dim. Else, use reshape
diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc
index 0ca8d7c..054244d 100644
--- a/src/relay/transforms/fuse_ops.cc
+++ b/src/relay/transforms/fuse_ops.cc
@@ -31,6 +31,7 @@
 #include <tvm/tir/op.h>
 
 #include "../../support/arena.h"
+#include "pass_util.h"
 #include "pattern_util.h"
 
 namespace tvm {
@@ -237,7 +238,13 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
     // need to call Update, as it may be an arbitrary expression.
     OpPatternKind op_pattern = kOpaque;
     if (const OpNode* opnode = call->op.as<OpNode>()) {
-      op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
+      auto op = GetRef<Op>(opnode);
+      if (IsDynamic(call->checked_type()) && IsDataDependant(call)) {
+        // output of a shape func can't be fed to a data-dependent shape func
+        op_pattern = kOpaque;
+      } else {
+        op_pattern = static_cast<OpPatternKind>(fpattern[op]);
+      }
     } else {
       this->Update(call->op, node, kOpaque);
     }
diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h
index 32ee09f..cbdd4b4 100644
--- a/src/relay/transforms/pass_util.h
+++ b/src/relay/transforms/pass_util.h
@@ -77,6 +77,20 @@ Type TypeSubst(const Type& type, const tvm::Map<TypeVar, 
Type>& subst_map);
 Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map);
 
 /*!
+ * \brief Check if type is dynamic.
+ * \param ty The type to be checked.
+ * \return Whether the type is dynamic.
+ */
+bool IsDynamic(const Type& ty);
+
+/*!
+ * \brief Check if call is data dependant.
+ * \param call The call to be checked.
+ * \return Whether the call is data dependant.
+ */
+bool IsDataDependant(const CallNode* call);
+
+/*!
  * \brief Make arbitrary transformation preserve the out most function.
  * \param func The transformation.
  * \param e The expression
diff --git a/src/relay/transforms/pattern_util.h 
b/src/relay/transforms/pattern_util.h
index edb6a65..0a51404 100644
--- a/src/relay/transforms/pattern_util.h
+++ b/src/relay/transforms/pattern_util.h
@@ -283,6 +283,34 @@ static inline Constant MakeConstantTensor(DataType dtype, 
std::vector<int64_t> s
 }
 
 /*!
+ * \brief Create a Constant with a tensor.
+ *
+ * \param dtype The data type.
+ * \param value The array of the tensor values.
+ * \return A Constant.
+ */
+template <typename T>
+static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> 
shape,
+                                          Array<T> value) {
+  runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0});
+  TVM_DTYPE_DISPATCH(dtype, DType, {
+    for (size_t i = 0; i < value.size(); i++) {
+      if (dtype == DataType::Float(16)) {
+        // convert to float16
+        // storage is uint16_t
+        // Similar handling as that in MakeConstantScalar
+        *(static_cast<DType*>(arr->data) + i) =
+            __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
+                static_cast<float>(value[i]));
+      } else {
+        *(static_cast<DType*>(arr->data) + i) = value[i];
+      }
+    }
+  })
+  return Constant(arr);
+}
+
+/*!
  * \brief Check if two expressions are equal scalars.
  * \param a The expression to be checked.
  * \param b The expression to be checked
@@ -519,12 +547,12 @@ static inline Expr Sum(Expr data, Array<Integer> axis, 
bool keepdims, bool exclu
   return Call(op, {data}, Attrs(attrs), {});
 }
 
+Expr MakeReshape(Expr data, Expr newshape);
+
 static inline Expr Reshape(Expr data, Array<Integer> newshape) {
-  auto attrs = make_object<ReshapeAttrs>();
-  attrs->newshape = std::move(newshape);
-  attrs->reverse = false;
-  static const Op& op = Op::Get("reshape");
-  return Call(op, {data}, Attrs(attrs), {});
+  auto newshape_tensor =
+      MakeConstantTensor(DataType::Int(32), 
{static_cast<int64_t>(newshape.size())}, newshape);
+  return MakeReshape(data, newshape_tensor);
 }
 
 static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size, 
Array<IndexExpr> strides,
diff --git a/tests/cpp/relay_build_module_test.cc 
b/tests/cpp/relay_build_module_test.cc
index 33f6061..d7ce0c0 100644
--- a/tests/cpp/relay_build_module_test.cc
+++ b/tests/cpp/relay_build_module_test.cc
@@ -105,6 +105,7 @@ TEST(Relay, BuildModule) {
   }
   auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs);
   (*reg)("add", "FTVMStrategy", fgeneric, 10);
+  (*reg)("add", "TShapeDataDependant", false, 10);
   // build
   auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
   tvm::runtime::Module build_mod = (*pfb)();
diff --git a/tests/python/frontend/tensorflow/test_forward.py 
b/tests/python/frontend/tensorflow/test_forward.py
index cd6c454..c3313b6 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -747,6 +747,17 @@ def _test_reshape_like(data, shape_like):
 
         compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')
 
+def _test_reshape_symbolic(data, a_data, b_data):
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+        a = array_ops.placeholder(shape=a_data.shape, dtype=a_data.dtype)
+        b = array_ops.placeholder(shape=b_data.shape, dtype=b_data.dtype)
+        newshape = tf.add(a, b)
+        out = array_ops.reshape(in_data, newshape)
+
+        for mode in ["debug", "vm"]:
+            compare_tf_with_tvm([data, a_data, b_data], [in_data.name, a.name, 
b.name], out.name, mode=mode)
+
 def test_forward_reshape():
     _test_reshape(np.arange(6.0), [2, 3])
     _test_reshape(np.arange(6), [-1, 2])
@@ -754,6 +765,10 @@ def test_forward_reshape():
     _test_reshape(np.arange(6), [-1])
     _test_reshape_with_call()
     _test_reshape_like(np.zeros((3, 6)), np.zeros((9, 2)))
+    _test_reshape_symbolic(np.arange(6.0), np.array([2, 0]), np.array([0, 3]))
+    _test_reshape_symbolic(np.arange(6), np.array([-1, 0]), np.array([0, 2]))
+    _test_reshape_symbolic(np.arange(6), np.array([3, 0]), np.array([3, -1]))
+    _test_reshape_symbolic(np.arange(6), np.array([0]), np.array([-1]))
 
 #######################################################################
 # DepthToSpace
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 6ce59bb..c9de675 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -138,23 +138,36 @@ def test_any_concat():
         result = ex.evaluate()(x_np, y_np)
         tvm.testing.assert_allclose(result.asnumpy(), ref)
 
-def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape):
+def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, 
variable_newshape=False):
     x = relay.var('x', shape=x_shape, dtype="float32")
-    y = relay.reshape(x, newshape=newshape)
-    mod = tvm.IRModule()
-    mod["main"] = relay.Function([x], y)
+    relu_x = relay.nn.relu(x)
     data = np.random.uniform(size=x_np_shape).astype('float32')
+    params = [x]
+    args = [data]
+
+    if variable_newshape:
+        newshape_var = relay.var('newshape', shape=(len(newshape),), 
dtype='int64')
+        params.append(newshape_var)
+        args.append(np.array(newshape, dtype='int64'))
+        newshape = newshape_var
+
+    y = relay.reshape(relu_x, newshape=newshape)
+    mod = tvm.IRModule()
+    mod["main"] = relay.Function(params, y)
+
     for kind in ["debug", "vm"]:
         ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
-        result = ex.evaluate()(data).asnumpy()
+        result = ex.evaluate()(*args).asnumpy()
         assert result.shape == out_shape
         tvm.testing.assert_allclose(result.flatten(), data.flatten())
 
 def test_any_reshape():
-    verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24))
-    verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12))
+    for variable_newshape in [False, True]:
+        # Variable newshape only supports that output rank is the same as 
newshape
+        verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24), 
variable_newshape)
+        verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12), 
variable_newshape)
+        verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 
4), variable_newshape)
     verify_any_reshape(any_dims(3), (0, -2), (2, 3, 4), (2, 3, 4))
-    verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4))
     verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12))
 
 def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py
index 2334de7..e1fdfcb 100644
--- a/vta/python/vta/top/graphpack.py
+++ b/vta/python/vta/top/graphpack.py
@@ -345,9 +345,9 @@ class ExprPack(ExprMutator):
                                         method,
                                         align_corners)
             elif call.op == self.reshape and len(input_types[0].shape) == 4:
-                data, = args
+                data, _ = args
                 data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
-                return op.reshape(data, input_types[0].shape)
+                return op.reshape(data, [int(x) for x in input_types[0].shape])
 
         return relay.Call(
             self.visit(call.op),

Reply via email to