You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/06/18 06:56:35 UTC

[tvm] branch main updated: [Metal] Fix bad stream after interrupted tuning session (#8244)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 77536da  [Metal] Fix bad stream after interrupted tuning session (#8244)
77536da is described below

commit 77536da8ff743f66d9c9ad2c030aa5d6fafd949e
Author: Egor Churaev <eg...@gmail.com>
AuthorDate: Fri Jun 18 09:56:08 2021 +0300

    [Metal] Fix bad stream after interrupted tuning session (#8244)
    
    * [Metal] Fix bad stream after interrupted tuning session
    
    After interrupted tuning session, we may face the problem that the
    stream object was released, but we didn't create a new one. In this case
    it wasn't possible to run a new Metal task on the device without
    restarting rpc application.
    
    Created a global function `metal.ResetGlobalState` which should be
    called in RPC application when the connection was closed. In this
    function, we reinitialize the streams of Metal devices. And it
    guarantees to us that the new RPC session will work with the correct
    streams.
    
    * Refactor metal_device_api
    
    - Rename function GetStream -> CastStreamOrGetCurrent
    - Add several checks on device id
    - When we use `SetStream` with nullptr, then the default stream will be
      associated with the device.
---
 src/runtime/metal/metal_common.h      |  1 +
 src/runtime/metal/metal_device_api.mm | 48 +++++++++++++++++++++++++----------
 2 files changed, 35 insertions(+), 14 deletions(-)

diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h
index 7d2ef0c..47a5999 100644
--- a/src/runtime/metal/metal_common.h
+++ b/src/runtime/metal/metal_common.h
@@ -163,6 +163,7 @@ class MetalWorkspace final : public DeviceAPI {
   void SetStream(Device dev, TVMStreamHandle stream) final;
   void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
   void FreeWorkspace(Device dev, void* data) final;
+  void ReinitializeStreams();
 
   // get the global workspace
   static MetalWorkspace* Global();
diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm
index 43d8ccd..0ef07b1 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -1,4 +1,3 @@
-
 /*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
@@ -131,6 +130,23 @@ MetalWorkspace::~MetalWorkspace() {
   }
 }
 
+void MetalWorkspace::ReinitializeStreams() {
+  std::vector<Stream*>& threadStreams = MetalThreadEntry::ThreadLocal()->stream;
+  ICHECK_EQ(default_streams_.size(), threadStreams.size());
+  for (size_t i = 0; i < default_streams_.size(); ++i) {
+    if (threadStreams[i] != nullptr && default_streams_[i] != threadStreams[i])
+      delete threadStreams[i];
+    delete default_streams_[i];
+  }
+  default_streams_.resize(devices.size());
+  threadStreams.resize(devices.size());
+  for (size_t i = 0; i < devices.size(); ++i) {
+    Stream* stream = new Stream(devices[i]);
+    default_streams_[i] = stream;
+    threadStreams[i] = stream;
+  }
+}
+
 void MetalWorkspace::Init() {
   if (initialized_) return;
   std::lock_guard<std::mutex> lock(this->mutex);
@@ -141,21 +157,16 @@ void MetalWorkspace::Init() {
   // on iPhone
   id<MTLDevice> d = MTLCreateSystemDefaultDevice();
   devices.push_back(d);
-  Stream* stream = new Stream(d);
-  MetalThreadEntry::ThreadLocal()->stream.push_back(stream);
-  default_streams_.push_back(stream);
 #else
   NSArray<id<MTLDevice> >* devs = MTLCopyAllDevices();
   for (size_t i = 0; i < devs.count; ++i) {
     id<MTLDevice> d = [devs objectAtIndex:i];
     devices.push_back(d);
-    Stream* stream = new Stream(d);
-    MetalThreadEntry::ThreadLocal()->stream.push_back(stream);
-    default_streams_.push_back(stream);
     LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String];
     warp_size.push_back(GetWarpSize(d));
   }
 #endif
+  ReinitializeStreams();
 }
 
 void MetalWorkspace::SetDevice(Device dev) {
@@ -193,11 +204,10 @@ void MetalWorkspace::FreeDataSpace(Device dev, void* ptr) {
   };
 }
 
-Stream* GetStream(TVMStreamHandle stream, int device_id) {
-  if (stream != nullptr)
-    return static_cast<Stream*>(stream);
-  else
-    return MetalThreadEntry::ThreadLocal()->stream[device_id];
+Stream* CastStreamOrGetCurrent(TVMStreamHandle stream, int device_id) {
+  if (stream != nullptr) return static_cast<Stream*>(stream);
+  ICHECK(MetalThreadEntry::ThreadLocal()->stream[device_id] != nullptr);
+  return MetalThreadEntry::ThreadLocal()->stream[device_id];
 }
 
 void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
@@ -207,7 +217,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void*
     this->Init();
     Device dev = dev_from;
     if (dev_from.device_type == kDLCPU) dev = dev_to;
-    Stream* s = GetStream(stream, dev.device_id);
+    Stream* s = CastStreamOrGetCurrent(stream, dev.device_id);
     if (s->HasErrorHappened()) {
       LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream";
     }
@@ -269,19 +279,23 @@ void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void*
 }
 
 TVMStreamHandle MetalWorkspace::CreateStream(Device dev) {
+  ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
   Stream* stream = new Stream(devices[dev.device_id]);
   return static_cast<TVMStreamHandle>(stream);
 }
 
 void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) {
   ICHECK(stream != nullptr);
+  ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
   Stream* s = static_cast<Stream*>(stream);
+  if (MetalThreadEntry::ThreadLocal()->stream[dev.device_id] == s)
+    MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = nullptr;
   delete s;
 }
 
 void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
   AUTORELEASEPOOL {
-    Stream* s = GetStream(stream, dev.device_id);
+    Stream* s = CastStreamOrGetCurrent(stream, dev.device_id);
     // commit an empty command buffer and wait until it completes.
     id<MTLCommandBuffer> cb = s->GetCommandBuffer();
     [cb commit];
@@ -293,6 +307,8 @@ void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
 }
 
 void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) {
+  ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
+  ICHECK(stream != nullptr);
   MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast<Stream*>(stream);
 }
 
@@ -337,6 +353,10 @@ TVM_REGISTER_GLOBAL("device_api.metal").set_body([](TVMArgs args, TVMRetValue* r
   *rv = static_cast<void*>(ptr);
 });
 
+TVM_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() {
+  MetalWorkspace::Global()->ReinitializeStreams();
+});
+
 }  // namespace metal
 }  // namespace runtime
 }  // namespace tvm