You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/16 23:13:36 UTC

[GitHub] piiswrong closed pull request #8677: multi processing and fork fix

piiswrong closed pull request #8677: multi processing and fork fix
URL: https://github.com/apache/incubator-mxnet/pull/8677
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/dmlc-core b/dmlc-core
index 595d02c0e8..87b7ffa59e 160000
--- a/dmlc-core
+++ b/dmlc-core
@@ -1 +1 @@
-Subproject commit 595d02c0e87be8a0846700462b6f45f1b1031e39
+Subproject commit 87b7ffa59eb78f753073ac56f5f60e46d930b93c
diff --git a/include/mxnet/base.h b/include/mxnet/base.h
index cceee70ffd..f8cc7d5547 100644
--- a/include/mxnet/base.h
+++ b/include/mxnet/base.h
@@ -143,7 +143,8 @@ struct Context {
   enum DeviceType {
     kCPU = cpu::kDevMask,
     kGPU = gpu::kDevMask,
-    kCPUPinned = 3
+    kCPUPinned = 3,
+    kCPUShared = 5,
   };
   /*! \brief the device type we run the op on */
   DeviceType dev_type;
@@ -155,11 +156,18 @@ struct Context {
    * \brief Get corresponding device mask
    * \return cpu::kDevMask or gpu::kDevMask
    */
-  inline int dev_mask() const {
-    if (dev_type == kCPUPinned) return cpu::kDevMask;
+  inline DeviceType dev_mask() const {
+    if (dev_type == kCPUPinned || dev_type == kCPUShared) return kCPU;
     return dev_type;
   }
   /*!
+   * \brief Returns dev_id for kGPU, 0 otherwise
+   */
+  inline int real_dev_id() const {
+    if (dev_type == kGPU) return dev_id;
+    return 0;
+  }
+  /*!
    * \brief Comparator, used to enable Context as std::map key.
    * \param b another context to compare
    * \return compared result
@@ -200,7 +208,7 @@ struct Context {
     return true;
   }
   /*! \brief the maximal device type */
-  static const int32_t kMaxDevType = 4;
+  static const int32_t kMaxDevType = 6;
   /*! \brief the maximal device index */
   static const int32_t kMaxDevID = 16;
   /*!
@@ -224,6 +232,12 @@ struct Context {
    */
   inline static Context CPUPinned(int32_t dev_id = -1);
   /*!
+   * Create a CPU shared memory context.
+   * \param dev_id dummy device id.
+   * \return CPU shared memory context.
+   */
+  inline static Context CPUShared(int32_t dev_id = 0);
+  /*!
    * Create a context from string of the format [cpu|gpu|cpu_pinned](n)
    * \param str the string pattern
    * \return Context
@@ -273,7 +287,7 @@ inline Context Context::Create(DeviceType dev_type, int32_t dev_id) {
   ctx.dev_type = dev_type;
   if (dev_id < 0) {
     ctx.dev_id = 0;
-    if (dev_type != kCPU) {
+    if (dev_type & kGPU) {
 #if MXNET_USE_CUDA
       CHECK_EQ(cudaGetDevice(&ctx.dev_id), cudaSuccess);
 #else
@@ -293,6 +307,10 @@ inline Context Context::CPUPinned(int32_t dev_id) {
   return Create(kCPUPinned, dev_id);
 }
 
+inline Context Context::CPUShared(int32_t dev_id) {
+  return Create(kCPUShared, dev_id);
+}
+
 inline Context Context::GPU(int32_t dev_id) {
   return Create(kGPU, dev_id);
 }
@@ -313,6 +331,8 @@ inline Context Context::FromString(std::string str) {
       ret = GPU(id);
     } else if (type == "cpu_pinned") {
       ret = CPUPinned(id);
+    } else if (type == "cpu_shared") {
+      ret = CPUShared(id);
     } else {
       LOG(FATAL) << "Invalid context string " << str;
     }
@@ -329,6 +349,8 @@ inline std::ostream& operator<<(std::ostream &out, const Context &ctx) {
     out << "gpu(";
   } else if (ctx.dev_type == Context::kCPUPinned) {
     out << "cpu_pinned(";
+  } else if (ctx.dev_type == Context::kCPUShared) {
+    out << "cpu_shared(";
   } else {
     out << "unknown(";
   }
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 8ea2b0e0e5..08c7109851 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -2007,6 +2007,26 @@ MXNET_DLL int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** ar
                                   mx_uint grid_dim_z, mx_uint block_dim_x,
                                   mx_uint block_dim_y, mx_uint block_dim_z,
                                   mx_uint shared_mem);
+/*!
+ * \brief Get shared memory handle from NDArray
+ * \param handle NDArray handle.
+ * \param shared_pid output PID
+ * \param shared_id output shared memory id.
+ */
+MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid,
+                                          int* shared_id);
+/*!
+ * \brief Reconstruct NDArray from shared memory handle
+ * \param shared_pid shared PID
+ * \param shared_id shared memory id
+ * \param shape pointer to NDArray dimensions
+ * \param ndim number of NDArray dimensions
+ * \param dtype data type of NDArray
+ * \param out constructed NDArray
+ */
+MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape,
+                                           mx_uint ndim, int dtype, NDArrayHandle *out);
+
 
 #ifdef __cplusplus
 }
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index 5a4697df4b..9ca2cedde9 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -113,6 +113,18 @@ class MXNET_API Engine {
    */
   virtual void NotifyShutdown() = 0;
   /*!
+   *\brief Stop all workers in the engine
+   */
+  virtual void Stop() {
+    LOG(FATAL) << "Engine cannot be stopped";
+  }
+  /*!
+   * \brief Restart all workers in the engine
+   */
+  virtual void Start() {
+    LOG(FATAL) << "Engine cannot be restarted";
+  }
+  /*!
    * \brief Allocate a new variable, the variable can then
    *        be used to schedule the operation concurrently via dependency
    *        patterns.
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 498a47f6ad..c2d55dc18f 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -160,6 +160,14 @@ class NDArray {
     Mkl_mem_ = std::make_shared<MKLMemHolder>();
 #endif
   }
+  /*! \brief create ndarray from shared memory */
+  NDArray(int shared_pid, int shared_id, const TShape& shape, int dtype)
+      : ptr_(std::make_shared<Chunk>(shared_pid, shared_id, shape, dtype)), shape_(shape),
+        dtype_(dtype), storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
+#if MKL_EXPERIMENTAL == 1
+    Mkl_mem_ = std::make_shared<MKLMemHolder>();
+#endif
+  }
 
   /*!
    * \brief constructing a static NDArray of non-default storage that shares data with TBlob
@@ -317,6 +325,13 @@ class NDArray {
     }
     return true;
   }
+  /*! \brief get storage handle */
+  inline Storage::Handle storage_handle() const {
+    CHECK(!is_none());
+    CHECK_EQ(storage_type(), kDefaultStorage);
+    CheckAndAlloc();
+    return ptr_->shandle;
+  }
   /*!
    * \brief Block until all the pending write operations with respect
    *    to current NDArray are finished, and read can be performed.
@@ -682,6 +697,18 @@ class NDArray {
       shandle.size = data.shape_.Size() * mshadow::mshadow_sizeof(data.type_flag_);
       storage_shape = data.shape_;
     }
+
+    Chunk(int shared_pid, int shared_id, const TShape& shape, int dtype)
+        : static_data(false), delay_alloc(false) {
+      var = Engine::Get()->NewVariable();
+      ctx = Context::CPUShared(0);
+      shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype);;
+      shandle.ctx = ctx;
+      shandle.shared_pid = shared_pid;
+      shandle.shared_id = shared_id;
+      Storage::Get()->Alloc(&shandle);
+      storage_shape = shape;
+    }
     // Constructor for a non-default storage chunk
     Chunk(NDArrayStorageType storage_type_, const TShape &storage_shape_, Context ctx_,
           bool delay_alloc_, int dtype, const std::vector<int> &aux_types_,
diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h
index 7e3af8eeca..cbe24c28be 100644
--- a/include/mxnet/storage.h
+++ b/include/mxnet/storage.h
@@ -50,6 +50,11 @@ class Storage {
      * \brief Context information about device and ID.
      */
     Context ctx;
