This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8ada2b1  [TIR] Fix VerifyGPUCode for vectorized halfx8 store (#9420)
8ada2b1 is described below

commit 8ada2b10312ac91a5a166e83d6791414ef27cc16
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Nov 3 04:45:05 2021 -0400

    [TIR] Fix VerifyGPUCode for vectorized halfx8 store (#9420)
---
 src/tir/analysis/verify_gpu_code.cc                |  8 +++---
 .../unittest/test_tir_analysis_verify_gpu_code.py  | 29 ++++++++++++++++++++++
 2 files changed, 33 insertions(+), 4 deletions(-)

diff --git a/src/tir/analysis/verify_gpu_code.cc 
b/src/tir/analysis/verify_gpu_code.cc
index efffa90..dc1ed1c 100644
--- a/src/tir/analysis/verify_gpu_code.cc
+++ b/src/tir/analysis/verify_gpu_code.cc
@@ -198,12 +198,12 @@ class GPUCodeVerifier : public StmtExprVisitor {
   }
 
   void VisitStmt_(const StoreNode* op) {
-    if (op->index->dtype.lanes() > 1) {
-      if (static_cast<size_t>(op->index->dtype.lanes() * 
op->index->dtype.bytes()) >
+    if (op->value->dtype.lanes() > 1) {
+      if (static_cast<size_t>(op->value->dtype.lanes() * 
op->value->dtype.bytes()) >
           max_vector_bytes_) {
         std::stringstream s;
-        s << "Number of lanes (" << op->index->dtype.lanes() << ") times 
number of bytes ("
-          << op->index->dtype.bytes() << ") for dtype " << op->index->dtype
+        s << "Number of lanes (" << op->value->dtype.lanes() << ") times 
number of bytes ("
+          << op->value->dtype.bytes() << ") for dtype " << op->value->dtype
           << " is greater than the maximum number of vector bytes (" << 
max_vector_bytes_ << ")";
         errors_.push_back(s.str());
       }
diff --git a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py 
b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py
index 9e9563a..b7d78aa 100644
--- a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py
+++ b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py
@@ -344,6 +344,34 @@ def test_vectorize():
 
 
 @tvm.testing.requires_gpu
+def test_vectorize_half():
+    N = 1024
+
+    A = te.placeholder((N, N), name="A", dtype="float16")
+    B = te.compute((N, N), lambda i, j: A[i, j])
+
+    s = te.create_schedule([B.op])
+
+    i, j = s[B].op.axis
+
+    s[B].bind(i, te.thread_axis("blockIdx.x"))
+    jo, ji = s[B].split(j, factor=8)
+    s[B].bind(jo, te.thread_axis("threadIdx.x"))
+    s[B].vectorize(ji)
+
+    for target in ["opencl", "cuda"]:
+        if not tvm.testing.device_enabled(target):
+            continue
+
+        valid = [None]
+        with tvm.transform.PassContext(
+            config={"tir.add_lower_pass": [(2, get_verify_pass(valid, 
max_vector_bytes=16))]}
+        ):
+            tvm.lower(s, [A, B])
+        assert valid[0]
+
+
[email protected]_gpu
 def test_vthread():
     N = 1024
 
@@ -409,5 +437,6 @@ if __name__ == "__main__":
     test_multiple_kernels()
     test_wrong_bind()
     test_vectorize()
+    test_vectorize_half()
     test_vthread()
     test_redundant_kernels()

Reply via email to