This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c6415d1492 Canonicalize type annotation during construction of Var and 
SizeVar (#11443)
c6415d1492 is described below

commit c6415d14928d1e09f4bd3105c7a5ddf87f92166b
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon May 30 04:53:29 2022 -0700

    Canonicalize type annotation during construction of Var and SizeVar (#11443)
    
    * Canonicalize type annotation during construction of Var and SizeVar
    
    * Update tests/cpp/expr_test.cc
    
    Co-authored-by: Junru Shao <[email protected]>
    
    * lint
    
    * fix
    
    Co-authored-by: Junru Shao <[email protected]>
---
 include/tvm/tir/op.h   |  9 +++++++++
 src/tir/ir/expr.cc     |  3 +++
 src/tir/op/op.cc       |  4 ++++
 tests/cpp/expr_test.cc | 11 +++++++++++
 4 files changed, 27 insertions(+)

diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 905c67f1c5..34935aec61 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -60,6 +60,15 @@ namespace tvm {
  */
 TVM_DLL Type GetType(const PrimExpr& expr);
 
+/*!
+ * \brief Get the type corresponding to DataType
+ * \param dtype The data type
+ * \return The result type
+ *
+ * \sa tvm/ir/type.h for discussion about the relation between Type and 
runtime::DataType.
+ */
+TVM_DLL Type GetTypeFromRuntimeDataType(const DataType& dtype);
+
 /*!
  * \brief Get the implied DataType for storing values with type during runtime.
  *
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index f4dbc238c1..7979c9f47a 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -65,6 +65,7 @@ namespace tir {
 Var::Var(String name_hint, DataType dtype, Span span) {
   auto n = make_object<VarNode>();
   n->name_hint = std::move(name_hint);
+  n->type_annotation = GetTypeFromRuntimeDataType(dtype);
   n->dtype = std::move(dtype);
   n->span = std::move(span);
   data_ = std::move(n);
@@ -99,6 +100,7 @@ Var Var::copy_with_dtype(DataType dtype) const {
   } else {
     new_ptr = make_object<VarNode>(*node);
   }
+  new_ptr->type_annotation = GetTypeFromRuntimeDataType(dtype);
   new_ptr->dtype = std::move(dtype);
   return Var(new_ptr);
 }
@@ -126,6 +128,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 SizeVar::SizeVar(String name_hint, DataType dtype, Span span) {
   auto n = make_object<SizeVarNode>();
   n->name_hint = std::move(name_hint);
+  n->type_annotation = GetTypeFromRuntimeDataType(dtype);
   n->dtype = std::move(dtype);
   n->span = std::move(span);
   data_ = std::move(n);
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 696d82be72..73249921bf 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -73,6 +73,10 @@ Type GetType(const PrimExpr& expr) {
   }
   // Default: return the type indicated by the dtype.
   runtime::DataType dtype = expr.dtype();
+  return GetTypeFromRuntimeDataType(dtype);
+}
+
+Type GetTypeFromRuntimeDataType(const DataType& dtype) {
   if (dtype.is_void()) {
     return VoidType();
   }
diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc
index 9c9ea756bb..f10d99eb1f 100644
--- a/tests/cpp/expr_test.cc
+++ b/tests/cpp/expr_test.cc
@@ -19,6 +19,7 @@
 
 #include <dmlc/logging.h>
 #include <gtest/gtest.h>
+#include <tvm/node/structural_equal.h>
 #include <tvm/te/operation.h>
 
 TEST(Expr, Basic) {
@@ -34,6 +35,16 @@ TEST(Expr, Basic) {
   ICHECK(os.str() == "max(((x + 1) + 2), 100)");
 }
 
+TEST(Expr, VarTypeAnnotation) {
+  using namespace tvm;
+  using namespace tvm::tir;
+  Var x("x", DataType::Float(32));
+  Var y("y", PrimType(DataType::Float(32)));
+  StructuralEqual checker;
+  ICHECK(checker(x->dtype, y->dtype));
+  ICHECK(checker(x->type_annotation, y->type_annotation));
+}
+
 TEST(ExprNodeRef, Basic) {
   using namespace tvm;
   using namespace tvm::tir;

Reply via email to