This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 37e6df1  [METAL] Fix memory leaks in Metal runtime (#7714)
37e6df1 is described below

commit 37e6df1a2654c3a06f3bdfb36fb107fa7a8265eb
Author: Egor Churaev <egor.chur...@gmail.com>
AuthorDate: Tue Mar 23 16:39:10 2021 +0300

    [METAL] Fix memory leaks in Metal runtime (#7714)
    
    * [METAL] Fix memory leaks in Metal runtime
    
    1. In case when we build runtime without ARC, we can have problems with
       memory releasing. Due to some of Objective-C methods returns
       autoreleased pointers, we should specify `autoreleasepool` blocks to
       determine life cycle of these pointers.
    2. Added workaround for problem with work group size.
       Sometimes auto scheduler generates parameters when work group size
       is more than possible. And in this case we got assert from Metal
       library. Added check for this situation and it helps to avoid
       assert.
    3. Fixed memory leak problem when fill tensor by random data.
       DLManagedTensor increases reference counter in NDArray but nobody
       delete this DLManagedTensor in proper way. This is why memory which
       was allocated by NDArray was never released.
    4. Removed unnecessary retains. It is not necessary use retain in some
       places where they were used, due to we build metal runtime without
       ARC.
    
    * Use const_cast instead of creation DLManagedTensor
---
 src/runtime/contrib/random/mt_random_engine.cc |   5 +-
 src/runtime/metal/metal_device_api.mm          | 258 +++++++++++++------------
 src/runtime/metal/metal_module.mm              |  88 +++++----
 3 files changed, 189 insertions(+), 162 deletions(-)

diff --git a/src/runtime/contrib/random/mt_random_engine.cc 
b/src/runtime/contrib/random/mt_random_engine.cc
index 699f6bb..81f46b2 100644
--- a/src/runtime/contrib/random/mt_random_engine.cc
+++ b/src/runtime/contrib/random/mt_random_engine.cc
@@ -126,8 +126,9 @@ class RandomEngine {
     } else {
       runtime::NDArray local = runtime::NDArray::Empty(
           std::vector<int64_t>{data->shape, data->shape + data->ndim}, 
data->dtype, {kDLCPU, 0});
-      FillData(&local.ToDLPack()->dl_tensor, size);
-      runtime::NDArray::CopyFromTo(&local.ToDLPack()->dl_tensor, data);
+      DLTensor* tensor = const_cast<DLTensor*>(local.operator->());
+      FillData(tensor, size);
+      runtime::NDArray::CopyFromTo(tensor, data);
     }
   }
 
diff --git a/src/runtime/metal/metal_device_api.mm 
b/src/runtime/metal/metal_device_api.mm
index 0169a4c..3d7abd1 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -30,50 +30,54 @@ namespace runtime {
 namespace metal {
 
 MetalWorkspace* MetalWorkspace::Global() {
-  // NOTE: explicitly use new to avoid exit-time destruction of global state
-  // Global state will be recycled by OS as the process exits.
-  static MetalWorkspace* inst = new MetalWorkspace();
-  return inst;
+  @autoreleasepool {
+    // NOTE: explicitly use new to avoid exit-time destruction of global state
+    // Global state will be recycled by OS as the process exits.
+    static MetalWorkspace* inst = new MetalWorkspace();
+    return inst;
+  }
 }
 
 void MetalWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* 
rv) {
-  this->Init();
-  size_t index = static_cast<size_t>(ctx.device_id);
-  if (kind == kExist) {
-    *rv = int(index < devices.size());
-    return;
-  }
-  ICHECK_LT(index, devices.size()) << "Invalid device id " << index;
-  switch (kind) {
-    case kMaxThreadsPerBlock: {
-      *rv = static_cast<int>([devices[ctx.device_id] 
maxThreadsPerThreadgroup].width);
-      break;
+  @autoreleasepool {
+    this->Init();
+    size_t index = static_cast<size_t>(ctx.device_id);
+    if (kind == kExist) {
+      *rv = int(index < devices.size());
+      return;
     }
-    case kWarpSize: {
-      // Set warp size to be 1 for safty reason.
-      *rv = 1;
-      break;
+    ICHECK_LT(index, devices.size()) << "Invalid device id " << index;
+    switch (kind) {
+      case kMaxThreadsPerBlock: {
+        *rv = static_cast<int>([devices[ctx.device_id] 
maxThreadsPerThreadgroup].width);
+        break;
+      }
+      case kWarpSize: {
+        // Set warp size to be 1 for safty reason.
+        *rv = 1;
+        break;
+      }
+      case kMaxSharedMemoryPerBlock:
+        return;
+      case kComputeVersion:
+        return;
+      case kDeviceName:
+        return;
+      case kMaxClockRate:
+        return;
+      case kMultiProcessorCount:
+        return;
+      case kMaxThreadDimensions:
+        return;
+      case kExist:
+        return;
+      case kMaxRegistersPerBlock:
+        return;
+      case kGcnArch:
+        return;
+      case kApiVersion:
+        return;
     }
-    case kMaxSharedMemoryPerBlock:
-      return;
-    case kComputeVersion:
-      return;
-    case kDeviceName:
-      return;
-    case kMaxClockRate:
-      return;
-    case kMultiProcessorCount:
-      return;
-    case kMaxThreadDimensions:
-      return;
-    case kExist:
-      return;
-    case kMaxRegistersPerBlock:
-      return;
-    case kGcnArch:
-      return;
-    case kApiVersion:
-      return;
   }
 }
 
@@ -106,7 +110,11 @@ int GetWarpSize(id<MTLDevice> dev) {
   ICHECK(f != nil);
   id<MTLComputePipelineState> state = [dev 
newComputePipelineStateWithFunction:f error:&error_msg];
   ICHECK(state != nil) << [[error_msg localizedDescription] UTF8String];
-  return static_cast<int>(state.threadExecutionWidth);
+  int size = static_cast<int>(state.threadExecutionWidth);
+  [state release];
+  [f release];
+  [lib release];
+  return size;
 }
 
 MetalWorkspace::~MetalWorkspace() {
@@ -127,14 +135,14 @@ void MetalWorkspace::Init() {
 #if TARGET_OS_IPHONE
   // on iPhone
   id<MTLDevice> d = MTLCreateSystemDefaultDevice();
-  devices.push_back([d retain]);
-  queues.push_back([[d newCommandQueue] retain]);
+  devices.push_back(d);
+  queues.push_back([d newCommandQueue]);
 #else
   NSArray<id<MTLDevice> >* devs = MTLCopyAllDevices();
   for (size_t i = 0; i < devs.count; ++i) {
     id<MTLDevice> d = [devs objectAtIndex:i];
-    devices.push_back([d retain]);
-    queues.push_back([[d newCommandQueue] retain]);
+    devices.push_back(d);
+    queues.push_back([d newCommandQueue]);
     LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name 
UTF8String];
     warp_size.push_back(GetWarpSize(d));
   }
@@ -147,102 +155,110 @@ void MetalWorkspace::SetDevice(TVMContext ctx) {
 
 void* MetalWorkspace::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t 
alignment,
                                      DLDataType type_hint) {
-  this->Init();
-  id<MTLDevice> dev = GetDevice(ctx);
-  // GPU memory only
-  MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
-  /*
-  #if TARGET_OS_IPHONE
-  storage_mode = MTLResourceStorageModeShared;
-  #else
-  storage_mode = MTLResourceStorageModeManaged;
-  #endif
-  */
-  id<MTLBuffer> buf = [dev newBufferWithLength:nbytes options:storage_mode];
-  ICHECK(buf != nil);
-  return (void*)(CFBridgingRetain(buf));
+  @autoreleasepool {
+    this->Init();
+    id<MTLDevice> dev = GetDevice(ctx);
+    // GPU memory only
+    MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
+    /*
+    #if TARGET_OS_IPHONE
+    storage_mode = MTLResourceStorageModeShared;
+    #else
+    storage_mode = MTLResourceStorageModeManaged;
+    #endif
+    */
+    id<MTLBuffer> buf = [dev newBufferWithLength:nbytes options:storage_mode];
+    ICHECK(buf != nil);
+    return (void*)(buf);
+  }
 }
 
 void MetalWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
-  // MTLBuffer PurgeableState should be set to empty before manual
-  // release in order to prevent memory leak
-  [(id<MTLBuffer>)ptr setPurgeableState:MTLPurgeableStateEmpty];
-  // release the ptr.
-  CFRelease(ptr);
+  @autoreleasepool {
+    // MTLBuffer PurgeableState should be set to empty before manual
+    // release in order to prevent memory leak
+    [(id<MTLBuffer>)ptr setPurgeableState:MTLPurgeableStateEmpty];
+    // release the ptr.
+    CFRelease(ptr);
+  }
 }
 
 void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, 
void* to,
                                     size_t to_offset, size_t size, TVMContext 
ctx_from,
                                     TVMContext ctx_to, DLDataType type_hint,
                                     TVMStreamHandle stream) {
-  this->Init();
-  ICHECK(stream == nullptr);
-  TVMContext ctx = ctx_from;
-  if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
-  id<MTLCommandQueue> queue = GetCommandQueue(ctx);
-  id<MTLCommandBuffer> cb = [queue commandBuffer];
-  int from_dev_type = static_cast<int>(ctx_from.device_type);
-  int to_dev_type = static_cast<int>(ctx_to.device_type);
+  @autoreleasepool {
+    this->Init();
+    ICHECK(stream == nullptr);
+    TVMContext ctx = ctx_from;
+    if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
+    id<MTLCommandQueue> queue = GetCommandQueue(ctx);
+    id<MTLCommandBuffer> cb = [queue commandBuffer];
+    int from_dev_type = static_cast<int>(ctx_from.device_type);
+    int to_dev_type = static_cast<int>(ctx_to.device_type);
 
-  if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
-    ICHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Metal disallow cross 
device copy.";
-    id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
-    [encoder copyFromBuffer:(__bridge id<MTLBuffer>)(from)
-               sourceOffset:from_offset
-                   toBuffer:(__bridge 
id<MTLBuffer>)(to)destinationOffset:to_offset
-                       size:size];
-    [encoder endEncoding];
-    [cb commit];
-  } else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
-    // copy to a local buffer before get into global buffer.
-    id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
-    if (from_buf.storageMode != MTLStorageModeShared) {
-      id<MTLBuffer> temp = 
MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_from, size);
+    if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
+      ICHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Metal disallow cross 
device copy.";
       id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
-      [encoder copyFromBuffer:from_buf
+      [encoder copyFromBuffer:(id<MTLBuffer>)(from)
                  sourceOffset:from_offset
-                     toBuffer:temp
-            destinationOffset:0
-                         size:size];
-      [encoder endEncoding];
-      [cb commit];
-      [cb waitUntilCompleted];
-      memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>([temp 
contents]), size);
-    } else {
-      memcpy(static_cast<char*>(to) + to_offset,
-             static_cast<char*>([from_buf contents]) + from_offset, size);
-    }
-  } else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
-    id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
-    if (to_buf.storageMode != MTLStorageModeShared) {
-      id<MTLBuffer> temp = 
MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_to, size);
-      memcpy([temp contents], static_cast<const char*>(from) + from_offset, 
size);
-      id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
-      [encoder copyFromBuffer:temp
-                 sourceOffset:0
-                     toBuffer:to_buf
-            destinationOffset:to_offset
+                     toBuffer:(id<MTLBuffer>)(to)destinationOffset:to_offset
                          size:size];
       [encoder endEncoding];
       [cb commit];
-      [cb waitUntilCompleted];
+    } else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
+      // copy to a local buffer before get into global buffer.
+      id<MTLBuffer> from_buf = (id<MTLBuffer>)(from);
+      if (from_buf.storageMode != MTLStorageModeShared) {
+        id<MTLBuffer> temp = 
MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_from, size);
+        id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
+        [encoder copyFromBuffer:from_buf
+                   sourceOffset:from_offset
+                       toBuffer:temp
+              destinationOffset:0
+                           size:size];
+        [encoder endEncoding];
+        [cb commit];
+        [cb waitUntilCompleted];
+        memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>([temp 
contents]), size);
+      } else {
+        memcpy(static_cast<char*>(to) + to_offset,
+               static_cast<char*>([from_buf contents]) + from_offset, size);
+      }
+    } else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
+      id<MTLBuffer> to_buf = (id<MTLBuffer>)(to);
+      if (to_buf.storageMode != MTLStorageModeShared) {
+        id<MTLBuffer> temp = 
MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_to, size);
+        memcpy([temp contents], static_cast<const char*>(from) + from_offset, 
size);
+        id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
+        [encoder copyFromBuffer:temp
+                   sourceOffset:0
+                       toBuffer:to_buf
+              destinationOffset:to_offset
+                           size:size];
+        [encoder endEncoding];
+        [cb commit];
+        [cb waitUntilCompleted];
+      } else {
+        memcpy(static_cast<char*>([to_buf contents]) + to_offset,
+               static_cast<const char*>(from) + from_offset, size);
+      }
     } else {
-      memcpy(static_cast<char*>([to_buf contents]) + to_offset,
-             static_cast<const char*>(from) + from_offset, size);
+      LOG(FATAL) << "Expect copy from/to Metal or between Metal"
+                 << ", from=" << from_dev_type << ", to=" << to_dev_type;
     }
