This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 83e1f07d82 [Unity][IR] First-class StructInfo (#13907)
83e1f07d82 is described below

commit 83e1f07d82acd721df515e27ef12693074d79997
Author: Yuchen Jin <yuch...@cs.washington.edu>
AuthorDate: Thu Feb 2 23:32:15 2023 -0500

    [Unity][IR] First-class StructInfo (#13907)
    
    * [Unity][IR] First-class StructInfo
    
    Relax tracks structural information (such as tensor shape) via `StructInfo` 
about the values in Relax.
    
    * Fix rust build
    
    ---------
    
    Co-authored-by: Junru Shao <junrushao1...@gmail.com>
---
 CMakeLists.txt                  |   1 +
 include/tvm/relax/struct_info.h | 430 ++++++++++++++++++++++++++++++++++++++++
 rust/tvm/src/ir/relay/mod.rs    |   2 +
 src/relax/ir/struct_info.cc     | 238 ++++++++++++++++++++++
 4 files changed, 671 insertions(+)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index e55b7174dc..88bf6472ce 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -289,6 +289,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
     src/driver/*.cc
     src/support/*.cc
     src/script/*.cc
+    src/relax/ir/*.cc
     src/relax/backend/vm/*.cc
     )
 
diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h
new file mode 100644
index 0000000000..d21c8db86b
--- /dev/null
+++ b/include/tvm/relax/struct_info.h
@@ -0,0 +1,430 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_RELAX_STRUCT_INFO_H_
+#define TVM_RELAX_STRUCT_INFO_H_
+
+#include <tvm/ir/env_func.h>
+#include <tvm/ir/source_map.h>
+#include <tvm/node/node.h>
+// #include <tvm/relax/block_builder.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/type.h>
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Opaque object.
+ */
+class ObjectStructInfoNode : public StructInfoNode {
+ public:
+  void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
+
+  bool SEqualReduce(const ObjectStructInfoNode* other, SEqualReducer equal) 
const { return true; }
+
+  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }
+
+  static constexpr const char* _type_key = "relax.ObjectStructInfo";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ObjectStructInfoNode, StructInfoNode);
+};
+
+/*!
+ * \brief Managed reference to ObjectStructInfoNode.
+ * \sa ObjectStructInfoNode
+ */
+class ObjectStructInfo : public StructInfo {
+ public:
+  TVM_DLL ObjectStructInfo(Span span = Span());
+
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectStructInfo, StructInfo, 
ObjectStructInfoNode);
+};
+
+/*!
+ * \brief Primitive value.
+ */
+class PrimStructInfoNode : public StructInfoNode {
+ public:
+  /*! \brief Underlying data type of the primitive value */
+  DataType dtype;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("dtype", &dtype);
+    v->Visit("span", &span);
+  }
+
+  bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) 
const {
+    return equal(dtype, other->dtype);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); }
+
+  static constexpr const char* _type_key = "relax.PrimStructInfo";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode);
+};
+
+/*!
+ * \brief Managed reference to PrimStructInfoNode.
+ * \sa PrimStructInfoNode
+ */
+class PrimStructInfo : public StructInfo {
+ public:
+  TVM_DLL PrimStructInfo(DataType dtype, Span span = Span());
+
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, 
PrimStructInfoNode);
+};
+
+/*!
+ * \brief StructInfo of shape value.
+ */
+class ShapeStructInfoNode : public StructInfoNode {
+ public:
+  /*! \brief optionally stores the symbolic value patterns of the shape */
+  Optional<Array<PrimExpr>> values;
+  /*!
+   * \brief The number of dimension of the shape, can be unknown.
+   * \sa kUnknownNDim
+   */
+  int ndim;
+
+  /*! \return Whether the struct info contains unknown ndim. */
+  bool IsUnknownNdim() const { return ndim == kUnknownNDim; }
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("values", &values);
+    v->Visit("ndim", &ndim);
+    v->Visit("span", &span);
+  }
+
+  bool SEqualReduce(const ShapeStructInfoNode* other, SEqualReducer equal) 
const {
+    return equal(values, other->values) && equal(ndim, other->ndim);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(values);
+    hash_reduce(ndim);
+  }
+
+  static constexpr const char* _type_key = "relax.ShapeStructInfo";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ShapeStructInfoNode, StructInfoNode);
+};
+
+/*!
+ * \brief Managed reference to ShapeStructInfoNode.
+ * \sa ShapeStructInfoNode
+ */
+class ShapeStructInfo : public StructInfo {
+ public:
+  /*!
+   * \brief Construction with known symbolic shape patterns
+   * \param values The symbolic shape values
+   * \param span The span of the AST.
+   */
+  TVM_DLL ShapeStructInfo(Array<PrimExpr> values, Span span = Span());
+  /*!
+   * \brief Construction with known unknown symbolic shape patterns.
+   * \param ndim Number of dimensions -- can be kUnknownNDim
+   * \param span The span of the AST.
+   */
+  TVM_DLL ShapeStructInfo(int ndim, Span span = Span());
+
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeStructInfo, StructInfo, 
ShapeStructInfoNode);
+};
+
+/*!
+ * \brief StructInfo of Tensor.
+ */
+class TensorStructInfoNode : public StructInfoNode {
+ public:
+  /*!
+   * \brief optionally store the shape expression of the tensor.
+   * \note shape must be normalized: it can only be NullOpt or ShapeExpr or 
Var.
+   */
+  Optional<Expr> shape;
+  /*! \brief The content data type, use void to denote the dtype is unknown. */
+  DataType dtype;
+  /*!
+   * \brief The number of dimension of the tensor, can be unknown.
+   * \sa kUnknownNDim
+   */
+  int ndim;
+
+  /*! \return Whether the struct info contains unknown ndim. */
+  bool IsUnknownNdim() const { return ndim == kUnknownNDim; }
+
+  /*! \return Whether the struct info contains unknown dtype. */
+  bool IsUnknownDtype() const { return dtype.is_void(); }
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("shape", &shape);
+    v->Visit("dtype", &dtype);
+    v->Visit("ndim", &ndim);
+    v->Visit("span", &span);
+  }
+
+  bool SEqualReduce(const TensorStructInfoNode* other, SEqualReducer equal) 
const {
+    return equal(shape, other->shape) && equal(ndim, other->ndim) && 
equal(dtype, other->dtype);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(shape);
+    hash_reduce(dtype);
+    hash_reduce(ndim);
+  }
+
+  static constexpr const char* _type_key = "relax.TensorStructInfo";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorStructInfoNode, StructInfoNode);
+};
+
+/*!
+ * \brief Managed reference to TensorStructInfoNode.
+ * \sa TensorStructInfoNode
+ */
+class TensorStructInfo : public StructInfo {
+ public:
+  /*!
+   * \brief Construction with a known shape expression.
+   * \param shape The shape of the tensor.
+   * \param dtype The data type of tensor's elements.
+   * \param span The span of the AST.
+   *
+   * \note shape must already be normalized.
+   */
+  TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Span span = Span());
+
+  /*!
+   * \brief Construction with an unknown shape expression.
+   * \param dtype The data type of tensor's elements.
+   * \param ndim The number of dimensions
+   * \param span The span of the AST.
+   */
+  TVM_DLL TensorStructInfo(DataType dtype, int ndim, Span span = Span());
+
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, 
TensorStructInfoNode);
+};
+
+/*!
+ * \brief StructInfo of Tuple.
+ */
+class TupleStructInfoNode : public StructInfoNode {
+ public:
+  /*! \brief The struct info of tuple fields. */
+  Array<StructInfo> fields;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("fields", &fields);
+    v->Visit("span", &span);
+  }
+
+  bool SEqualReduce(const TupleStructInfoNode* other, SEqualReducer equal) 
const {
+    return equal(fields, other->fields);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); }
+
+  static constexpr const char* _type_key = "relax.TupleStructInfo";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TupleStructInfoNode, StructInfoNode);
+};
+
+/*!
+ * \brief Managed reference to TupleStructInfoNode.
+ * \sa TupleStructInfoNode
+ */
+class TupleStructInfo : public StructInfo {
+ public:
+  /*!
+   * \brief Constructor
+   * \param fields Struct info of tuple fields.
+   * \param span The span of the AST.
+   */
+  TVM_DLL TupleStructInfo(Array<StructInfo> fields, Span span = Span());
+
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo, 
TupleStructInfoNode);
+};
+
+class BlockBuilder;
+
+/*!
+ * \brief custom-defined StructInfo derivation function.
+ * \param call The call expression to be derived.
+ * \param ctx The builder context.
+ * \return The derived struct info of the call.
+ */
+using StructInfoDeriveFunc = TypedEnvFunc<StructInfo(const Call& call, const 
BlockBuilder& ctx)>;
+
+/*!
+ * \brief Structure information about function.
+ *
+ * This data structure contains enough information for us to
+ * do best-effort structure information deduction.
+ */
+class FuncStructInfoNode : public StructInfoNode {
+ public:
+  /*!
+   * \brief The parameter struct info of the function.
+   * \note When params is NullOpt means the function can take arbitrary number 
of arguments.
+   *       We define such functions as Opaque function.
+   */
+  Optional<Array<StructInfo>> params;
+  /*!
+   * \brief The struct info of the function's return value.
+   */
+  StructInfo ret;
+  /*!
+   * \brief Derivation function of opaque functions that may take any number 
of parameters.
+   * \note When derive_func is not empty, then params should be NullOpt,
+   *       ret should be ObjectStructInfo()
+   */
+  Optional<StructInfoDeriveFunc> derive_func;
+
+  /*!
+   * \return Whether the func struct info is opaque.
+   * \note We define a function as opaque we have no constraints on params.
+   */
+  bool IsOpaque() const { return !params.defined(); }
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("params", &params);
+    v->Visit("ret", &ret);
+    v->Visit("derive_func", &derive_func);
+    v->Visit("span", &span);
+  }
+
+  bool SEqualReduce(const FuncStructInfoNode* other, SEqualReducer equal) 
const {
+    return equal.DefEqual(params, other->params) && equal(ret, other->ret) &&
+           equal(derive_func, other->derive_func);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce.DefHash(params);
+    hash_reduce(ret);
+    hash_reduce(derive_func);
+  }
+
+  static constexpr const char* _type_key = "relax.FuncStructInfo";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FuncStructInfoNode, StructInfoNode);
+};
+
+/*!
+ * \brief Managed reference to FuncStructInfoNode.
+ * \sa FuncStructInfoNode
+ */
+class FuncStructInfo : public StructInfo {
+ public:
+  /*!
+   * \brief Constructor from parameter struct info and return value struct 
info.
+   * \param params The struct info of function parameters.
+   * \param ret The return value struct info.
+   * \param span The span of the AST.
+   *
+   * \note If the ret contains variables(tir::Var and relax::Var), they must 
be deducible from
+   * params. If you are unsure, you can always erase ret to static.
+   */
+  TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, Span span = 
Span());
+
+  /*!
+   * \brief Constructing an opaque function struct info using derive_func.
+   *
+   * \param derive_func Derivation function.
+   * \param span The span of the AST.
+   *
+   * \return The FuncStructInfo for opaque packedfunc.
+   * \note Defaults to an derive func that always return ObjectStructInfo if 
not specified.
+   */
+  TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, 
Span span = Span());
+
+  /*!
+   * \brief Construct an opaque function using from return struct info.
+   *
+   * \param ret The struct info of the return value.
+   * \param span The span of the AST.
+   *
+   * \return The FuncStructInfo for opaque packedfunc.
+   * \note Defaults to an derive func that always return ObjectStructInfo if 
not specified.
+   */
+  TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = 
ObjectStructInfo(), Span span = Span());
+
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, 
FuncStructInfoNode);
+};
+
+/*!
+ * \brief Match and check if expr have StructInfo T and return it.
+ *
+ * \param expr The input expression.
+ * \return The result of match.
+ * \tparam T the underlying structure info type
+ */
+template <typename T>
+inline Optional<T> MatchStructInfo(const Expr& expr) {
+  using TNode = typename T::ContainerType;
+  if (const TNode* ptr = expr->struct_info_.as<TNode>()) {
+    return GetRef<T>(ptr);
+  } else {
+    return NullOpt;
+  }
+}
+
+/*!
+ * \brief Get the structure info of a given expr and try to cast it as const 
T*.
+ *
+ * \param expr The input expression.
+ * \return The pointer. Returns nullptr if the type does not match
+ * \tparam T the underlying structure info type
+ */
+template <typename T>
+inline const T* GetStructInfoAs(const Expr& expr) {
+  ICHECK(expr->struct_info_.defined())
+      << "The struct_info is not populated, check if you have normalized the 
expr";
+  return expr->struct_info_.as<T>();
+}
+
+/*!
+ * \brief Get the underlying structure info of expr.
+ *
+ * \param expr The input expression.
+ * \return underlying struct info.
+ */
+inline StructInfo GetStructInfo(const Expr& expr) {
+  auto* ptr = expr->struct_info_.as<StructInfoNode>();
+  ICHECK(ptr) << "The struct_info is not populated, check if you have 
normalized the expr";
+  return GetRef<StructInfo>(ptr);
+}
+
+/*!
+ * \brief Whether the expr has void struct info.
+ *
+ * \param expr The input expression.
+ * \return Whether the expr has void struct info.
+ */
+inline bool HasVoidStructInfo(const Expr& expr) {
+  auto* ptr = expr->struct_info_.as<TupleStructInfoNode>();
+  return ptr != nullptr && ptr->fields.size() == 0;
+}
+
+/*!
+ * \brief Update the struct info of an Expr.
+ * \param expr The Expr whose struct info to be updated.
+ * \param struct_info The struct_info assigned.
+ * \note We ensure idempotence, that is we can only update the struct_info of 
an Expr only
+ *  if the original one is nullptr.
+ */
+TVM_DLL void UpdateStructInfo(Expr expr, StructInfo struct_info);
+
+}  // namespace relax
+}  // namespace tvm
+#endif  // TVM_RELAX_STRUCT_INFO_H_
diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs
index abc25e89c4..08ce082c45 100644
--- a/rust/tvm/src/ir/relay/mod.rs
+++ b/rust/tvm/src/ir/relay/mod.rs
@@ -40,6 +40,7 @@ pub mod attrs;
 pub struct ExprNode {
     pub base: BaseExprNode,
     pub checked_type: Type,
+    pub struct_info: ObjectRef,
     pub virtual_device: ObjectRef,
 }
 