+    /*!
+     * \brief Id for IPC shared memory
+     */
+    int shared_pid{-1};
+    int shared_id{-1};
   };
   /*!
    * \brief Allocate a new contiguous memory for a given size.
@@ -57,7 +62,23 @@ class Storage {
    * \param ctx Context information about the device and ID.
    * \return Handle struct.
    */
-  virtual Handle Alloc(size_t size, Context ctx) = 0;
+  Handle Alloc(size_t size, Context ctx) {
+    Handle hd;
+    hd.size = size;
+    hd.ctx = ctx;
+    this->Alloc(&hd);
+    return hd;
+  }
+  /*!
+   * \brief Allocate a new contiguous memory for a given size.
+   * \param handle handle initialized with size and ctx
+   */
+  virtual void Alloc(Handle* handle) = 0;
+  /*!
+   * \brief Increase ref counter on shared memory.
+   * \param handle handle to shared memory.
+   */
+  virtual void SharedIncrementRefCount(Handle handle) = 0;
   /*!
    * \brief Free storage.
    * \param handle Handle struect.
diff --git a/python/mxnet/context.py b/python/mxnet/context.py
index 9798b480d2..beccaebcef 100644
--- a/python/mxnet/context.py
+++ b/python/mxnet/context.py
@@ -62,8 +62,8 @@ class Context(object):
     """
     # static class variable
     default_ctx = None
-    devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned'}
-    devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3}
+    devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'}
+    devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5}
     def __init__(self, device_type, device_id=0):
         if isinstance(device_type, Context):
             self.device_typeid = device_type.device_typeid
@@ -128,14 +128,13 @@ def cpu(device_id=0):
 
     Examples
     ----------
-    >>> with mx.Context('cpu', 1):
+    >>> with mx.cpu():
     ...     cpu_array = mx.nd.ones((2, 3))
     >>> cpu_array.context
-    cpu(1)
-    >>> with mx.cpu(1):
-    ...    cpu_array = mx.nd.ones((2, 3))
+    cpu(0)
+    >>> cpu_array = mx.nd.ones((2, 3), ctx=mx.cpu())
     >>> cpu_array.context
-    cpu(1)
+    cpu(0)
 
     Parameters
     ----------
@@ -151,6 +150,36 @@ def cpu(device_id=0):
     return Context('cpu', device_id)
 
 
+def cpu_pinned(device_id=0):
+    """Returns a CPU pinned memory context. Copying from CPU pinned memory to GPU
+    is faster than from normal CPU memory.
+
+    This function is a short cut for ``Context('cpu_pinned', device_id)``.
+
+    Examples
+    ----------
+    >>> with mx.cpu_pinned():
+    ...     cpu_array = mx.nd.ones((2, 3))
+    >>> cpu_array.context
+    cpu_pinned(0)
+    >>> cpu_array = mx.nd.ones((2, 3), ctx=mx.cpu_pinned())
+    >>> cpu_array.context
+    cpu_pinned(0)
+
+    Parameters
+    ----------
+    device_id : int, optional
+        The device id of the device. `device_id` is not needed for CPU.
+        This is included to make interface compatible with GPU.
+
+    Returns
+    -------
+    context : Context
+        The corresponding CPU pinned memory context.
+    """
+    return Context('cpu_pinned', device_id)
+
+
 def gpu(device_id=0):
     """Returns a GPU context.
 
@@ -159,12 +188,14 @@ def gpu(device_id=0):
 
     Examples
     ----------
-    >>> with mx.Context('gpu', 1):
+    >>> cpu_array = mx.nd.ones((2, 3))
+    >>> cpu_array.context
+    cpu(0)
+    >>> with mx.gpu(1):
     ...     gpu_array = mx.nd.ones((2, 3))
     >>> gpu_array.context
     gpu(1)
-    >>> with mx.gpu(1):
-    ...    gpu_array = mx.nd.ones((2, 3))
+    >>> gpu_array = mx.nd.ones((2, 3), ctx=mx.gpu(1))
     >>> gpu_array.context
     gpu(1)
 
diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 4f029bf409..beb228ec24 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -20,24 +20,107 @@
 """Dataset generator."""
 __all__ = ['DataLoader']
 
+import multiprocessing
+import multiprocessing.queues
+from multiprocessing.reduction import ForkingPickler
+import pickle
+import io
+import os
+import sys
+import warnings
 import numpy as np
 
 from . import sampler as _sampler
-from ... import nd
+from ... import nd, context
 
 
-def _batchify(data):
+def rebuild_ndarray(*args):
+    """Rebuild ndarray from pickled shared memory"""
+    # pylint: disable=no-value-for-parameter
+    return nd.NDArray(nd.ndarray._new_from_shared_mem(*args))
+
+
+def reduce_ndarray(data):
+    """Reduce ndarray to shared memory handle"""
+    return rebuild_ndarray, data._to_shared_mem()
+
+ForkingPickler.register(nd.NDArray, reduce_ndarray)
+
+
+class ConnectionWrapper(object):
+    """Connection wrapper for multiprocessing that supports sending
+    NDArray via shared memory."""
+
+    def __init__(self, conn):
+        self.conn = conn
+
+    def send(self, obj):
+        """Send object"""
+        buf = io.BytesIO()
+        ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)
+        self.send_bytes(buf.getvalue())
+
+    def recv(self):
+        """Receive object"""
+        buf = self.recv_bytes()
+        return pickle.loads(buf)
+
+    def __getattr__(self, name):
+        """Emmulate conn"""
+        return getattr(self.conn, name)
+
+
+class Queue(multiprocessing.queues.Queue):
+    """Wrapper for multiprocessing queue that dumps NDArray with shared memory."""
+    def __init__(self, *args, **kwargs):
+        if sys.version_info[0] <= 2:
+            super(Queue, self).__init__(*args, **kwargs)
+        else:
+            super(Queue, self).__init__(*args, ctx=multiprocessing.get_context(),
+                                        **kwargs)
+        self._reader = ConnectionWrapper(self._reader)
+        self._writer = ConnectionWrapper(self._writer)
+        self._send = self._writer.send
+        self._recv = self._reader.recv
+
+
+def default_batchify_fn(data):
     """Collate data into batch."""
     if isinstance(data[0], nd.NDArray):
         return nd.stack(*data)
     elif isinstance(data[0], tuple):
         data = zip(*data)
-        return [_batchify(i) for i in data]
+        return [default_batchify_fn(i) for i in data]
     else:
         data = np.asarray(data)
         return nd.array(data, dtype=data.dtype)
 
 
+def default_mp_batchify_fn(data):
+    """Collate data into batch. Use shared memory for stacking."""
+    if isinstance(data[0], nd.NDArray):
+        out = nd.empty((len(data),) + data[0].shape, dtype=data[0].dtype,
+                       ctx=context.Context('cpu_shared', 0))
+        return nd.stack(*data, out=out)
+    elif isinstance(data[0], tuple):
+        data = zip(*data)
+        return [default_mp_batchify_fn(i) for i in data]
+    else:
+        data = np.asarray(data)
+        return nd.array(data, dtype=data.dtype,
+                        ctx=context.Context('cpu_shared', 0))
+
+
+def worker_loop(dataset, key_queue, data_queue, batchify_fn):
+    """Worker loop for multiprocessing DataLoader."""
+    while True:
+        idx, samples = key_queue.get()
+        if idx is None:
+            break
+        batch = batchify_fn([dataset[i] for i in samples])
+        data_queue.put((idx, batch))
+
+
 class DataLoader(object):
     """Loads data from a dataset and returns mini-batches of data.
 
@@ -62,9 +145,27 @@ class DataLoader(object):
     batch_sampler : Sampler
         A sampler that returns mini-batches. Do not specify batch_size,
         shuffle, sampler, and last_batch if batch_sampler is specified.
+    batchify_fn : callable
+        Callback function to allow users to specify how to merge samples
+        into a batch. Defaults to `default_batchify_fn`::
+
+            def default_batchify_fn(data):
+                if isinstance(data[0], nd.NDArray):
+                    return nd.stack(*data)
+                elif isinstance(data[0], tuple):
+                    data = zip(*data)
+                    return [default_batchify_fn(i) for i in data]
+                else:
+                    data = np.asarray(data)
+                    return nd.array(data, dtype=data.dtype)
+
+    num_workers : int, default 0
+        The number of multiprocessing workers to use for data preprocessing.
+        `num_workers > 0` is not supported on Windows yet.
     """
     def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
-                 last_batch=None, batch_sampler=None):
+                 last_batch=None, batch_sampler=None, batchify_fn=None,
+                 num_workers=0):
         self._dataset = dataset
 
         if batch_sampler is None:
@@ -87,10 +188,53 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
                              "not be specified if batch_sampler is specified.")
 
         self._batch_sampler = batch_sampler
+        if num_workers > 0 and os.name == 'nt':
+            warnings.warn("DataLoader does not support num_workers > 0 on Windows yet.")
+            num_workers = 0
+        self._num_workers = num_workers
+        if batchify_fn is None:
+            if num_workers > 0:
+                self._batchify_fn = default_mp_batchify_fn
+            else:
+                self._batchify_fn = default_batchify_fn
+        else:
+            self._batchify_fn = batchify_fn
 
     def __iter__(self):
-        for batch in self._batch_sampler:
-            yield _batchify([self._dataset[idx] for idx in batch])
+        if self._num_workers == 0:
+            for batch in self._batch_sampler:
+                yield self._batchify_fn([self._dataset[idx] for idx in batch])
+            return
+
+        key_queue = Queue()
+        data_queue = Queue(2*self._num_workers)
+
+        workers = []
+        for _ in range(self._num_workers):
+            worker = multiprocessing.Process(
+                target=worker_loop,
+                args=(self._dataset, key_queue, data_queue, self._batchify_fn))
+            worker.daemon = True
+            worker.start()
+            workers.append(worker)
+
+        for idx, batch in enumerate(self._batch_sampler):
+            key_queue.put((idx, batch))
+
+        data_buffer = {}
+        curr_idx = 0
+        for _ in range(len(self._batch_sampler)):
+            idx, batch = data_queue.get()
+            data_buffer[idx] = batch
+            while curr_idx in data_buffer:
+                yield data_buffer.pop(curr_idx)
+                curr_idx += 1
+
+        for _ in range(self._num_workers):
+            key_queue.put((None, None))
+
+        for worker in workers:
+            worker.join()
 
     def __len__(self):
         return len(self._batch_sampler)
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 6cbf3284e5..d7536bc76f 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -139,6 +139,18 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
     return hdl
 
 
+def _new_from_shared_mem(shared_pid, shared_id, shape, dtype):
+    hdl = NDArrayHandle()
+    check_call(_LIB.MXNDArrayCreateFromSharedMem(
+        ctypes.c_int(shared_pid),
+        ctypes.c_int(shared_id),
+        c_array(mx_uint, shape),
+        mx_uint(len(shape)),
+        ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
+        ctypes.byref(hdl)))
+    return hdl
+
+
 def waitall():
     """Wait for all async operations to finish in MXNet.
 
@@ -173,6 +185,13 @@ def __repr__(self):
     def __reduce__(self):
         return NDArray, (None,), self.__getstate__()
 
+    def _to_shared_mem(self):
+        shared_pid = ctypes.c_int()
+        shared_id = ctypes.c_int()
+        check_call(_LIB.MXNDArrayGetSharedMemHandle(
+            self.handle, ctypes.byref(shared_pid), ctypes.byref(shared_id)))
+        return shared_pid.value, shared_id.value, self.shape, self.dtype
+
     def __add__(self, other):
         """x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
         return add(self, other)
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 15cd061579..da759fe2f4 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -34,6 +34,7 @@
 #include <mxnet/c_api.h>
 #include <mxnet/kvstore.h>
 #include <mxnet/rtc.h>
+#include <mxnet/storage.h>
 #include <vector>
 #include <sstream>
 #include <string>
@@ -1241,3 +1242,31 @@ int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args,
 #endif
   API_END();
 }
+
+
+int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid, int* shared_id) {
+  API_BEGIN();
+  NDArray* arr = reinterpret_cast<NDArray*>(handle);
+  Storage::Handle shandle;
+  if (arr->ctx().dev_type == Context::kCPUShared) {
+    arr->WaitToRead();
+    shandle = arr->storage_handle();
+    Storage::Get()->SharedIncrementRefCount(shandle);
+  } else {
+    NDArray new_arr(arr->shape(), Context::CPUShared(0), false, arr->dtype());
+    CopyFromTo(*arr, new_arr);
+    new_arr.WaitToRead();
+    shandle = new_arr.storage_handle();
+    Storage::Get()->SharedIncrementRefCount(shandle);
+  }
+  *shared_pid = shandle.shared_pid;
+  *shared_id = shandle.shared_id;
+  API_END();
+}
+
+int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape,
+                                 mx_uint ndim, int dtype, NDArrayHandle *out) {
+  API_BEGIN();
+  *out = new NDArray(shared_pid, shared_id, TShape(shape, shape + ndim), dtype);
+  API_END();
+}
diff --git a/src/common/lazy_alloc_array.h b/src/common/lazy_alloc_array.h
index aa2cd4a139..36e6af53fc 100644
--- a/src/common/lazy_alloc_array.h
+++ b/src/common/lazy_alloc_array.h
@@ -56,8 +56,6 @@ class LazyAllocArray {
   /*! \brief clear all the allocated elements in array */
   inline void Clear();
 
