This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9dfa116c94 [Metal] Batched command dispatch and staging buffer pool 
(#18877)
9dfa116c94 is described below

commit 9dfa116c942180e90092505d679b422ff3073410
Author: Miti <[email protected]>
AuthorDate: Sat Mar 7 17:25:54 2026 +0100

    [Metal] Batched command dispatch and staging buffer pool (#18877)
---
 src/runtime/metal/metal_common.h      | 229 ++++++++++++++++++++++++++++++++--
 src/runtime/metal/metal_device_api.mm | 110 ++++++++++++----
 src/runtime/metal/metal_module.mm     |  19 +--
 3 files changed, 309 insertions(+), 49 deletions(-)

diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h
index 8d72fac97a..cc538f84dc 100644
--- a/src/runtime/metal/metal_common.h
+++ b/src/runtime/metal/metal_common.h
@@ -103,13 +103,37 @@ class AutoReleasePoolWrapper {
 };
 
 /*!
- * \brief Structure for error handling in queues
+ * \brief Metal command stream with batched dispatch support.
+ *
+ * Compute dispatches are batched into a single command buffer via
+ * GetPendingComputeEncoder(). Blit operations (copies) are interleaved
+ * on the same command buffer via GetBlitEncoderOnPendingBuffer().
+ * The command buffer is committed when FlushCommandBuffer() is called.
+ *
+ * Must call FlushCommandBuffer() before:
+ * - GPU→CPU readback (need data in CPU memory)
+ * - Buffer deallocation (FreeDataSpace, setPurgeableState:Empty on
+ *   a buffer referenced by an uncommitted CB crashes Metal)
+ * - Stream sync (StreamSync / Synchronize)
  */
 class Stream {
  public:
   explicit Stream(id<MTLDevice> device) { queue_ = [device newCommandQueue]; }
-  ~Stream() { [queue_ release]; }
-  id<MTLCommandBuffer> GetCommandBuffer(std::string label = "", bool 
attach_error_callback = true) {
+  // Stream is only destroyed during MetalWorkspace teardown (process exit
+  // or ReinitializeDefaultStreams), so no GPU work is in flight. We flush
+  // to commit any pending CB but do not wait for completion.
+  ~Stream() {
+    FlushCommandBuffer();
+    [queue_ release];
+  }
+
+  /*!
+   * \brief Get a standalone command buffer (for GPU→CPU readback only).
+   *
+   * Used when we need a separate command buffer that we can commit
+   * and waitUntilCompleted on independently.
+   */
+  id<MTLCommandBuffer> GetCommandBuffer(std::string label = "") {
     id<MTLCommandBuffer> cb = [queue_ commandBuffer];
     if (!label.empty()) {
       cb.label = [NSString stringWithUTF8String:label.c_str()];
@@ -123,6 +147,99 @@ class Stream {
     return cb;
   }
 
+  /*!
+   * \brief Get the pending compute command encoder, creating one if needed.
+   *
+   * Multiple compute dispatches are batched into a single command buffer
+   * and encoder. Blit operations (copies) can be interleaved on the same
+   * command buffer via GetBlitEncoderOnPendingBuffer(). The entire command
+   * buffer is committed when FlushCommandBuffer() is called.
+   *
+   * Must flush before:
+   * - GPU→CPU readback (need data on CPU immediately)
+   * - Buffer deallocation (FreeDataSpace)
+   * - Stream sync (StreamSync)
+   */
+  id<MTLComputeCommandEncoder> GetPendingComputeEncoder(const std::string& 
kernel_name = "") {
+    if (pending_compute_encoder_ == nil) {
+      id<MTLCommandBuffer> cb = GetOrCreatePendingCommandBuffer();
+      pending_compute_encoder_ = [[cb computeCommandEncoder] retain];
+    }
+    if (!kernel_name.empty()) {
+      last_dispatched_kernel_ = kernel_name;
+    }
+    profile.dispatches++;
+    return pending_compute_encoder_;
+  }
+
+  /*!
+   * \brief Get a blit encoder on the pending command buffer.
+   *
+   * Pauses the active compute encoder (if any), creates a blit encoder
+   * on the same command buffer. Caller must call [encoder endEncoding]
+   * when done. The next GetPendingComputeEncoder() call will create a
+   * new compute encoder on the same command buffer.
+   *
+   * Metal guarantees sequential ordering of encoders within a command
+   * buffer, so blits encoded here execute after prior compute dispatches
+   * and before subsequent ones.
+   */
+  id<MTLBlitCommandEncoder> GetBlitEncoderOnPendingBuffer() {
+    EndPendingComputeEncoder();
+    id<MTLCommandBuffer> cb = GetOrCreatePendingCommandBuffer();
+    profile.blits++;
+    return [cb blitCommandEncoder];
+  }
+
+  /*!
+   * \brief Flush: end active encoder, commit the command buffer.
+   *
+   * Safe to call when nothing is pending (no-op).
+   */
+  void FlushCommandBuffer() {
+    EndPendingComputeEncoder();
+    if (pending_command_buffer_ != nil) {
+      [pending_command_buffer_ commit];
+      [pending_command_buffer_ release];
+      pending_command_buffer_ = nil;
+      profile.flushes++;
+    }
+  }
+
+  /*!
+   * \brief Flush pending work, then wait for all submitted work to complete.
+   */
+  void Synchronize() {
+    FlushCommandBuffer();
+    id<MTLCommandBuffer> cb = [queue_ commandBuffer];
+    [cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
+      if (buffer.status == MTLCommandBufferStatusError) {
+        TVM_FFI_ICHECK(buffer.error != nil);
+        this->SetError(buffer.error.localizedDescription.UTF8String);
+      }
+    }];
+    [cb commit];
+    [cb waitUntilCompleted];
+    profile.syncs++;
+  }
+
+  bool HasPendingWork() const { return pending_command_buffer_ != nil; }
+
+  /*! \brief Profiling counters for diagnosing dispatch/copy/sync overhead. */
+  struct ProfileCounters {
+    size_t dispatches = 0;
+    size_t flushes = 0;
+    size_t syncs = 0;
+    size_t blits = 0;
+    size_t gpu_to_cpu = 0;
+    size_t cpu_to_gpu = 0;
+    size_t gpu_to_gpu = 0;
+    size_t free_syncs = 0;  // FreeDataSpace calls that triggered a sync
+
+    void Reset() { *this = ProfileCounters(); }
+  };
+  ProfileCounters profile;
+
   void SetError(std::string error_description) {
     error_happened_ = true;
     error_description_ = std::move(error_description);
@@ -133,8 +250,42 @@ class Stream {
   const std::string& ErrorDescription() const { return error_description_; }
 
  private:
+  /*! \brief Get or create the pending command buffer (shared by compute and 
blit). */
+  id<MTLCommandBuffer> GetOrCreatePendingCommandBuffer() {
+    if (pending_command_buffer_ == nil) {
+      pending_command_buffer_ = [[queue_ commandBuffer] retain];
+      pending_command_buffer_.label = @"TVMBatched";
+      [pending_command_buffer_ addCompletedHandler:^(id<MTLCommandBuffer> 
buffer) {
+        if (buffer.status == MTLCommandBufferStatusError) {
+          TVM_FFI_ICHECK(buffer.error != nil);
+          std::string msg = buffer.error.localizedDescription.UTF8String;
+          if (!this->last_dispatched_kernel_.empty()) {
+            msg = "GPUError after kernel " + this->last_dispatched_kernel_ + 
": " + msg;
+          }
+          this->SetError(msg);
+        }
+      }];
+    }
+    return pending_command_buffer_;
+  }
+
+  /*! \brief End the active compute encoder without committing the command 
buffer. */
+  void EndPendingComputeEncoder() {
+    if (pending_compute_encoder_ != nil) {
+      [pending_compute_encoder_ endEncoding];
+      [pending_compute_encoder_ release];
+      pending_compute_encoder_ = nil;
+    }
+  }
+
   // Queue
   id<MTLCommandQueue> queue_;
+  // Pending command buffer (shared by compute and blit encoders)
+  id<MTLCommandBuffer> pending_command_buffer_ = nil;
+  // Active compute encoder on the pending command buffer (nil when 
paused/blit)
+  id<MTLComputeCommandEncoder> pending_compute_encoder_ = nil;
+  // Last dispatched kernel name (for error diagnostics)
+  std::string last_dispatched_kernel_;
   // Check if error happened in one previous run
   bool error_happened_{false};
   // error description
@@ -201,8 +352,67 @@ class MetalThreadEntry {
   Device device;
   /*! \brief The current stream */
   std::vector<TVMStreamHandle> stream;
-  /*! \brief The shared buffer used for copy. */
+  /*! \brief The shared buffer used for GPU→CPU readback. */
   std::vector<id<MTLBuffer>> temp_buffer_;
+  /*!
+   * \brief Pool of staging buffers for CPU→GPU copies that are inlined
+   * into the pending command buffer. Each inlined copy needs its own
+   * staging buffer because the GPU reads them asynchronously.
+   * Buffers are recycled after FlushCommandBuffer()/Synchronize().
+   */
+  struct StagingBufferPool {
+   public:
+    /*! \brief Maximum staging buffers before requiring a flush.
+     * Prevents unbounded pool growth in workloads with many CPU→GPU copies
+     * between syncs. When this limit is reached, the caller must flush the
+     * stream (to make all pending staging buffers safe to reuse) before
+     * requesting more buffers. */
+    static constexpr size_t kMaxStagingBuffers = 64;
+
+    id<MTLBuffer> GetOrCreate(id<MTLDevice> dev, size_t nbytes) {
+      if (next_index_ < pool_.size() && pool_[next_index_].size >= nbytes) {
+        return pool_[next_index_++].buffer;
+      }
+      // Need a new or bigger buffer at this index
+      if (next_index_ < pool_.size() && pool_[next_index_].buffer != nil) {
+        [pool_[next_index_].buffer release];
+      }
+      if (next_index_ >= pool_.size()) {
+        pool_.push_back({nil, 0});
+      }
+      pool_[next_index_].buffer = [dev newBufferWithLength:nbytes 
options:MTLStorageModeShared];
+      TVM_FFI_ICHECK(pool_[next_index_].buffer != nil)
+          << "Failed to allocate staging buffer of size " << nbytes;
+      pool_[next_index_].size = nbytes;
+      return pool_[next_index_++].buffer;
+    }
+
+    // Called after flush/sync, all staging buffers are safe to reuse
+    void ResetIndex() { next_index_ = 0; }
+
+    // Number of staging buffers used in the current batch
+    size_t Size() const { return next_index_; }
+
+    // True when the pool has reached its limit and needs a flush before more 
allocations
+    bool NeedsFlush() const { return next_index_ >= kMaxStagingBuffers; }
+
+    ~StagingBufferPool() {
+      for (auto& e : pool_) {
+        if (e.buffer != nil) {
+          [e.buffer release];
+        }
+      }
+    }
+
+   private:
+    struct Entry {
+      id<MTLBuffer> buffer = nil;
+      size_t size = 0;
+    };
+    std::vector<Entry> pool_;
+    size_t next_index_ = 0;  // sequential within current batch, reset on sync
+  };
+  std::vector<StagingBufferPool> staging_pools_;  // per device
   /*! \brief workspace pool */
   WorkspacePool pool;
   // constructor
@@ -210,13 +420,18 @@ class MetalThreadEntry {
     device.device_id = 0;
     device.device_type = static_cast<DLDeviceType>(kDLMetal);
     MetalWorkspace* global_ws = MetalWorkspace::Global();
-    // by default, set the stream to nullptr, which indicate
-    // that we are using default stream
     this->stream.resize(global_ws->devices.size(), nullptr);
+    this->staging_pools_.resize(global_ws->devices.size());
   }
   ~MetalThreadEntry();
-  // Get temp buffer with at least size under dev.
+  // Get temp buffer with at least size under dev (for GPU→CPU readback).
   id<MTLBuffer> GetTempBuffer(Device dev, size_t size);
+  // Get a staging buffer for inlined CPU→GPU copy (from pool).
+  id<MTLBuffer> GetOrCreateStagingBuffer(Device dev, size_t size);
+  // Check if the staging pool has reached its limit and needs a flush.
+  bool StagingPoolNeedsFlush(Device dev);
+  // Reset the staging pool index after a flush.
+  void ResetStagingPool(Device dev);
   // get the global workspace
   static MetalThreadEntry* ThreadLocal();
 };
diff --git a/src/runtime/metal/metal_device_api.mm 
b/src/runtime/metal/metal_device_api.mm
index 5ff9c2dfcd..f240f589c1 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -200,14 +200,18 @@ void* MetalWorkspace::AllocDataSpace(Device device, 
size_t nbytes, size_t alignm
 
 void MetalWorkspace::FreeDataSpace(Device dev, void* ptr) {
   AUTORELEASEPOOL {
-    // need to make sure buffer is not in use in command buffer
-    // before set the purgeable state to empty
-    // otherwise can cause issues sometimes
-    this->StreamSync(dev, nullptr);
-    // MTLBuffer PurgeableState should be set to empty before manual
-    // release in order to prevent memory leak
+    Stream* s = CastStreamOrGetDefault(nullptr, dev.device_id);
+    if (s->HasPendingWork()) {
+      s->profile.free_syncs++;
+      // Buffer may be referenced by pending compute/blit encoders.
+      // Must fully sync, setPurgeableState:Empty on a buffer in an
+      // uncommitted or incomplete CB crashes Metal.
+      this->StreamSync(dev, nullptr);
+    }
+    // No pending work, safe to release immediately.
+    // Either nothing was dispatched since last sync, or the GPU→CPU
+    // readback path already flushed+waited.
     [(id<MTLBuffer>)ptr setPurgeableState:MTLPurgeableStateEmpty];
-    // release the ptr.
     CFRelease(ptr);
   };
 }
@@ -229,25 +233,30 @@ void MetalWorkspace::CopyDataFromTo(const void* from, 
size_t from_offset, void*
     if (s->HasErrorHappened()) {
       LOG(FATAL) << "GPUError: " << s->ErrorDescription();
     }
-    id<MTLCommandBuffer> cb = 
s->GetCommandBuffer(/*label=*/"TVMCopyDataFromTo");
     int from_dev_type = static_cast<int>(dev_from.device_type);
     int to_dev_type = static_cast<int>(dev_to.device_type);
 
     if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
+      s->profile.gpu_to_gpu++;
+      // GPU→GPU: inline blit into the pending command buffer.
+      // No flush needed, Metal guarantees encoder ordering within a CB.
       TVM_FFI_ICHECK_EQ(dev_from.device_id, dev_to.device_id)
           << "Metal disallow cross device copy.";
-      id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
+      id<MTLBlitCommandEncoder> encoder = s->GetBlitEncoderOnPendingBuffer();
       [encoder copyFromBuffer:(id<MTLBuffer>)(from)
                  sourceOffset:from_offset
                      toBuffer:(id<MTLBuffer>)(to)destinationOffset:to_offset
                          size:size];
       [encoder endEncoding];
-      [cb commit];
+
     } else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
-      // copy to a local buffer before get into global buffer.
+      s->profile.gpu_to_cpu++;
+      // GPU→CPU: must flush and wait, we need data in CPU memory.
+      s->FlushCommandBuffer();
       id<MTLBuffer> from_buf = (id<MTLBuffer>)(from);
       if (from_buf.storageMode != MTLStorageModeShared) {
         id<MTLBuffer> temp = 
MetalThreadEntry::ThreadLocal()->GetTempBuffer(dev_from, size);
+        id<MTLCommandBuffer> cb = s->GetCommandBuffer("TVMCopyGPUtoCPU");
         id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
         [encoder copyFromBuffer:from_buf
                    sourceOffset:from_offset
@@ -262,24 +271,36 @@ void MetalWorkspace::CopyDataFromTo(const void* from, 
size_t from_offset, void*
         memcpy(static_cast<char*>(to) + to_offset,
                static_cast<char*>([from_buf contents]) + from_offset, size);
       }
+
     } else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
+      s->profile.cpu_to_gpu++;
+      // CPU→GPU: inline blit into the pending command buffer.
+      // We use a staging buffer from the pool (not the single temp_buffer_)
+      // so multiple CPU→GPU copies can be inlined before a flush.
       id<MTLBuffer> to_buf = (id<MTLBuffer>)(to);
       if (to_buf.storageMode != MTLStorageModeShared) {
-        id<MTLBuffer> temp = 
MetalThreadEntry::ThreadLocal()->GetTempBuffer(dev_to, size);
-        memcpy([temp contents], static_cast<const char*>(from) + from_offset, 
size);
-        id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
-        [encoder copyFromBuffer:temp
+        MetalThreadEntry* t = MetalThreadEntry::ThreadLocal();
+        // If the staging pool is full, flush pending work so buffers can be 
reused.
+        if (t->StagingPoolNeedsFlush(dev_to)) {
+          s->FlushCommandBuffer();
+          t->ResetStagingPool(dev_to);
+        }
+        id<MTLBuffer> staging = t->GetOrCreateStagingBuffer(dev_to, size);
+        memcpy([staging contents], static_cast<const char*>(from) + 
from_offset, size);
+        id<MTLBlitCommandEncoder> encoder = s->GetBlitEncoderOnPendingBuffer();
+        [encoder copyFromBuffer:staging
                    sourceOffset:0
                        toBuffer:to_buf
               destinationOffset:to_offset
                            size:size];
         [encoder endEncoding];
-        [cb commit];
-        [cb waitUntilCompleted];
+        // No flush, no wait. Metal executes encoders in order within the CB.
+        // The staging buffer stays alive until flush, when the pool resets.
       } else {
         memcpy(static_cast<char*>([to_buf contents]) + to_offset,
                static_cast<const char*>(from) + from_offset, size);
       }
+
     } else {
       LOG(FATAL) << "Expect copy from/to Metal or between Metal"
                  << ", from=" << from_dev_type << ", to=" << to_dev_type;
@@ -302,10 +323,9 @@ void MetalWorkspace::FreeStream(Device dev, 
TVMStreamHandle stream) {
 void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
   AUTORELEASEPOOL {
     Stream* s = CastStreamOrGetDefault(stream, dev.device_id);
-    // commit an empty command buffer and wait until it completes.
-    id<MTLCommandBuffer> cb = s->GetCommandBuffer(/*label=*/"TVMStreamSync");
-    [cb commit];
-    [cb waitUntilCompleted];
+    s->Synchronize();
+    // After sync, all staging buffers are safe to reuse.
+    
MetalThreadEntry::ThreadLocal()->staging_pools_[dev.device_id].ResetIndex();
     if (s->HasErrorHappened()) {
       LOG(FATAL) << "GPUError: " << s->ErrorDescription();
     }
@@ -336,10 +356,17 @@ id<MTLBuffer> MetalThreadEntry::GetTempBuffer(Device dev, 
size_t size) {
   if (temp_buffer_[dev.device_id] == nil || temp_buffer_[dev.device_id].length 
< size) {
     id<MTLDevice> mtl_dev = MetalWorkspace::Global()->GetDevice(dev);
     if (temp_buffer_[dev.device_id] != nil) {
-      // need to make sure buffer is not in use in command buffer
-      // before set the purgeable state to empty
-      // otherwise can cause issues sometimes
-      MetalWorkspace::Global()->StreamSync(dev, nullptr);
+      // The caller (GPU→CPU path in CopyDataFromTo) already called
+      // FlushCommandBuffer() before calling us, so all pending work
+      // using this buffer has been committed. We just need to wait
+      // for completion before releasing.
+      auto* ws = MetalWorkspace::Global();
+      Stream* s = ws->CastStreamOrGetDefault(nullptr, dev.device_id);
+      if (s->HasPendingWork()) {
+        // Only sync if there's actually pending work (shouldn't happen
+        // since caller flushed, but be safe).
+        ws->StreamSync(dev, nullptr);
+      }
       [temp_buffer_[dev.device_id] setPurgeableState:MTLPurgeableStateEmpty];
       [temp_buffer_[dev.device_id] release];
     }
@@ -348,6 +375,17 @@ id<MTLBuffer> MetalThreadEntry::GetTempBuffer(Device dev, 
size_t size) {
   return temp_buffer_[dev.device_id];
 }
 
+id<MTLBuffer> MetalThreadEntry::GetOrCreateStagingBuffer(Device dev, size_t 
size) {
+  id<MTLDevice> mtl_dev = MetalWorkspace::Global()->GetDevice(dev);
+  return staging_pools_[dev.device_id].GetOrCreate(mtl_dev, size);
+}
+
+bool MetalThreadEntry::StagingPoolNeedsFlush(Device dev) {
+  return staging_pools_[dev.device_id].NeedsFlush();
+}
+
+void MetalThreadEntry::ResetStagingPool(Device dev) { 
staging_pools_[dev.device_id].ResetIndex(); }
+
 MetalThreadEntry* MetalThreadEntry::ThreadLocal() {
   static thread_local MetalThreadEntry inst;
   return &inst;
@@ -362,7 +400,27 @@ TVM_FFI_STATIC_INIT_BLOCK() {
                     *rv = static_cast<void*>(ptr);
                   })
       .def("metal.ResetGlobalState",
-           []() { MetalWorkspace::Global()->ReinitializeDefaultStreams(); });
+           []() { MetalWorkspace::Global()->ReinitializeDefaultStreams(); })
+      .def("metal.GetProfileCounters",
+           [](int device_id) {
+             auto* ws = MetalWorkspace::Global();
+             Stream* s = ws->CastStreamOrGetDefault(nullptr, device_id);
+             const auto& p = s->profile;
+             ffi::Map<ffi::String, int64_t> result;
+             result.Set("dispatches", static_cast<int64_t>(p.dispatches));
+             result.Set("flushes", static_cast<int64_t>(p.flushes));
+             result.Set("syncs", static_cast<int64_t>(p.syncs));
+             result.Set("blits", static_cast<int64_t>(p.blits));
+             result.Set("gpu_to_cpu", static_cast<int64_t>(p.gpu_to_cpu));
+             result.Set("cpu_to_gpu", static_cast<int64_t>(p.cpu_to_gpu));
+             result.Set("gpu_to_gpu", static_cast<int64_t>(p.gpu_to_gpu));
+             result.Set("free_syncs", static_cast<int64_t>(p.free_syncs));
+             return result;
+           })
+      .def("metal.ResetProfileCounters", [](int device_id) {
+        auto* ws = MetalWorkspace::Global();
+        ws->CastStreamOrGetDefault(nullptr, device_id)->profile.Reset();
+      });
 }
 
 class MetalTimerNode : public TimerNode {
diff --git a/src/runtime/metal/metal_module.mm 
b/src/runtime/metal/metal_module.mm
index 0066b651fc..6837404ad3 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -213,10 +213,9 @@ class MetalWrappedFunc {
       int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
       auto maxTotalThreadsPerThreadgroup = 
scache_[device_id].maxTotalThreadsPerThreadgroup;
       TVM_FFI_ICHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
-      // attach error message directly in this functio
-      id<MTLCommandBuffer> cb = 
stream->GetCommandBuffer(/*label=*/"TVMKernel:" + func_name_,
-                                                         
/*attach_error_callback=*/false);
-      id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
+      // Reuse the pending compute encoder to batch dispatches.
+      // The encoder is flushed on sync, copy, or buffer deallocation.
+      id<MTLComputeCommandEncoder> encoder = 
stream->GetPendingComputeEncoder(func_name_);
       [encoder setComputePipelineState:scache_[device_id]];
       for (size_t i = 0; i < num_buffer_args_; ++i) {
         void* buf = args[static_cast<int>(i)].cast<void*>();
@@ -231,18 +230,6 @@ class MetalWrappedFunc {
       MTLSize dimGrid = MTLSizeMake(wl.grid_dim(0), wl.grid_dim(1), 
wl.grid_dim(2));
       MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1), 
wl.block_dim(2));
       [encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock];
-      [encoder endEncoding];
-      // attach error message with function name
-      [cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
-        if (buffer.status == MTLCommandBufferStatusError) {
-          TVM_FFI_ICHECK(buffer.error != nil);
-          std::ostringstream os;
-          os << "GPUError happens after running " << func_name_ << ": "
-             << buffer.error.localizedDescription.UTF8String;
-          stream->SetError(os.str());
-        }
-      }];
-      [cb commit];
     };
   }
 

Reply via email to