-  } else {
-    LOG(FATAL) << "Expect copy from/to Metal or between Metal"
-               << ", from=" << from_dev_type << ", to=" << to_dev_type;
   }
 }
 
 void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
-  ICHECK(stream == nullptr);
-  // commit an empty command buffer and wait until it completes.
-  id<MTLCommandQueue> queue = GetCommandQueue(ctx);
-  id<MTLCommandBuffer> cb = [queue commandBuffer];
-  [cb commit];
-  [cb waitUntilCompleted];
+  @autoreleasepool {
+    ICHECK(stream == nullptr);
+    // commit an empty command buffer and wait until it completes.
+    id<MTLCommandQueue> queue = GetCommandQueue(ctx);
+    id<MTLCommandBuffer> cb = [queue commandBuffer];
+    [cb commit];
+    [cb waitUntilCompleted];
+  }
 }
 
 void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size, DLDataType 
type_hint) {
@@ -269,10 +285,10 @@ id<MTLBuffer> MetalThreadEntry::GetTempBuffer(TVMContext 
ctx, size_t size) {
   if (temp_buffer_[ctx.device_id] == nil || temp_buffer_[ctx.device_id].length 
< size) {
     id<MTLDevice> dev = MetalWorkspace::Global()->GetDevice(ctx);
     if (temp_buffer_[ctx.device_id] != nil) {
+      [temp_buffer_[ctx.device_id] setPurgeableState:MTLPurgeableStateEmpty];
       [temp_buffer_[ctx.device_id] release];
     }
-    temp_buffer_[ctx.device_id] = [[dev newBufferWithLength:size
-                                                    
options:MTLStorageModeShared] retain];
+    temp_buffer_[ctx.device_id] = [dev newBufferWithLength:size 
options:MTLStorageModeShared];
   }
   return temp_buffer_[ctx.device_id];
 }
diff --git a/src/runtime/metal/metal_module.mm 
b/src/runtime/metal/metal_module.mm
index 8f1fde8..c7e2d8b 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -113,7 +113,6 @@ class MetalModuleNode final : public runtime::ModuleNode {
           LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg 
localizedDescription] UTF8String];
         }
       }
