You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/06 05:57:52 UTC

[GitHub] piiswrong closed pull request #11159: fix shared_storage free

piiswrong closed pull request #11159: fix shared_storage free
URL: https://github.com/apache/incubator-mxnet/pull/11159
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 151b49d457a..29b9b81aca0 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -57,6 +57,8 @@ def rebuild_ndarray(pid, fd, shape, dtype):
 
     def reduce_ndarray(data):
         """Reduce ndarray to shared memory handle"""
+        # keep a local ref before duplicating fd
+        data = data.as_in_context(context.Context('cpu_shared', 0))
         pid, fd, shape, dtype = data._to_shared_mem()
         if sys.version_info[0] == 2:
             fd = multiprocessing.reduction.reduce_handle(fd)
diff --git a/src/storage/cpu_shared_storage_manager.h b/src/storage/cpu_shared_storage_manager.h
index 85c6a352afd..a52d779d231 100644
--- a/src/storage/cpu_shared_storage_manager.h
+++ b/src/storage/cpu_shared_storage_manager.h
@@ -174,8 +174,12 @@ void CPUSharedStorageManager::Alloc(Storage::Handle* handle) {
   }
 
   if (fid == -1) {
-    LOG(FATAL) << "Failed to open shared memory. shm_open failed with error "
-               << strerror(errno);
+    if (is_new) {
+      LOG(FATAL) << "Failed to open shared memory. shm_open failed with error "
+                 << strerror(errno);
+    } else {
+      LOG(FATAL) << "Invalid file descriptor from shared array.";
+    }
   }
 
   if (is_new) CHECK_EQ(ftruncate(fid, size), 0);
@@ -216,9 +220,11 @@ void CPUSharedStorageManager::FreeImpl(const Storage::Handle& handle) {
       << strerror(errno);
 
 #ifdef __linux__
+  if (handle.shared_id != -1) {
   CHECK_EQ(close(handle.shared_id), 0)
       << "Failed to close shared memory. close failed with error "
       << strerror(errno);
+  }
 #else
   if (count == 0) {
     auto filename = SharedHandleToString(handle.shared_pid, handle.shared_id);
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 93160aa0940..751886b8e7f 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -140,6 +140,16 @@ def __getitem__(self, idx):
         def __len__(self):
             return 50
 
+        def batchify_list(self, data):
+            """
+            return list of ndarray without stack/concat/pad
+            """
+            if isinstance(data, (tuple, list)):
+                return list(data)
+            if isinstance(data, mx.nd.NDArray):
+                return [data]
+            return data
+
         def batchify(self, data):
             """
             Collate data into batch. Use shared memory for stacking.
@@ -194,6 +204,14 @@ def batchify(self, data):
                     print(data)
                     print('{}:{}'.format(epoch, i))
 
+        data = Dummy(True)
+        loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify_list, num_workers=2)
+        for epoch in range(1):
+            for i, data in enumerate(loader):
+                if i % 100 == 0:
+                    print(data)
+                    print('{}:{}'.format(epoch, i))
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 496f80f927f..a0604658ee1 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -1304,7 +1304,6 @@ def test_norm(ctx=default_context()):
             assert arr1.shape == arr2.shape
             mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy())
 
-
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services