You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/10/19 06:44:09 UTC

[incubator-mxnet] branch master updated: [MXNET-1107] Fix CPUPinned unexpected behaviour (#12031)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3154ec3  [MXNET-1107] Fix CPUPinned unexpected behaviour (#12031)
3154ec3 is described below

commit 3154ec3e63d1fb72ca05f898a0a2216b290d7b85
Author: Carl Yang <ca...@gmail.com>
AuthorDate: Thu Oct 18 23:43:51 2018 -0700

    [MXNET-1107] Fix CPUPinned unexpected behaviour (#12031)
    
    * Fix CPUPinned unexpected behaviour
    
    * fix lint
    
    * add guards
    
    * Actually, this may affect perf
    
    * trigger ci
    
    * fix lint
    
    * fix documentation
    
    * fix for dist_sync_device
    
    * add guard
    
    * fix bug with memory
    
    * try fix for gluon mp interaction
    
    * blah
    
    * trigger jenkins
    
    * Try fix for gluon multiprocessing bug
    
    Thanks Nvidia!
    
    * edit
    
    * try nvidia fix
    
    * address Haibin and Lin's comments
    
    * get rid of blank line in Makefile
---
 include/mxnet/base.h        |   4 +-
 src/common/cuda_utils.h     | 222 ++++++++++++++++++++++++--------------------
 src/common/rtc.cc           |   3 +-
 src/engine/stream_manager.h |  10 +-
 src/kvstore/comm.h          |   5 +-
 src/kvstore/comm_tree.h     |   3 +-
 src/kvstore/kvstore_nccl.h  |   6 +-
 src/storage/storage.cc      |  24 +++++
 8 files changed, 168 insertions(+), 109 deletions(-)

diff --git a/include/mxnet/base.h b/include/mxnet/base.h
index dfe1899..783f74a 100644
--- a/include/mxnet/base.h
+++ b/include/mxnet/base.h
@@ -153,10 +153,10 @@ struct Context {
     return dev_type;
   }
   /*!
-   * \brief Returns dev_id for kGPU, 0 otherwise
+   * \brief Returns dev_id for kGPU and kCPUPinned, 0 otherwise
    */
   inline int real_dev_id() const {
-    if (dev_type == kGPU) return dev_id;
+    if (dev_type == kCPUPinned || dev_type == kGPU) return dev_id;
     return 0;
   }
   /*!
diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h
index 0ada350..047edde 100644
--- a/src/common/cuda_utils.h
+++ b/src/common/cuda_utils.h
@@ -68,6 +68,110 @@ inline __device__ bool __is_supported_cuda_architecture() {
 }
 #endif  // __CUDACC__
 
+/*!
+ * \brief Check CUDA error.
+ * \param msg Message to print if an error occured.
+ */
+#define CHECK_CUDA_ERROR(msg)                                                \
+  {                                                                          \
+    cudaError_t e = cudaGetLastError();                                      \
+    CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
+  }
+
+/*!
+ * \brief Protected CUDA call.
+ * \param func Expression to call.
+ *
+ * It checks for CUDA errors after invocation of the expression.
+ */
+#define CUDA_CALL(func)                                            \
+  {                                                                \
+    cudaError_t e = (func);                                        \
+    CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading)       \
+        << "CUDA: " << cudaGetErrorString(e);                      \
+  }
+
+/*!
+ * \brief Protected cuBLAS call.
+ * \param func Expression to call.
+ *
+ * It checks for cuBLAS errors after invocation of the expression.
+ */
+#define CUBLAS_CALL(func)                                       \
+  {                                                             \
+    cublasStatus_t e = (func);                                  \
+    CHECK_EQ(e, CUBLAS_STATUS_SUCCESS)                          \
+        << "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
+  }
+
+/*!
+ * \brief Protected cuSolver call.
+ * \param func Expression to call.
+ *
+ * It checks for cuSolver errors after invocation of the expression.
+ */
+#define CUSOLVER_CALL(func)                                         \
+  {                                                                 \
+    cusolverStatus_t e = (func);                                    \
+    CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS)                            \
+        << "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
+  }
+
+/*!
+ * \brief Protected cuRAND call.
+ * \param func Expression to call.
+ *
+ * It checks for cuRAND errors after invocation of the expression.
+ */
+#define CURAND_CALL(func)                                       \
+  {                                                             \
+    curandStatus_t e = (func);                                  \
+    CHECK_EQ(e, CURAND_STATUS_SUCCESS)                          \
+        << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
+  }
+
+/*!
+ * \brief Protected NVRTC call.
+ * \param func Expression to call.
+ *
+ * It checks for NVRTC errors after invocation of the expression.
+ */
+#define NVRTC_CALL(x)                                   \
+  {                                                     \
+    nvrtcResult result = x;                             \
+    CHECK_EQ(result, NVRTC_SUCCESS)                     \
+      << #x " failed with error "                       \
+      << nvrtcGetErrorString(result);                   \
+  }
+
+/*!
+ * \brief Protected CUDA driver call.
+ * \param func Expression to call.
+ *
+ * It checks for CUDA driver errors after invocation of the expression.
+ */
+#define CUDA_DRIVER_CALL(func)                                          \
+  {                                                                     \
+    CUresult e = (func);                                                \
+    if (e != CUDA_SUCCESS) {                                            \
+      char const * err_msg = nullptr;                                         \
+      if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) {  \
+        LOG(FATAL) << "CUDA Driver: Unknown error " << e;               \
+      } else {                                                          \
+        LOG(FATAL) << "CUDA Driver: " << err_msg;                       \
+      }                                                                 \
+    }                                                                   \
+  }
+
+
+#if !defined(_MSC_VER)
+#define CUDA_UNROLL _Pragma("unroll")
+#define CUDA_NOUNROLL _Pragma("nounroll")
+#else
+#define CUDA_UNROLL
+#define CUDA_NOUNROLL
+#endif
+
 namespace mxnet {
 namespace common {
 /*! \brief common utils for cuda */
@@ -179,113 +283,31 @@ inline DType __device__ CudaMin(DType a, DType b) {
     return a < b ? a : b;
 }
 
-}  // namespace cuda
-}  // namespace common
-}  // namespace mxnet
-
-/*!
- * \brief Check CUDA error.
- * \param msg Message to print if an error occured.
- */
-#define CHECK_CUDA_ERROR(msg)                                                \
-  {                                                                          \
-    cudaError_t e = cudaGetLastError();                                      \
-    CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
+class DeviceStore {
+ public:
+  /*! \brief default constructor- only optionally restores previous device */
+  explicit DeviceStore(bool restore = true) : restore_(restore) {
+    if (restore_)
+      CUDA_CALL(cudaGetDevice(&restore_device_));
   }
 
-/*!
- * \brief Protected CUDA call.
- * \param func Expression to call.
- *
- * It checks for CUDA errors after invocation of the expression.
- */
-#define CUDA_CALL(func)                                            \
-  {                                                                \
-    cudaError_t e = (func);                                        \
-    CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading)       \
-        << "CUDA: " << cudaGetErrorString(e);                      \
+  ~DeviceStore() {
+    if (restore_)
+      CUDA_CALL(cudaSetDevice(restore_device_));
   }
 
-/*!
- * \brief Protected cuBLAS call.
- * \param func Expression to call.
- *
- * It checks for cuBLAS errors after invocation of the expression.
- */
-#define CUBLAS_CALL(func)                                       \
-  {                                                             \
-    cublasStatus_t e = (func);                                  \
-    CHECK_EQ(e, CUBLAS_STATUS_SUCCESS)                          \
-        << "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
+  void SetDevice(int device) {
+    CUDA_CALL(cudaSetDevice(device));
   }
 
-/*!
- * \brief Protected cuSolver call.
- * \param func Expression to call.
- *
- * It checks for cuSolver errors after invocation of the expression.
- */
-#define CUSOLVER_CALL(func)                                         \
-  {                                                                 \
-    cusolverStatus_t e = (func);                                    \
-    CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS)                            \
-        << "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
-  }
+ private:
+  int restore_device_;
+  bool restore_;
+};
 
-/*!
- * \brief Protected cuRAND call.
- * \param func Expression to call.
- *
- * It checks for cuRAND errors after invocation of the expression.
- */
-#define CURAND_CALL(func)                                       \
-  {                                                             \
-    curandStatus_t e = (func);                                  \
-    CHECK_EQ(e, CURAND_STATUS_SUCCESS)                          \
-        << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
-  }
-
-/*!
- * \brief Protected NVRTC call.
- * \param func Expression to call.
- *
- * It checks for NVRTC errors after invocation of the expression.
- */
-#define NVRTC_CALL(x)                                   \
-  {                                                     \
-    nvrtcResult result = x;                             \
-    CHECK_EQ(result, NVRTC_SUCCESS)                     \
-      << #x " failed with error "                       \
-      << nvrtcGetErrorString(result);                   \
-  }
-
-/*!
- * \brief Protected CUDA driver call.
- * \param func Expression to call.
- *
- * It checks for CUDA driver errors after invocation of the expression.
- */
-#define CUDA_DRIVER_CALL(func)                                          \
-  {                                                                     \
-    CUresult e = (func);                                                \
-    if (e != CUDA_SUCCESS) {                                            \
-      char const * err_msg = nullptr;                                         \
-      if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) {  \
-        LOG(FATAL) << "CUDA Driver: Unknown error " << e;               \
-      } else {                                                          \
-        LOG(FATAL) << "CUDA Driver: " << err_msg;                       \
-      }                                                                 \
-    }                                                                   \
-  }
-
-
-#if !defined(_MSC_VER)
-#define CUDA_UNROLL _Pragma("unroll")
-#define CUDA_NOUNROLL _Pragma("nounroll")
-#else
-#define CUDA_UNROLL
-#define CUDA_NOUNROLL
-#endif
+}  // namespace cuda
+}  // namespace common
+}  // namespace mxnet
 
 /*!
  * \brief Determine major version number of the gpu's cuda compute architecture.
diff --git a/src/common/rtc.cc b/src/common/rtc.cc
index da083c9..ea20a60 100644
--- a/src/common/rtc.cc
+++ b/src/common/rtc.cc
@@ -77,11 +77,12 @@ CUfunction CudaModule::Chunk::GetFunction(
   CHECK_EQ(ctx.dev_mask(), Context::kGPU)
       << "CUDA Runtime compilation only supports Nvidia GPU.";
   auto iter = mod_.find(ctx.dev_id);
+  mxnet::common::cuda::DeviceStore device_store;
   CUmodule module;
   if (iter != mod_.end()) {
     module = iter->second;
   } else {
-    CUDA_CALL(cudaSetDevice(ctx.dev_id));
+    device_store.SetDevice(ctx.dev_id);
     CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module, ptx_, 0, 0, 0));
     mod_[ctx.dev_id] = module;
   }
diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h
index ddbfde8..d4ac042 100644
--- a/src/engine/stream_manager.h
+++ b/src/engine/stream_manager.h
@@ -65,6 +65,9 @@ template <std::size_t kNumGpus, std::size_t kStreams>
 RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
     Context const& ctx) {
   RunContext ret;
+#if MXNET_USE_CUDA
+  mxnet::common::cuda::DeviceStore device_store;
+#endif
   switch (ctx.dev_mask()) {
     case cpu::kDevMask:
       ret = RunContext{ctx, nullptr};
@@ -72,7 +75,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
     case gpu::kDevMask: {
 #if MXNET_USE_CUDA
       std::size_t use_counter;
-      CUDA_CALL(cudaSetDevice(ctx.dev_id));
+      device_store.SetDevice(ctx.dev_id);
       {
         std::lock_guard<std::mutex> lock{mutex_};
         auto&& counter = gpu_cnt_.at(ctx.dev_id);
@@ -101,13 +104,16 @@ template <std::size_t kNumGpus, std::size_t kStreams>
 RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
     Context const& ctx) {
   RunContext ret;
+#if MXNET_USE_CUDA
+  mxnet::common::cuda::DeviceStore device_store;
+#endif
   switch (ctx.dev_mask()) {
     case cpu::kDevMask:
       ret = RunContext{ctx, nullptr};
       break;
     case gpu::kDevMask: {
 #if MXNET_USE_CUDA
-      CUDA_CALL(cudaSetDevice(ctx.dev_id));
+      device_store.SetDevice(ctx.dev_id);
       {
         std::lock_guard<std::mutex> lock{mutex_};
         if (gpu_io_streams_.at(ctx.dev_id) == nullptr) {
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index 61370a5..581ef81 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -723,8 +723,11 @@ class CommDevice : public Comm {
     int n = static_cast<int>(gpus.size());
     int enabled = 0;
     std::vector<int> p2p(n*n);
+
+    // Restores active device to what it was before EnableP2P
+    mxnet::common::cuda::DeviceStore device_store;
     for (int i = 0; i < n; ++i) {
-      cudaSetDevice(gpus[i]);
+     device_store.SetDevice(gpus[i]);
       for (int j = 0; j < n; j++) {
         int access;
         cudaDeviceCanAccessPeer(&access, gpus[i], gpus[j]);
diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h
index e857f33..8d36803 100644
--- a/src/kvstore/comm_tree.h
+++ b/src/kvstore/comm_tree.h
@@ -339,8 +339,9 @@ class CommDeviceTree : public CommDevice {
     int n = static_cast<int>(gpus.size());
     int enabled = 0;
     std::vector<int> p2p(n*n);
+    mxnet::common::cuda::DeviceStore device_store;
     for (int i = 0; i < n; ++i) {
-      cudaSetDevice(gpus[i]);
+      device_store.SetDevice(gpus[i]);
       for (int j = 0; j < n; j++) {
         int access;
         cudaDeviceCanAccessPeer(&access, gpus[i], gpus[j]);
diff --git a/src/kvstore/kvstore_nccl.h b/src/kvstore/kvstore_nccl.h
index 485cd95..d0f397c 100644
--- a/src/kvstore/kvstore_nccl.h
+++ b/src/kvstore/kvstore_nccl.h
@@ -428,8 +428,9 @@ class KVStoreNCCL : public KVStoreLocal {
         mutate_vars.push_back(ptr(dst[i])->var());
     }
     Engine::Get()->PushSync([this](RunContext rctx) {
+        mxnet::common::cuda::DeviceStore device_store;
         for (auto cur : nccl_data_) {
-          CUDA_CALL(cudaSetDevice(cur.second.dev_id));
+          device_store.SetDevice(cur.second.dev_id);
           CUDA_CALL(cudaStreamSynchronize(cur.second.stream));
         }
       },
@@ -479,12 +480,13 @@ class KVStoreNCCL : public KVStoreLocal {
     std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU));
     std::vector<ncclComm_t> comms(devs.size());
     ncclCommInitAll(&(comms[0]), devs.size(), &(device_ids_[0]));
+    mxnet::common::cuda::DeviceStore device_store;
     for (size_t i = 0; i < devs.size(); ++i) {
       NCCLEntry e;
       e.dev_id = device_ids_[i];
       e.comm = comms[i];
       e.rank = i;
-      cudaSetDevice(e.dev_id);
+      device_store.SetDevice(e.dev_id);
       cudaStreamCreate(&(e.stream));
       nccl_data_[device_ids_[i]] = e;
     }
diff --git a/src/storage/storage.cc b/src/storage/storage.cc
index a0a3ed7..c7100a4 100644
--- a/src/storage/storage.cc
+++ b/src/storage/storage.cc
@@ -51,7 +51,13 @@ class StorageImpl : public Storage {
   static void ActivateDevice(Context ctx) {
     switch (ctx.dev_type) {
       case Context::kCPU:
+        break;
       case Context::kCPUPinned:
+#if MXNET_USE_CUDA
+        if (num_gpu_device > 0) {
+          CUDA_CALL(cudaSetDevice(ctx.real_dev_id()));
+        }
+#endif  // MXNET_USE_CUDA
         break;
       case Context::kCPUShared: {
 #if defined(ANDROID) || defined(__ANDROID__)
@@ -143,6 +149,12 @@ void StorageImpl::Alloc(Storage::Handle* handle) {
         return ptr;
       });
 
+#if MXNET_USE_CUDA
+  // Will restore gpu device to before ActivateDevice if necessary
+  bool restore = handle->ctx.dev_type == Context::kCPUPinned ||
+                 handle->ctx.dev_type == Context::kGPU;
+  mxnet::common::cuda::DeviceStore device_store(restore);
+#endif
   this->ActivateDevice(handle->ctx);
   manager->Alloc(handle);
   profiler_.OnAlloc(*handle);
@@ -156,6 +168,12 @@ void StorageImpl::Free(Storage::Handle handle) {
         LOG(FATAL) <<  "Cannot Free space to a device you have not allocated";
         return nullptr;
       });
+
+#if MXNET_USE_CUDA
+  // Will restore gpu device to before ActivateDevice if necessary
+  bool restore = ctx.dev_type == Context::kCPUPinned || ctx.dev_type == Context::kGPU;
+  mxnet::common::cuda::DeviceStore device_store(restore);
+#endif
   this->ActivateDevice(ctx);
   manager->Free(handle);
   profiler_.OnFree(handle);
@@ -169,6 +187,12 @@ void StorageImpl::DirectFree(Storage::Handle handle) {
         LOG(FATAL) <<  "Cannot Free space to a device you have not allocated";
         return nullptr;
       });
+
+#if MXNET_USE_CUDA
+  // Will restore gpu device to before ActivateDevice if necessary
+  bool restore = ctx.dev_type == Context::kCPUPinned || ctx.dev_type == Context::kGPU;
+  mxnet::common::cuda::DeviceStore device_store(restore);
+#endif
   this->ActivateDevice(ctx);
   manager->DirectFree(handle);
   profiler_.OnFree(handle);