This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0978ab656c [RUNTIME][METAL] Provide richer runtime when error happens
(#16713)
0978ab656c is described below
commit 0978ab656c0b76fe69e116f3254b55084996c5ba
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Mar 14 09:06:12 2024 -0400
[RUNTIME][METAL] Provide richer runtime when error happens (#16713)
This PR enhances metal runtime to include more error messages
when error happens.
---
src/runtime/metal/metal_common.h | 27 +++++++++++++++++++--------
src/runtime/metal/metal_device_api.mm | 4 ++--
src/runtime/metal/metal_module.mm | 16 +++++++++++++++-
3 files changed, 36 insertions(+), 11 deletions(-)
diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h
index dc7b344800..e5339e6366 100644
--- a/src/runtime/metal/metal_common.h
+++ b/src/runtime/metal/metal_common.h
@@ -38,6 +38,7 @@
#include <memory>
#include <mutex>
#include <string>
+#include <utility>
#include <vector>
#include "../workspace_pool.h"
@@ -106,25 +107,35 @@ class AutoReleasePoolWrapper {
*/
class Stream {
public:
- explicit Stream(id<MTLDevice> device) : error_happened_(false) {
- queue_ = [device newCommandQueue];
- }
+ explicit Stream(id<MTLDevice> device) { queue_ = [device newCommandQueue]; }
~Stream() { [queue_ release]; }
- id<MTLCommandBuffer> GetCommandBuffer() {
+ id<MTLCommandBuffer> GetCommandBuffer(bool attach_error_callback = true) {
id<MTLCommandBuffer> cb = [queue_ commandBuffer];
[cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
- if (buffer.status == MTLCommandBufferStatusError) SetErrorStatus();
+ if (buffer.status == MTLCommandBufferStatusError) {
+ ICHECK(buffer.error != nil);
+ this->SetError(buffer.error.localizedDescription.UTF8String);
+ }
}];
return cb;
}
- bool HasErrorHappened() { return error_happened_; }
+
+ void SetError(std::string error_description) {
+ error_happened_ = true;
+ error_description_ = std::move(error_description);
+ }
+
+ bool HasErrorHappened() const { return error_happened_; }
+
+ const std::string& ErrorDescription() const { return error_description_; }
private:
- void SetErrorStatus() { error_happened_ = true; }
// Queue
id<MTLCommandQueue> queue_;
// Check if error happened in one previous run
- bool error_happened_;
+ bool error_happened_{false};
+ // error description
+ std::string error_description_;
};
/*!
diff --git a/src/runtime/metal/metal_device_api.mm
b/src/runtime/metal/metal_device_api.mm
index 3b01bc65b1..37fb9dc347 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -222,7 +222,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
size_t from_offset, void*
if (dev_from.device_type == kDLCPU) dev = dev_to;
Stream* s = this->CastStreamOrGetDefault(stream, dev.device_id);
if (s->HasErrorHappened()) {
- LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to
current stream";
+ LOG(FATAL) << "GPUError: " << s->ErrorDescription();
}
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
int from_dev_type = static_cast<int>(dev_from.device_type);
@@ -301,7 +301,7 @@ void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle
stream) {
[cb commit];
[cb waitUntilCompleted];
if (s->HasErrorHappened()) {
- LOG(FATAL) << "Error! Some problems on GPU happaned!";
+ LOG(FATAL) << "GPUError: " << s->ErrorDescription();
}
};
}
diff --git a/src/runtime/metal/metal_module.mm
b/src/runtime/metal/metal_module.mm
index 01d1079426..16956ed611 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -194,7 +194,10 @@ class MetalWrappedFunc {
// obtain the stream
auto stream =
metal::MetalWorkspace::Global()->CastStreamOrGetDefault(t->stream[device_id],
device_id);
+
+ // skip launching so the error can be printed during sync
if (stream->HasErrorHappened()) return;
+
if (scache_[device_id] == nil) {
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
}
@@ -202,7 +205,8 @@ class MetalWrappedFunc {
int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
auto maxTotalThreadsPerThreadgroup =
scache_[device_id].maxTotalThreadsPerThreadgroup;
CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
- id<MTLCommandBuffer> cb = stream->GetCommandBuffer();
+ // attach error message directly in this functio
+ id<MTLCommandBuffer> cb = stream->GetCommandBuffer(/*
attach_error_callback= */ false);
id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
[encoder setComputePipelineState:scache_[device_id]];
for (size_t i = 0; i < num_buffer_args_; ++i) {
@@ -219,6 +223,16 @@ class MetalWrappedFunc {
MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2));
[encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock];
[encoder endEncoding];
+ // attach error message with function name
+ [cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
+ if (buffer.status == MTLCommandBufferStatusError) {
+ ICHECK(buffer.error != nil);
+ std::ostringstream os;
+ os << "GPUError happens after running " << func_name_ << ": "
+ << buffer.error.localizedDescription.UTF8String;
+ stream->SetError(os.str());
+ }
+ }];
[cb commit];
};
}