You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/06/18 06:56:35 UTC
[tvm] branch main updated: [Metal] Fix bad stream after interrupted
tuning session (#8244)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 77536da [Metal] Fix bad stream after interrupted tuning session (#8244)
77536da is described below
commit 77536da8ff743f66d9c9ad2c030aa5d6fafd949e
Author: Egor Churaev <eg...@gmail.com>
AuthorDate: Fri Jun 18 09:56:08 2021 +0300
[Metal] Fix bad stream after interrupted tuning session (#8244)
* [Metal] Fix bad stream after interrupted tuning session
After interrupted tuning session, we may face the problem that the
stream object was released, but we didn't create a new one. In this case
it wasn't possible to run a new Metal task on the device without
restarting rpc application.
Created a global function `metal.ResetGlobalState` which should be
called in RPC application when the connection was closed. In this
function, we reinitialize the streams of Metal devices. And it
guarantees to us that the new RPC session will work with the correct
streams.
* Refactor metal_device_api
- Rename function GetStream -> CastStreamOrGetCurrent
- Add several checks on device id
- When we use `SetStream` with nullptr, then the default stream will be
associated with the device.
---
src/runtime/metal/metal_common.h | 1 +
src/runtime/metal/metal_device_api.mm | 48 +++++++++++++++++++++++++----------
2 files changed, 35 insertions(+), 14 deletions(-)
diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h
index 7d2ef0c..47a5999 100644
--- a/src/runtime/metal/metal_common.h
+++ b/src/runtime/metal/metal_common.h
@@ -163,6 +163,7 @@ class MetalWorkspace final : public DeviceAPI {
void SetStream(Device dev, TVMStreamHandle stream) final;
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
void FreeWorkspace(Device dev, void* data) final;
+ void ReinitializeStreams();
// get the global workspace
static MetalWorkspace* Global();
diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm
index 43d8ccd..0ef07b1 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -1,4 +1,3 @@
-
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
@@ -131,6 +130,23 @@ MetalWorkspace::~MetalWorkspace() {
}
}
+void MetalWorkspace::ReinitializeStreams() {
+ std::vector<Stream*>& threadStreams = MetalThreadEntry::ThreadLocal()->stream;
+ ICHECK_EQ(default_streams_.size(), threadStreams.size());
+ for (size_t i = 0; i < default_streams_.size(); ++i) {
+ if (threadStreams[i] != nullptr && default_streams_[i] != threadStreams[i])
+ delete threadStreams[i];
+ delete default_streams_[i];
+ }
+ default_streams_.resize(devices.size());
+ threadStreams.resize(devices.size());
+ for (size_t i = 0; i < devices.size(); ++i) {
+ Stream* stream = new Stream(devices[i]);
+ default_streams_[i] = stream;
+ threadStreams[i] = stream;
+ }
+}
+
void MetalWorkspace::Init() {
if (initialized_) return;
std::lock_guard<std::mutex> lock(this->mutex);
@@ -141,21 +157,16 @@ void MetalWorkspace::Init() {
// on iPhone
id<MTLDevice> d = MTLCreateSystemDefaultDevice();
devices.push_back(d);
- Stream* stream = new Stream(d);
- MetalThreadEntry::ThreadLocal()->stream.push_back(stream);
- default_streams_.push_back(stream);
#else
NSArray<id<MTLDevice> >* devs = MTLCopyAllDevices();
for (size_t i = 0; i < devs.count; ++i) {
id<MTLDevice> d = [devs objectAtIndex:i];
devices.push_back(d);
- Stream* stream = new Stream(d);
- MetalThreadEntry::ThreadLocal()->stream.push_back(stream);
- default_streams_.push_back(stream);
LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String];
warp_size.push_back(GetWarpSize(d));
}
#endif
+ ReinitializeStreams();
}
void MetalWorkspace::SetDevice(Device dev) {
@@ -193,11 +204,10 @@ void MetalWorkspace::FreeDataSpace(Device dev, void* ptr) {
};
}
-Stream* GetStream(TVMStreamHandle stream, int device_id) {
- if (stream != nullptr)
- return static_cast<Stream*>(stream);
- else
- return MetalThreadEntry::ThreadLocal()->stream[device_id];
+Stream* CastStreamOrGetCurrent(TVMStreamHandle stream, int device_id) {
+ if (stream != nullptr) return static_cast<Stream*>(stream);
+ ICHECK(MetalThreadEntry::ThreadLocal()->stream[device_id] != nullptr);
+ return MetalThreadEntry::ThreadLocal()->stream[device_id];
}
void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
@@ -207,7 +217,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void*
this->Init();
Device dev = dev_from;
if (dev_from.device_type == kDLCPU) dev = dev_to;
- Stream* s = GetStream(stream, dev.device_id);
+ Stream* s = CastStreamOrGetCurrent(stream, dev.device_id);
if (s->HasErrorHappened()) {
LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream";
}
@@ -269,19 +279,23 @@ void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void*
}
TVMStreamHandle MetalWorkspace::CreateStream(Device dev) {
+ ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
Stream* stream = new Stream(devices[dev.device_id]);
return static_cast<TVMStreamHandle>(stream);
}
void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) {
ICHECK(stream != nullptr);
+ ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
Stream* s = static_cast<Stream*>(stream);
+ if (MetalThreadEntry::ThreadLocal()->stream[dev.device_id] == s)
+ MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = nullptr;
delete s;
}
void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
AUTORELEASEPOOL {
- Stream* s = GetStream(stream, dev.device_id);
+ Stream* s = CastStreamOrGetCurrent(stream, dev.device_id);
// commit an empty command buffer and wait until it completes.
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
[cb commit];
@@ -293,6 +307,8 @@ void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
}
void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) {
+ ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
+ ICHECK(stream != nullptr);
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast<Stream*>(stream);
}
@@ -337,6 +353,10 @@ TVM_REGISTER_GLOBAL("device_api.metal").set_body([](TVMArgs args, TVMRetValue* r
*rv = static_cast<void*>(ptr);
});
+TVM_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() {
+ MetalWorkspace::Global()->ReinitializeStreams();
+});
+
} // namespace metal
} // namespace runtime
} // namespace tvm