This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 83e1f07d82 [Unity][IR] First-class StructInfo (#13907)
83e1f07d82 is described below
commit 83e1f07d82acd721df515e27ef12693074d79997
Author: Yuchen Jin <[email protected]>
AuthorDate: Thu Feb 2 23:32:15 2023 -0500
[Unity][IR] First-class StructInfo (#13907)
* [Unity][IR] First-class StructInfo
Relax tracks structural information (such as tensor shape) via `StructInfo`
about the values in Relax.
* Fix rust build
---------
Co-authored-by: Junru Shao <[email protected]>
---
CMakeLists.txt | 1 +
include/tvm/relax/struct_info.h | 430 ++++++++++++++++++++++++++++++++++++++++
rust/tvm/src/ir/relay/mod.rs | 2 +
src/relax/ir/struct_info.cc | 238 ++++++++++++++++++++++
4 files changed, 671 insertions(+)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e55b7174dc..88bf6472ce 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -289,6 +289,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/driver/*.cc
src/support/*.cc
src/script/*.cc
+ src/relax/ir/*.cc
src/relax/backend/vm/*.cc
)
diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h
new file mode 100644
index 0000000000..d21c8db86b
--- /dev/null
+++ b/include/tvm/relax/struct_info.h
@@ -0,0 +1,430 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_RELAX_STRUCT_INFO_H_
+#define TVM_RELAX_STRUCT_INFO_H_
+
+#include <tvm/ir/env_func.h>
+#include <tvm/ir/source_map.h>
+#include <tvm/node/node.h>
+// #include <tvm/relax/block_builder.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/type.h>
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Opaque object.
+ */
+class ObjectStructInfoNode : public StructInfoNode {
+ public:
+ void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
+
+ bool SEqualReduce(const ObjectStructInfoNode* other, SEqualReducer equal)
const { return true; }
+
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }
+
+ static constexpr const char* _type_key = "relax.ObjectStructInfo";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ObjectStructInfoNode, StructInfoNode);
+};
+
+/*!
+ * \brief Managed reference to ObjectStructInfoNode.
+ * \sa ObjectStructInfoNode
+ */
+class ObjectStructInfo : public StructInfo {
+ public:
+ TVM_DLL ObjectStructInfo(Span span = Span());
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectStructInfo, StructInfo,
ObjectStructInfoNode);
+};
+
+/*!
+ * \brief Primitive value.
+ */
+class PrimStructInfoNode : public StructInfoNode {
+ public:
+ /*! \brief Underlying data type of the primitive value */
+ DataType dtype;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("dtype", &dtype);
+ v->Visit("span", &span);
+ }
+
+ bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal)
const {
+ return equal(dtype, other->dtype);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); }
+
+ static constexpr const char* _type_key = "relax.PrimStructInfo";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode);
+};
+
+/*!
+ * \brief Managed reference to PrimStructInfoNode.
+ * \sa PrimStructInfoNode
+ */
+class PrimStructInfo : public StructInfo {
+ public:
+ TVM_DLL PrimStructInfo(DataType dtype, Span span = Span());
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo,
PrimStructInfoNode);
+};
+
+/*!
+ * \brief StructInfo of shape value.
+ */
+class ShapeStructInfoNode : public StructInfoNode {
+ public:
+ /*! \brief optionally stores the symbolic value patterns of the shape */
+ Optional<Array<PrimExpr>> values;
+ /*!
+ * \brief The number of dimension of the shape, can be unknown.
+ * \sa kUnknownNDim
+ */
+ int ndim;
+
+ /*! \return Whether the struct info contains unknown ndim. */
+ bool IsUnknownNdim() const { return ndim == kUnknownNDim; }
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("values", &values);
+ v->Visit("ndim", &ndim);
+ v->Visit("span", &span);
+ }
+
+ bool SEqualReduce(const ShapeStructInfoNode* other, SEqualReducer equal)
const {
+ return equal(values, other->values) && equal(ndim, other->ndim);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(values);
+ hash_reduce(ndim);
+ }
+
+ static constexpr const char* _type_key = "relax.ShapeStructInfo";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ShapeStructInfoNode, StructInfoNode);
+};
+
+/*!
+ * \brief Managed reference to ShapeStructInfoNode.
+ * \sa ShapeStructInfoNode
+ */
+class ShapeStructInfo : public StructInfo {
+ public:
+ /*!
+ * \brief Construction with known symbolic shape patterns
+ * \param values The symbolic shape values
+ * \param span The span of the AST.
+ */
+ TVM_DLL ShapeStructInfo(Array<PrimExpr> values, Span span = Span());
+ /*!
+ * \brief Construction with known unknown symbolic shape patterns.
+ * \param ndim Number of dimensions -- can be kUnknownNDim
+ * \param span The span of the AST.
+ */
+ TVM_DLL ShapeStructInfo(int ndim, Span span = Span());
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeStructInfo, StructInfo,
ShapeStructInfoNode);
+};
+
+/*!
+ * \brief StructInfo of Tensor.
+ */
+class TensorStructInfoNode : public StructInfoNode {
+ public:
+ /*!
+ * \brief optionally store the shape expression of the tensor.
+ * \note shape must be normalized: it can only be NullOpt or ShapeExpr or
Var.
+ */
+ Optional<Expr> shape;
+ /*! \brief The content data type, use void to denote the dtype is unknown. */
+ DataType dtype;
+ /*!
+ * \brief The number of dimension of the tensor, can be unknown.
+ * \sa kUnknownNDim
+ */
+ int ndim;
+
+ /*! \return Whether the struct info contains unknown ndim. */
+ bool IsUnknownNdim() const { return ndim == kUnknownNDim; }
+
+ /*! \return Whether the struct info contains unknown dtype. */
+ bool IsUnknownDtype() const { return dtype.is_void(); }
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("shape", &shape);
+ v->Visit("dtype", &dtype);
+ v->Visit("ndim", &ndim);
+ v->Visit("span", &span);
+ }
+
+ bool SEqualReduce(const TensorStructInfoNode* other, SEqualReducer equal)
const {
+ return equal(shape, other->shape) && equal(ndim, other->ndim) &&
equal(dtype, other->dtype);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(shape);
+ hash_reduce(dtype);
+ hash_reduce(ndim);
+ }
+
+ static constexpr const char* _type_key = "relax.TensorStructInfo";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TensorStructInfoNode, StructInfoNode);
+};
+
+/*!
+ * \brief Managed reference to TensorStructInfoNode.
+ * \sa TensorStructInfoNode
+ */
+class TensorStructInfo : public StructInfo {
+ public:
+ /*!
+ * \brief Construction with a known shape expression.
+ * \param shape The shape of the tensor.
+ * \param dtype The data type of tensor's elements.
+ * \param span The span of the AST.
+ *
+ * \note shape must already be normalized.
+ */
+ TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Span span = Span());
+
+ /*!
+ * \brief Construction with an unknown shape expression.
+ * \param dtype The data type of tensor's elements.
+ * \param ndim The number of dimensions
+ * \param span The span of the AST.
+ */
+ TVM_DLL TensorStructInfo(DataType dtype, int ndim, Span span = Span());
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo,
TensorStructInfoNode);
+};
+
+/*!
+ * \brief StructInfo of Tuple.
+ */
+class TupleStructInfoNode : public StructInfoNode {
+ public:
+ /*! \brief The struct info of tuple fields. */
+ Array<StructInfo> fields;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("fields", &fields);
+ v->Visit("span", &span);
+ }
+
+ bool SEqualReduce(const TupleStructInfoNode* other, SEqualReducer equal)
const {
+ return equal(fields, other->fields);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); }
+
+ static constexpr const char* _type_key = "relax.TupleStructInfo";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TupleStructInfoNode, StructInfoNode);
+};
+
+/*!
+ * \brief Managed reference to TupleStructInfoNode.
+ * \sa TupleStructInfoNode
+ */
+class TupleStructInfo : public StructInfo {
+ public:
+ /*!
+ * \brief Constructor
+ * \param fields Struct info of tuple fields.
+ * \param span The span of the AST.
+ */
+ TVM_DLL TupleStructInfo(Array<StructInfo> fields, Span span = Span());
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo,
TupleStructInfoNode);
+};
+
+class BlockBuilder;
+
+/*!
+ * \brief custom-defined StructInfo derivation function.
+ * \param call The call expression to be derived.
+ * \param ctx The builder context.
+ * \return The derived struct info of the call.
+ */
+using StructInfoDeriveFunc = TypedEnvFunc<StructInfo(const Call& call, const
BlockBuilder& ctx)>;
+
+/*!
+ * \brief Structure information about function.
+ *
+ * This data structure contains enough information for us to
+ * do best-effort structure information deduction.
+ */
+class FuncStructInfoNode : public StructInfoNode {
+ public:
+ /*!
+ * \brief The parameter struct info of the function.
+ * \note When params is NullOpt means the function can take arbitrary number
of arguments.
+ * We define such functions as Opaque function.
+ */
+ Optional<Array<StructInfo>> params;
+ /*!
+ * \brief The struct info of the function's return value.
+ */
+ StructInfo ret;
+ /*!
+ * \brief Derivation function of opaque functions that may take any number
of parameters.
+ * \note When derive_func is not empty, then params should be NullOpt,
+ * ret should be ObjectStructInfo()
+ */
+ Optional<StructInfoDeriveFunc> derive_func;
+
+ /*!
+ * \return Whether the func struct info is opaque.
+ * \note We define a function as opaque we have no constraints on params.
+ */
+ bool IsOpaque() const { return !params.defined(); }
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("params", ¶ms);
+ v->Visit("ret", &ret);
+ v->Visit("derive_func", &derive_func);
+ v->Visit("span", &span);
+ }
+
+ bool SEqualReduce(const FuncStructInfoNode* other, SEqualReducer equal)
const {
+ return equal.DefEqual(params, other->params) && equal(ret, other->ret) &&
+ equal(derive_func, other->derive_func);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce.DefHash(params);
+ hash_reduce(ret);
+ hash_reduce(derive_func);
+ }
+
+ static constexpr const char* _type_key = "relax.FuncStructInfo";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FuncStructInfoNode, StructInfoNode);
+};
+
+/*!
+ * \brief Managed reference to FuncStructInfoNode.
+ * \sa FuncStructInfoNode
+ */
+class FuncStructInfo : public StructInfo {
+ public:
+ /*!
+ * \brief Constructor from parameter struct info and return value struct
info.
+ * \param params The struct info of function parameters.
+ * \param ret The return value struct info.
+ * \param span The span of the AST.
+ *
+ * \note If the ret contains variables(tir::Var and relax::Var), they must
be deducible from
+ * params. If you are unsure, you can always erase ret to static.
+ */
+ TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, Span span =
Span());
+
+ /*!
+ * \brief Constructing an opaque function struct info using derive_func.
+ *
+ * \param derive_func Derivation function.
+ * \param span The span of the AST.
+ *
+ * \return The FuncStructInfo for opaque packedfunc.
+ * \note Defaults to an derive func that always return ObjectStructInfo if
not specified.
+ */
+ TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func,
Span span = Span());
+
+ /*!
+ * \brief Construct an opaque function using from return struct info.
+ *
+ * \param ret The struct info of the return value.
+ * \param span The span of the AST.
+ *
+ * \return The FuncStructInfo for opaque packedfunc.
+ * \note Defaults to an derive func that always return ObjectStructInfo if
not specified.
+ */
+ TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret =
ObjectStructInfo(), Span span = Span());
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo,
FuncStructInfoNode);
+};
+
+/*!
+ * \brief Match and check if expr have StructInfo T and return it.
+ *
+ * \param expr The input expression.
+ * \return The result of match.
+ * \tparam T the underlying structure info type
+ */
+template <typename T>
+inline Optional<T> MatchStructInfo(const Expr& expr) {
+ using TNode = typename T::ContainerType;
+ if (const TNode* ptr = expr->struct_info_.as<TNode>()) {
+ return GetRef<T>(ptr);
+ } else {
+ return NullOpt;
+ }
+}
+
+/*!
+ * \brief Get the structure info of a given expr and try to cast it as const
T*.
+ *
+ * \param expr The input expression.
+ * \return The pointer. Returns nullptr if the type does not match
+ * \tparam T the underlying structure info type
+ */
+template <typename T>
+inline const T* GetStructInfoAs(const Expr& expr) {
+ ICHECK(expr->struct_info_.defined())
+ << "The struct_info is not populated, check if you have normalized the
expr";
+ return expr->struct_info_.as<T>();
+}
+
+/*!
+ * \brief Get the underlying structure info of expr.
+ *
+ * \param expr The input expression.
+ * \return underlying struct info.
+ */
+inline StructInfo GetStructInfo(const Expr& expr) {
+ auto* ptr = expr->struct_info_.as<StructInfoNode>();
+ ICHECK(ptr) << "The struct_info is not populated, check if you have
normalized the expr";
+ return GetRef<StructInfo>(ptr);
+}
+
+/*!
+ * \brief Whether the expr has void struct info.
+ *
+ * \param expr The input expression.
+ * \return Whether the expr has void struct info.
+ */
+inline bool HasVoidStructInfo(const Expr& expr) {
+ auto* ptr = expr->struct_info_.as<TupleStructInfoNode>();
+ return ptr != nullptr && ptr->fields.size() == 0;
+}
+
+/*!
+ * \brief Update the struct info of an Expr.
+ * \param expr The Expr whose struct info to be updated.
+ * \param struct_info The struct_info assigned.
+ * \note We ensure idempotence, that is we can only update the struct_info of
an Expr only
+ * if the original one is nullptr.
+ */
+TVM_DLL void UpdateStructInfo(Expr expr, StructInfo struct_info);
+
+} // namespace relax
+} // namespace tvm
+#endif // TVM_RELAX_STRUCT_INFO_H_
diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs
index abc25e89c4..08ce082c45 100644
--- a/rust/tvm/src/ir/relay/mod.rs
+++ b/rust/tvm/src/ir/relay/mod.rs
@@ -40,6 +40,7 @@ pub mod attrs;
pub struct ExprNode {
pub base: BaseExprNode,
pub checked_type: Type,
+ pub struct_info: ObjectRef,
pub virtual_device: ObjectRef,
}
@@ -48,6 +49,7 @@ impl ExprNode {
ExprNode {
base: BaseExprNode::base::<T>(span.clone()),
checked_type: Type::null(),
+ struct_info: ObjectRef::null(),
virtual_device: ObjectRef::null(),
}
}
diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc
new file mode 100644
index 0000000000..88046ed81f
--- /dev/null
+++ b/src/relax/ir/struct_info.cc
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/ir/struct_info.cc
+ * \brief Relax struct info.
+ */
+#include <tvm/relax/struct_info.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace relax {
+
+ObjectStructInfo::ObjectStructInfo(Span span) {
+ ObjectPtr<ObjectStructInfoNode> n = make_object<ObjectStructInfoNode>();
+ n->span = span;
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(ObjectStructInfoNode);
+
+TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) {
+ return ObjectStructInfo(span);
+});
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<ObjectStructInfoNode>([](const ObjectRef& ref, ReprPrinter*
p) {
+ p->stream << "ObjectStructInfo()";
+ });
+
+// Prim
+PrimStructInfo::PrimStructInfo(DataType dtype, Span span) {
+ ObjectPtr<PrimStructInfoNode> n = make_object<PrimStructInfoNode>();
+ n->dtype = dtype;
+ n->span = span;
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(PrimStructInfoNode);
+
+TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype,
Span span) {
+ return PrimStructInfo(dtype, span);
+});
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<PrimStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p)
{
+ const auto* node = static_cast<const PrimStructInfoNode*>(ref.get());
+ p->stream << "PrimStructInfo(" << node->dtype << ")";
+ });
+
+// Shape
+ShapeStructInfo::ShapeStructInfo(Array<PrimExpr> values, Span span) {
+ ObjectPtr<ShapeStructInfoNode> n = make_object<ShapeStructInfoNode>();
+ n->ndim = static_cast<int>(values.size());
+ n->values = values.Map([](PrimExpr value) {
+ if (value->IsInstance<IntImmNode>()) {
+ return tvm::cast(DataType::Int(64), value);
+ }
+ ICHECK(value.dtype() == DataType::Int(64))
+ << "the value in ShapeStructInfo can only have dtype of int64";
+ return value;
+ });
+ n->span = span;
+ data_ = std::move(n);
+}
+
+ShapeStructInfo::ShapeStructInfo(int ndim, Span span) {
+ ObjectPtr<ShapeStructInfoNode> n = make_object<ShapeStructInfoNode>();
+ CHECK_GE(ndim, -1) << "ndim of ShapeStructInfo must be >= -1, but got " <<
ndim;
+ n->ndim = ndim;
+ n->span = span;
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(ShapeStructInfoNode);
+
+TVM_REGISTER_GLOBAL("relax.ShapeStructInfo")
+ .set_body_typed([](Optional<Array<PrimExpr>> values, int ndim, Span span) {
+ if (values.defined()) {
+ CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify
values and ndim";
+ return ShapeStructInfo(values.value(), span);
+ } else {
+ return ShapeStructInfo(ndim, span);
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<ShapeStructInfoNode>([](const ObjectRef& ref, ReprPrinter*
p) {
+ const auto* node = static_cast<const ShapeStructInfoNode*>(ref.get());
+ if (node->values.defined()) {
+ p->stream << "ShapeStructInfo(" << node->values.value() << ")";
+ } else {
+ p->stream << "ShapeStructInfo(ndim=" << node->ndim << ")";
+ }
+ });
+
+// Tensor
+TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) {
+ ObjectPtr<TensorStructInfoNode> n = make_object<TensorStructInfoNode>();
+ // assign ndim before move
+ Optional<ShapeStructInfo> sinfo = MatchStructInfo<ShapeStructInfo>(shape);
+ ICHECK(sinfo) << "We expect shape to contain pre-set shape struct info";
+ ICHECK(shape.defined()) << "Must provide a shape in this constructor";
+ ICHECK(shape->IsInstance<ShapeExprNode>() || shape->IsInstance<VarNode>())
+ << "We require shape to be normalized when constructing
TensorStructInfo";
+ n->ndim = sinfo.get()->ndim;
+ // assign rest of the fields.
+ n->shape = std::move(shape);
+ n->dtype = dtype;
+ n->span = span;
+ data_ = std::move(n);
+}
+
+TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Span span) {
+ ObjectPtr<TensorStructInfoNode> n = make_object<TensorStructInfoNode>();
+ CHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " <<
ndim;
+ n->ndim = ndim;
+ n->dtype = dtype;
+ n->span = span;
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(TensorStructInfoNode);
+
+TVM_REGISTER_GLOBAL("relax.TensorStructInfo")
+ .set_body_typed([](Optional<Expr> shape, DataType dtype, int ndim, Span
span) {
+ if (shape.defined()) {
+ CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape
and ndim";
+ return TensorStructInfo(shape.value(), dtype, span);
+ } else {
+ return TensorStructInfo(dtype, ndim, span);
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<TensorStructInfoNode>([](const ObjectRef& ref, ReprPrinter*
p) {
+ const auto* node = static_cast<const TensorStructInfoNode*>(ref.get());
+ if (node->shape.defined()) {
+ p->stream << "TensorStructInfo(" << node->shape.value() << ", " <<
node->dtype << ")";
+ } else {
+ p->stream << "TensorStructInfo(" << node->dtype << ", ndim=" <<
node->ndim << ")";
+ }
+ });
+
+// Tuple
+TupleStructInfo::TupleStructInfo(Array<StructInfo> fields, Span span) {
+ ObjectPtr<TupleStructInfoNode> n = make_object<TupleStructInfoNode>();
+ n->fields = std::move(fields);
+ n->span = span;
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(TupleStructInfoNode);
+
+TVM_REGISTER_GLOBAL("relax.TupleStructInfo")
+ .set_body_typed([](Array<StructInfo> fields, Span span) {
+ return TupleStructInfo(fields, span);
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<TupleStructInfoNode>([](const ObjectRef& ref, ReprPrinter*
p) {
+ const auto* node = static_cast<const TupleStructInfoNode*>(ref.get());
+ p->stream << "TupleStructInfo(" << node->fields << ")";
+ });
+
+// Func
+FuncStructInfo::FuncStructInfo(Array<StructInfo> params, StructInfo ret, Span
span) {
+ ObjectPtr<FuncStructInfoNode> n = make_object<FuncStructInfoNode>();
+ n->params = std::move(params);
+ n->ret = std::move(ret);
+ n->span = span;
+ data_ = std::move(n);
+}
+
+FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func,
Span span) {
+ ObjectPtr<FuncStructInfoNode> n = make_object<FuncStructInfoNode>();
+ n->derive_func = std::move(derive_func);
+ n->ret = ObjectStructInfo();
+ n->span = span;
+ return FuncStructInfo(n);
+}
+
+FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, Span span) {
+ ObjectPtr<FuncStructInfoNode> n = make_object<FuncStructInfoNode>();
+ n->ret = std::move(ret);
+ n->span = span;
+ return FuncStructInfo(n);
+}
+
+TVM_REGISTER_NODE_TYPE(FuncStructInfoNode);
+
+TVM_REGISTER_GLOBAL("relax.FuncStructInfo")
+ .set_body_typed([](Array<StructInfo> params, StructInfo ret, Span span) {
+ return FuncStructInfo(params, ret, span);
+ });
+
+TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc")
+ .set_body_typed([](Optional<StructInfo> ret,
Optional<StructInfoDeriveFunc> derive_func,
+ Span span) {
+ if (derive_func.defined()) {
+ ICHECK(!ret.defined()) << "ValueError: Cannot specify both ret and
derive_func";
+ return FuncStructInfo::OpaqueFunc(derive_func.value(), span);
+ } else {
+ return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()),
span);
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<FuncStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p)
{
+ const auto* node = static_cast<const FuncStructInfoNode*>(ref.get());
+ p->stream << "FuncStructInfo(" << node->params << ", " << node->ret <<
")";
+ });
+
+// Helper functions
+// TODO(unity-team): add UpdateStructInfo once analysis.cc is upstreamed
+
+TVM_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) {
+ return GetStructInfo(expr);
+});
+
+} // namespace relax
+} // namespace tvm