This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0978ab656c [RUNTIME][METAL] Provide richer runtime when error happens 
(#16713)
0978ab656c is described below

commit 0978ab656c0b76fe69e116f3254b55084996c5ba
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Mar 14 09:06:12 2024 -0400

    [RUNTIME][METAL] Provide richer runtime when error happens (#16713)
    
    This PR enhances metal runtime to include more error messages
    when error happens.
---
 src/runtime/metal/metal_common.h      | 27 +++++++++++++++++++--------
 src/runtime/metal/metal_device_api.mm |  4 ++--
 src/runtime/metal/metal_module.mm     | 16 +++++++++++++++-
 3 files changed, 36 insertions(+), 11 deletions(-)

diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h
index dc7b344800..e5339e6366 100644
--- a/src/runtime/metal/metal_common.h
+++ b/src/runtime/metal/metal_common.h
@@ -38,6 +38,7 @@
 #include <memory>
 #include <mutex>
 #include <string>
+#include <utility>
 #include <vector>
 
 #include "../workspace_pool.h"
@@ -106,25 +107,35 @@ class AutoReleasePoolWrapper {
  */
 class Stream {
  public:
-  explicit Stream(id<MTLDevice> device) : error_happened_(false) {
-    queue_ = [device newCommandQueue];
-  }
+  explicit Stream(id<MTLDevice> device) { queue_ = [device newCommandQueue]; }
   ~Stream() { [queue_ release]; }
-  id<MTLCommandBuffer> GetCommandBuffer() {
+  id<MTLCommandBuffer> GetCommandBuffer(bool attach_error_callback = true) {
     id<MTLCommandBuffer> cb = [queue_ commandBuffer];
     [cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
-      if (buffer.status == MTLCommandBufferStatusError) SetErrorStatus();
+      if (buffer.status == MTLCommandBufferStatusError) {
+        ICHECK(buffer.error != nil);
+        this->SetError(buffer.error.localizedDescription.UTF8String);
+      }
     }];
     return cb;
   }
-  bool HasErrorHappened() { return error_happened_; }
+
+  void SetError(std::string error_description) {
+    error_happened_ = true;
+    error_description_ = std::move(error_description);
+  }
+
+  bool HasErrorHappened() const { return error_happened_; }
+
+  const std::string& ErrorDescription() const { return error_description_; }
 
  private:
-  void SetErrorStatus() { error_happened_ = true; }
   // Queue
   id<MTLCommandQueue> queue_;
   // Check if error happened in one previous run
-  bool error_happened_;
+  bool error_happened_{false};
+  // error description
+  std::string error_description_;
 };
 
 /*!
diff --git a/src/runtime/metal/metal_device_api.mm 
b/src/runtime/metal/metal_device_api.mm
index 3b01bc65b1..37fb9dc347 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -222,7 +222,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from, 
size_t from_offset, void*
     if (dev_from.device_type == kDLCPU) dev = dev_to;
     Stream* s = this->CastStreamOrGetDefault(stream, dev.device_id);
     if (s->HasErrorHappened()) {
-      LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to 
current stream";
+      LOG(FATAL) << "GPUError: " << s->ErrorDescription();
     }
     id<MTLCommandBuffer> cb = s->GetCommandBuffer();
     int from_dev_type = static_cast<int>(dev_from.device_type);
@@ -301,7 +301,7 @@ void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle 
stream) {
     [cb commit];
     [cb waitUntilCompleted];
     if (s->HasErrorHappened()) {
-      LOG(FATAL) << "Error! Some problems on GPU happaned!";
+      LOG(FATAL) << "GPUError: " << s->ErrorDescription();
     }
   };
 }
diff --git a/src/runtime/metal/metal_module.mm 
b/src/runtime/metal/metal_module.mm
index 01d1079426..16956ed611 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -194,7 +194,10 @@ class MetalWrappedFunc {
       // obtain the stream
       auto stream =
           
metal::MetalWorkspace::Global()->CastStreamOrGetDefault(t->stream[device_id], 
device_id);
+
+      // skip launching so the error can be printed during sync
       if (stream->HasErrorHappened()) return;
+
       if (scache_[device_id] == nil) {
         scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
       }
@@ -202,7 +205,8 @@ class MetalWrappedFunc {
       int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
       auto maxTotalThreadsPerThreadgroup = 
scache_[device_id].maxTotalThreadsPerThreadgroup;
       CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
-      id<MTLCommandBuffer> cb = stream->GetCommandBuffer();
+      // attach error message directly in this functio
+      id<MTLCommandBuffer> cb = stream->GetCommandBuffer(/* 
attach_error_callback= */ false);
       id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
       [encoder setComputePipelineState:scache_[device_id]];
       for (size_t i = 0; i < num_buffer_args_; ++i) {
@@ -219,6 +223,16 @@ class MetalWrappedFunc {
       MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1), 
wl.block_dim(2));
       [encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock];
       [encoder endEncoding];
+      // attach error message with function name
+      [cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
+        if (buffer.status == MTLCommandBufferStatusError) {
+          ICHECK(buffer.error != nil);
+          std::ostringstream os;
+          os << "GPUError happens after running " << func_name_ << ": "
+             << buffer.error.localizedDescription.UTF8String;
+          stream->SetError(os.str());
+        }
+      }];
       [cb commit];
     };
   }

Reply via email to