This is an automated email from the ASF dual-hosted git repository.
kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new aa42f97 [RELAY][TF] Support symbolic newshape for Reshape (#5429)
aa42f97 is described below
commit aa42f978b89cc643b40c4f2bfdf2fc42dc1aeadd
Author: lixiaoquan <[email protected]>
AuthorDate: Wed May 13 14:37:33 2020 +0800
[RELAY][TF] Support symbolic newshape for Reshape (#5429)
* [RELAY][TF] Support symbolic newshape for Reshape
* Only need to pass data
* Use MakeReshape() in Reshape()
* Change newshape to Expr
* Create a template for Array<T>
* Fuse reshape when newshape is constant
* Make newshape Optional
* Use bool() of Optional
Co-authored-by: Li Xiaoquan <[email protected]>
---
include/tvm/relay/attrs/transform.h | 2 +-
python/tvm/relay/_parser.py | 5 +-
python/tvm/relay/frontend/tensorflow.py | 13 +--
python/tvm/relay/op/_tensor_grad.py | 2 +-
python/tvm/relay/op/_transform.py | 81 +++++++++++++-
python/tvm/relay/op/transform.py | 8 +-
src/relay/analysis/util.cc | 40 +++++++
src/relay/backend/compile_engine.cc | 24 +----
src/relay/op/tensor/transform.cc | 132 ++++++++++++++---------
src/relay/op/tensor/transform.h | 2 +-
src/relay/transforms/fold_scale_axis.cc | 3 +-
src/relay/transforms/fuse_ops.cc | 9 +-
src/relay/transforms/pass_util.h | 14 +++
src/relay/transforms/pattern_util.h | 38 ++++++-
tests/cpp/relay_build_module_test.cc | 1 +
tests/python/frontend/tensorflow/test_forward.py | 15 +++
tests/python/relay/test_any.py | 29 +++--
vta/python/vta/top/graphpack.py | 4 +-
18 files changed, 312 insertions(+), 110 deletions(-)
diff --git a/include/tvm/relay/attrs/transform.h
b/include/tvm/relay/attrs/transform.h
index 84dda6f..c0e2272 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -82,7 +82,7 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs>
{
/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
- Array<Integer> newshape;
+ Optional<Array<Integer>> newshape;
bool reverse;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape).describe(
diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py
index 7731efe..1d97b55 100644
--- a/python/tvm/relay/_parser.py
+++ b/python/tvm/relay/_parser.py
@@ -114,7 +114,10 @@ class FuncOp(OpWrapper):
def __call__(self, args, attrs, type_args):
if attrs is None:
attrs = {}
- x = self.operator(*args, **{k: self.convert(v) for k, v in
attrs.items()})
+ if self.operator is op.reshape:
+ x = self.operator(*args)
+ else:
+ x = self.operator(*args, **{k: self.convert(v) for k, v in
attrs.items()})
if isinstance(x, expr.TupleWrapper):
x = x.astuple()
return x
diff --git a/python/tvm/relay/frontend/tensorflow.py
b/python/tvm/relay/frontend/tensorflow.py
index 913a016..ab9e9e6 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1155,14 +1155,11 @@ def _reshape():
shape_arg =
tuple(params_new.asnumpy().astype('int64').flatten())
except Exception:
# Deal with symbolic shape case.
- # Currently only shape_of can be the direct ancestor.
- if not isinstance(pop_node, tvm.relay.expr.Call) or \
- "shape_of" not in str(pop_node.op):
- raise RuntimeError("If shape operator is used in reshape
to "
- "express reshape_like, shape_of must be
"
- "the direct ancestor of reshape when
input "
- "shape is symbolic.")
- return _op.reshape_like(inputs[0], pop_node.args[0])
+ if isinstance(pop_node, _expr.Call) and \
+ "shape_of" in str(pop_node.op):
+ # shape_of is the direct ancestor.
+ return _op.reshape_like(inputs[0], pop_node.args[0])
+ shape_arg = pop_node
return AttrCvt(
op_name="reshape",
extras={'newshape': shape_arg},
diff --git a/python/tvm/relay/op/_tensor_grad.py
b/python/tvm/relay/op/_tensor_grad.py
index c034bcc..8be3358 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -479,7 +479,7 @@ def dense_grad(orig, grad):
@register_gradient("reshape")
def reshape_grad(orig, grad):
"""Gradient of reshape"""
- return [reshape_like(grad, orig.args[0])]
+ return [reshape_like(grad, orig.args[0]), orig.args[1]]
@register_gradient("cast")
diff --git a/python/tvm/relay/op/_transform.py
b/python/tvm/relay/op/_transform.py
index ee23fce..43d8d62 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -123,7 +123,7 @@ def concatenate_shape_func(attrs, inputs, _):
return [_concatenate_shape_func(inputs, convert(axis))]
@script
-def _reshape_shape_func(data_shape, newshape, ndim):
+def _reshape_shape_func_input_shape(data_shape, newshape, ndim):
out = output_tensor((ndim,), "int64")
src_idx = 0
dst_idx = 0
@@ -189,10 +189,83 @@ def _reshape_shape_func(data_shape, newshape, ndim):
out[infer_idx] = old_size // new_size
return out
-@_reg.register_shape_func("reshape", False)
+@script
+def _reshape_shape_func_input_data(data, newshape, ndim):
+ out = output_tensor((ndim,), "int64")
+ data_shape = allocate((len(data.shape),), "int64")
+ for x in const_range(len(data.shape)):
+ data_shape[x] = int64(data.shape[x])
+ src_idx = 0
+ dst_idx = 0
+ infer_idx = -1
+ copy = False
+ skip = 0
+ for i in const_range(len(newshape)):
+ if skip > 0:
+ skip -= 1
+ elif newshape[i] > 0:
+ out[dst_idx] = int64(newshape[i])
+ src_idx += 1
+ dst_idx += 1
+ elif newshape[i] == 0:
+ out[dst_idx] = data_shape[src_idx]
+ src_idx += 1
+ dst_idx += 1
+ elif newshape[i] == -1:
+ assert infer_idx < 0, "One and only one dim can be inferred"
+ out[dst_idx] = int64(1)
+ infer_idx = i
+ dst_idx += 1
+ elif newshape[i] == -2:
+ copy = True
+ elif newshape[i] == -3:
+ assert data_shape.shape[0] - src_idx > 1, \
+ "Not enough dims in input shape for -3"
+ out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
+ src_idx += 2
+ dst_idx += 1
+ elif newshape[i] == -4:
+ assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
+ if newshape[i+1] == -1:
+ assert newshape[i+2] != -1, "Split dims cannot both be -1."
+ out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2])
+ out[dst_idx+1] = int64(newshape[i+2])
+ else:
+ out[dst_idx] = int64(newshape[i+1])
+ if newshape[i+2] == -1:
+ out[dst_idx+1] = data_shape[src_idx] //
int64(newshape[i+1])
+ else:
+ out[dst_idx+1] = int64(newshape[i+2])
+ assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
+ "Product of split dims doesn't match to input dim"
+ src_idx += 1
+ dst_idx += 2
+ skip = 2
+ else:
+ assert False, "Invalid special values in new shape"
+ if len(data_shape.shape) > 0:
+ # if data is not constant, we can then handle -1 and -2
+ if copy:
+ for i in range(src_idx, data_shape.shape[0]):
+ out[dst_idx] = data_shape[i]
+ dst_idx += 1
+ if infer_idx >= 0:
+ old_size = int64(1)
+ for i in const_range(data_shape.shape[0]):
+ old_size *= data_shape[i]
+ new_size = int64(1)
+ for i in const_range(out.shape[0]):
+ new_size *= out[i]
+ out[infer_idx] = old_size // new_size
+ return out
+
+@_reg.register_shape_func("reshape", True)
def reshape_shape_func(attrs, inputs, out_ndims):
- newshape = get_const_tuple(attrs.newshape)
- return [_reshape_shape_func(inputs[0], convert(newshape), out_ndims[0])]
+ if attrs.newshape is None:
+ return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]
+ return [_reshape_shape_func_input_shape(inputs[0],
+ convert(attrs.newshape),
+ out_ndims[0])]
@script
def _take_no_axis_shape_func(indices_shape, out_ndim):
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 4e9bb45..2d9e4ba 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -201,7 +201,7 @@ def reshape(data, newshape):
data : relay.Expr
The input data to the operator.
- newshape : Union[int, Tuple[int], List[int]]
+ newshape : Union[int, Tuple[int], List[int]] or relay.Expr
The new shape. Should be compatible with the original shape.
Returns
@@ -210,8 +210,10 @@ def reshape(data, newshape):
The reshaped result.
"""
if isinstance(newshape, int):
- newshape = [newshape]
- return _make.reshape(data, list(newshape))
+ newshape = const([newshape])
+ if isinstance(newshape, (tuple, list)):
+ newshape = const(list(newshape))
+ return _make.reshape(data, newshape)
def argwhere(condition):
"""Find the indices of elements of a tensor that are
diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc
index 1d84016..af23836 100644
--- a/src/relay/analysis/util.cc
+++ b/src/relay/analysis/util.cc
@@ -27,6 +27,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/pattern_functor.h>
#include "../transforms/pass_util.h"
@@ -414,5 +415,44 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar,
Type>& subst_map) {
return ret;
}
+struct IsDynamicVisitor : public TypeVisitor {
+ bool is_dyn{false};
+ void VisitType_(const TensorTypeNode* tt) {
+ for (auto dim : tt->shape) {
+ if (dim.as<Any>()) {
+ is_dyn = true;
+ break;
+ }
+ }
+ }
+};
+
+bool IsDynamic(const Type& ty) {
+ IsDynamicVisitor v;
+ v.VisitType(ty);
+ return v.is_dyn;
+}
+
+TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);
+
+bool IsDataDependant(const CallNode* call) {
+ static auto tshape_data_dependant =
Op::GetAttr<TShapeDataDependant>("TShapeDataDependant");
+ Op op = Downcast<Op>(call->op);
+
+ if (!tshape_data_dependant.count(op)) {
+ return false;
+ }
+
+ if (op->name == "reshape") {
+ if (const auto* attrs = call->attrs.as<ReshapeAttrs>()) {
+ if (attrs->newshape) {
+ // If newshape attribute exists, it isn't data dependant.
+ return false;
+ }
+ }
+ }
+
+ return tshape_data_dependant[op];
+}
} // namespace relay
} // namespace tvm
diff --git a/src/relay/backend/compile_engine.cc
b/src/relay/backend/compile_engine.cc
index 3851de1..12a5add 100644
--- a/src/relay/backend/compile_engine.cc
+++ b/src/relay/backend/compile_engine.cc
@@ -45,6 +45,7 @@
#include <utility>
#include <vector>
+#include "../transforms/pass_util.h"
#include "utils.h"
namespace tvm {
@@ -70,27 +71,6 @@ CCacheKey::CCacheKey(Function source_func, Target target) {
data_ = std::move(n);
}
-struct IsDynamicVisitor : public TypeVisitor {
- bool is_dyn{false};
- void VisitType_(const TensorTypeNode* tt) {
- for (auto dim : tt->shape) {
- if (dim.as<Any>()) {
- is_dyn = true;
- break;
- }
- }
- }
-};
-
-bool IsDynamic(const Type& ty) {
- IsDynamicVisitor v;
- v.VisitType(ty);
- return v.is_dyn;
-}
-
-// TODO(@jroesch): MOVE ME
-TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);
-
Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// for now, we always use int32 shape when possible
// even if the result of shape inference becomes int64.
@@ -485,7 +465,7 @@ class MakeShapeFunc : public
backend::MemoizedExprTranslator<Array<te::Tensor>>
CHECK_GT(tshape_data_dependant.count(op), 0)
<< "Internal error, cannot find TShapeDataDependant for " << op->name;
- data_dependants_.push_back(tshape_data_dependant[op]);
+ data_dependants_.push_back(IsDataDependant(call_node));
// Visit all inputs
Array<te::Tensor> inputs;
int count_tuple = 0;
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 15761f6..8b58946 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -447,10 +447,54 @@ RELAY_REGISTER_OP("transpose")
/* relay.reshape */
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
+double ToScalar(const runtime::NDArray& array, int i = 0) {
+ if (array->dtype.code == kDLInt) {
+ if (array->dtype.bits == 8) {
+ return reinterpret_cast<int8_t*>(array->data)[i];
+ } else if (array->dtype.bits == 16) {
+ return reinterpret_cast<int16_t*>(array->data)[i];
+ } else if (array->dtype.bits == 32) {
+ return reinterpret_cast<int32_t*>(array->data)[i];
+ } else if (array->dtype.bits == 64) {
+ return reinterpret_cast<int64_t*>(array->data)[i];
+ }
+ } else if (array->dtype.code == kDLUInt) {
+ if (array->dtype.bits == 8) {
+ return reinterpret_cast<uint8_t*>(array->data)[i];
+ } else if (array->dtype.bits == 16) {
+ return reinterpret_cast<uint16_t*>(array->data)[i];
+ } else if (array->dtype.bits == 32) {
+ return reinterpret_cast<uint32_t*>(array->data)[i];
+ } else if (array->dtype.bits == 64) {
+ return reinterpret_cast<uint64_t*>(array->data)[i];
+ }
+ } else if (array->dtype.code == kDLFloat) {
+#if (__ARM_FP16_FORMAT_IEEE == 1)
+ if (array->dtype.bits == 16) {
+ return reinterpret_cast<__fp16*>(array->data)[i];
+ }
+#endif
+ if (array->dtype.bits == 32) {
+ return reinterpret_cast<float*>(array->data)[i];
+ } else if (array->dtype.bits == 64) {
+ return reinterpret_cast<double*>(array->data)[i];
+ }
+ }
+ LOG(FATAL) << "Unknown data type: " <<
tvm::runtime::DLDataType2String(array->dtype);
+ // make compiler happy
+ return -std::numeric_limits<double>::infinity();
+}
+
bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
- // types: [data, result]
- CHECK_EQ(types.size(), 2);
+ const auto* param = attrs.as<ReshapeAttrs>();
+ if (param->reverse) {
+ // types: [data, result]
+ CHECK_EQ(types.size(), 2);
+ } else {
+ // types: [data, newshape, result]
+ CHECK_EQ(types.size(), 3);
+ }
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
@@ -458,17 +502,31 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
return false;
}
- const auto* param = attrs.as<ReshapeAttrs>();
+ Array<IndexExpr> oshape;
Array<IndexExpr> data_shape;
Array<Integer> newshape;
- if (param->reverse) {
- data_shape.assign(data->shape.rbegin(), data->shape.rend());
- newshape.assign(param->newshape.rbegin(), param->newshape.rend());
+
+ if (param->newshape) {
+ auto temp = param->newshape.value();
+ if (param->reverse) {
+ data_shape.assign(data->shape.rbegin(), data->shape.rend());
+ newshape.assign(temp.rbegin(), temp.rend());
+ } else {
+ data_shape = data->shape;
+ newshape = temp;
+ }
} else {
- data_shape = data->shape;
- newshape = param->newshape;
+ const auto* newshape = types[1].as<TensorTypeNode>();
+
+ // Doesn't support dynamic output rank
+ for (int i = 0; i < newshape->shape[0].as<IntImmNode>()->value; i++) {
+ oshape.push_back(Any::make());
+ }
+
+ reporter->Assign(types[2], TensorType(oshape, data->dtype));
+ return true;
}
- Array<IndexExpr> oshape;
+
std::unordered_set<size_t> used_input_dims;
std::unordered_set<size_t> used_output_dims;
size_t src_idx = 0;
@@ -581,7 +639,7 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
reporter->Assign(types[1],
TensorType(Array<IndexExpr>(oshape.rbegin(),
oshape.rend()), data->dtype));
} else {
- reporter->Assign(types[1], TensorType(oshape, data->dtype));
+ reporter->Assign(types[2], TensorType(oshape, data->dtype));
}
return true;
}
@@ -601,12 +659,19 @@ Array<te::Tensor> ReshapeCompute(const Attrs& attrs,
const Array<te::Tensor>& in
return {topi::reshape(inputs[0], newshape)};
}
-Expr MakeReshape(Expr data, Array<Integer> newshape) {
+Expr MakeReshape(Expr data, Expr newshape) {
auto attrs = make_object<ReshapeAttrs>();
- attrs->newshape = std::move(newshape);
+ if (const ConstantNode* c = newshape.as<ConstantNode>()) {
+ CHECK_EQ(c->data->ndim, 1);
+ Array<Integer> newshape;
+ for (int i = 0; i < c->data->shape[0]; i++) {
+ newshape.push_back(Integer(static_cast<int>(ToScalar(c->data, i))));
+ }
+ attrs->newshape = newshape;
+ }
attrs->reverse = false;
static const Op& op = Op::Get("reshape");
- return Call(op, {data}, Attrs(attrs), {});
+ return Call(op, {data, newshape}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op._make.reshape").set_body_typed(MakeReshape);
@@ -662,9 +727,10 @@ Example::
- data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4)
)code" TVM_ADD_FILELINE)
- .set_num_inputs(1)
+ .set_num_inputs(2)
.set_attrs_type<ReshapeAttrs>()
.add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("newshape", "Tensor", "The shape of output tensor.")
.set_support_level(3)
.add_type_rel("Reshape", ReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
@@ -1005,44 +1071,6 @@ and type as the input array.
// arange operator
TVM_REGISTER_NODE_TYPE(ArangeAttrs);
-double ToScalar(const runtime::NDArray& array) {
- if (array->dtype.code == kDLInt) {
- if (array->dtype.bits == 8) {
- return reinterpret_cast<int8_t*>(array->data)[0];
- } else if (array->dtype.bits == 16) {
- return reinterpret_cast<int16_t*>(array->data)[0];
- } else if (array->dtype.bits == 32) {
- return reinterpret_cast<int32_t*>(array->data)[0];
- } else if (array->dtype.bits == 64) {
- return reinterpret_cast<int64_t*>(array->data)[0];
- }
- } else if (array->dtype.code == kDLUInt) {
- if (array->dtype.bits == 8) {
- return reinterpret_cast<uint8_t*>(array->data)[0];
- } else if (array->dtype.bits == 16) {
- return reinterpret_cast<uint16_t*>(array->data)[0];
- } else if (array->dtype.bits == 32) {
- return reinterpret_cast<uint32_t*>(array->data)[0];
- } else if (array->dtype.bits == 64) {
- return reinterpret_cast<uint64_t*>(array->data)[0];
- }
- } else if (array->dtype.code == kDLFloat) {
-#if (__ARM_FP16_FORMAT_IEEE == 1)
- if (array->dtype.bits == 16) {
- return reinterpret_cast<__fp16*>(array->data)[0];
- }
-#endif
- if (array->dtype.bits == 32) {
- return reinterpret_cast<float*>(array->data)[0];
- } else if (array->dtype.bits == 64) {
- return reinterpret_cast<double*>(array->data)[0];
- }
- }
- LOG(FATAL) << "Unknown data type: " <<
tvm::runtime::DLDataType2String(array->dtype);
- // make compiler happy
- return -std::numeric_limits<double>::infinity();
-}
-
bool ArangeRel(const Array<Type>& types, int num_inputs, const Attrs&
raw_attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h
index 1d1f9c0..bc35ed6 100644
--- a/src/relay/op/tensor/transform.h
+++ b/src/relay/op/tensor/transform.h
@@ -38,7 +38,7 @@
namespace tvm {
namespace relay {
-extern Expr MakeReshape(Expr data, Array<Integer> newshape);
+extern Expr MakeReshape(Expr data, Expr newshape);
template <typename AttrType>
bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs&
attrs,
diff --git a/src/relay/transforms/fold_scale_axis.cc
b/src/relay/transforms/fold_scale_axis.cc
index 4c8025a..4083d08 100644
--- a/src/relay/transforms/fold_scale_axis.cc
+++ b/src/relay/transforms/fold_scale_axis.cc
@@ -329,7 +329,8 @@ static Expr ReshapeToMatchAxis(Expr scale, const
Array<PrimExpr>& shape,
arr.push_back(1);
}
}
- return MakeReshape(scale, std::move(arr));
+ return MakeReshape(
+ scale, MakeConstantTensor(DataType::Int(32),
{static_cast<int64_t>(arr.size())}, arr));
}
// if only one axis, use expand dim. Else, use reshape
diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc
index 0ca8d7c..054244d 100644
--- a/src/relay/transforms/fuse_ops.cc
+++ b/src/relay/transforms/fuse_ops.cc
@@ -31,6 +31,7 @@
#include <tvm/tir/op.h>
#include "../../support/arena.h"
+#include "pass_util.h"
#include "pattern_util.h"
namespace tvm {
@@ -237,7 +238,13 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
// need to call Update, as it may be an arbitrary expression.
OpPatternKind op_pattern = kOpaque;
if (const OpNode* opnode = call->op.as<OpNode>()) {
- op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
+ auto op = GetRef<Op>(opnode);
+ if (IsDynamic(call->checked_type()) && IsDataDependant(call)) {
+ // output of a shape func can't be fed to a data-dependent shape func
+ op_pattern = kOpaque;
+ } else {
+ op_pattern = static_cast<OpPatternKind>(fpattern[op]);
+ }
} else {
this->Update(call->op, node, kOpaque);
}
diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h
index 32ee09f..cbdd4b4 100644
--- a/src/relay/transforms/pass_util.h
+++ b/src/relay/transforms/pass_util.h
@@ -77,6 +77,20 @@ Type TypeSubst(const Type& type, const tvm::Map<TypeVar,
Type>& subst_map);
Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map);
/*!
+ * \brief Check if type is dynamic.
+ * \param ty The type to be checked.
+ * \return Whether the type is dynamic.
+ */
+bool IsDynamic(const Type& ty);
+
+/*!
+ * \brief Check if call is data dependant.
+ * \param call The call to be checked.
+ * \return Whether the call is data dependant.
+ */
+bool IsDataDependant(const CallNode* call);
+
+/*!
* \brief Make arbitrary transformation preserve the out most function.
* \param func The transformation.
* \param e The expression
diff --git a/src/relay/transforms/pattern_util.h
b/src/relay/transforms/pattern_util.h
index edb6a65..0a51404 100644
--- a/src/relay/transforms/pattern_util.h
+++ b/src/relay/transforms/pattern_util.h
@@ -283,6 +283,34 @@ static inline Constant MakeConstantTensor(DataType dtype,
std::vector<int64_t> s
}
/*!
+ * \brief Create a Constant with a tensor.
+ *
+ * \param dtype The data type.
+ * \param value The array of the tensor values.
+ * \return A Constant.
+ */
+template <typename T>
+static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t>
shape,
+ Array<T> value) {
+ runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0});
+ TVM_DTYPE_DISPATCH(dtype, DType, {
+ for (size_t i = 0; i < value.size(); i++) {
+ if (dtype == DataType::Float(16)) {
+ // convert to float16
+ // storage is uint16_t
+ // Similar handling as that in MakeConstantScalar
+ *(static_cast<DType*>(arr->data) + i) =
+ __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
+ static_cast<float>(value[i]));
+ } else {
+ *(static_cast<DType*>(arr->data) + i) = value[i];
+ }
+ }
+ })
+ return Constant(arr);
+}
+
+/*!
* \brief Check if two expressions are equal scalars.
* \param a The expression to be checked.
* \param b The expression to be checked
@@ -519,12 +547,12 @@ static inline Expr Sum(Expr data, Array<Integer> axis,
bool keepdims, bool exclu
return Call(op, {data}, Attrs(attrs), {});
}
+Expr MakeReshape(Expr data, Expr newshape);
+
static inline Expr Reshape(Expr data, Array<Integer> newshape) {
- auto attrs = make_object<ReshapeAttrs>();
- attrs->newshape = std::move(newshape);
- attrs->reverse = false;
- static const Op& op = Op::Get("reshape");
- return Call(op, {data}, Attrs(attrs), {});
+ auto newshape_tensor =
+ MakeConstantTensor(DataType::Int(32),
{static_cast<int64_t>(newshape.size())}, newshape);
+ return MakeReshape(data, newshape_tensor);
}
static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size,
Array<IndexExpr> strides,
diff --git a/tests/cpp/relay_build_module_test.cc
b/tests/cpp/relay_build_module_test.cc
index 33f6061..d7ce0c0 100644
--- a/tests/cpp/relay_build_module_test.cc
+++ b/tests/cpp/relay_build_module_test.cc
@@ -105,6 +105,7 @@ TEST(Relay, BuildModule) {
}
auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs);
(*reg)("add", "FTVMStrategy", fgeneric, 10);
+ (*reg)("add", "TShapeDataDependant", false, 10);
// build
auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
tvm::runtime::Module build_mod = (*pfb)();
diff --git a/tests/python/frontend/tensorflow/test_forward.py
b/tests/python/frontend/tensorflow/test_forward.py
index cd6c454..c3313b6 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -747,6 +747,17 @@ def _test_reshape_like(data, shape_like):
compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')
+def _test_reshape_symbolic(data, a_data, b_data):
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+ a = array_ops.placeholder(shape=a_data.shape, dtype=a_data.dtype)
+ b = array_ops.placeholder(shape=b_data.shape, dtype=b_data.dtype)
+ newshape = tf.add(a, b)
+ out = array_ops.reshape(in_data, newshape)
+
+ for mode in ["debug", "vm"]:
+ compare_tf_with_tvm([data, a_data, b_data], [in_data.name, a.name,
b.name], out.name, mode=mode)
+
def test_forward_reshape():
_test_reshape(np.arange(6.0), [2, 3])
_test_reshape(np.arange(6), [-1, 2])
@@ -754,6 +765,10 @@ def test_forward_reshape():
_test_reshape(np.arange(6), [-1])
_test_reshape_with_call()
_test_reshape_like(np.zeros((3, 6)), np.zeros((9, 2)))
+ _test_reshape_symbolic(np.arange(6.0), np.array([2, 0]), np.array([0, 3]))
+ _test_reshape_symbolic(np.arange(6), np.array([-1, 0]), np.array([0, 2]))
+ _test_reshape_symbolic(np.arange(6), np.array([3, 0]), np.array([3, -1]))
+ _test_reshape_symbolic(np.arange(6), np.array([0]), np.array([-1]))
#######################################################################
# DepthToSpace
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 6ce59bb..c9de675 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -138,23 +138,36 @@ def test_any_concat():
result = ex.evaluate()(x_np, y_np)
tvm.testing.assert_allclose(result.asnumpy(), ref)
-def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape):
+def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape,
variable_newshape=False):
x = relay.var('x', shape=x_shape, dtype="float32")
- y = relay.reshape(x, newshape=newshape)
- mod = tvm.IRModule()
- mod["main"] = relay.Function([x], y)
+ relu_x = relay.nn.relu(x)
data = np.random.uniform(size=x_np_shape).astype('float32')
+ params = [x]
+ args = [data]
+
+ if variable_newshape:
+ newshape_var = relay.var('newshape', shape=(len(newshape),),
dtype='int64')
+ params.append(newshape_var)
+ args.append(np.array(newshape, dtype='int64'))
+ newshape = newshape_var
+
+ y = relay.reshape(relu_x, newshape=newshape)
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function(params, y)
+
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
- result = ex.evaluate()(data).asnumpy()
+ result = ex.evaluate()(*args).asnumpy()
assert result.shape == out_shape
tvm.testing.assert_allclose(result.flatten(), data.flatten())
def test_any_reshape():
- verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24))
- verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12))
+ for variable_newshape in [False, True]:
+ # Variable newshape only supports that output rank is the same as
newshape
+ verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24),
variable_newshape)
+ verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12),
variable_newshape)
+ verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3,
4), variable_newshape)
verify_any_reshape(any_dims(3), (0, -2), (2, 3, 4), (2, 3, 4))
- verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4))
verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12))
def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py
index 2334de7..e1fdfcb 100644
--- a/vta/python/vta/top/graphpack.py
+++ b/vta/python/vta/top/graphpack.py
@@ -345,9 +345,9 @@ class ExprPack(ExprMutator):
method,
align_corners)
elif call.op == self.reshape and len(input_types[0].shape) == 4:
- data, = args
+ data, _ = args
data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
- return op.reshape(data, input_types[0].shape)
+ return op.reshape(data, [int(x) for x in input_types[0].shape])
return relay.Call(
self.visit(call.op),