This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c6415d1492 Canonicalize type annotation during construction of Var and
SizeVar (#11443)
c6415d1492 is described below
commit c6415d14928d1e09f4bd3105c7a5ddf87f92166b
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon May 30 04:53:29 2022 -0700
Canonicalize type annotation during construction of Var and SizeVar (#11443)
* Canonicalize type annotation during construction of Var and SizeVar
* Update tests/cpp/expr_test.cc
Co-authored-by: Junru Shao <[email protected]>
* lint
* fix
Co-authored-by: Junru Shao <[email protected]>
---
include/tvm/tir/op.h | 9 +++++++++
src/tir/ir/expr.cc | 3 +++
src/tir/op/op.cc | 4 ++++
tests/cpp/expr_test.cc | 11 +++++++++++
4 files changed, 27 insertions(+)
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 905c67f1c5..34935aec61 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -60,6 +60,15 @@ namespace tvm {
*/
TVM_DLL Type GetType(const PrimExpr& expr);
+/*!
+ * \brief Get the type corresponding to DataType
+ * \param dtype The data type
+ * \return The result type
+ *
+ * \sa tvm/ir/type.h for discussion about the relation between Type and
runtime::DataType.
+ */
+TVM_DLL Type GetTypeFromRuntimeDataType(const DataType& dtype);
+
/*!
* \brief Get the implied DataType for storing values with type during runtime.
*
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index f4dbc238c1..7979c9f47a 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -65,6 +65,7 @@ namespace tir {
Var::Var(String name_hint, DataType dtype, Span span) {
auto n = make_object<VarNode>();
n->name_hint = std::move(name_hint);
+ n->type_annotation = GetTypeFromRuntimeDataType(dtype);
n->dtype = std::move(dtype);
n->span = std::move(span);
data_ = std::move(n);
@@ -99,6 +100,7 @@ Var Var::copy_with_dtype(DataType dtype) const {
} else {
new_ptr = make_object<VarNode>(*node);
}
+ new_ptr->type_annotation = GetTypeFromRuntimeDataType(dtype);
new_ptr->dtype = std::move(dtype);
return Var(new_ptr);
}
@@ -126,6 +128,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
SizeVar::SizeVar(String name_hint, DataType dtype, Span span) {
auto n = make_object<SizeVarNode>();
n->name_hint = std::move(name_hint);
+ n->type_annotation = GetTypeFromRuntimeDataType(dtype);
n->dtype = std::move(dtype);
n->span = std::move(span);
data_ = std::move(n);
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 696d82be72..73249921bf 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -73,6 +73,10 @@ Type GetType(const PrimExpr& expr) {
}
// Default: return the type indicated by the dtype.
runtime::DataType dtype = expr.dtype();
+ return GetTypeFromRuntimeDataType(dtype);
+}
+
+Type GetTypeFromRuntimeDataType(const DataType& dtype) {
if (dtype.is_void()) {
return VoidType();
}
diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc
index 9c9ea756bb..f10d99eb1f 100644
--- a/tests/cpp/expr_test.cc
+++ b/tests/cpp/expr_test.cc
@@ -19,6 +19,7 @@
#include <dmlc/logging.h>
#include <gtest/gtest.h>
+#include <tvm/node/structural_equal.h>
#include <tvm/te/operation.h>
TEST(Expr, Basic) {
@@ -34,6 +35,16 @@ TEST(Expr, Basic) {
ICHECK(os.str() == "max(((x + 1) + 2), 100)");
}
+TEST(Expr, VarTypeAnnotation) {
+ using namespace tvm;
+ using namespace tvm::tir;
+ Var x("x", DataType::Float(32));
+ Var y("y", PrimType(DataType::Float(32)));
+ StructuralEqual checker;
+ ICHECK(checker(x->dtype, y->dtype));
+ ICHECK(checker(x->type_annotation, y->type_annotation));
+}
+
TEST(ExprNodeRef, Basic) {
using namespace tvm;
using namespace tvm::tir;