-  void SignalForKill();
-
  private:
   template<typename SyncObject>
   class unique_unlock {
@@ -86,12 +84,12 @@ class LazyAllocArray {
   /*! \brief overflow array of more elements */
   std::vector<std::shared_ptr<TElem> > more_;
   /*! \brief Signal shutdown of array */
-  std::atomic<bool> exit_now_;
+  std::atomic<bool> is_clearing_;
 };
 
 template<typename TElem>
 inline LazyAllocArray<TElem>::LazyAllocArray()
-  : exit_now_(false) {
+  : is_clearing_(false) {
 }
 
 // implementations
@@ -106,7 +104,7 @@ inline std::shared_ptr<TElem> LazyAllocArray<TElem>::Get(int index, FCreate crea
       return ptr;
     } else {
       std::lock_guard<std::mutex> lock(create_mutex_);
-      if (!exit_now_.load()) {
+      if (!is_clearing_.load()) {
         std::shared_ptr<TElem> ptr = head_[idx];
         if (ptr) {
           return ptr;
@@ -117,7 +115,7 @@ inline std::shared_ptr<TElem> LazyAllocArray<TElem>::Get(int index, FCreate crea
     }
   } else {
     std::lock_guard<std::mutex> lock(create_mutex_);
-    if (!exit_now_.load()) {
+    if (!is_clearing_.load()) {
       idx -= kInitSize;
       if (more_.size() <= idx) {
         more_.reserve(idx + 1);
@@ -139,7 +137,7 @@ inline std::shared_ptr<TElem> LazyAllocArray<TElem>::Get(int index, FCreate crea
 template<typename TElem>
 inline void LazyAllocArray<TElem>::Clear() {
   std::unique_lock<std::mutex> lock(create_mutex_);
-  exit_now_.store(true);
+  is_clearing_.store(true);
   // Currently, head_ and more_ never get smaller, so it's safe to
   // iterate them outside of the lock.  The loops should catch
   // any growth which might happen when create_mutex_ is unlocked
@@ -155,6 +153,8 @@ inline void LazyAllocArray<TElem>::Clear() {
     unique_unlock<std::mutex> unlocker(&lock);
     p = std::shared_ptr<TElem>(nullptr);
   }
+  more_.clear();
+  is_clearing_.store(false);
 }
 
 template<typename TElem>
@@ -173,12 +173,6 @@ inline void LazyAllocArray<TElem>::ForEach(FVisit fvisit) {
   }
 }
 
-template<typename TElem>
-inline void LazyAllocArray<TElem>::SignalForKill() {
-  std::lock_guard<std::mutex> lock(create_mutex_);
-  exit_now_.store(true);
-}
-
 }  // namespace common
 }  // namespace mxnet
 #endif  // MXNET_COMMON_LAZY_ALLOC_ARRAY_H_
diff --git a/src/common/rtc.cc b/src/common/rtc.cc
index cd26f0e05a..cc51aaa108 100644
--- a/src/common/rtc.cc
+++ b/src/common/rtc.cc
@@ -124,7 +124,7 @@ void CudaModule::Kernel::Launch(
     uint32_t grid_dim_x, uint32_t grid_dim_y, uint32_t grid_dim_z,
     uint32_t block_dim_x, uint32_t block_dim_y, uint32_t block_dim_z,
     uint32_t shared_mem) {
-  CHECK_EQ(ctx.dev_mask(), gpu::kDevMask)
+  CHECK_EQ(ctx.dev_mask(), Context::kGPU)
       << "CUDA Runtime compilation only supports Nvidia GPU.";
 
   auto mod = mod_;
diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h
index cd6db53f14..c052924825 100644
--- a/src/engine/stream_manager.h
+++ b/src/engine/stream_manager.h
@@ -89,6 +89,8 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
 #else
       LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
 #endif  // MXNET_USE_CUDA
+    default:
+      LOG(FATAL) << "Not Reached";
     }
   }
   return ret;
@@ -116,6 +118,8 @@ RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
 #else
       LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
 #endif  // MXNET_USE_CUDA
+    default:
+      LOG(FATAL) << "Not Reached";
     }
   }
   return ret;
diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc
index e01dd4ed45..413a7a5e97 100644
--- a/src/engine/threaded_engine_perdevice.cc
+++ b/src/engine/threaded_engine_perdevice.cc
@@ -50,6 +50,38 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
   static auto constexpr kWorkerQueue = kFIFO;
 
   ThreadedEnginePerDevice() noexcept(false) {
+    this->Start();
+#ifndef _WIN32
+    pthread_atfork(
+      []() {
+        Engine::Get()->WaitForAll();
+        Engine::Get()->Stop();
+      },
+      []() {
+        Engine::Get()->Start();
+      },
+      []() {
+        // Make children single threaded since they are typically workers
+        dmlc::SetEnv("MXNET_CPU_WORKER_NTHREADS", 1);
+        dmlc::SetEnv("OMP_NUM_THREADS", 1);
+        OpenMP::Get()->set_enabled(false);
+        Engine::Get()->Start();
+      });
+#endif
+  }
+  ~ThreadedEnginePerDevice() noexcept(false) {
+    this->Stop();
+  }
+
+  void Stop() override {
+    SignalQueuesForKill();
+    gpu_normal_workers_.Clear();
+    gpu_copy_workers_.Clear();
+    cpu_normal_workers_.Clear();
+    cpu_priority_worker_.reset(nullptr);
+  }
+
+  void Start() override {
     gpu_worker_nthreads_ = common::GetNumThreadPerGPU();
     cpu_worker_nthreads_ = dmlc::GetEnv("MXNET_CPU_WORKER_NTHREADS", 1);
     // create CPU task
@@ -61,13 +93,6 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
         }));
     // GPU tasks will be created lazily
   }
-  ~ThreadedEnginePerDevice() noexcept(false) {
-    SignalQueuesForKill();
-    gpu_normal_workers_.Clear();
-    gpu_copy_workers_.Clear();
-    cpu_normal_workers_.Clear();
-    cpu_priority_worker_.reset(nullptr);
-  }
 
  protected:
   void PushToExecute(OprBlock *opr_block, bool pusher_thread) override {
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 34099d029c..ecc40314e9 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -66,9 +66,9 @@ inline Context GetContext(const nnvm::NodeAttrs& attrs,
   } else {
     ctx = default_ctx;
   }
-  // Pinned context doesn't propagate
-  if (ctx.dev_type == Context::kCPUPinned) {
-    ctx = Context::CPU();
+  // Non-default context (pinned, shared) does not propagate
+  if (ctx.dev_mask() != ctx.dev_type) {
+    ctx = Context::Create(ctx.dev_mask(), ctx.dev_id);
   }
 #if !MXNET_USE_CUDA
   if (ctx.dev_mask() == gpu::kDevMask) {
@@ -659,9 +659,12 @@ inline std::vector<Context> PlaceDevice(const nnvm::IndexedGraph& idx) {
       vctx[j.node_id] = vctx[i];
     }
   }
+  // check all context initialized
   for (size_t i = 0; i < idx.num_nodes(); ++i) {
     CHECK_NE(vctx[i].dev_type, -1)
         << "Cannot decide context for node " << idx[i].source->attrs.name;
+    // Non-default context do not propagate.
+    vctx[i].dev_type = vctx[i].dev_mask();
   }
 
   return vctx;
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 18ca2117a4..3a1fed080f 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -580,8 +580,8 @@ void ElementwiseSum(const std::vector<NDArray> &source, NDArray *out, int priori
     }
     CHECK_EQ(source[i].shape() , out->shape())
         << "operands shape mismatch";
-    if (out->ctx().dev_mask() == cpu::kDevMask) {
-      CHECK_EQ(source[i].ctx().dev_mask(),  cpu::kDevMask)
+    if (out->ctx().dev_mask() == Context::kCPU) {
+      CHECK_EQ(source[i].ctx().dev_mask(), Context::kCPU)
           << "operands context mismatch";
     } else {
       CHECK(source[i].ctx() == out->ctx())
@@ -1361,7 +1361,7 @@ void Imdecode(NDArray *ret, NDArray mean, size_t index,
     CHECK_EQ(ret->shape().ndim(), 4U);
     buff = ret->Slice(index, index+1);
   }
-  CHECK_EQ(buff.ctx().dev_mask(), cpu::kDevMask);
+  CHECK_EQ(buff.ctx().dev_mask(), Context::kCPU);
   CHECK_EQ(n_channels, buff.shape()[1]);
   CHECK_EQ(y1-y0, buff.shape()[2]);
   CHECK_EQ(x1-x0, buff.shape()[3]);
@@ -1381,7 +1381,7 @@ void Imdecode(NDArray *ret, NDArray mean, size_t index,
     })
   } else {
     CHECK_EQ(mean.dtype(), buff.dtype());
-    CHECK_EQ(mean.ctx().dev_mask(), cpu::kDevMask);
+    CHECK_EQ(mean.ctx().dev_mask(), Context::kCPU);
     CHECK_EQ(mean.shape()[0], buff.shape()[1]);
     CHECK_EQ(mean.shape()[1], buff.shape()[2]);
     CHECK_EQ(mean.shape()[2], buff.shape()[3]);
diff --git a/src/storage/cpu_shared_storage_manager.h b/src/storage/cpu_shared_storage_manager.h
new file mode 100644
index 0000000000..d623cf2c7b
--- /dev/null
+++ b/src/storage/cpu_shared_storage_manager.h
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_STORAGE_CPU_SHARED_STORAGE_MANAGER_H_
+#define MXNET_STORAGE_CPU_SHARED_STORAGE_MANAGER_H_
+
+#if MXNET_USE_CUDA
+  #include <cuda_runtime.h>
+#endif  // MXNET_USE_CUDA
+#include <mxnet/base.h>
+
+#ifndef _WIN32
+#include <sys/mman.h>
+#include <sys/fcntl.h>
+#include <unistd.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#endif  // _WIN32
+
+#include <unordered_map>
+#include <vector>
+#include <atomic>
+#include <iostream>
+#include <mutex>
+#include <new>
+#include <string>
+#include <limits>
+
+#include "./storage_manager.h"
+#include "../common/cuda_utils.h"
+
+
+namespace mxnet {
+namespace storage {
+/*!
+ * \brief Storage manager for cpu shared memory
+ */
+class CPUSharedStorageManager final : public StorageManager {
+ public:
+  /*!
+   * \brief Default constructor.
+   */
+  CPUSharedStorageManager() : rand_gen_(std::random_device()()) {}
+  /*!
+   * \brief Default destructor.
+   */
+  ~CPUSharedStorageManager() {
+    for (const auto& kv : pool_) {
+      FreeImpl(kv.second);
+    }
+  }
+
+  void Alloc(Storage::Handle* handle) override;
+  void Free(Storage::Handle handle) override {
+    pool_.erase(handle.dptr);
+    FreeImpl(handle);
+  }
+
+  void DirectFree(Storage::Handle handle) override {
+    Free(handle);
+  }
+
+  void IncrementRefCount(const Storage::Handle& handle) {
+    std::atomic<int>* counter = reinterpret_cast<std::atomic<int>*>(
+        static_cast<char*>(handle.dptr) - alignment_);
+    ++(*counter);
+  }
+
+  int DecrementRefCount(const Storage::Handle& handle) {
+    std::atomic<int>* counter = reinterpret_cast<std::atomic<int>*>(
+        static_cast<char*>(handle.dptr) - alignment_);
+    return --(*counter);
+  }
+
+ private:
+  static constexpr size_t alignment_ = 16;
+
+  std::mutex mutex_;
+  std::mt19937 rand_gen_;
+  std::unordered_map<void*, Storage::Handle> pool_;
+
+  void FreeImpl(const Storage::Handle& handle);
+
+  std::string SharedHandleToString(int shared_pid, int shared_id) {
+    std::stringstream name;
+    name << "/mx_" << std::hex << shared_pid << "_" << std::hex << shared_id;
+    return name.str();
+  }
+  DISALLOW_COPY_AND_ASSIGN(CPUSharedStorageManager);
+};  // class CPUSharedStorageManager
+
+void CPUSharedStorageManager::Alloc(Storage::Handle* handle) {
+  std::lock_guard<std::mutex> lock(mutex_);
+  std::uniform_int_distribution<> dis(0, std::numeric_limits<int>::max());
+  int fid = -1;
+  bool is_new = false;
+  size_t size = handle->size + alignment_;
+  void* ptr = nullptr;
+#ifdef _WIN32
+  LOG(FATAL) << "Shared memory is not supported on Windows yet.";
+#else
+  if (handle->shared_id == -1 && handle->shared_pid == -1) {
+    is_new = true;
+    handle->shared_pid = getpid();
+    for (int i = 0; i < 10; ++i) {
+      handle->shared_id = dis(rand_gen_);
+      auto filename = SharedHandleToString(handle->shared_pid, handle->shared_id);
+      fid = shm_open(filename.c_str(), O_EXCL|O_CREAT|O_RDWR, 0666);
+      if (fid != -1) break;
+    }
+  } else {
+    auto filename = SharedHandleToString(handle->shared_pid, handle->shared_id);
+    fid = shm_open(filename.c_str(), O_RDWR, 0666);
+  }
+
+  if (fid == -1) {
+    LOG(FATAL) << "Failed to open shared memory. shm_open failed with error "
+               << strerror(errno);
+  }
+
+  if (is_new) CHECK_EQ(ftruncate(fid, size), 0);
+
+  ptr = mmap(NULL, size, PROT_READ|PROT_WRITE, MAP_SHARED, fid, 0);
+  CHECK_NE(ptr, MAP_FAILED)
+      << "Failed to map shared memory. mmap failed with error " << strerror(errno);
+#endif  // _WIN32
+
+  if (is_new) {
+    new (ptr) std::atomic<int>(1);
+  }
+  handle->dptr = static_cast<char*>(ptr) + alignment_;
+  pool_[handle->dptr] = *handle;
+}
+
+void CPUSharedStorageManager::FreeImpl(const Storage::Handle& handle) {
+  int count = DecrementRefCount(handle);
+  CHECK_GE(count, 0);
+#ifdef _WIN32
+  LOG(FATAL) << "Shared memory is not supported on Windows yet.";
+#else
+  CHECK_EQ(munmap(static_cast<char*>(handle.dptr) - alignment_,
+                  handle.size + alignment_), 0)
+      << "Failed to unmap shared memory. munmap failed with error "
+      << strerror(errno);
+
+  if (count == 0) {
+    auto filename = SharedHandleToString(handle.shared_pid, handle.shared_id);
+    CHECK_EQ(shm_unlink(filename.c_str()), 0)
+        << "Failed to unlink shared memory. shm_unlink failed with error "
+        << strerror(errno);
+  }
+#endif  // _WIN32
+}
+
+}  // namespace storage
+}  // namespace mxnet
+
+#endif  // MXNET_STORAGE_CPU_SHARED_STORAGE_MANAGER_H_
diff --git a/src/storage/naive_storage_manager.h b/src/storage/naive_storage_manager.h
index 731f374bbf..2479039cdf 100644
--- a/src/storage/naive_storage_manager.h
+++ b/src/storage/naive_storage_manager.h
@@ -44,11 +44,11 @@ class NaiveStorageManager final : public StorageManager {
    * \brief Default destructor.
    */
   ~NaiveStorageManager() = default;
-  void* Alloc(size_t size) override;
-  void Free(void* ptr, size_t) override;
+  void Alloc(Storage::Handle* handle) override;
+  void Free(Storage::Handle handle) override;
 
-  void DirectFree(void* ptr, size_t size) override {
-    DeviceStorage::Free(ptr);
+  void DirectFree(Storage::Handle handle) override {
+    DeviceStorage::Free(handle.dptr);
   }
 
  private:
@@ -56,13 +56,13 @@ class NaiveStorageManager final : public StorageManager {
 };  // class NaiveStorageManager
 
 template <class DeviceStorage>
-void* NaiveStorageManager<DeviceStorage>::Alloc(size_t size) {
-  return DeviceStorage::Alloc(size);
+void NaiveStorageManager<DeviceStorage>::Alloc(Storage::Handle* handle) {
+  handle->dptr = DeviceStorage::Alloc(handle->size);
 }
 
 template <class DeviceStorage>
-void NaiveStorageManager<DeviceStorage>::Free(void* ptr, size_t) {
-  DeviceStorage::Free(ptr);
+void NaiveStorageManager<DeviceStorage>::Free(Storage::Handle handle) {
+  DeviceStorage::Free(handle.dptr);
 }
 
 }  // namespace storage
diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h
index b2c6633a80..634e57a8d1 100644
--- a/src/storage/pooled_storage_manager.h
+++ b/src/storage/pooled_storage_manager.h
@@ -58,12 +58,12 @@ class GPUPooledStorageManager final : public StorageManager {
     ReleaseAll();
   }
 
-  void* Alloc(size_t raw_size) override;
-  void Free(void* ptr, size_t raw_size) override;
+  void Alloc(Storage::Handle* handle) override;
+  void Free(Storage::Handle handle) override;
 
-  void DirectFree(void* ptr, size_t raw_size) override {
-    cudaError_t err = cudaFree(ptr);
-    size_t size = raw_size + NDEV;
+  void DirectFree(Storage::Handle handle) override {
+    cudaError_t err = cudaFree(handle.dptr);
+    size_t size = handle.size + NDEV;
     // ignore unloading error, as memory has already been recycled
     if (err != cudaSuccess && err != cudaErrorCudartUnloading) {
       LOG(FATAL) << "CUDA: " << cudaGetErrorString(err);
@@ -86,9 +86,9 @@ class GPUPooledStorageManager final : public StorageManager {
   DISALLOW_COPY_AND_ASSIGN(GPUPooledStorageManager);
 };  // class GPUPooledStorageManager
 
-void* GPUPooledStorageManager::Alloc(size_t raw_size) {
+void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {
   std::lock_guard<std::mutex> lock(mutex_);
-  size_t size = raw_size + NDEV;
+  size_t size = handle->size + NDEV;
   auto&& reuse_it = memory_pool_.find(size);
   if (reuse_it == memory_pool_.end() || reuse_it->second.size() == 0) {
     size_t free, total;
@@ -102,26 +102,29 @@ void* GPUPooledStorageManager::Alloc(size_t raw_size) {
       LOG(FATAL) << "cudaMalloc failed: " << cudaGetErrorString(e);
     }
     used_memory_ += size;
-    return ret;
+    handle->dptr = ret;
   } else {
     auto&& reuse_pool = reuse_it->second;
     auto ret = reuse_pool.back();
     reuse_pool.pop_back();
-    return ret;
+    handle->dptr = ret;
   }
 }
 
-void GPUPooledStorageManager::Free(void* ptr, size_t raw_size) {
+void GPUPooledStorageManager::Free(Storage::Handle handle) {
   std::lock_guard<std::mutex> lock(mutex_);
-  size_t size = raw_size + NDEV;
+  size_t size = handle.size + NDEV;
   auto&& reuse_pool = memory_pool_[size];
-  reuse_pool.push_back(ptr);
+  reuse_pool.push_back(handle.dptr);
 }
 
 void GPUPooledStorageManager::ReleaseAll() {
   for (auto&& i : memory_pool_) {
     for (auto&& j : i.second) {
-      DirectFree(j, i.first - NDEV);
+      Storage::Handle handle;
+      handle.dptr = j;
+      handle.size = i.first - NDEV;
+      DirectFree(handle);
     }
   }
   memory_pool_.clear();
diff --git a/src/storage/storage.cc b/src/storage/storage.cc
index fa15a44b4f..7a41900483 100644
--- a/src/storage/storage.cc
+++ b/src/storage/storage.cc
@@ -26,6 +26,7 @@
 #include "./storage_manager.h"
 #include "./naive_storage_manager.h"
 #include "./pooled_storage_manager.h"
+#include "./cpu_shared_storage_manager.h"
 #include "./cpu_device_storage.h"
 #include "./pinned_memory_storage.h"
 #include "../common/cuda_utils.h"
@@ -36,9 +37,10 @@ namespace mxnet {
 // consider change storage as a pure abstract class
 class StorageImpl : public Storage {
  public:
-  Handle Alloc(size_t size, Context ctx) override;
+  void Alloc(Handle* handle) override;
   void Free(Handle handle) override;
   void DirectFree(Handle handle) override;
+  void SharedIncrementRefCount(Handle handle) override;
   StorageImpl() {}
   virtual ~StorageImpl() = default;
 
@@ -51,12 +53,13 @@ class StorageImpl : public Storage {
 
   static void ActivateDevice(Context ctx) {
     switch (ctx.dev_type) {
-      case Context::kCPU: break;
+      case Context::kCPU:
+      case Context::kCPUShared: break;
       case Context::kGPU:
       case Context::kCPUPinned: {
 #if MXNET_USE_CUDA
           if (num_gpu_device > 0) {
-            CUDA_CALL(cudaSetDevice(ctx.dev_id));
+            CUDA_CALL(cudaSetDevice(ctx.real_dev_id()));
           }
 #endif  // MXNET_USE_CUDA
           break;
@@ -73,20 +76,21 @@ class StorageImpl : public Storage {
 int StorageImpl::num_gpu_device = 0;
 #endif  // MXNET_USE_CUDA
 
-Storage::Handle StorageImpl::Alloc(size_t size, Context ctx) {
+void StorageImpl::Alloc(Storage::Handle* handle) {
   // space already recycled, ignore request
-  Handle hd;
-  hd.ctx = ctx;
-  hd.size = size;
-  auto&& device = storage_managers_.at(ctx.dev_type);
+  auto&& device = storage_managers_.at(handle->ctx.dev_type);
   std::shared_ptr<storage::StorageManager> manager = device.Get(
-      ctx.dev_id, [ctx]() {
+      handle->ctx.real_dev_id(), [handle]() {
         storage::StorageManager *ptr = nullptr;
-        switch (ctx.dev_type) {
+        switch (handle->ctx.dev_type) {
           case Context::kCPU: {
             ptr = new storage::NaiveStorageManager<storage::CPUDeviceStorage>();
             break;
           }
+          case Context::kCPUShared: {
+            ptr = new storage::CPUSharedStorageManager();
+            break;
+          }
           case Context::kCPUPinned: {
 #if MXNET_USE_CUDA
             num_gpu_device = 0;
@@ -114,38 +118,47 @@ Storage::Handle StorageImpl::Alloc(size_t size, Context ctx) {
 #endif  // MXNET_USE_CUDA
             break;
           }
-          default: LOG(FATAL) <<  "Unimplemented device " << ctx.dev_type;
+          default: LOG(FATAL) <<  "Unimplemented device " << handle->ctx.dev_type;
         }
         return ptr;
       });
-  this->ActivateDevice(ctx);
-  hd.dptr = manager->Alloc(size);
-  return hd;
+
+  this->ActivateDevice(handle->ctx);
+  manager->Alloc(handle);
 }
 
 void StorageImpl::Free(Storage::Handle handle) {
   const Context &ctx = handle.ctx;
   auto&& device = storage_managers_.at(ctx.dev_type);
   std::shared_ptr<storage::StorageManager> manager = device.Get(
-      ctx.dev_id, []() {
+      ctx.real_dev_id(), []() {
         LOG(FATAL) <<  "Cannot Free space to a device you have not allocated";
         return nullptr;
       });
   this->ActivateDevice(ctx);
-  manager->Free(handle.dptr, handle.size);
+  manager->Free(handle);
 }
 
 void StorageImpl::DirectFree(Storage::Handle handle) {
   const Context &ctx = handle.ctx;
   auto&& device = storage_managers_.at(ctx.dev_type);
   std::shared_ptr<storage::StorageManager> manager = device.Get(
-      ctx.dev_id, []() {
+      ctx.real_dev_id(), []() {
         LOG(FATAL) <<  "Cannot Free space to a device you have not allocated";
         return nullptr;
       });
   this->ActivateDevice(ctx);
-  // directly free ths data.
-  manager->DirectFree(handle.dptr, handle.size);
+  manager->DirectFree(handle);
+}
+
+void StorageImpl::SharedIncrementRefCount(Storage::Handle handle) {
+  CHECK_EQ(handle.ctx.dev_type, Context::kCPUShared);
+  auto&& device = storage_managers_.at(Context::kCPUShared);
+  auto manager = device.Get(0, []() {
+      LOG(FATAL) << "Cannot increment ref count before allocating any shared memory.";
+      return nullptr;
+    });
+  dynamic_cast<storage::CPUSharedStorageManager*>(manager.get())->IncrementRefCount(handle);
 }
 
 std::shared_ptr<Storage> Storage::_GetSharedRef() {
diff --git a/src/storage/storage_manager.h b/src/storage/storage_manager.h
index 924d2ed48b..326af4c590 100644
--- a/src/storage/storage_manager.h
+++ b/src/storage/storage_manager.h
@@ -17,13 +17,10 @@
  * under the License.
  */
 
-/*!
- * \file storage_manager.h
- * \brief Storage manager.
- */
 #ifndef MXNET_STORAGE_STORAGE_MANAGER_H_
 #define MXNET_STORAGE_STORAGE_MANAGER_H_
 
+#include <mxnet/storage.h>
 #include <cstddef>
 
 namespace mxnet {
@@ -39,19 +36,19 @@ class StorageManager {
    * \param size Size to allocate.
    * \return Pointer to the storage.
    */
-  virtual void* Alloc(size_t size) = 0;
+  virtual void Alloc(Storage::Handle* handle) = 0;
   /*!
    * \brief Deallocation.
    * \param ptr Pointer to deallocate.
    * \param size Size of the storage.
    */
-  virtual void Free(void* ptr, size_t size) = 0;
+  virtual void Free(Storage::Handle handle) = 0;
   /*!
    * \brief Direct de-allocation.
    * \param ptr Pointer to deallocate.
    * \param size Size of the storage.
    */
-  virtual void DirectFree(void* ptr, size_t size) = 0;
+  virtual void DirectFree(Storage::Handle handle) = 0;
   /*!
    * \brief Destructor.
    */
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 397fbbd33e..63c5d28b7c 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -17,6 +17,7 @@
 
 import os
 import tarfile
+import unittest
 import mxnet as mx
 import numpy as np
 from mxnet import gluon
@@ -92,6 +93,20 @@ def test_image_folder_dataset():
     assert len(dataset.items) == 16
 
 
+class Dataset(gluon.data.Dataset):
+    def __len__(self):
+        return 100
+    def __getitem__(self, key):
+        return mx.nd.full((10,), key)
+
+@unittest.skip("Somehow fails with MKL. Cannot reproduce locally")
+def test_multi_worker():
+    data = Dataset()
+    loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5)
+    for i, batch in enumerate(loader):
+        assert (batch.asnumpy() == i).all()
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services