You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/18 19:07:02 UTC

[GitHub] piiswrong closed pull request #8967: add shared storage in windows

piiswrong closed pull request #8967: add shared storage in windows
URL: https://github.com/apache/incubator-mxnet/pull/8967
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/amalgamation/amalgamation.py b/amalgamation/amalgamation.py
index b378817e14..9419898135 100644
--- a/amalgamation/amalgamation.py
+++ b/amalgamation/amalgamation.py
@@ -43,6 +43,7 @@
 
 if platform.system() != 'Windows':
   blacklist.append('windows.h')
+  blacklist.append('process.h')
 
 def pprint(lst):
     for item in lst:
diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index beb228ec24..8dea59fae9 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -25,9 +25,7 @@
 from multiprocessing.reduction import ForkingPickler
 import pickle
 import io
-import os
 import sys
-import warnings
 import numpy as np
 
 from . import sampler as _sampler
@@ -52,7 +50,7 @@ class ConnectionWrapper(object):
     NDArray via shared memory."""
 
     def __init__(self, conn):
-        self.conn = conn
+        self._conn = conn
 
     def send(self, obj):
         """Send object"""
@@ -67,7 +65,8 @@ def recv(self):
 
     def __getattr__(self, name):
         """Emmulate conn"""
-        return getattr(self.conn, name)
+        attr = self.__dict__.get('_conn', None)
+        return getattr(attr, name)
 
 
 class Queue(multiprocessing.queues.Queue):
@@ -188,9 +187,6 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
                              "not be specified if batch_sampler is specified.")
 
         self._batch_sampler = batch_sampler
-        if num_workers > 0 and os.name == 'nt':
-            warnings.warn("DataLoader does not support num_workers > 0 on Windows yet.")
-            num_workers = 0
         self._num_workers = num_workers
         if batchify_fn is None:
             if num_workers > 0:
diff --git a/src/storage/cpu_shared_storage_manager.h b/src/storage/cpu_shared_storage_manager.h
index 9f0f2a354d..98f706b802 100644
--- a/src/storage/cpu_shared_storage_manager.h
+++ b/src/storage/cpu_shared_storage_manager.h
@@ -31,6 +31,9 @@
 #include <unistd.h>
 #include <sys/types.h>
 #include <sys/stat.h>
+#else
+#include <Windows.h>
+#include <process.h>
 #endif  // _WIN32
 
 #include <unordered_map>
@@ -64,6 +67,9 @@ class CPUSharedStorageManager final : public StorageManager {
     for (const auto& kv : pool_) {
       FreeImpl(kv.second);
     }
+#ifdef _WIN32
+    CheckAndRealFree();
+#endif
   }
 
   void Alloc(Storage::Handle* handle) override;
@@ -91,11 +97,18 @@ class CPUSharedStorageManager final : public StorageManager {
  private:
   static constexpr size_t alignment_ = 16;
 
-  std::mutex mutex_;
+  std::recursive_mutex mutex_;
   std::mt19937 rand_gen_;
   std::unordered_map<void*, Storage::Handle> pool_;
+#ifdef _WIN32
+  std::unordered_map<void*, Storage::Handle> is_free_;
+  std::unordered_map<void*, HANDLE> map_handle_map_;
+#endif
 
   void FreeImpl(const Storage::Handle& handle);
+#ifdef _WIN32
+  void CheckAndRealFree();
+#endif
 
   std::string SharedHandleToString(int shared_pid, int shared_id) {
     std::stringstream name;
@@ -106,14 +119,44 @@ class CPUSharedStorageManager final : public StorageManager {
 };  // class CPUSharedStorageManager
 
 void CPUSharedStorageManager::Alloc(Storage::Handle* handle) {
-  std::lock_guard<std::mutex> lock(mutex_);
+  std::lock_guard<std::recursive_mutex> lock(mutex_);
   std::uniform_int_distribution<> dis(0, std::numeric_limits<int>::max());
   int fid = -1;
   bool is_new = false;
   size_t size = handle->size + alignment_;
-  void* ptr = nullptr;
-#ifdef _WIN32
-  LOG(FATAL) << "Shared memory is not supported on Windows yet.";
+  void *ptr = nullptr;
+  #ifdef _WIN32
+  CheckAndRealFree();
+  HANDLE map_handle = nullptr;
+  uint32_t error = 0;
+  if (handle->shared_id == -1 && handle->shared_pid == -1) {
+    is_new = true;
+    handle->shared_pid = _getpid();
+    for (int i = 0; i < 10; ++i) {
+      handle->shared_id = dis(rand_gen_);
+      auto filename = SharedHandleToString(handle->shared_pid, handle->shared_id);
+      map_handle = CreateFileMapping(INVALID_HANDLE_VALUE,
+                                     NULL, PAGE_READWRITE, 0, size, filename.c_str());
+      if ((error = GetLastError()) == ERROR_SUCCESS) {
+        break;;
+      }
+    }
+  } else {
+    auto filename = SharedHandleToString(handle->shared_pid, handle->shared_id);
+    map_handle = OpenFileMapping(FILE_MAP_READ | FILE_MAP_WRITE,
+                                 FALSE, filename.c_str());
+    error = GetLastError();
+  }
+
+  if (error != ERROR_SUCCESS && map_handle == nullptr) {
+    LOG(FATAL) << "Failed to open shared memory. CreateFileMapping failed with error "
+               << error;
+  }
+
+  ptr = MapViewOfFile(map_handle, FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0);
+  CHECK_NE(ptr, (void *)0)
+      << "Failed to map shared memory. MapViewOfFile failed with error " << GetLastError();
+  map_handle_map_[ptr] = map_handle;
 #else
   if (handle->shared_id == -1 && handle->shared_pid == -1) {
     is_new = true;
@@ -153,7 +196,7 @@ void CPUSharedStorageManager::FreeImpl(const Storage::Handle& handle) {
   int count = DecrementRefCount(handle);
   CHECK_GE(count, 0);
 #ifdef _WIN32
-  LOG(FATAL) << "Shared memory is not supported on Windows yet.";
+  is_free_[handle.dptr] = handle;
 #else
   CHECK_EQ(munmap(static_cast<char*>(handle.dptr) - alignment_,
                   handle.size + alignment_), 0)
@@ -169,6 +212,26 @@ void CPUSharedStorageManager::FreeImpl(const Storage::Handle& handle) {
 #endif  // _WIN32
 }
 
+#ifdef _WIN32
+inline void CPUSharedStorageManager::CheckAndRealFree() {
+  std::lock_guard<std::recursive_mutex> lock(mutex_);
+  for (auto it = std::begin(is_free_); it != std::end(is_free_);) {
+    void* ptr = static_cast<char*>(it->second.dptr) - alignment_;
+    std::atomic<int>* counter = reinterpret_cast<std::atomic<int>*>(
+      static_cast<char*>(it->second.dptr) - alignment_);
+    if ((*counter) == 0) {
+      CHECK_NE(UnmapViewOfFile(ptr), 0)
+        << "Failed to UnmapViewOfFile shared memory ";
+      CHECK_NE(CloseHandle(map_handle_map_[ptr]), 0)
+        << "Failed to CloseHandle shared memory ";
+      map_handle_map_.erase(ptr);
+      it = is_free_.erase(it);
+    } else {
+      ++it;
+    }
+  }
+}
+#endif  // _WIN32
 }  // namespace storage
 }  // namespace mxnet
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services