This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 074a07e CUDA device API & VerifyGPUCode pass update (#5898)
074a07e is described below
commit 074a07ede031762afb755f8b9273c0afb55c8b4a
Author: Chenfan <[email protected]>
AuthorDate: Thu Jun 25 13:44:39 2020 +0800
CUDA device API & VerifyGPUCode pass update (#5898)
* Add kMaxRegistersPerBlock device api for cuda
* Add vectorize check to verify_gpu_code
* Lint fix
* Cast fix
---
include/tvm/runtime/device_api.h | 3 +-
src/runtime/cuda/cuda_device_api.cc | 4 +++
src/runtime/metal/metal_device_api.mm | 4 ++-
src/runtime/opencl/opencl_device_api.cc | 2 ++
src/runtime/rocm/rocm_device_api.cc | 2 ++
src/runtime/vulkan/vulkan.cc | 2 ++
src/tir/analysis/verify_gpu_code.cc | 42 ++++++++++++++++------
.../unittest/test_tir_analysis_verify_gpu_code.py | 25 +++++++++++++
8 files changed, 72 insertions(+), 12 deletions(-)
diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 421811a..3cf5566 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -44,7 +44,8 @@ enum DeviceAttrKind : int {
kMaxClockRate = 6,
kMultiProcessorCount = 7,
kMaxThreadDimensions = 8,
- kGcnArch = 9
+ kMaxRegistersPerBlock = 9,
+ kGcnArch = 10
};
/*! \brief Number of bytes each allocation must align to */
diff --git a/src/runtime/cuda/cuda_device_api.cc
b/src/runtime/cuda/cuda_device_api.cc
index a6d4a54..ccd8e91 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -92,6 +92,10 @@ class CUDADeviceAPI final : public DeviceAPI {
*rv = ss.str();
return;
}
+ case kMaxRegistersPerBlock: {
+ CUDA_CALL(cudaDeviceGetAttribute(&value,
cudaDevAttrMaxRegistersPerBlock, ctx.device_id));
+ break;
+ }
case kGcnArch:
return;
}
diff --git a/src/runtime/metal/metal_device_api.mm
b/src/runtime/metal/metal_device_api.mm
index 3bad2c3..a64f35c 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -64,7 +64,9 @@ void MetalWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind
kind, TVMRetValue* r
case kMaxThreadDimensions:
return;
case kExist:
- break;
+ return;
+ case kMaxRegistersPerBlock:
+ return;
case kGcnArch:
return;
}
diff --git a/src/runtime/opencl/opencl_device_api.cc
b/src/runtime/opencl/opencl_device_api.cc
index 6d9835e..72d03fb 100644
--- a/src/runtime/opencl/opencl_device_api.cc
+++ b/src/runtime/opencl/opencl_device_api.cc
@@ -107,6 +107,8 @@ void OpenCLWorkspace::GetAttr(TVMContext ctx,
DeviceAttrKind kind, TVMRetValue*
*rv = ss.str();
break;
}
+ case kMaxRegistersPerBlock:
+ return;
case kGcnArch:
return;
}
diff --git a/src/runtime/rocm/rocm_device_api.cc
b/src/runtime/rocm/rocm_device_api.cc
index 475c4fb..e3dbef5 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -102,6 +102,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
*rv = ss.str();
return;
}
+ case kMaxRegistersPerBlock:
+ return;
case kGcnArch: {
hipDeviceProp_t prop;
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc
index 4481011..ade4ddc 100644
--- a/src/runtime/vulkan/vulkan.cc
+++ b/src/runtime/vulkan/vulkan.cc
@@ -413,6 +413,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx,
DeviceAttrKind kind, TVMRetValue*
*rv = ss.str();
break;
}
+ case kMaxRegistersPerBlock:
+ return;
case kGcnArch:
return;
}
diff --git a/src/tir/analysis/verify_gpu_code.cc
b/src/tir/analysis/verify_gpu_code.cc
index 1fbae0f..9477e04 100644
--- a/src/tir/analysis/verify_gpu_code.cc
+++ b/src/tir/analysis/verify_gpu_code.cc
@@ -33,20 +33,22 @@
namespace tvm {
namespace tir {
-class GPUCodeVerifier : public StmtVisitor {
+class GPUCodeVerifier : public StmtExprVisitor {
public:
bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t
max_shared_memory_per_block,
int64_t max_threads_per_block, int64_t max_thread_x, int64_t
max_thread_y,
- int64_t max_thread_z) {
+ int64_t max_thread_z, int64_t max_vector_bytes) {
max_local_memory_per_block_ =
static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ =
static_cast<size_t>(max_shared_memory_per_block);
max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
max_thread_x_ = static_cast<size_t>(max_thread_x);
max_thread_y_ = static_cast<size_t>(max_thread_y);
max_thread_z_ = static_cast<size_t>(max_thread_z);
+ max_vector_bytes_ = static_cast<size_t>(max_vector_bytes);
Reset_();
+ // TODO(jcf94): Add support of detecting CUDA Misaligned Address error
this->VisitStmt(stmt);
return valid_;
@@ -62,6 +64,9 @@ class GPUCodeVerifier : public StmtVisitor {
size_t size = static_cast<size_t>(op->constant_allocation_size());
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
}
+ if (op->dtype.lanes() > 1) {
+ valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <=
max_vector_bytes_;
+ }
}
void VisitStmt_(const AttrStmtNode* op) final {
@@ -129,6 +134,17 @@ class GPUCodeVerifier : public StmtVisitor {
}
}
+ void VisitExpr_(const LoadNode* op) {
+ // Currently not able to check out: If the index expression failed
+ // to be simplified to a RampNode
+ if (op->index->IsInstance<RampNode>()) {
+ if (op->dtype.lanes() > 1) {
+ valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes())
<= max_vector_bytes_;
+ }
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
private:
int nest_level_{0};
@@ -146,6 +162,7 @@ class GPUCodeVerifier : public StmtVisitor {
size_t max_shared_memory_per_block_;
size_t max_threads_per_block_;
size_t max_thread_x_, max_thread_y_, max_thread_z_;
+ size_t max_vector_bytes_;
bool valid_{true};
@@ -169,27 +186,32 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String,
PrimExpr> constraints) {
int64_t max_thread_x = INT64_MAX;
int64_t max_thread_y = INT64_MAX;
int64_t max_thread_z = INT64_MAX;
+ int64_t max_vector_bytes = INT64_MAX;
for (auto iter : constraints) {
const IntImmNode* val = iter.second.as<IntImmNode>();
- if (iter.first == "max_local_memory_per_block")
+ if (iter.first == "max_local_memory_per_block") {
max_local_memory_per_block = val->value;
- else if (iter.first == "max_shared_memory_per_block")
+ } else if (iter.first == "max_shared_memory_per_block") {
max_shared_memory_per_block = val->value;
- else if (iter.first == "max_threads_per_block")
+ } else if (iter.first == "max_threads_per_block") {
max_threads_per_block = val->value;
- else if (iter.first == "max_thread_x")
+ } else if (iter.first == "max_thread_x") {
max_thread_x = val->value;
- else if (iter.first == "max_thread_y")
+ } else if (iter.first == "max_thread_y") {
max_thread_y = val->value;
- else if (iter.first == "max_thread_z")
+ } else if (iter.first == "max_thread_z") {
max_thread_z = val->value;
- else
+ } else if (iter.first == "max_vector_bytes") {
+ max_vector_bytes = val->value;
+ } else {
LOG(FATAL) << "Invalid check item: " << iter.first;
+ }
}
return verifier.Verify(func->body, max_local_memory_per_block,
max_shared_memory_per_block,
- max_threads_per_block, max_thread_x, max_thread_y,
max_thread_z);
+ max_threads_per_block, max_thread_x, max_thread_y,
max_thread_z,
+ max_vector_bytes);
}
TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);
diff --git a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py
b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py
index 11960ca..ece8402 100644
--- a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py
+++ b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py
@@ -208,6 +208,30 @@ def test_wrong_bind():
tvm.build(s, [A, B], target)
assert not valid[0]
+def test_vectorize():
+ N = 1024
+
+ A = te.placeholder((N, N), name='A')
+ B = te.compute((N, N), lambda i, j: A[i, j])
+
+ s = te.create_schedule([B.op])
+
+ i, j = s[B].op.axis
+
+ s[B].bind(i, te.thread_axis("blockIdx.x"))
+ jo, ji = s[B].split(j, factor=64)
+ s[B].bind(jo, te.thread_axis("threadIdx.x"))
+ s[B].vectorize(ji)
+
+ for target in ['opencl', 'cuda']:
+ if not tvm.context(target).exist:
+ continue
+
+ valid = [None]
+ with tvm.transform.PassContext(config={"tir.add_lower_pass": [
+ (2, get_verify_pass(valid, max_vector_bytes=16))]}):
+ tvm.lower(s, [A, B])
+ assert not valid[0]
if __name__ == "__main__":
test_local_memory()
@@ -215,3 +239,4 @@ if __name__ == "__main__":
test_num_thread()
test_multiple_kernels()
test_wrong_bind()
+ test_vectorize()