-      [e.lib retain];
     }
     id<MTLFunction> f =
         [e.lib newFunctionWithName:[NSString 
stringWithUTF8String:func_name.c_str()]];
@@ -123,11 +122,13 @@ class MetalModuleNode final : public runtime::ModuleNode {
     ICHECK(state != nil) << "cannot get state:"
                          << " for function " << func_name
                          << [[err_msg localizedDescription] UTF8String];
+    [f release];
     // The state.threadExecutionWidth can change dynamically according
     // to the resource constraint in kernel, so it is not strictly hold
     // Turn of warp aware optimziation for now.
     // ICHECK_EQ(state.threadExecutionWidth, w->warp_size[device_id]);
-    e.smap[func_name] = [state retain];
+    if (e.smap[func_name] != nil) [e.smap[func_name] release];
+    e.smap[func_name] = state;
     return state;
   }
 
@@ -181,31 +182,36 @@ class MetalWrappedFunc {
   }
   // invoke the function with void arguments
   void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) 
const {
-    metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
-    int device_id = t->context.device_id;
-    if (scache_[device_id] == nil) {
-      scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
-    }
-    ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
-    id<MTLCommandQueue> queue = w_->GetCommandQueue(t->context);
-    id<MTLCommandBuffer> cb = [queue commandBuffer];
-    id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
-    [encoder setComputePipelineState:scache_[device_id]];
-    for (size_t i = 0; i < num_buffer_args_; ++i) {
-      void* buf = args[static_cast<int>(i)];
-      [encoder setBuffer:(__bridge id<MTLBuffer>)(buf) offset:0 atIndex:i];
-    }
-    if (num_pack_args_ != 0) {
-      [encoder setBytes:pack_args
-                 length:num_pack_args_ * sizeof(ArgUnion64)
-                atIndex:num_buffer_args_];
+    @autoreleasepool {
+      metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
+      int device_id = t->context.device_id;
+      if (scache_[device_id] == nil) {
+        scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
+      }
+      ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
+      int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
+      auto maxTotalThreadsPerThreadgroup = 
scache_[device_id].maxTotalThreadsPerThreadgroup;
+      CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
+      id<MTLCommandQueue> queue = w_->GetCommandQueue(t->context);
+      id<MTLCommandBuffer> cb = [queue commandBuffer];
+      id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
+      [encoder setComputePipelineState:scache_[device_id]];
+      for (size_t i = 0; i < num_buffer_args_; ++i) {
+        void* buf = args[static_cast<int>(i)];
+        [encoder setBuffer:(id<MTLBuffer>)(buf) offset:0 atIndex:i];
+      }
+      if (num_pack_args_ != 0) {
+        [encoder setBytes:pack_args
+                   length:num_pack_args_ * sizeof(ArgUnion64)
+                  atIndex:num_buffer_args_];
+      }
+      // launch
+      MTLSize dimGrid = MTLSizeMake(wl.grid_dim(0), wl.grid_dim(1), 
wl.grid_dim(2));
+      MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1), 
wl.block_dim(2));
+      [encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock];
+      [encoder endEncoding];
+      [cb commit];
     }
