This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9dfa116c94 [Metal] Batched command dispatch and staging buffer pool
(#18877)
9dfa116c94 is described below
commit 9dfa116c942180e90092505d679b422ff3073410
Author: Miti <[email protected]>
AuthorDate: Sat Mar 7 17:25:54 2026 +0100
[Metal] Batched command dispatch and staging buffer pool (#18877)
---
src/runtime/metal/metal_common.h | 229 ++++++++++++++++++++++++++++++++--
src/runtime/metal/metal_device_api.mm | 110 ++++++++++++----
src/runtime/metal/metal_module.mm | 19 +--
3 files changed, 309 insertions(+), 49 deletions(-)
diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h
index 8d72fac97a..cc538f84dc 100644
--- a/src/runtime/metal/metal_common.h
+++ b/src/runtime/metal/metal_common.h
@@ -103,13 +103,37 @@ class AutoReleasePoolWrapper {
};
/*!
- * \brief Structure for error handling in queues
+ * \brief Metal command stream with batched dispatch support.
+ *
+ * Compute dispatches are batched into a single command buffer via
+ * GetPendingComputeEncoder(). Blit operations (copies) are interleaved
+ * on the same command buffer via GetBlitEncoderOnPendingBuffer().
+ * The command buffer is committed when FlushCommandBuffer() is called.
+ *
+ * Must call FlushCommandBuffer() before:
+ * - GPU→CPU readback (need data in CPU memory)
+ * - Buffer deallocation (FreeDataSpace, setPurgeableState:Empty on
+ * a buffer referenced by an uncommitted CB crashes Metal)
+ * - Stream sync (StreamSync / Synchronize)
*/
class Stream {
public:
explicit Stream(id<MTLDevice> device) { queue_ = [device newCommandQueue]; }
- ~Stream() { [queue_ release]; }
- id<MTLCommandBuffer> GetCommandBuffer(std::string label = "", bool
attach_error_callback = true) {
+ // Stream is only destroyed during MetalWorkspace teardown (process exit
+ // or ReinitializeDefaultStreams), so no GPU work is in flight. We flush
+ // to commit any pending CB but do not wait for completion.
+ ~Stream() {
+ FlushCommandBuffer();
+ [queue_ release];
+ }
+
+ /*!
+ * \brief Get a standalone command buffer (for GPU→CPU readback only).
+ *
+ * Used when we need a separate command buffer that we can commit
+ * and waitUntilCompleted on independently.
+ */
+ id<MTLCommandBuffer> GetCommandBuffer(std::string label = "") {
id<MTLCommandBuffer> cb = [queue_ commandBuffer];
if (!label.empty()) {
cb.label = [NSString stringWithUTF8String:label.c_str()];
@@ -123,6 +147,99 @@ class Stream {
return cb;
}
+ /*!
+ * \brief Get the pending compute command encoder, creating one if needed.
+ *
+ * Multiple compute dispatches are batched into a single command buffer
+ * and encoder. Blit operations (copies) can be interleaved on the same
+ * command buffer via GetBlitEncoderOnPendingBuffer(). The entire command
+ * buffer is committed when FlushCommandBuffer() is called.
+ *
+ * Must flush before:
+ * - GPU→CPU readback (need data on CPU immediately)
+ * - Buffer deallocation (FreeDataSpace)
+ * - Stream sync (StreamSync)
+ */
+ id<MTLComputeCommandEncoder> GetPendingComputeEncoder(const std::string&
kernel_name = "") {
+ if (pending_compute_encoder_ == nil) {
+ id<MTLCommandBuffer> cb = GetOrCreatePendingCommandBuffer();
+ pending_compute_encoder_ = [[cb computeCommandEncoder] retain];
+ }
+ if (!kernel_name.empty()) {
+ last_dispatched_kernel_ = kernel_name;
+ }
+ profile.dispatches++;
+ return pending_compute_encoder_;
+ }
+
+ /*!
+ * \brief Get a blit encoder on the pending command buffer.
+ *
+ * Pauses the active compute encoder (if any), creates a blit encoder
+ * on the same command buffer. Caller must call [encoder endEncoding]
+ * when done. The next GetPendingComputeEncoder() call will create a
+ * new compute encoder on the same command buffer.
+ *
+ * Metal guarantees sequential ordering of encoders within a command
+ * buffer, so blits encoded here execute after prior compute dispatches
+ * and before subsequent ones.
+ */
+ id<MTLBlitCommandEncoder> GetBlitEncoderOnPendingBuffer() {
+ EndPendingComputeEncoder();
+ id<MTLCommandBuffer> cb = GetOrCreatePendingCommandBuffer();
+ profile.blits++;
+ return [cb blitCommandEncoder];
+ }
+
+ /*!
+ * \brief Flush: end active encoder, commit the command buffer.
+ *
+ * Safe to call when nothing is pending (no-op).
+ */
+ void FlushCommandBuffer() {
+ EndPendingComputeEncoder();
+ if (pending_command_buffer_ != nil) {
+ [pending_command_buffer_ commit];
+ [pending_command_buffer_ release];
+ pending_command_buffer_ = nil;
+ profile.flushes++;
+ }
+ }
+
+ /*!
+ * \brief Flush pending work, then wait for all submitted work to complete.
+ */
+ void Synchronize() {
+ FlushCommandBuffer();
+ id<MTLCommandBuffer> cb = [queue_ commandBuffer];
+ [cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
+ if (buffer.status == MTLCommandBufferStatusError) {
+ TVM_FFI_ICHECK(buffer.error != nil);
+ this->SetError(buffer.error.localizedDescription.UTF8String);
+ }
+ }];
+ [cb commit];
+ [cb waitUntilCompleted];
+ profile.syncs++;
+ }
+
+ bool HasPendingWork() const { return pending_command_buffer_ != nil; }
+
+ /*! \brief Profiling counters for diagnosing dispatch/copy/sync overhead. */
+ struct ProfileCounters {
+ size_t dispatches = 0;
+ size_t flushes = 0;
+ size_t syncs = 0;
+ size_t blits = 0;
+ size_t gpu_to_cpu = 0;
+ size_t cpu_to_gpu = 0;
+ size_t gpu_to_gpu = 0;
+ size_t free_syncs = 0; // FreeDataSpace calls that triggered a sync
+
+ void Reset() { *this = ProfileCounters(); }
+ };
+ ProfileCounters profile;
+
void SetError(std::string error_description) {
error_happened_ = true;
error_description_ = std::move(error_description);
@@ -133,8 +250,42 @@ class Stream {
const std::string& ErrorDescription() const { return error_description_; }
private:
+ /*! \brief Get or create the pending command buffer (shared by compute and
blit). */
+ id<MTLCommandBuffer> GetOrCreatePendingCommandBuffer() {
+ if (pending_command_buffer_ == nil) {
+ pending_command_buffer_ = [[queue_ commandBuffer] retain];
+ pending_command_buffer_.label = @"TVMBatched";
+ [pending_command_buffer_ addCompletedHandler:^(id<MTLCommandBuffer>
buffer) {
+ if (buffer.status == MTLCommandBufferStatusError) {
+ TVM_FFI_ICHECK(buffer.error != nil);
+ std::string msg = buffer.error.localizedDescription.UTF8String;
+ if (!this->last_dispatched_kernel_.empty()) {
+ msg = "GPUError after kernel " + this->last_dispatched_kernel_ +
": " + msg;
+ }
+ this->SetError(msg);
+ }
+ }];
+ }
+ return pending_command_buffer_;
+ }
+
+ /*! \brief End the active compute encoder without committing the command
buffer. */
+ void EndPendingComputeEncoder() {
+ if (pending_compute_encoder_ != nil) {
+ [pending_compute_encoder_ endEncoding];
+ [pending_compute_encoder_ release];
+ pending_compute_encoder_ = nil;
+ }
+ }
+
// Queue
id<MTLCommandQueue> queue_;
+ // Pending command buffer (shared by compute and blit encoders)
+ id<MTLCommandBuffer> pending_command_buffer_ = nil;
+ // Active compute encoder on the pending command buffer (nil when
paused/blit)
+ id<MTLComputeCommandEncoder> pending_compute_encoder_ = nil;
+ // Last dispatched kernel name (for error diagnostics)
+ std::string last_dispatched_kernel_;
// Check if error happened in one previous run
bool error_happened_{false};
// error description
@@ -201,8 +352,67 @@ class MetalThreadEntry {
Device device;
/*! \brief The current stream */
std::vector<TVMStreamHandle> stream;
- /*! \brief The shared buffer used for copy. */
+ /*! \brief The shared buffer used for GPU→CPU readback. */
std::vector<id<MTLBuffer>> temp_buffer_;
+ /*!
+ * \brief Pool of staging buffers for CPU→GPU copies that are inlined
+ * into the pending command buffer. Each inlined copy needs its own
+ * staging buffer because the GPU reads them asynchronously.
+ * Buffers are recycled after FlushCommandBuffer()/Synchronize().
+ */
+ struct StagingBufferPool {
+ public:
+ /*! \brief Maximum staging buffers before requiring a flush.
+ * Prevents unbounded pool growth in workloads with many CPU→GPU copies
+ * between syncs. When this limit is reached, the caller must flush the
+ * stream (to make all pending staging buffers safe to reuse) before
+ * requesting more buffers. */
+ static constexpr size_t kMaxStagingBuffers = 64;
+
+ id<MTLBuffer> GetOrCreate(id<MTLDevice> dev, size_t nbytes) {
+ if (next_index_ < pool_.size() && pool_[next_index_].size >= nbytes) {
+ return pool_[next_index_++].buffer;
+ }
+ // Need a new or bigger buffer at this index
+ if (next_index_ < pool_.size() && pool_[next_index_].buffer != nil) {
+ [pool_[next_index_].buffer release];
+ }
+ if (next_index_ >= pool_.size()) {
+ pool_.push_back({nil, 0});
+ }
+ pool_[next_index_].buffer = [dev newBufferWithLength:nbytes
options:MTLStorageModeShared];
+ TVM_FFI_ICHECK(pool_[next_index_].buffer != nil)
+ << "Failed to allocate staging buffer of size " << nbytes;
+ pool_[next_index_].size = nbytes;
+ return pool_[next_index_++].buffer;
+ }
+
+ // Called after flush/sync, all staging buffers are safe to reuse
+ void ResetIndex() { next_index_ = 0; }
+
+ // Number of staging buffers used in the current batch
+ size_t Size() const { return next_index_; }
+
+ // True when the pool has reached its limit and needs a flush before more
allocations
+ bool NeedsFlush() const { return next_index_ >= kMaxStagingBuffers; }
+
+ ~StagingBufferPool() {
+ for (auto& e : pool_) {
+ if (e.buffer != nil) {
+ [e.buffer release];
+ }
+ }
+ }
+
+ private:
+ struct Entry {
+ id<MTLBuffer> buffer = nil;
+ size_t size = 0;
+ };
+ std::vector<Entry> pool_;
+ size_t next_index_ = 0; // sequential within current batch, reset on sync
+ };
+ std::vector<StagingBufferPool> staging_pools_; // per device
/*! \brief workspace pool */
WorkspacePool pool;
// constructor
@@ -210,13 +420,18 @@ class MetalThreadEntry {
device.device_id = 0;
device.device_type = static_cast<DLDeviceType>(kDLMetal);
MetalWorkspace* global_ws = MetalWorkspace::Global();
- // by default, set the stream to nullptr, which indicate
- // that we are using default stream
this->stream.resize(global_ws->devices.size(), nullptr);
+ this->staging_pools_.resize(global_ws->devices.size());
}
~MetalThreadEntry();
- // Get temp buffer with at least size under dev.
+ // Get temp buffer with at least size under dev (for GPU→CPU readback).
id<MTLBuffer> GetTempBuffer(Device dev, size_t size);
+ // Get a staging buffer for inlined CPU→GPU copy (from pool).
+ id<MTLBuffer> GetOrCreateStagingBuffer(Device dev, size_t size);
+ // Check if the staging pool has reached its limit and needs a flush.
+ bool StagingPoolNeedsFlush(Device dev);
+ // Reset the staging pool index after a flush.
+ void ResetStagingPool(Device dev);
// get the global workspace
static MetalThreadEntry* ThreadLocal();
};
diff --git a/src/runtime/metal/metal_device_api.mm
b/src/runtime/metal/metal_device_api.mm
index 5ff9c2dfcd..f240f589c1 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -200,14 +200,18 @@ void* MetalWorkspace::AllocDataSpace(Device device,
size_t nbytes, size_t alignm
void MetalWorkspace::FreeDataSpace(Device dev, void* ptr) {
AUTORELEASEPOOL {
- // need to make sure buffer is not in use in command buffer
- // before set the purgeable state to empty
- // otherwise can cause issues sometimes
- this->StreamSync(dev, nullptr);
- // MTLBuffer PurgeableState should be set to empty before manual
- // release in order to prevent memory leak
+ Stream* s = CastStreamOrGetDefault(nullptr, dev.device_id);
+ if (s->HasPendingWork()) {
+ s->profile.free_syncs++;
+ // Buffer may be referenced by pending compute/blit encoders.
+ // Must fully sync, setPurgeableState:Empty on a buffer in an
+ // uncommitted or incomplete CB crashes Metal.
+ this->StreamSync(dev, nullptr);
+ }
+ // No pending work, safe to release immediately.
+ // Either nothing was dispatched since last sync, or the GPU→CPU
+ // readback path already flushed+waited.
[(id<MTLBuffer>)ptr setPurgeableState:MTLPurgeableStateEmpty];
- // release the ptr.
CFRelease(ptr);
};
}
@@ -229,25 +233,30 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
size_t from_offset, void*
if (s->HasErrorHappened()) {
LOG(FATAL) << "GPUError: " << s->ErrorDescription();
}
- id<MTLCommandBuffer> cb =
s->GetCommandBuffer(/*label=*/"TVMCopyDataFromTo");
int from_dev_type = static_cast<int>(dev_from.device_type);
int to_dev_type = static_cast<int>(dev_to.device_type);
if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
+ s->profile.gpu_to_gpu++;
+ // GPU→GPU: inline blit into the pending command buffer.
+ // No flush needed, Metal guarantees encoder ordering within a CB.
TVM_FFI_ICHECK_EQ(dev_from.device_id, dev_to.device_id)
<< "Metal disallow cross device copy.";
- id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
+ id<MTLBlitCommandEncoder> encoder = s->GetBlitEncoderOnPendingBuffer();
[encoder copyFromBuffer:(id<MTLBuffer>)(from)
sourceOffset:from_offset
toBuffer:(id<MTLBuffer>)(to)destinationOffset:to_offset
size:size];
[encoder endEncoding];
- [cb commit];
+
} else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
- // copy to a local buffer before get into global buffer.
+ s->profile.gpu_to_cpu++;
+ // GPU→CPU: must flush and wait, we need data in CPU memory.
+ s->FlushCommandBuffer();
id<MTLBuffer> from_buf = (id<MTLBuffer>)(from);
if (from_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp =
MetalThreadEntry::ThreadLocal()->GetTempBuffer(dev_from, size);
+ id<MTLCommandBuffer> cb = s->GetCommandBuffer("TVMCopyGPUtoCPU");
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:from_buf
sourceOffset:from_offset
@@ -262,24 +271,36 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
size_t from_offset, void*
memcpy(static_cast<char*>(to) + to_offset,
static_cast<char*>([from_buf contents]) + from_offset, size);
}
+
} else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
+ s->profile.cpu_to_gpu++;
+ // CPU→GPU: inline blit into the pending command buffer.
+ // We use a staging buffer from the pool (not the single temp_buffer_)
+ // so multiple CPU→GPU copies can be inlined before a flush.
id<MTLBuffer> to_buf = (id<MTLBuffer>)(to);
if (to_buf.storageMode != MTLStorageModeShared) {
- id<MTLBuffer> temp =
MetalThreadEntry::ThreadLocal()->GetTempBuffer(dev_to, size);
- memcpy([temp contents], static_cast<const char*>(from) + from_offset,
size);
- id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
- [encoder copyFromBuffer:temp
+ MetalThreadEntry* t = MetalThreadEntry::ThreadLocal();
+ // If the staging pool is full, flush pending work so buffers can be
reused.
+ if (t->StagingPoolNeedsFlush(dev_to)) {
+ s->FlushCommandBuffer();
+ t->ResetStagingPool(dev_to);
+ }
+ id<MTLBuffer> staging = t->GetOrCreateStagingBuffer(dev_to, size);
+ memcpy([staging contents], static_cast<const char*>(from) +
from_offset, size);
+ id<MTLBlitCommandEncoder> encoder = s->GetBlitEncoderOnPendingBuffer();
+ [encoder copyFromBuffer:staging
sourceOffset:0
toBuffer:to_buf
destinationOffset:to_offset
size:size];
[encoder endEncoding];
- [cb commit];
- [cb waitUntilCompleted];
+ // No flush, no wait. Metal executes encoders in order within the CB.
+ // The staging buffer stays alive until flush, when the pool resets.
} else {
memcpy(static_cast<char*>([to_buf contents]) + to_offset,
static_cast<const char*>(from) + from_offset, size);
}
+
} else {
LOG(FATAL) << "Expect copy from/to Metal or between Metal"
<< ", from=" << from_dev_type << ", to=" << to_dev_type;
@@ -302,10 +323,9 @@ void MetalWorkspace::FreeStream(Device dev,
TVMStreamHandle stream) {
void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
AUTORELEASEPOOL {
Stream* s = CastStreamOrGetDefault(stream, dev.device_id);
- // commit an empty command buffer and wait until it completes.
- id<MTLCommandBuffer> cb = s->GetCommandBuffer(/*label=*/"TVMStreamSync");
- [cb commit];
- [cb waitUntilCompleted];
+ s->Synchronize();
+ // After sync, all staging buffers are safe to reuse.
+
MetalThreadEntry::ThreadLocal()->staging_pools_[dev.device_id].ResetIndex();
if (s->HasErrorHappened()) {
LOG(FATAL) << "GPUError: " << s->ErrorDescription();
}
@@ -336,10 +356,17 @@ id<MTLBuffer> MetalThreadEntry::GetTempBuffer(Device dev,
size_t size) {
if (temp_buffer_[dev.device_id] == nil || temp_buffer_[dev.device_id].length
< size) {
id<MTLDevice> mtl_dev = MetalWorkspace::Global()->GetDevice(dev);
if (temp_buffer_[dev.device_id] != nil) {
- // need to make sure buffer is not in use in command buffer
- // before set the purgeable state to empty
- // otherwise can cause issues sometimes
- MetalWorkspace::Global()->StreamSync(dev, nullptr);
+ // The caller (GPU→CPU path in CopyDataFromTo) already called
+ // FlushCommandBuffer() before calling us, so all pending work
+ // using this buffer has been committed. We just need to wait
+ // for completion before releasing.
+ auto* ws = MetalWorkspace::Global();
+ Stream* s = ws->CastStreamOrGetDefault(nullptr, dev.device_id);
+ if (s->HasPendingWork()) {
+ // Only sync if there's actually pending work (shouldn't happen
+ // since caller flushed, but be safe).
+ ws->StreamSync(dev, nullptr);
+ }
[temp_buffer_[dev.device_id] setPurgeableState:MTLPurgeableStateEmpty];
[temp_buffer_[dev.device_id] release];
}
@@ -348,6 +375,17 @@ id<MTLBuffer> MetalThreadEntry::GetTempBuffer(Device dev,
size_t size) {
return temp_buffer_[dev.device_id];
}
+id<MTLBuffer> MetalThreadEntry::GetOrCreateStagingBuffer(Device dev, size_t
size) {
+ id<MTLDevice> mtl_dev = MetalWorkspace::Global()->GetDevice(dev);
+ return staging_pools_[dev.device_id].GetOrCreate(mtl_dev, size);
+}
+
+bool MetalThreadEntry::StagingPoolNeedsFlush(Device dev) {
+ return staging_pools_[dev.device_id].NeedsFlush();
+}
+
+void MetalThreadEntry::ResetStagingPool(Device dev) {
staging_pools_[dev.device_id].ResetIndex(); }
+
MetalThreadEntry* MetalThreadEntry::ThreadLocal() {
static thread_local MetalThreadEntry inst;
return &inst;
@@ -362,7 +400,27 @@ TVM_FFI_STATIC_INIT_BLOCK() {
*rv = static_cast<void*>(ptr);
})
.def("metal.ResetGlobalState",
- []() { MetalWorkspace::Global()->ReinitializeDefaultStreams(); });
+ []() { MetalWorkspace::Global()->ReinitializeDefaultStreams(); })
+ .def("metal.GetProfileCounters",
+ [](int device_id) {
+ auto* ws = MetalWorkspace::Global();
+ Stream* s = ws->CastStreamOrGetDefault(nullptr, device_id);
+ const auto& p = s->profile;
+ ffi::Map<ffi::String, int64_t> result;
+ result.Set("dispatches", static_cast<int64_t>(p.dispatches));
+ result.Set("flushes", static_cast<int64_t>(p.flushes));
+ result.Set("syncs", static_cast<int64_t>(p.syncs));
+ result.Set("blits", static_cast<int64_t>(p.blits));
+ result.Set("gpu_to_cpu", static_cast<int64_t>(p.gpu_to_cpu));
+ result.Set("cpu_to_gpu", static_cast<int64_t>(p.cpu_to_gpu));
+ result.Set("gpu_to_gpu", static_cast<int64_t>(p.gpu_to_gpu));
+ result.Set("free_syncs", static_cast<int64_t>(p.free_syncs));
+ return result;
+ })
+ .def("metal.ResetProfileCounters", [](int device_id) {
+ auto* ws = MetalWorkspace::Global();
+ ws->CastStreamOrGetDefault(nullptr, device_id)->profile.Reset();
+ });
}
class MetalTimerNode : public TimerNode {
diff --git a/src/runtime/metal/metal_module.mm
b/src/runtime/metal/metal_module.mm
index 0066b651fc..6837404ad3 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -213,10 +213,9 @@ class MetalWrappedFunc {
int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
auto maxTotalThreadsPerThreadgroup =
scache_[device_id].maxTotalThreadsPerThreadgroup;
TVM_FFI_ICHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
- // attach error message directly in this functio
- id<MTLCommandBuffer> cb =
stream->GetCommandBuffer(/*label=*/"TVMKernel:" + func_name_,
-
/*attach_error_callback=*/false);
- id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
+ // Reuse the pending compute encoder to batch dispatches.
+ // The encoder is flushed on sync, copy, or buffer deallocation.
+ id<MTLComputeCommandEncoder> encoder =
stream->GetPendingComputeEncoder(func_name_);
[encoder setComputePipelineState:scache_[device_id]];
for (size_t i = 0; i < num_buffer_args_; ++i) {
void* buf = args[static_cast<int>(i)].cast<void*>();
@@ -231,18 +230,6 @@ class MetalWrappedFunc {
MTLSize dimGrid = MTLSizeMake(wl.grid_dim(0), wl.grid_dim(1),
wl.grid_dim(2));
MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2));
[encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock];
- [encoder endEncoding];
- // attach error message with function name
- [cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
- if (buffer.status == MTLCommandBufferStatusError) {
- TVM_FFI_ICHECK(buffer.error != nil);
- std::ostringstream os;
- os << "GPUError happens after running " << func_name_ << ": "
- << buffer.error.localizedDescription.UTF8String;
- stream->SetError(os.str());
- }
- }];
- [cb commit];
};
}