This is an automated email from the ASF dual-hosted git repository. junrushao pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 83e1f07d82 [Unity][IR] First-class StructInfo (#13907) 83e1f07d82 is described below commit 83e1f07d82acd721df515e27ef12693074d79997 Author: Yuchen Jin <yuch...@cs.washington.edu> AuthorDate: Thu Feb 2 23:32:15 2023 -0500 [Unity][IR] First-class StructInfo (#13907) * [Unity][IR] First-class StructInfo Relax tracks structural information (such as tensor shape) via `StructInfo` about the values in Relax. * Fix rust build --------- Co-authored-by: Junru Shao <junrushao1...@gmail.com> --- CMakeLists.txt | 1 + include/tvm/relax/struct_info.h | 430 ++++++++++++++++++++++++++++++++++++++++ rust/tvm/src/ir/relay/mod.rs | 2 + src/relax/ir/struct_info.cc | 238 ++++++++++++++++++++++ 4 files changed, 671 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index e55b7174dc..88bf6472ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -289,6 +289,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/driver/*.cc src/support/*.cc src/script/*.cc + src/relax/ir/*.cc src/relax/backend/vm/*.cc ) diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h new file mode 100644 index 0000000000..d21c8db86b --- /dev/null +++ b/include/tvm/relax/struct_info.h @@ -0,0 +1,430 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RELAX_STRUCT_INFO_H_ +#define TVM_RELAX_STRUCT_INFO_H_ + +#include <tvm/ir/env_func.h> +#include <tvm/ir/source_map.h> +#include <tvm/node/node.h> +// #include <tvm/relax/block_builder.h> +#include <tvm/relax/expr.h> +#include <tvm/relax/type.h> + +namespace tvm { +namespace relax { + +/*! + * \brief Opaque object. + */ +class ObjectStructInfoNode : public StructInfoNode { + public: + void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const ObjectStructInfoNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.ObjectStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(ObjectStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to ObjectStructInfoNode. + * \sa ObjectStructInfoNode + */ +class ObjectStructInfo : public StructInfo { + public: + TVM_DLL ObjectStructInfo(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectStructInfo, StructInfo, ObjectStructInfoNode); +}; + +/*! + * \brief Primitive value. + */ +class PrimStructInfoNode : public StructInfoNode { + public: + /*! \brief Underlying data type of the primitive value */ + DataType dtype; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); } + + static constexpr const char* _type_key = "relax.PrimStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to PrimStructInfoNode. + * \sa PrimStructInfoNode + */ +class PrimStructInfo : public StructInfo { + public: + TVM_DLL PrimStructInfo(DataType dtype, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, PrimStructInfoNode); +}; + +/*! + * \brief StructInfo of shape value. + */ +class ShapeStructInfoNode : public StructInfoNode { + public: + /*! \brief optionally stores the symbolic value patterns of the shape */ + Optional<Array<PrimExpr>> values; + /*! + * \brief The number of dimension of the shape, can be unknown. + * \sa kUnknownNDim + */ + int ndim; + + /*! \return Whether the struct info contains unknown ndim. */ + bool IsUnknownNdim() const { return ndim == kUnknownNDim; } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("values", &values); + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } + + bool SEqualReduce(const ShapeStructInfoNode* other, SEqualReducer equal) const { + return equal(values, other->values) && equal(ndim, other->ndim); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(values); + hash_reduce(ndim); + } + + static constexpr const char* _type_key = "relax.ShapeStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to ShapeStructInfoNode. + * \sa ShapeStructInfoNode + */ +class ShapeStructInfo : public StructInfo { + public: + /*! + * \brief Construction with known symbolic shape patterns + * \param values The symbolic shape values + * \param span The span of the AST. + */ + TVM_DLL ShapeStructInfo(Array<PrimExpr> values, Span span = Span()); + /*! + * \brief Construction with known unknown symbolic shape patterns. + * \param ndim Number of dimensions -- can be kUnknownNDim + * \param span The span of the AST. + */ + TVM_DLL ShapeStructInfo(int ndim, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeStructInfo, StructInfo, ShapeStructInfoNode); +}; + +/*! + * \brief StructInfo of Tensor. + */ +class TensorStructInfoNode : public StructInfoNode { + public: + /*! + * \brief optionally store the shape expression of the tensor. + * \note shape must be normalized: it can only be NullOpt or ShapeExpr or Var. + */ + Optional<Expr> shape; + /*! \brief The content data type, use void to denote the dtype is unknown. */ + DataType dtype; + /*! + * \brief The number of dimension of the tensor, can be unknown. + * \sa kUnknownNDim + */ + int ndim; + + /*! \return Whether the struct info contains unknown ndim. */ + bool IsUnknownNdim() const { return ndim == kUnknownNDim; } + + /*! \return Whether the struct info contains unknown dtype. */ + bool IsUnknownDtype() const { return dtype.is_void(); } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } + + bool SEqualReduce(const TensorStructInfoNode* other, SEqualReducer equal) const { + return equal(shape, other->shape) && equal(ndim, other->ndim) && equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(shape); + hash_reduce(dtype); + hash_reduce(ndim); + } + + static constexpr const char* _type_key = "relax.TensorStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to TensorStructInfoNode. + * \sa TensorStructInfoNode + */ +class TensorStructInfo : public StructInfo { + public: + /*! + * \brief Construction with a known shape expression. + * \param shape The shape of the tensor. + * \param dtype The data type of tensor's elements. + * \param span The span of the AST. + * + * \note shape must already be normalized. + */ + TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Span span = Span()); + + /*! + * \brief Construction with an unknown shape expression. + * \param dtype The data type of tensor's elements. + * \param ndim The number of dimensions + * \param span The span of the AST. + */ + TVM_DLL TensorStructInfo(DataType dtype, int ndim, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode); +}; + +/*! + * \brief StructInfo of Tuple. + */ +class TupleStructInfoNode : public StructInfoNode { + public: + /*! \brief The struct info of tuple fields. */ + Array<StructInfo> fields; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("fields", &fields); + v->Visit("span", &span); + } + + bool SEqualReduce(const TupleStructInfoNode* other, SEqualReducer equal) const { + return equal(fields, other->fields); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } + + static constexpr const char* _type_key = "relax.TupleStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to TupleStructInfoNode. + * \sa TupleStructInfoNode + */ +class TupleStructInfo : public StructInfo { + public: + /*! + * \brief Constructor + * \param fields Struct info of tuple fields. + * \param span The span of the AST. + */ + TVM_DLL TupleStructInfo(Array<StructInfo> fields, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo, TupleStructInfoNode); +}; + +class BlockBuilder; + +/*! + * \brief custom-defined StructInfo derivation function. + * \param call The call expression to be derived. + * \param ctx The builder context. + * \return The derived struct info of the call. + */ +using StructInfoDeriveFunc = TypedEnvFunc<StructInfo(const Call& call, const BlockBuilder& ctx)>; + +/*! + * \brief Structure information about function. + * + * This data structure contains enough information for us to + * do best-effort structure information deduction. + */ +class FuncStructInfoNode : public StructInfoNode { + public: + /*! + * \brief The parameter struct info of the function. + * \note When params is NullOpt means the function can take arbitrary number of arguments. + * We define such functions as Opaque function. + */ + Optional<Array<StructInfo>> params; + /*! + * \brief The struct info of the function's return value. + */ + StructInfo ret; + /*! + * \brief Derivation function of opaque functions that may take any number of parameters. + * \note When derive_func is not empty, then params should be NullOpt, + * ret should be ObjectStructInfo() + */ + Optional<StructInfoDeriveFunc> derive_func; + + /*! + * \return Whether the func struct info is opaque. + * \note We define a function as opaque we have no constraints on params. + */ + bool IsOpaque() const { return !params.defined(); } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("ret", &ret); + v->Visit("derive_func", &derive_func); + v->Visit("span", &span); + } + + bool SEqualReduce(const FuncStructInfoNode* other, SEqualReducer equal) const { + return equal.DefEqual(params, other->params) && equal(ret, other->ret) && + equal(derive_func, other->derive_func); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(params); + hash_reduce(ret); + hash_reduce(derive_func); + } + + static constexpr const char* _type_key = "relax.FuncStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuncStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to FuncStructInfoNode. + * \sa FuncStructInfoNode + */ +class FuncStructInfo : public StructInfo { + public: + /*! + * \brief Constructor from parameter struct info and return value struct info. + * \param params The struct info of function parameters. + * \param ret The return value struct info. + * \param span The span of the AST. + * + * \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from + * params. If you are unsure, you can always erase ret to static. + */ + TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, Span span = Span()); + + /*! + * \brief Constructing an opaque function struct info using derive_func. + * + * \param derive_func Derivation function. + * \param span The span of the AST. + * + * \return The FuncStructInfo for opaque packedfunc. + * \note Defaults to an derive func that always return ObjectStructInfo if not specified. + */ + TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, Span span = Span()); + + /*! + * \brief Construct an opaque function using from return struct info. + * + * \param ret The struct info of the return value. + * \param span The span of the AST. + * + * \return The FuncStructInfo for opaque packedfunc. + * \note Defaults to an derive func that always return ObjectStructInfo if not specified. + */ + TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode); +}; + +/*! + * \brief Match and check if expr have StructInfo T and return it. + * + * \param expr The input expression. + * \return The result of match. + * \tparam T the underlying structure info type + */ +template <typename T> +inline Optional<T> MatchStructInfo(const Expr& expr) { + using TNode = typename T::ContainerType; + if (const TNode* ptr = expr->struct_info_.as<TNode>()) { + return GetRef<T>(ptr); + } else { + return NullOpt; + } +} + +/*! + * \brief Get the structure info of a given expr and try to cast it as const T*. + * + * \param expr The input expression. + * \return The pointer. Returns nullptr if the type does not match + * \tparam T the underlying structure info type + */ +template <typename T> +inline const T* GetStructInfoAs(const Expr& expr) { + ICHECK(expr->struct_info_.defined()) + << "The struct_info is not populated, check if you have normalized the expr"; + return expr->struct_info_.as<T>(); +} + +/*! + * \brief Get the underlying structure info of expr. + * + * \param expr The input expression. + * \return underlying struct info. + */ +inline StructInfo GetStructInfo(const Expr& expr) { + auto* ptr = expr->struct_info_.as<StructInfoNode>(); + ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr"; + return GetRef<StructInfo>(ptr); +} + +/*! + * \brief Whether the expr has void struct info. + * + * \param expr The input expression. + * \return Whether the expr has void struct info. + */ +inline bool HasVoidStructInfo(const Expr& expr) { + auto* ptr = expr->struct_info_.as<TupleStructInfoNode>(); + return ptr != nullptr && ptr->fields.size() == 0; +} + +/*! + * \brief Update the struct info of an Expr. + * \param expr The Expr whose struct info to be updated. + * \param struct_info The struct_info assigned. + * \note We ensure idempotence, that is we can only update the struct_info of an Expr only + * if the original one is nullptr. + */ +TVM_DLL void UpdateStructInfo(Expr expr, StructInfo struct_info); + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_STRUCT_INFO_H_ diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index abc25e89c4..08ce082c45 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -40,6 +40,7 @@ pub mod attrs; pub struct ExprNode { pub base: BaseExprNode, pub checked_type: Type, + pub struct_info: ObjectRef, pub virtual_device: ObjectRef, } @@ -48,6 +49,7 @@ impl ExprNode { ExprNode { base: BaseExprNode::base::<T>(span.clone()), checked_type: Type::null(), + struct_info: ObjectRef::null(), virtual_device: ObjectRef::null(), } } diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc new file mode 100644 index 0000000000..88046ed81f --- /dev/null +++ b/src/relax/ir/struct_info.cc @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/struct_info.cc + * \brief Relax struct info. + */ +#include <tvm/relax/struct_info.h> +#include <tvm/runtime/registry.h> + +namespace tvm { +namespace relax { + +ObjectStructInfo::ObjectStructInfo(Span span) { + ObjectPtr<ObjectStructInfoNode> n = make_object<ObjectStructInfoNode>(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ObjectStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) { + return ObjectStructInfo(span); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<ObjectStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "ObjectStructInfo()"; + }); + +// Prim +PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { + ObjectPtr<PrimStructInfoNode> n = make_object<PrimStructInfoNode>(); + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PrimStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, Span span) { + return PrimStructInfo(dtype, span); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<PrimStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast<const PrimStructInfoNode*>(ref.get()); + p->stream << "PrimStructInfo(" << node->dtype << ")"; + }); + +// Shape +ShapeStructInfo::ShapeStructInfo(Array<PrimExpr> values, Span span) { + ObjectPtr<ShapeStructInfoNode> n = make_object<ShapeStructInfoNode>(); + n->ndim = static_cast<int>(values.size()); + n->values = values.Map([](PrimExpr value) { + if (value->IsInstance<IntImmNode>()) { + return tvm::cast(DataType::Int(64), value); + } + ICHECK(value.dtype() == DataType::Int(64)) + << "the value in ShapeStructInfo can only have dtype of int64"; + return value; + }); + n->span = span; + data_ = std::move(n); +} + +ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { + ObjectPtr<ShapeStructInfoNode> n = make_object<ShapeStructInfoNode>(); + CHECK_GE(ndim, -1) << "ndim of ShapeStructInfo must be >= -1, but got " << ndim; + n->ndim = ndim; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ShapeStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.ShapeStructInfo") + .set_body_typed([](Optional<Array<PrimExpr>> values, int ndim, Span span) { + if (values.defined()) { + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify values and ndim"; + return ShapeStructInfo(values.value(), span); + } else { + return ShapeStructInfo(ndim, span); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<ShapeStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast<const ShapeStructInfoNode*>(ref.get()); + if (node->values.defined()) { + p->stream << "ShapeStructInfo(" << node->values.value() << ")"; + } else { + p->stream << "ShapeStructInfo(ndim=" << node->ndim << ")"; + } + }); + +// Tensor +TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) { + ObjectPtr<TensorStructInfoNode> n = make_object<TensorStructInfoNode>(); + // assign ndim before move + Optional<ShapeStructInfo> sinfo = MatchStructInfo<ShapeStructInfo>(shape); + ICHECK(sinfo) << "We expect shape to contain pre-set shape struct info"; + ICHECK(shape.defined()) << "Must provide a shape in this constructor"; + ICHECK(shape->IsInstance<ShapeExprNode>() || shape->IsInstance<VarNode>()) + << "We require shape to be normalized when constructing TensorStructInfo"; + n->ndim = sinfo.get()->ndim; + // assign rest of the fields. + n->shape = std::move(shape); + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Span span) { + ObjectPtr<TensorStructInfoNode> n = make_object<TensorStructInfoNode>(); + CHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " << ndim; + n->ndim = ndim; + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TensorStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.TensorStructInfo") + .set_body_typed([](Optional<Expr> shape, DataType dtype, int ndim, Span span) { + if (shape.defined()) { + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape and ndim"; + return TensorStructInfo(shape.value(), dtype, span); + } else { + return TensorStructInfo(dtype, ndim, span); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<TensorStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast<const TensorStructInfoNode*>(ref.get()); + if (node->shape.defined()) { + p->stream << "TensorStructInfo(" << node->shape.value() << ", " << node->dtype << ")"; + } else { + p->stream << "TensorStructInfo(" << node->dtype << ", ndim=" << node->ndim << ")"; + } + }); + +// Tuple +TupleStructInfo::TupleStructInfo(Array<StructInfo> fields, Span span) { + ObjectPtr<TupleStructInfoNode> n = make_object<TupleStructInfoNode>(); + n->fields = std::move(fields); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TupleStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.TupleStructInfo") + .set_body_typed([](Array<StructInfo> fields, Span span) { + return TupleStructInfo(fields, span); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<TupleStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast<const TupleStructInfoNode*>(ref.get()); + p->stream << "TupleStructInfo(" << node->fields << ")"; + }); + +// Func +FuncStructInfo::FuncStructInfo(Array<StructInfo> params, StructInfo ret, Span span) { + ObjectPtr<FuncStructInfoNode> n = make_object<FuncStructInfoNode>(); + n->params = std::move(params); + n->ret = std::move(ret); + n->span = span; + data_ = std::move(n); +} + +FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func, Span span) { + ObjectPtr<FuncStructInfoNode> n = make_object<FuncStructInfoNode>(); + n->derive_func = std::move(derive_func); + n->ret = ObjectStructInfo(); + n->span = span; + return FuncStructInfo(n); +} + +FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, Span span) { + ObjectPtr<FuncStructInfoNode> n = make_object<FuncStructInfoNode>(); + n->ret = std::move(ret); + n->span = span; + return FuncStructInfo(n); +} + +TVM_REGISTER_NODE_TYPE(FuncStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.FuncStructInfo") + .set_body_typed([](Array<StructInfo> params, StructInfo ret, Span span) { + return FuncStructInfo(params, ret, span); + }); + +TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc") + .set_body_typed([](Optional<StructInfo> ret, Optional<StructInfoDeriveFunc> derive_func, + Span span) { + if (derive_func.defined()) { + ICHECK(!ret.defined()) << "ValueError: Cannot specify both ret and derive_func"; + return FuncStructInfo::OpaqueFunc(derive_func.value(), span); + } else { + return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), span); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<FuncStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast<const FuncStructInfoNode*>(ref.get()); + p->stream << "FuncStructInfo(" << node->params << ", " << node->ret << ")"; + }); + +// Helper functions +// TODO(unity-team): add UpdateStructInfo once analysis.cc is upstreamed + +TVM_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) { + return GetStructInfo(expr); +}); + +} // namespace relax +} // namespace tvm