This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 25bcd1c Improve error messages for memory verifier and gpu memory
verifier (#6281)
25bcd1c is described below
commit 25bcd1ceae299371a56b26358b224d07d34119d7
Author: Tristan Konolige <[email protected]>
AuthorDate: Fri Aug 14 20:06:24 2020 -0700
Improve error messages for memory verifier and gpu memory verifier (#6281)
* [FIX] Print exactly what issues the GPU memory verifier encountered.
* [FIX] Print exactly why memory verifier failed.
---
src/tir/analysis/verify_gpu_code.cc | 112 +++++++++++++++++++++++++++---------
src/tir/analysis/verify_memory.cc | 56 +++++++++---------
2 files changed, 115 insertions(+), 53 deletions(-)
diff --git a/src/tir/analysis/verify_gpu_code.cc
b/src/tir/analysis/verify_gpu_code.cc
index cce0823..5ef755a 100644
--- a/src/tir/analysis/verify_gpu_code.cc
+++ b/src/tir/analysis/verify_gpu_code.cc
@@ -35,9 +35,10 @@ namespace tir {
class GPUCodeVerifier : public StmtExprVisitor {
public:
- bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t
max_shared_memory_per_block,
- int64_t max_threads_per_block, int64_t max_thread_x, int64_t
max_thread_y,
- int64_t max_thread_z, int64_t max_vthread, int64_t
max_vector_bytes) {
+ std::vector<String> Verify(Stmt stmt, int64_t max_local_memory_per_block,
+ int64_t max_shared_memory_per_block, int64_t
max_threads_per_block,
+ int64_t max_thread_x, int64_t max_thread_y,
int64_t max_thread_z,
+ int64_t max_vthread, int64_t max_vector_bytes) {
max_local_memory_per_block_ =
static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ =
static_cast<size_t>(max_shared_memory_per_block);
max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
@@ -52,7 +53,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
// TODO(jcf94): Add support of detecting CUDA Misaligned Address error
this->VisitStmt(stmt);
- return valid_;
+ return errors_;
}
void VisitStmt_(const AllocateNode* op) final {
@@ -66,7 +67,13 @@ class GPUCodeVerifier : public StmtExprVisitor {
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
}
if (op->dtype.lanes() > 1) {
- valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <=
max_vector_bytes_;
+ if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) >
max_vector_bytes_) {
+ std::stringstream s;
+ s << "Number of lanes (" << op->dtype.lanes() << ") times number of
bytes ("
+ << op->dtype.bytes() << ") for dtype " << op->dtype
+ << " is greater than the maximum number of vector bytes (" <<
max_vector_bytes_ << ")";
+ errors_.push_back(s.str());
+ }
}
}
@@ -98,27 +105,39 @@ class GPUCodeVerifier : public StmtExprVisitor {
visited_threads_.insert(name);
thread_per_block_ *= length;
+ auto err = [this](std::string id, size_t ext, size_t m) {
+ if (ext > m) {
+ std::stringstream s;
+ s << "Extent of " << id << " (" << ext << ") is greater than
maximum allowed (" << m
+ << ");";
+ errors_.push_back(s.str());
+ }
+ };
+
if (name == "threadIdx.x") {
- valid_ &= length <= max_thread_x_;
+ err("threadIdx.x", length, max_thread_x_);
thread_x_extent_ = length;
} else if (name == "threadIdx.y") {
- valid_ &= length <= max_thread_y_;
+ err("threadIdx.y", length, max_thread_y_);
thread_y_extent_ = length;
} else if (name == "threadIdx.z") {
- valid_ &= length <= max_thread_z_;
+ err("threadIdx.z", length, max_thread_z_);
thread_z_extent_ = length;
} else if (name == "vthread") {
- valid_ &= length <= max_vthread_;
+ err("vthread", length, max_vthread_);
}
} else {
// the thread should be bound to axes with the same length
- if (name == "threadIdx.x") {
- valid_ &= length == thread_x_extent_;
- } else if (name == "threadIdx.y") {
- valid_ &= length == thread_y_extent_;
- } else if (name == "threadIdx.z") {
- valid_ &= length == thread_z_extent_;
- }
+ auto err = [this, name](std::string id, size_t ext, size_t m) {
+ if (name == id && ext != m) {
+ std::stringstream s;
+ s << "Extent of " << id << " (" << ext << ") does not match the
bound " << m;
+ errors_.push_back(s.str());
+ }
+ };
+ err("threadIdx.x", length, thread_x_extent_);
+ err("threadIdx.y", length, thread_y_extent_);
+ err("threadIdx.z", length, thread_z_extent_);
}
}
@@ -128,10 +147,17 @@ class GPUCodeVerifier : public StmtExprVisitor {
if (nest_level_ == 0) {
// exit a kernel, check the validity
- valid_ &= thread_per_block_ <= max_threads_per_block_;
-
- valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
- valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
+ auto err = [this](std::string id, size_t num, size_t m) {
+ if (num > m) {
+ std::stringstream s;
+ s << "Used " << id << " (" << num << ") is greater than the
allowed maximum (" << m
+ << ")";
+ errors_.push_back(s.str());
+ }
+ };
+ err("threads per block", thread_per_block_, max_threads_per_block_);
+ err("local memory per block", local_memory_per_block_,
max_local_memory_per_block_);
+ err("shared memory per block", shared_memory_per_block_,
max_shared_memory_per_block_);
}
} else {
StmtVisitor::VisitStmt_(op);
@@ -143,7 +169,13 @@ class GPUCodeVerifier : public StmtExprVisitor {
const auto* extent = op->extent.as<IntImmNode>();
CHECK(extent);
- valid_ &= static_cast<size_t>(extent->value) <= max_vthread_;
+ size_t num_vthread = static_cast<size_t>(extent->value);
+ if (num_vthread > max_vthread_) {
+ std::stringstream s;
+ s << "Number of vthreads (" << num_vthread << ") is greater than the
allowed maximum ("
+ << max_vthread_ << ")";
+ errors_.push_back(s.str());
+ }
}
StmtVisitor::VisitStmt_(op);
@@ -151,15 +183,27 @@ class GPUCodeVerifier : public StmtExprVisitor {
void VisitExpr_(const LoadNode* op) {
if (op->dtype.lanes() > 1) {
- valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <=
max_vector_bytes_;
+ if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) >
max_vector_bytes_) {
+ std::stringstream s;
+ s << "Number of lanes (" << op->dtype.lanes() << ") times number of
bytes ("
+ << op->dtype.bytes() << ") for dtype " << op->dtype
+ << " is greater than the maximum number of vector bytes (" <<
max_vector_bytes_ << ")";
+ errors_.push_back(s.str());
+ }
}
ExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const StoreNode* op) {
if (op->index->dtype.lanes() > 1) {
- valid_ &= static_cast<size_t>(op->index->dtype.lanes() *
op->index->dtype.bytes()) <=
- max_vector_bytes_;
+ if (static_cast<size_t>(op->index->dtype.lanes() *
op->index->dtype.bytes()) >
+ max_vector_bytes_) {
+ std::stringstream s;
+ s << "Number of lanes (" << op->index->dtype.lanes() << ") times
number of bytes ("
+ << op->index->dtype.bytes() << ") for dtype " << op->index->dtype
+ << " is greater than the maximum number of vector bytes (" <<
max_vector_bytes_ << ")";
+ errors_.push_back(s.str());
+ }
}
StmtVisitor::VisitStmt_(op);
}
@@ -183,7 +227,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
size_t max_thread_x_, max_thread_y_, max_thread_z_, max_vthread_;
size_t max_vector_bytes_;
- bool valid_{true};
+ std::vector<String> errors_;
void Reset_() {
visited_local_buffers_.clear();
@@ -196,7 +240,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}
};
-bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
+std::vector<String> VerifyGPUCode_(const PrimFunc& func, Map<String, PrimExpr>
constraints) {
GPUCodeVerifier verifier;
int64_t max_local_memory_per_block = INT64_MAX;
@@ -236,6 +280,11 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String,
PrimExpr> constraints) {
max_vthread, max_vector_bytes);
}
+bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
+ auto errs = VerifyGPUCode_(func, constraints);
+ return errs.size() == 0;
+}
+
TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);
namespace transform {
@@ -245,7 +294,16 @@ Pass VerifyGPUCode(Map<String, PrimExpr> constraints) {
for (auto kv : mod->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
auto func = GetRef<PrimFunc>(n);
- CHECK(VerifyGPUCode(func, constraints)) << "RuntimeError: GPU
constraint violated" << func;
+ auto errs = VerifyGPUCode_(func, constraints);
+ if (errs.size() != 0) {
+ std::stringstream s;
+ for (auto& err : errs) {
+ s << " " << err << std::endl;
+ }
+ LOG(FATAL) << "RuntimeError: GPU constraint(s) violated:\n"
+ << s.str() << " In function\n"
+ << func;
+ }
}
}
return mod;
diff --git a/src/tir/analysis/verify_memory.cc
b/src/tir/analysis/verify_memory.cc
index dfad549..64097e1 100644
--- a/src/tir/analysis/verify_memory.cc
+++ b/src/tir/analysis/verify_memory.cc
@@ -62,20 +62,14 @@ class MemoryAccessVerifier final : protected
StmtExprVisitor {
}
/// Verification result
- bool Failed() const { return failure_; }
+ std::vector<String> Errors() const { return errs_; }
protected:
/// Visitor implementation
//@{
- void VisitExpr(const PrimExpr& n) final {
- if (Failed()) return;
- StmtExprVisitor::VisitExpr(n);
- }
+ void VisitExpr(const PrimExpr& n) final { StmtExprVisitor::VisitExpr(n); }
- void VisitStmt(const Stmt& n) final {
- if (Failed()) return;
- StmtExprVisitor::VisitStmt(n);
- }
+ void VisitStmt(const Stmt& n) final { StmtExprVisitor::VisitStmt(n); }
void VisitStmt_(const LetStmtNode* op) final {
// Book keep definitions
@@ -139,7 +133,11 @@ class MemoryAccessVerifier final : protected
StmtExprVisitor {
if (!IsFromFunctionArgs(var.get())) return;
// The verification fails in this case.
- SetFailure();
+ std::stringstream s;
+ s << "Variable `" << var
+ << "` is directly accessed by host memory (it is not contained in a
thread environment or in "
+ "the function arguments.";
+ errs_.push_back(s.str());
}
/// Status getter/setter
@@ -147,7 +145,6 @@ class MemoryAccessVerifier final : protected
StmtExprVisitor {
bool InThreadEnv() const { return in_thread_env_; }
void EnterThreadEnv() { in_thread_env_ = true; }
void ExitThreadEnv() { in_thread_env_ = false; }
- void SetFailure() { failure_ = true; }
//@}
/// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
@@ -162,7 +159,7 @@ class MemoryAccessVerifier final : protected
StmtExprVisitor {
/// Status of visitor
//@{
bool in_thread_env_{false};
- bool failure_{false}; ///< If the verification fails (i.e. has illegal
access)
+ std::vector<String> errs_;
//@}
tir::PrimFunc func_{nullptr}; ///< Function to be
verified.
int dev_type_{kDLCPU}; ///< Device type
@@ -171,7 +168,7 @@ class MemoryAccessVerifier final : protected
StmtExprVisitor {
} // namespace
/// Interface of VerifyMemory pass
-bool VerifyMemory(const PrimFunc& func) {
+std::vector<String> VerifyMemory_(const PrimFunc& func) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute";
@@ -179,30 +176,37 @@ bool VerifyMemory(const PrimFunc& func) {
CallingConv::kDefault) {
MemoryAccessVerifier v(func, target.value()->kind->device_type);
v.Run();
- return !v.Failed();
+ return v.Errors();
} else {
- return true;
+ return {};
}
}
+bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() ==
0; }
+
TVM_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory);
namespace transform {
Pass VerifyMemory() {
- auto pass_func =
- [=](IRModule mod, PassContext ctx) {
- for (auto kv : mod->functions) {
- if (auto* n = kv.second.as<PrimFuncNode>()) {
- auto func = GetRef<PrimFunc>(n);
- CHECK(VerifyMemory(func))
- << "RuntimeError: Direct host side access to device memory is
detected."
- << " Did you forget to bind?\n"
- << func;
+ auto pass_func = [=](IRModule mod, PassContext ctx) {
+ for (auto kv : mod->functions) {
+ if (auto* n = kv.second.as<PrimFuncNode>()) {
+ auto func = GetRef<PrimFunc>(n);
+ auto errs = VerifyMemory_(func);
+ if (errs.size() > 0) {
+ std::stringstream s;
+ for (auto& err : errs) {
+ s << " " << err << "\n";
}
+ LOG(FATAL) << "RuntimeError: Memory verification failed with the
following errors:\n"
+ << s.str() << " Did you forget to bind?\n"
+ << func;
}
- return mod;
- };
+ }
+ }
+ return mod;
+ };
return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory",
{});
}