This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8ada2b1 [TIR] Fix VerifyGPUCode for vectorized halfx8 store (#9420)
8ada2b1 is described below
commit 8ada2b10312ac91a5a166e83d6791414ef27cc16
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Nov 3 04:45:05 2021 -0400
[TIR] Fix VerifyGPUCode for vectorized halfx8 store (#9420)
---
src/tir/analysis/verify_gpu_code.cc | 8 +++---
.../unittest/test_tir_analysis_verify_gpu_code.py | 29 ++++++++++++++++++++++
2 files changed, 33 insertions(+), 4 deletions(-)
diff --git a/src/tir/analysis/verify_gpu_code.cc
b/src/tir/analysis/verify_gpu_code.cc
index efffa90..dc1ed1c 100644
--- a/src/tir/analysis/verify_gpu_code.cc
+++ b/src/tir/analysis/verify_gpu_code.cc
@@ -198,12 +198,12 @@ class GPUCodeVerifier : public StmtExprVisitor {
}
void VisitStmt_(const StoreNode* op) {
- if (op->index->dtype.lanes() > 1) {
- if (static_cast<size_t>(op->index->dtype.lanes() *
op->index->dtype.bytes()) >
+ if (op->value->dtype.lanes() > 1) {
+ if (static_cast<size_t>(op->value->dtype.lanes() *
op->value->dtype.bytes()) >
max_vector_bytes_) {
std::stringstream s;
- s << "Number of lanes (" << op->index->dtype.lanes() << ") times
number of bytes ("
- << op->index->dtype.bytes() << ") for dtype " << op->index->dtype
+ s << "Number of lanes (" << op->value->dtype.lanes() << ") times
number of bytes ("
+ << op->value->dtype.bytes() << ") for dtype " << op->value->dtype
<< " is greater than the maximum number of vector bytes (" <<
max_vector_bytes_ << ")";
errors_.push_back(s.str());
}
diff --git a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py
b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py
index 9e9563a..b7d78aa 100644
--- a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py
+++ b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py
@@ -344,6 +344,34 @@ def test_vectorize():
@tvm.testing.requires_gpu
+def test_vectorize_half():
+ N = 1024
+
+ A = te.placeholder((N, N), name="A", dtype="float16")
+ B = te.compute((N, N), lambda i, j: A[i, j])
+
+ s = te.create_schedule([B.op])
+
+ i, j = s[B].op.axis
+
+ s[B].bind(i, te.thread_axis("blockIdx.x"))
+ jo, ji = s[B].split(j, factor=8)
+ s[B].bind(jo, te.thread_axis("threadIdx.x"))
+ s[B].vectorize(ji)
+
+ for target in ["opencl", "cuda"]:
+ if not tvm.testing.device_enabled(target):
+ continue
+
+ valid = [None]
+ with tvm.transform.PassContext(
+ config={"tir.add_lower_pass": [(2, get_verify_pass(valid,
max_vector_bytes=16))]}
+ ):
+ tvm.lower(s, [A, B])
+ assert valid[0]
+
+
[email protected]_gpu
def test_vthread():
N = 1024
@@ -409,5 +437,6 @@ if __name__ == "__main__":
test_multiple_kernels()
test_wrong_bind()
test_vectorize()
+ test_vectorize_half()
test_vthread()
test_redundant_kernels()