This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2cbbcd5e2c [Refactor][Metal] Update ICHECK to TVM_FFI_ICHECK in Metal
runtime (#18811)
2cbbcd5e2c is described below
commit 2cbbcd5e2ce774d338744fa84a5719abc6125bc2
Author: Bryan <[email protected]>
AuthorDate: Mon Feb 23 07:29:37 2026 -0500
[Refactor][Metal] Update ICHECK to TVM_FFI_ICHECK in Metal runtime (#18811)
This commit updates all ICHECK macros to TVM_FFI_ICHECK in the Metal
runtime implementation to align with the new FFI refactoring.
The changes include updating ICHECK, ICHECK_LT, ICHECK_EQ macros to
their TVM_FFI_ICHECK equivalents across the following files:
- src/runtime/metal/metal_device_api.mm
- src/runtime/metal/metal_module.mm
This refactoring ensures consistency with the new TVM FFI interface and
maintains the same error checking behavior while using the updated macro
names.
---
src/runtime/metal/metal_device_api.mm | 23 ++++++++++++-----------
src/runtime/metal/metal_module.mm | 20 ++++++++++----------
2 files changed, 22 insertions(+), 21 deletions(-)
diff --git a/src/runtime/metal/metal_device_api.mm
b/src/runtime/metal/metal_device_api.mm
index c0218a5bf2..5ff9c2dfcd 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -48,7 +48,7 @@ void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind kind,
ffi::Any* rv) {
*rv = int(index < devices.size());
return;
}
- ICHECK_LT(index, devices.size()) << "Invalid device id " << index;
+ TVM_FFI_ICHECK_LT(index, devices.size()) << "Invalid device id " << index;
switch (kind) {
case kMaxThreadsPerBlock: {
*rv = static_cast<int>([devices[dev.device_id]
maxThreadsPerThreadgroup].width);
@@ -125,11 +125,11 @@ int GetWarpSize(id<MTLDevice> dev) {
id<MTLLibrary> lib = [dev newLibraryWithSource:[NSString
stringWithUTF8String:kDummyKernel]
options:nil
error:&error_msg];
- ICHECK(lib != nil) << [[error_msg localizedDescription] UTF8String];
+ TVM_FFI_ICHECK(lib != nil) << [[error_msg localizedDescription] UTF8String];
id<MTLFunction> f = [lib newFunctionWithName:[NSString
stringWithUTF8String:"CopyKernel"]];
- ICHECK(f != nil);
+ TVM_FFI_ICHECK(f != nil);
id<MTLComputePipelineState> state = [dev
newComputePipelineStateWithFunction:f error:&error_msg];
- ICHECK(state != nil) << [[error_msg localizedDescription] UTF8String];
+ TVM_FFI_ICHECK(state != nil) << [[error_msg localizedDescription]
UTF8String];
int size = static_cast<int>(state.threadExecutionWidth);
[state release];
[f release];
@@ -193,7 +193,7 @@ void* MetalWorkspace::AllocDataSpace(Device device, size_t
nbytes, size_t alignm
#endif
*/
buf = [dev newBufferWithLength:nbytes options:storage_mode];
- ICHECK(buf != nil);
+ TVM_FFI_ICHECK(buf != nil);
};
return (void*)(buf);
}
@@ -214,8 +214,8 @@ void MetalWorkspace::FreeDataSpace(Device dev, void* ptr) {
Stream* MetalWorkspace::CastStreamOrGetDefault(TVMStreamHandle stream, int
device_id) {
if (stream != nullptr) return static_cast<Stream*>(stream);
- ICHECK_LT(static_cast<size_t>(device_id), default_streams_.size());
- ICHECK(default_streams_[device_id] != nullptr);
+ TVM_FFI_ICHECK_LT(static_cast<size_t>(device_id), default_streams_.size());
+ TVM_FFI_ICHECK(default_streams_[device_id] != nullptr);
return default_streams_[device_id];
}
@@ -234,7 +234,8 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
size_t from_offset, void*
int to_dev_type = static_cast<int>(dev_to.device_type);
if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
- ICHECK_EQ(dev_from.device_id, dev_to.device_id) << "Metal disallow cross
device copy.";
+ TVM_FFI_ICHECK_EQ(dev_from.device_id, dev_to.device_id)
+ << "Metal disallow cross device copy.";
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:(id<MTLBuffer>)(from)
sourceOffset:from_offset
@@ -287,14 +288,14 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
size_t from_offset, void*
}
TVMStreamHandle MetalWorkspace::CreateStream(Device dev) {
- ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " <<
dev.device_id;
+ TVM_FFI_ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " <<
dev.device_id;
Stream* stream = new Stream(devices[dev.device_id]);
return static_cast<TVMStreamHandle>(stream);
}
void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) {
- ICHECK(stream != nullptr);
- ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " <<
dev.device_id;
+ TVM_FFI_ICHECK(stream != nullptr);
+ TVM_FFI_ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " <<
dev.device_id;
delete static_cast<Stream*>(stream);
}
diff --git a/src/runtime/metal/metal_module.mm
b/src/runtime/metal/metal_module.mm
index cf1a1641be..deb863c69b 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -86,7 +86,7 @@ class MetalModuleNode final : public ffi::ModuleObj {
// get a from primary context in device_id
id<MTLComputePipelineState> GetPipelineState(size_t device_id, const
std::string& func_name) {
metal::MetalWorkspace* w = metal::MetalWorkspace::Global();
- ICHECK_LT(device_id, w->devices.size());
+ TVM_FFI_ICHECK_LT(device_id, w->devices.size());
// start lock scope.
std::lock_guard<std::mutex> lock(mutex_);
if (finfo_.size() <= device_id) {
@@ -100,7 +100,7 @@ class MetalModuleNode final : public ffi::ModuleObj {
id<MTLLibrary> lib = nil;
auto kernel = smap_.find(func_name);
// Directly lookup kernels
- ICHECK(kernel != smap_.end());
+ TVM_FFI_ICHECK(kernel != smap_.end());
const std::string& source = kernel->second;
if (fmt_ == "metal") {
@@ -132,18 +132,18 @@ class MetalModuleNode final : public ffi::ModuleObj {
}
}
id<MTLFunction> f = [lib newFunctionWithName:[NSString
stringWithUTF8String:func_name.c_str()]];
- ICHECK(f != nil) << "cannot find function " << func_name;
+ TVM_FFI_ICHECK(f != nil) << "cannot find function " << func_name;
id<MTLComputePipelineState> state =
[w->devices[device_id] newComputePipelineStateWithFunction:f
error:&err_msg];
- ICHECK(state != nil) << "cannot get state:"
- << " for function " << func_name
- << [[err_msg localizedDescription] UTF8String];
+ TVM_FFI_ICHECK(state != nil) << "cannot get state:"
+ << " for function " << func_name
+ << [[err_msg localizedDescription]
UTF8String];
[f release];
[lib release];
// The state.threadExecutionWidth can change dynamically according
// to the resource constraint in kernel, so it is not strictly hold
// Turn of warp aware optimziation for now.
- // ICHECK_EQ(state.threadExecutionWidth, w->warp_size[device_id]);
+ // TVM_FFI_ICHECK_EQ(state.threadExecutionWidth, w->warp_size[device_id]);
if (e.smap[func_name] != nil) [e.smap[func_name] release];
e.smap[func_name] = state;
return state;
@@ -235,7 +235,7 @@ class MetalWrappedFunc {
// attach error message with function name
[cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
if (buffer.status == MTLCommandBufferStatusError) {
- ICHECK(buffer.error != nil);
+ TVM_FFI_ICHECK(buffer.error != nil);
std::ostringstream os;
os << "GPUError happens after running " << func_name_ << ": "
<< buffer.error.localizedDescription.UTF8String;
@@ -270,7 +270,7 @@ ffi::Optional<ffi::Function>
MetalModuleNode::GetFunction(const ffi::String& nam
ffi::Function ret;
AUTORELEASEPOOL {
ObjectPtr<Object> sptr_to_self = ffi::GetObjectPtr<Object>(this);
- ICHECK_EQ(sptr_to_self.get(), this);
+ TVM_FFI_ICHECK_EQ(sptr_to_self.get(), this);
auto opt_info = fmap_.Get(name);
if (!opt_info.has_value()) {
return;
@@ -325,7 +325,7 @@ ffi::Module MetalModuleLoadFromBytes(const ffi::Bytes&
bytes) {
stream.Read(&ver);
stream.Read(&smap);
- ICHECK(stream.Read(&fmap));
+ TVM_FFI_ICHECK(stream.Read(&fmap));
stream.Read(&fmt);
return MetalModuleCreate(smap, fmap, fmt, "");