This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ff0b07ba6f [TIR] Add `is_vector` Method to DataType class and update 
usages across Codebase (#17443)
ff0b07ba6f is described below

commit ff0b07ba6f225128fb030ebb0f45704d44812f00
Author: Lei Wang <[email protected]>
AuthorDate: Sun Oct 6 21:54:13 2024 +0800

    [TIR] Add `is_vector` Method to DataType class and update usages across 
Codebase (#17443)
    
    * Refactor data_type.h and c_runtime_api.h
    
    This commit refactors the `data_type.h` and `c_runtime_api.h` files. It 
introduces a new function `is_vector()` in the `DataType` class to check if a 
type is a vector type. Additionally, it adds a new constant `kTVMGridConstant` 
in the `TVMTypeCode` enum in `c_runtime_api.h`. These changes improve the code 
organization and provide better support for vector types.
    
    * revert kTVMGridConstant
    
    * lint fix
---
 include/tvm/runtime/data_type.h        | 2 ++
 include/tvm/topi/elemwise.h            | 2 +-
 src/target/llvm/codegen_llvm.cc        | 2 +-
 src/target/llvm/intrin_rule_hexagon.cc | 8 ++++----
 src/tir/analysis/verify_gpu_code.cc    | 8 ++++----
 5 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index a330ccbbdf..c49fde1746 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -148,6 +148,8 @@ class DataType {
   bool is_fixed_length_vector() const { return 
static_cast<int16_t>(data_.lanes) > 1; }
   /*! \return Whether the type is a scalable vector. */
   bool is_scalable_vector() const { return static_cast<int16_t>(data_.lanes) < 
-1; }
+  /*! \return whether type is a vector type. */
+  bool is_vector() const { return lanes() > 1; }
   /*! \return whether type is a bool vector type. */
   bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && 
bits() == 1; }
   /*! \return whether type is a Void type. */
diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h
index 132992c57d..806ddcb662 100644
--- a/include/tvm/topi/elemwise.h
+++ b/include/tvm/topi/elemwise.h
@@ -287,7 +287,7 @@ inline Tensor cast(const Tensor& x, DataType type, 
std::string name = "T_cast",
         if (expr.dtype().code() == type.code() && expr.dtype().bits() == 
type.bits()) {
           if (expr.dtype().lanes() == type.lanes()) {
             return expr;
-          } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) {
+          } else if (expr.dtype().lanes() == 1 && type.is_vector()) {
             return tvm::tir::Broadcast(expr, type.lanes());
           }
         }
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index e21436e556..3d6d3a9461 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -1737,7 +1737,7 @@ void CodeGenLLVM::BufferAccessHelper(
     if (const RampNode* ramp = last_index.as<RampNode>()) {
       PrimExpr offset = ramp->base + (ramp->stride * i);
       last_index_value = MakeValue(offset);
-    } else if (last_index.dtype().lanes() > 1) {
+    } else if (last_index.dtype().is_vector()) {
       if (i == 0) {
         cached_vector_index = MakeValue(last_index);
       }
diff --git a/src/target/llvm/intrin_rule_hexagon.cc 
b/src/target/llvm/intrin_rule_hexagon.cc
index 7c4b38c1d7..2661f2fa65 100644
--- a/src/target/llvm/intrin_rule_hexagon.cc
+++ b/src/target/llvm/intrin_rule_hexagon.cc
@@ -66,7 +66,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) {
 
   // Enable QHL library for FP16 data type
   const PrimExpr& x = call->args[0];
-  if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
+  if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
     return TVMExternCall(call, tvm_wrapper);
   }
 #endif
@@ -116,7 +116,7 @@ TVM_REGISTER_OP("tir.tanh")
       }
 
       // Enable QHL library for FP16 data type
-      if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
+      if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
         std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf");
         return TVMExternCall(call, tvm_wrapper);
       }
@@ -152,7 +152,7 @@ TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>(
       }
 
       // Enable QHL library for FP16 data type
-      if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
+      if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
         std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf");
         return TVMExternCall(call, tvm_wrapper);
       }
@@ -191,7 +191,7 @@ TVM_REGISTER_OP("tir.sigmoid")
       const tir::Call new_call = tir::Call(call->dtype, call->op, new_args);
 
       // Enable QHL library for FP16 data type
-      if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
+      if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
         std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf");
         return TVMExternCall(new_call.get(), tvm_wrapper);
       }
diff --git a/src/tir/analysis/verify_gpu_code.cc 
b/src/tir/analysis/verify_gpu_code.cc
index f012f8a1b3..8eda537579 100644
--- a/src/tir/analysis/verify_gpu_code.cc
+++ b/src/tir/analysis/verify_gpu_code.cc
@@ -71,7 +71,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
       size_t size = static_cast<size_t>(op->ConstantAllocationSize());
       shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
     }
-    if (op->dtype.lanes() > 1) {
+    if (op->dtype.is_vector()) {
       if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > 
max_vector_bytes_) {
         std::stringstream s;
         s << "Number of lanes (" << op->dtype.lanes() << ") times number of 
bytes ("
@@ -202,7 +202,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
   }
 
   void VisitExpr_(const CastNode* op) {
-    if (op->dtype.lanes() > 1) {
+    if (op->dtype.is_vector()) {
       if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > 
max_vector_bytes_) {
         std::stringstream s;
         s << "Number of lanes (" << op->dtype.lanes() << ") times number of 
bytes ("
@@ -215,7 +215,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
   }
 
   void VisitExpr_(const BufferLoadNode* op) {
-    if (op->dtype.lanes() > 1) {
+    if (op->dtype.is_vector()) {
       if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > 
max_vector_bytes_) {
         std::stringstream s;
         s << "Number of lanes (" << op->dtype.lanes() << ") times number of 
bytes ("
@@ -229,7 +229,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
   }
 
   void VisitStmt_(const BufferStoreNode* op) {
-    if (op->value->dtype.lanes() > 1) {
+    if (op->value->dtype.is_vector()) {
       if (static_cast<size_t>(op->value->dtype.lanes() * 
op->value->dtype.bytes()) >
           max_vector_bytes_) {
         std::stringstream s;

Reply via email to