This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 8c31d0d Remove PrimExpr from String (#5311)
8c31d0d is described below
commit 8c31d0dd9fadec6d901915bf729728e3d11deffb
Author: Zhi <[email protected]>
AuthorDate: Sun Apr 12 09:12:23 2020 -0700
Remove PrimExpr from String (#5311)
---
include/tvm/ir/expr.h | 6 ------
src/ir/expr.cc | 3 ---
src/target/target.cc | 2 +-
src/tir/ir/stmt.cc | 43 +++++++++++++++++++++----------------
topi/include/topi/contrib/cublas.h | 4 ++--
topi/include/topi/contrib/rocblas.h | 2 +-
6 files changed, 28 insertions(+), 32 deletions(-)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 4e0a301..859a134 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -108,12 +108,6 @@ class PrimExpr : public BaseExpr {
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)
- /*!
- * \brief construct from runtime String.
- * \param value The value to be constructed.
- */
- TVM_DLL PrimExpr(runtime::String value); // NOLINT(*)
-
/*! \return the data type of this expression. */
DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype;
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index e08d832..7272213 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -40,9 +40,6 @@ PrimExpr::PrimExpr(int32_t value)
PrimExpr::PrimExpr(float value)
: PrimExpr(FloatImm(DataType::Float(32), value)) {}
-PrimExpr::PrimExpr(runtime::String value)
- : PrimExpr(tir::StringImmNode::make(value)) {}
-
PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
using runtime::ObjectTypeChecker;
if (auto* ptr = ref.as<tir::IterVarNode>()) {
diff --git a/src/target/target.cc b/src/target/target.cc
index 61d5f6f..50856d6 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -137,7 +137,7 @@ Target CreateTarget(const std::string& target_name,
} else if (target_name == "hybrid") {
t->device_type = kDLCPU;
} else if (target_name == "hexagon") {
- t->keys_array.push_back(runtime::String("hexagon"));
+ t->keys_array.push_back("hexagon");
t->device_type = kDLHexagon;
} else {
LOG(ERROR) << "Unknown target name " << target_name;
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 64e7ef5..705fe7b 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -58,7 +58,6 @@ Stmt AttrStmtNode::make(ObjectRef node,
TVM_REGISTER_GLOBAL("tir.AttrStmt")
.set_body_typed(AttrStmtNode::make);
-
Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
CHECK(condition.defined());
CHECK(message.dtype() == DataType::Int(32) ||
@@ -74,8 +73,14 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr
message, Stmt body) {
}
TVM_REGISTER_GLOBAL("tir.AssertStmt")
-.set_body_typed(AssertStmtNode::make);
-
+.set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) {
+ if (const auto* str = message.as<StringObj>()) {
+ auto msg = StringImmNode::make(str->data);
+ return AssertStmtNode::make(condition, msg, body);
+ } else {
+ return AssertStmtNode::make(condition, Downcast<PrimExpr>(message), body);
+ }
+});
Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body)
{
CHECK(body.defined());
@@ -92,11 +97,11 @@ TVM_REGISTER_GLOBAL("tir.ProducerConsumer")
Stmt ForNode::make(Var loop_var,
- PrimExpr min,
- PrimExpr extent,
- ForType for_type,
- DeviceAPI device_api,
- Stmt body) {
+ PrimExpr min,
+ PrimExpr extent,
+ ForType for_type,
+ DeviceAPI device_api,
+ Stmt body) {
CHECK(min.defined());
CHECK(extent.defined());
CHECK(min.dtype().is_scalar());
@@ -119,11 +124,11 @@ TVM_REGISTER_GLOBAL("tir.For")
Var loop_var, PrimExpr min, PrimExpr extent,
int for_type, int device_api, Stmt body) {
return ForNode::make(loop_var,
- min,
- extent,
- static_cast<ForType>(for_type),
- static_cast<DeviceAPI>(device_api),
- body);
+ min,
+ extent,
+ static_cast<ForType>(for_type),
+ static_cast<DeviceAPI>(device_api),
+ body);
});
@@ -176,12 +181,12 @@ TVM_REGISTER_GLOBAL("tir.Provide")
Stmt AllocateNode::make(Var buffer_var,
- DataType dtype,
- Array<PrimExpr> extents,
- PrimExpr condition,
- Stmt body,
- PrimExpr new_expr,
- std::string free_function) {
+ DataType dtype,
+ Array<PrimExpr> extents,
+ PrimExpr condition,
+ Stmt body,
+ PrimExpr new_expr,
+ std::string free_function) {
for (size_t i = 0; i < extents.size(); ++i) {
CHECK(extents[i].defined());
CHECK(extents[i].dtype().is_scalar());
diff --git a/topi/include/topi/contrib/cublas.h
b/topi/include/topi/contrib/cublas.h
index ee18dea..f2ed029 100644
--- a/topi/include/topi/contrib/cublas.h
+++ b/topi/include/topi/contrib/cublas.h
@@ -53,7 +53,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- runtime::String("tvm.contrib.cublas.matmul"),
+ StringImmNode::make("tvm.contrib.cublas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
@@ -85,7 +85,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs,
{ { b, n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- runtime::String("tvm.contrib.cublas.batch_matmul"),
+ StringImmNode::make("tvm.contrib.cublas.batch_matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
diff --git a/topi/include/topi/contrib/rocblas.h
b/topi/include/topi/contrib/rocblas.h
index 9fe1825..f0bf926 100644
--- a/topi/include/topi/contrib/rocblas.h
+++ b/topi/include/topi/contrib/rocblas.h
@@ -52,7 +52,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs,
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- runtime::String("tvm.contrib.rocblas.matmul"),
+ StringImmNode::make("tvm.contrib.rocblas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),