You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/06/04 19:49:34 UTC

[tvm] branch main updated: [METAL] Fix the rest memory leaks in Metal runtime (#8175)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c9db3d0  [METAL] Fix the rest memory leaks in Metal runtime (#8175)
c9db3d0 is described below

commit c9db3d0a2b03bbbe9e4e0e421e0b848d540f94ca
Author: Egor Churaev <eg...@gmail.com>
AuthorDate: Fri Jun 4 22:49:13 2021 +0300

    [METAL] Fix the rest memory leaks in Metal runtime (#8175)
    
    * [METAL] Fix the rest memory leaks in Metal runtime
    
    When we throw exception from autoreleasepool, then the resources won't
    be released in proper way. In the documentation we can see that "When
    the block is exited with an exception, the pool is not drained.".
    
    Link on the documentation:
    https://clang.llvm.org/docs/AutomaticReferenceCounting.html#autoreleasepool
    
    Implemented a wrapper which handles all exceptions in autoreleasepool
    block and throw them after this block.
    
    * Apply comments
    
    * Add documentation comments to wrapper and macro
---
 src/runtime/metal/metal_common.h      | 56 +++++++++++++++++++++++++++++++++++
 src/runtime/metal/metal_device_api.mm | 40 ++++++++++++++-----------
 src/runtime/metal/metal_module.mm     | 26 +++++++++-------
 3 files changed, 94 insertions(+), 28 deletions(-)

diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h
index 9ebe04e..7d2ef0c 100644
--- a/src/runtime/metal/metal_common.h
+++ b/src/runtime/metal/metal_common.h
@@ -42,10 +42,66 @@
 
 #include "../workspace_pool.h"
 
+/* Macro for convenience in using AutoReleasePoolWrapper.
+ * With this macro we can add AutoReleasePoolWrapper to our ObjC code in more
+ * native way.
+ *
+ * For example, this is ObjC code with autoreleasepool:
+ *     @autoreleasepool {
+ *         // Some code
+ *     }
+ *
+ * To avoid possible memory leaks when an exception will be generated, we
+ * should update this code:
+ *     AUTORELEASEPOOL { // Replace @autoreleasepool -> AUTORELEASEPOOL
+ *         // Some code
+ *     }; // Add semicolon after close bracket
+ *
+ * In macro AUTORELEASEPOOL we get the instance of AutoReleasePoolWrapper and
+ * put a lambda function with code from autoreleasepool to the insertion
+ * operator of AutoReleasePoolWrapper class.
+ *
+ * Note: If you want to return a value from the autoreleasepool, you should
+ * declare the variable with result before AUTORELEASEPOOL macro. This variable
+ * will be captured by reference and you can use it in the code in autorelease
+ * pool. But you should write return statement after AUTORELEASEPOOL macro.
+ */
+#define AUTORELEASEPOOL tvm::runtime::metal::AutoReleasePoolWrapper::GetInstance() << [&]()
+
 namespace tvm {
 namespace runtime {
 namespace metal {
 /*!
+ * \brief Wrapper on autoreleasepool with exception handling
+ *
+ * \note In case when the exception was thrown from the autoreleasepool, the
+ * allocated resources won't be released in proper way. So, we handle exception
+ * in autoreleasepool and after the autoreleasepool we rethrow this exception.
+ */
+class AutoReleasePoolWrapper {
+ public:
+  static AutoReleasePoolWrapper& GetInstance();
+  template <typename T>
+  void operator<<(const T& f) {
+    std::exception_ptr eptr;
+    @autoreleasepool {
+      try {
+        f();
+      } catch (...) {
+        eptr = std::current_exception();
+      }
+    }
+    if (eptr) std::rethrow_exception(eptr);
+  }
+
+ private:
+  AutoReleasePoolWrapper() = default;
+  ~AutoReleasePoolWrapper() = default;
+  AutoReleasePoolWrapper(const AutoReleasePoolWrapper&) = delete;
+  AutoReleasePoolWrapper& operator=(const AutoReleasePoolWrapper&) = delete;
+};
+
+/*!
  * \brief Structure for error handling in queues
  */
 class Stream {
diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm
index 193e464..1c5666d 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -29,17 +29,20 @@ namespace tvm {
 namespace runtime {
 namespace metal {
 
+AutoReleasePoolWrapper& AutoReleasePoolWrapper::GetInstance() {
+  static AutoReleasePoolWrapper instance;
+  return instance;
+}
+
 MetalWorkspace* MetalWorkspace::Global() {
-  @autoreleasepool {
-    // NOTE: explicitly use new to avoid exit-time destruction of global state
-    // Global state will be recycled by OS as the process exits.
-    static MetalWorkspace* inst = new MetalWorkspace();
-    return inst;
-  }
+  // NOTE: explicitly use new to avoid exit-time destruction of global state
+  // Global state will be recycled by OS as the process exits.
+  static MetalWorkspace* inst = new MetalWorkspace();
+  return inst;
 }
 
 void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) {
-  @autoreleasepool {
+  AUTORELEASEPOOL {
     this->Init();
     size_t index = static_cast<size_t>(dev.device_id);
     if (kind == kExist) {
@@ -80,7 +83,7 @@ void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) {
       case kDriverVersion:
         return;
     }
-  }
+  };
 }
 
 static const char* kDummyKernel = R"A0B0(
@@ -161,7 +164,8 @@ void MetalWorkspace::SetDevice(Device dev) {
 
 void* MetalWorkspace::AllocDataSpace(Device device, size_t nbytes, size_t alignment,
                                      DLDataType type_hint) {
-  @autoreleasepool {
+  id<MTLBuffer> buf;
+  AUTORELEASEPOOL {
     this->Init();
     id<MTLDevice> dev = GetDevice(device);
     // GPU memory only
@@ -173,20 +177,20 @@ void* MetalWorkspace::AllocDataSpace(Device device, size_t nbytes, size_t alignm
     storage_mode = MTLResourceStorageModeManaged;
     #endif
     */
-    id<MTLBuffer> buf = [dev newBufferWithLength:nbytes options:storage_mode];
+    buf = [dev newBufferWithLength:nbytes options:storage_mode];
     ICHECK(buf != nil);
-    return (void*)(buf);
-  }
+  };
+  return (void*)(buf);
 }
 
 void MetalWorkspace::FreeDataSpace(Device dev, void* ptr) {
-  @autoreleasepool {
+  AUTORELEASEPOOL {
     // MTLBuffer PurgeableState should be set to empty before manual
     // release in order to prevent memory leak
     [(id<MTLBuffer>)ptr setPurgeableState:MTLPurgeableStateEmpty];
     // release the ptr.
     CFRelease(ptr);
-  }
+  };
 }
 
 Stream* GetStream(TVMStreamHandle stream, int device_id) {
@@ -199,7 +203,7 @@ Stream* GetStream(TVMStreamHandle stream, int device_id) {
 void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
                                     size_t to_offset, size_t size, Device dev_from, Device dev_to,
                                     DLDataType type_hint, TVMStreamHandle stream) {
-  @autoreleasepool {
+  AUTORELEASEPOOL {
     this->Init();
     Device dev = dev_from;
     Stream* s = GetStream(stream, dev.device_id);
@@ -261,7 +265,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void*
       LOG(FATAL) << "Expect copy from/to Metal or between Metal"
                  << ", from=" << from_dev_type << ", to=" << to_dev_type;
     }
-  }
+  };
 }
 
 TVMStreamHandle MetalWorkspace::CreateStream(Device dev) {
@@ -276,7 +280,7 @@ void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) {
 }
 
 void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
-  @autoreleasepool {
+  AUTORELEASEPOOL {
     Stream* s = GetStream(stream, dev.device_id);
     // commit an empty command buffer and wait until it completes.
     id<MTLCommandBuffer> cb = s->GetCommandBuffer();
@@ -285,7 +289,7 @@ void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
     if (s->HasErrorHappened()) {
       LOG(FATAL) << "Error! Some problems on GPU happaned!";
     }
-  }
+  };
 }
 
 void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) {
diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm
index 2920c60..8850188 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -193,7 +193,7 @@ class MetalWrappedFunc {
   }
   // invoke the function with void arguments
   void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const {
-    @autoreleasepool {
+    AUTORELEASEPOOL {
       metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
       int device_id = t->device.device_id;
       auto stream = static_cast<metal::Stream*>(t->stream[device_id]);
@@ -223,7 +223,7 @@ class MetalWrappedFunc {
       [encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock];
       [encoder endEncoding];
       [cb commit];
-    }
+    };
   }
 
  private:
@@ -248,27 +248,33 @@ class MetalWrappedFunc {
 
 PackedFunc MetalModuleNode::GetFunction(const std::string& name,
                                         const ObjectPtr<Object>& sptr_to_self) {
-  @autoreleasepool {
+  PackedFunc pf;
+  AUTORELEASEPOOL {
     ICHECK_EQ(sptr_to_self.get(), this);
     ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
     auto it = fmap_.find(name);
-    if (it == fmap_.end()) return PackedFunc();
+    if (it == fmap_.end()) {
+      pf = PackedFunc();
+      return;
+    }
     const FunctionInfo& info = it->second;
     MetalWrappedFunc f;
     size_t num_buffer_args = NumBufferArgs(info.arg_types);
     f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args,
            info.thread_axis_tags);
-    return PackFuncNonBufferArg(f, info.arg_types);
-  }
+    pf = PackFuncNonBufferArg(f, info.arg_types);
+  };
+  return pf;
 }
 
 Module MetalModuleCreate(std::string data, std::string fmt,
                          std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
-  @autoreleasepool {
+  ObjectPtr<Object> n;
+  AUTORELEASEPOOL {
     metal::MetalWorkspace::Global()->Init();
-    auto n = make_object<MetalModuleNode>(data, fmt, fmap, source);
-    return Module(n);
-  }
+    n = make_object<MetalModuleNode>(data, fmt, fmap, source);
+  };
+  return Module(n);
 }
 
 // Load module from module.