@@ -48,6 +49,7 @@ impl ExprNode {
         ExprNode {
             base: BaseExprNode::base::<T>(span.clone()),
             checked_type: Type::null(),
+            struct_info: ObjectRef::null(),
             virtual_device: ObjectRef::null(),
         }
     }
diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc
new file mode 100644
index 0000000000..88046ed81f
--- /dev/null
+++ b/src/relax/ir/struct_info.cc
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/ir/struct_info.cc
+ * \brief Relax struct info.
+ */
+#include <tvm/relax/struct_info.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace relax {
+
+ObjectStructInfo::ObjectStructInfo(Span span) {
+  ObjectPtr<ObjectStructInfoNode> n = make_object<ObjectStructInfoNode>();
+  n->span = span;
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(ObjectStructInfoNode);
+
+TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) {
+  return ObjectStructInfo(span);
+});
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<ObjectStructInfoNode>([](const ObjectRef& ref, ReprPrinter* 
p) {
+      p->stream << "ObjectStructInfo()";
+    });
+
+// Prim
+PrimStructInfo::PrimStructInfo(DataType dtype, Span span) {
+  ObjectPtr<PrimStructInfoNode> n = make_object<PrimStructInfoNode>();
+  n->dtype = dtype;
+  n->span = span;
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(PrimStructInfoNode);
+
+TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, 
Span span) {
+  return PrimStructInfo(dtype, span);
+});
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<PrimStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) 
{
+      const auto* node = static_cast<const PrimStructInfoNode*>(ref.get());
+      p->stream << "PrimStructInfo(" << node->dtype << ")";
+    });
+
+// Shape
+ShapeStructInfo::ShapeStructInfo(Array<PrimExpr> values, Span span) {
+  ObjectPtr<ShapeStructInfoNode> n = make_object<ShapeStructInfoNode>();
+  n->ndim = static_cast<int>(values.size());
+  n->values = values.Map([](PrimExpr value) {
+    if (value->IsInstance<IntImmNode>()) {
+      return tvm::cast(DataType::Int(64), value);
+    }
+    ICHECK(value.dtype() == DataType::Int(64))
+        << "the value in ShapeStructInfo can only have dtype of int64";
+    return value;
+  });
+  n->span = span;
+  data_ = std::move(n);
+}
+
+ShapeStructInfo::ShapeStructInfo(int ndim, Span span) {
+  ObjectPtr<ShapeStructInfoNode> n = make_object<ShapeStructInfoNode>();
+  CHECK_GE(ndim, -1) << "ndim of ShapeStructInfo must be >= -1, but got " << 
ndim;
+  n->ndim = ndim;
+  n->span = span;
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(ShapeStructInfoNode);
+
+TVM_REGISTER_GLOBAL("relax.ShapeStructInfo")
+    .set_body_typed([](Optional<Array<PrimExpr>> values, int ndim, Span span) {
+      if (values.defined()) {
+        CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify 
values and ndim";
+        return ShapeStructInfo(values.value(), span);
+      } else {
+        return ShapeStructInfo(ndim, span);
+      }
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<ShapeStructInfoNode>([](const ObjectRef& ref, ReprPrinter* 
p) {
+      const auto* node = static_cast<const ShapeStructInfoNode*>(ref.get());
+      if (node->values.defined()) {
+        p->stream << "ShapeStructInfo(" << node->values.value() << ")";
+      } else {
+        p->stream << "ShapeStructInfo(ndim=" << node->ndim << ")";
+      }
+    });
+
+// Tensor
+TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) {
+  ObjectPtr<TensorStructInfoNode> n = make_object<TensorStructInfoNode>();
+  // assign ndim before move
+  Optional<ShapeStructInfo> sinfo = MatchStructInfo<ShapeStructInfo>(shape);
+  ICHECK(sinfo) << "We expect shape to contain pre-set shape struct info";
+  ICHECK(shape.defined()) << "Must provide a shape in this constructor";
+  ICHECK(shape->IsInstance<ShapeExprNode>() || shape->IsInstance<VarNode>())
+      << "We require shape to be normalized when constructing 
TensorStructInfo";
+  n->ndim = sinfo.get()->ndim;
+  // assign rest of the fields.
+  n->shape = std::move(shape);
+  n->dtype = dtype;
+  n->span = span;
+  data_ = std::move(n);
+}
+
+TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Span span) {
+  ObjectPtr<TensorStructInfoNode> n = make_object<TensorStructInfoNode>();
+  CHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " << 
ndim;
+  n->ndim = ndim;
+  n->dtype = dtype;
+  n->span = span;
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(TensorStructInfoNode);
+
+TVM_REGISTER_GLOBAL("relax.TensorStructInfo")
+    .set_body_typed([](Optional<Expr> shape, DataType dtype, int ndim, Span 
span) {
+      if (shape.defined()) {
+        CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape 
and ndim";
+        return TensorStructInfo(shape.value(), dtype, span);
+      } else {
+        return TensorStructInfo(dtype, ndim, span);
+      }
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<TensorStructInfoNode>([](const ObjectRef& ref, ReprPrinter* 
p) {
+      const auto* node = static_cast<const TensorStructInfoNode*>(ref.get());
+      if (node->shape.defined()) {
+        p->stream << "TensorStructInfo(" << node->shape.value() << ", " << 
node->dtype << ")";
+      } else {
+        p->stream << "TensorStructInfo(" << node->dtype << ", ndim=" << 
node->ndim << ")";
+      }
+    });
+
+// Tuple
+TupleStructInfo::TupleStructInfo(Array<StructInfo> fields, Span span) {
+  ObjectPtr<TupleStructInfoNode> n = make_object<TupleStructInfoNode>();
+  n->fields = std::move(fields);
+  n->span = span;
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(TupleStructInfoNode);
+
+TVM_REGISTER_GLOBAL("relax.TupleStructInfo")
+    .set_body_typed([](Array<StructInfo> fields, Span span) {
+      return TupleStructInfo(fields, span);
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<TupleStructInfoNode>([](const ObjectRef& ref, ReprPrinter* 
p) {
+      const auto* node = static_cast<const TupleStructInfoNode*>(ref.get());
+      p->stream << "TupleStructInfo(" << node->fields << ")";
+    });
+
+// Func
+FuncStructInfo::FuncStructInfo(Array<StructInfo> params, StructInfo ret, Span 
span) {
+  ObjectPtr<FuncStructInfoNode> n = make_object<FuncStructInfoNode>();
+  n->params = std::move(params);
+  n->ret = std::move(ret);
+  n->span = span;
+  data_ = std::move(n);
+}
+
+FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func, 
Span span) {
+  ObjectPtr<FuncStructInfoNode> n = make_object<FuncStructInfoNode>();
+  n->derive_func = std::move(derive_func);
+  n->ret = ObjectStructInfo();
+  n->span = span;
+  return FuncStructInfo(n);
+}
+
+FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, Span span) {
+  ObjectPtr<FuncStructInfoNode> n = make_object<FuncStructInfoNode>();
+  n->ret = std::move(ret);
+  n->span = span;
+  return FuncStructInfo(n);
+}
+
+TVM_REGISTER_NODE_TYPE(FuncStructInfoNode);
+
+TVM_REGISTER_GLOBAL("relax.FuncStructInfo")
+    .set_body_typed([](Array<StructInfo> params, StructInfo ret, Span span) {
+      return FuncStructInfo(params, ret, span);
+    });
+
+TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc")
+    .set_body_typed([](Optional<StructInfo> ret, 
Optional<StructInfoDeriveFunc> derive_func,
+                       Span span) {
+      if (derive_func.defined()) {
+        ICHECK(!ret.defined()) << "ValueError: Cannot specify both ret and 
derive_func";
+        return FuncStructInfo::OpaqueFunc(derive_func.value(), span);
+      } else {
+        return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), 
span);
+      }
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<FuncStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) 
{
+      const auto* node = static_cast<const FuncStructInfoNode*>(ref.get());
+      p->stream << "FuncStructInfo(" << node->params << ", " << node->ret << 
")";
+    });
+
+// Helper functions
+// TODO(unity-team): add UpdateStructInfo once analysis.cc is upstreamed
+
+TVM_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) {
+  return GetStructInfo(expr);
+});
+
+}  // namespace relax
+}  // namespace tvm

Reply via email to