-    // launch
-    MTLSize dimGrid = MTLSizeMake(wl.grid_dim(0), wl.grid_dim(1), 
wl.grid_dim(2));
-    MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1), 
wl.block_dim(2));
-    [encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock];
-    [encoder endEncoding];
-    [cb commit];
   }
 
  private:
@@ -230,23 +236,27 @@ class MetalWrappedFunc {
 
 PackedFunc MetalModuleNode::GetFunction(const std::string& name,
                                         const ObjectPtr<Object>& sptr_to_self) 
{
-  ICHECK_EQ(sptr_to_self.get(), this);
-  ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
-  auto it = fmap_.find(name);
-  if (it == fmap_.end()) return PackedFunc();
-  const FunctionInfo& info = it->second;
-  MetalWrappedFunc f;
-  size_t num_buffer_args = NumBufferArgs(info.arg_types);
-  f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - 
num_buffer_args,
-         info.thread_axis_tags);
-  return PackFuncNonBufferArg(f, info.arg_types);
+  @autoreleasepool {
+    ICHECK_EQ(sptr_to_self.get(), this);
+    ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
+    auto it = fmap_.find(name);
+    if (it == fmap_.end()) return PackedFunc();
+    const FunctionInfo& info = it->second;
+    MetalWrappedFunc f;
+    size_t num_buffer_args = NumBufferArgs(info.arg_types);
+    f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - 
num_buffer_args,
+           info.thread_axis_tags);
+    return PackFuncNonBufferArg(f, info.arg_types);
+  }
 }
 
 Module MetalModuleCreate(std::string data, std::string fmt,
                          std::unordered_map<std::string, FunctionInfo> fmap, 
std::string source) {
-  metal::MetalWorkspace::Global()->Init();
-  auto n = make_object<MetalModuleNode>(data, fmt, fmap, source);
-  return Module(n);
+  @autoreleasepool {
+    metal::MetalWorkspace::Global()->Init();
+    auto n = make_object<MetalModuleNode>(data, fmt, fmap, source);
+    return Module(n);
+  }
 }
 
 // Load module from module.

Reply via email to