This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ff0b07ba6f [TIR] Add `is_vector` Method to DataType class and update
usages across Codebase (#17443)
ff0b07ba6f is described below
commit ff0b07ba6f225128fb030ebb0f45704d44812f00
Author: Lei Wang <[email protected]>
AuthorDate: Sun Oct 6 21:54:13 2024 +0800
[TIR] Add `is_vector` Method to DataType class and update usages across
Codebase (#17443)
* Refactor data_type.h and c_runtime_api.h
This commit refactors the `data_type.h` and `c_runtime_api.h` files. It
introduces a new function `is_vector()` in the `DataType` class to check if a
type is a vector type. Additionally, it adds a new constant `kTVMGridConstant`
in the `TVMTypeCode` enum in `c_runtime_api.h`. These changes improve the code
organization and provide better support for vector types.
* revert kTVMGridConstant
* lint fix
---
include/tvm/runtime/data_type.h | 2 ++
include/tvm/topi/elemwise.h | 2 +-
src/target/llvm/codegen_llvm.cc | 2 +-
src/target/llvm/intrin_rule_hexagon.cc | 8 ++++----
src/tir/analysis/verify_gpu_code.cc | 8 ++++----
5 files changed, 12 insertions(+), 10 deletions(-)
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index a330ccbbdf..c49fde1746 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -148,6 +148,8 @@ class DataType {
bool is_fixed_length_vector() const { return
static_cast<int16_t>(data_.lanes) > 1; }
/*! \return Whether the type is a scalable vector. */
bool is_scalable_vector() const { return static_cast<int16_t>(data_.lanes) <
-1; }
+ /*! \return whether type is a vector type. */
+ bool is_vector() const { return lanes() > 1; }
/*! \return whether type is a bool vector type. */
bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() &&
bits() == 1; }
/*! \return whether type is a Void type. */
diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h
index 132992c57d..806ddcb662 100644
--- a/include/tvm/topi/elemwise.h
+++ b/include/tvm/topi/elemwise.h
@@ -287,7 +287,7 @@ inline Tensor cast(const Tensor& x, DataType type,
std::string name = "T_cast",
if (expr.dtype().code() == type.code() && expr.dtype().bits() ==
type.bits()) {
if (expr.dtype().lanes() == type.lanes()) {
return expr;
- } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) {
+ } else if (expr.dtype().lanes() == 1 && type.is_vector()) {
return tvm::tir::Broadcast(expr, type.lanes());
}
}
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index e21436e556..3d6d3a9461 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -1737,7 +1737,7 @@ void CodeGenLLVM::BufferAccessHelper(
if (const RampNode* ramp = last_index.as<RampNode>()) {
PrimExpr offset = ramp->base + (ramp->stride * i);
last_index_value = MakeValue(offset);
- } else if (last_index.dtype().lanes() > 1) {
+ } else if (last_index.dtype().is_vector()) {
if (i == 0) {
cached_vector_index = MakeValue(last_index);
}
diff --git a/src/target/llvm/intrin_rule_hexagon.cc
b/src/target/llvm/intrin_rule_hexagon.cc
index 7c4b38c1d7..2661f2fa65 100644
--- a/src/target/llvm/intrin_rule_hexagon.cc
+++ b/src/target/llvm/intrin_rule_hexagon.cc
@@ -66,7 +66,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) {
// Enable QHL library for FP16 data type
const PrimExpr& x = call->args[0];
- if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
+ if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
return TVMExternCall(call, tvm_wrapper);
}
#endif
@@ -116,7 +116,7 @@ TVM_REGISTER_OP("tir.tanh")
}
// Enable QHL library for FP16 data type
- if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
+ if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf");
return TVMExternCall(call, tvm_wrapper);
}
@@ -152,7 +152,7 @@ TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>(
}
// Enable QHL library for FP16 data type
- if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
+ if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf");
return TVMExternCall(call, tvm_wrapper);
}
@@ -191,7 +191,7 @@ TVM_REGISTER_OP("tir.sigmoid")
const tir::Call new_call = tir::Call(call->dtype, call->op, new_args);
// Enable QHL library for FP16 data type
- if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
+ if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf");
return TVMExternCall(new_call.get(), tvm_wrapper);
}
diff --git a/src/tir/analysis/verify_gpu_code.cc
b/src/tir/analysis/verify_gpu_code.cc
index f012f8a1b3..8eda537579 100644
--- a/src/tir/analysis/verify_gpu_code.cc
+++ b/src/tir/analysis/verify_gpu_code.cc
@@ -71,7 +71,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
size_t size = static_cast<size_t>(op->ConstantAllocationSize());
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
}
- if (op->dtype.lanes() > 1) {
+ if (op->dtype.is_vector()) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) >
max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of
bytes ("
@@ -202,7 +202,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}
void VisitExpr_(const CastNode* op) {
- if (op->dtype.lanes() > 1) {
+ if (op->dtype.is_vector()) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) >
max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of
bytes ("
@@ -215,7 +215,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}
void VisitExpr_(const BufferLoadNode* op) {
- if (op->dtype.lanes() > 1) {
+ if (op->dtype.is_vector()) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) >
max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of
bytes ("
@@ -229,7 +229,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}
void VisitStmt_(const BufferStoreNode* op) {
- if (op->value->dtype.lanes() > 1) {
+ if (op->value->dtype.is_vector()) {
if (static_cast<size_t>(op->value->dtype.lanes() *
op->value->dtype.bytes()) >
max_vector_bytes_) {
std::stringstream s;