This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 780d6e6d12 [Unity] Allow specifying struct_info for relax constant
(#15220)
780d6e6d12 is described below
commit 780d6e6d12f7d20625a3c532847bacfe49a4f3ec
Author: Hongyi Jin <[email protected]>
AuthorDate: Tue Jul 4 05:29:04 2023 -0700
[Unity] Allow specifying struct_info for relax constant (#15220)
add struct_info param for constant
---
include/tvm/relax/expr.h | 14 ++++++++++----
python/tvm/relax/expr.py | 8 ++++++--
src/relax/ir/expr.cc | 21 ++++++++++++++-------
3 files changed, 30 insertions(+), 13 deletions(-)
diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index 36a8109c35..2d1e805a41 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -550,10 +550,13 @@ class ConstantNode : public LeafExprNode {
bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
// struct info can be deterministically derived from data.
- return equal(data, other->data);
+ return equal(data, other->data) && equal(struct_info_,
other->struct_info_);
}
- void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(data);
+ hash_reduce(struct_info_);
+ }
static constexpr const char* _type_key = "relax.expr.Constant";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, LeafExprNode);
@@ -564,9 +567,12 @@ class Constant : public LeafExpr {
/*!
* \brief The constructor
* \param data The data of the constant tensor.
- * \param span The source span of the expression.
+ * \param struct_info_annotation The struct info of the constant tensor. If
not specified, infer
+ * it from data. \param span The source span of the expression.
*/
- TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span());
+ TVM_DLL explicit Constant(runtime::NDArray data,
+ Optional<StructInfo> struct_info_annotation =
NullOpt,
+ Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Constant, LeafExpr, ConstantNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode);
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 22e5cbcddd..1db873a472 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -381,8 +381,12 @@ def make_shape(shape: Union[List[Any], typing.Tuple[Any,
...]]) -> ShapeExpr:
@tvm._ffi.register_object("relax.expr.Constant")
class Constant(ExprWithOp):
- def __init__(self, data: tvm.nd.NDArray, span: Span = None) -> None:
- self.__init_handle_by_constructor__(_ffi_api.Constant, data, span) #
type: ignore
+ def __init__(
+ self, data: tvm.nd.NDArray, struct_info: Optional[StructInfo] = None,
span: Span = None
+ ) -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.Constant, data, struct_info, span
+ ) # type: ignore
@tvm._ffi.register_object("relax.expr.Var")
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index 3dafc0ddef..ccff18cd40 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -277,7 +277,7 @@ TVM_REGISTER_GLOBAL("relax.DataflowVarFromId")
return DataflowVar(vid, struct_info_annotation, span);
});
-Constant::Constant(runtime::NDArray data, Span span) {
+Constant::Constant(runtime::NDArray data, Optional<StructInfo>
struct_info_annotation, Span span) {
ObjectPtr<ConstantNode> n = make_object<ConstantNode>();
n->data = std::move(data);
n->span = std::move(span);
@@ -288,18 +288,25 @@ Constant::Constant(runtime::NDArray data, Span span) {
for (size_t dim = 0; dim < shape_tuple.size(); ++dim) {
values.push_back(IntImm(DataType::Int(64), shape_tuple[dim]));
}
- TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), span);
+ if (struct_info_annotation.defined()) {
+ n->struct_info_ = struct_info_annotation.value();
+ n->checked_type_ = GetStaticType(struct_info_annotation.value());
+ } else {
+ TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), span);
+ n->struct_info_ = tinfo;
+ n->checked_type_ = DynTensorType(tinfo->ndim, tinfo->dtype);
+ }
- n->struct_info_ = tinfo;
- n->checked_type_ = DynTensorType(tinfo->ndim, tinfo->dtype);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(ConstantNode);
-TVM_REGISTER_GLOBAL("relax.Constant").set_body_typed([](runtime::NDArray data,
Span span = Span()) {
- return Constant(data, span);
-});
+TVM_REGISTER_GLOBAL("relax.Constant")
+ .set_body_typed([](runtime::NDArray data, Optional<StructInfo>
struct_info_annotation = NullOpt,
+ Span span = Span()) {
+ return Constant(data, struct_info_annotation, span);
+ });
PrimValue::PrimValue(PrimExpr value, Span span) {
ObjectPtr<PrimValueNode> n = make_object<PrimValueNode>();