You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/07 05:43:42 UTC

[GitHub] piiswrong closed pull request #8553: Engine reserves cores from OMP. Set some defaults for dynamic and recursion

piiswrong closed pull request #8553: Engine reserves cores from OMP.  Set some defaults for dynamic and recursion
URL: https://github.com/apache/incubator-mxnet/pull/8553
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 63bc8d740b..539515b3a2 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -185,10 +185,15 @@ endif()
 
 # ---[ jemalloc
 if(USE_JEMALLOC)
+  if(USE_GPERFTOOLS)
+    message(ERROR "Only one of USE_JEMALLOC and USE_GPERFTOOLS can be defined at once")
+  endif()
   find_package(JeMalloc)
   if(JEMALLOC_FOUND)
     message(STATUS "Using JEMalloc malloc")
     add_definitions(-DUSE_JEMALLOC)
+    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${ALT_MALLOC_FLAGS}")
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ALT_MALLOC_FLAGS}")
     include_directories(${JEMALLOC_INCLUDE_DIRS})
     set(mxnet_LINKER_LIBS ${mxnet_LINKER_LIBS} ${JEMALLOC_LIBRARIES})
   endif()
diff --git a/src/engine/openmp.cc b/src/engine/openmp.cc
index a605f977b6..be7885ba75 100644
--- a/src/engine/openmp.cc
+++ b/src/engine/openmp.cc
@@ -29,44 +29,67 @@ namespace engine {
 #define ARCH_IS_INTEL_X86
 #endif
 
+static inline bool is_env_set(const char *var) {
+  return dmlc::GetEnv(var, INT_MIN) == INT_MIN;
+}
+
 OpenMP *OpenMP::Get() {
   static OpenMP openMP;
   return &openMP;
 }
 
 OpenMP::OpenMP()
-  : omp_num_threads_set_in_environment(dmlc::GetEnv("OMP_NUM_THREADS", INT_MIN) == INT_MIN) {
+  : omp_num_threads_set_in_environment(is_env_set("OMP_NUM_THREADS")) {
 #ifdef _OPENMP
-  if (!omp_num_threads_set_in_environment) {
-    omp_set_nested(true);
-    omp_set_dynamic(false);
-  }
   const int max = dmlc::GetEnv("MXNET_OMP_MAX_THREADS", INT_MIN);
   if (max != INT_MIN) {
     omp_thread_max_ = max;
   } else {
+    if (!omp_num_threads_set_in_environment) {
+      omp_thread_max_ = omp_get_num_procs();
 #ifdef ARCH_IS_INTEL_X86
-    omp_thread_max_ = omp_get_num_procs() >> 1;
+      omp_thread_max_ >>= 1;
 #endif
+      omp_set_num_threads(omp_thread_max_);
+    } else {
+      omp_thread_max_ = omp_get_max_threads();
+  }
   }
+  omp_set_nested(dmlc::GetEnv("OMP_NESTED", false));
+  omp_set_dynamic(dmlc::GetEnv("OMP_DYNAMIC", false));
 #else
   enabled_ = false;
   omp_thread_max_ = 1;
 #endif
 }
 
-int OpenMP::GetRecommendedOMPThreadCount() const {
+void OpenMP::set_reserve_cores(int cores) {
+  CHECK_GE(cores, 0);
+  reserve_cores_ = cores;
+#ifdef _OPENMP
+  if (reserve_cores_ >= omp_thread_max_) {
+    omp_set_num_threads(1);
+  } else {
+    omp_set_num_threads(omp_thread_max_ - reserve_cores_);
+  }
+#endif
+}
+
+int OpenMP::GetRecommendedOMPThreadCount(bool exclude_reserved) const {
 #ifdef _OPENMP
   if (omp_num_threads_set_in_environment) {
     return omp_get_max_threads();
   }
   if (enabled_) {
-#ifdef ARCH_IS_INTEL_X86
-    // x86 does hyperthreading, but do to cache issues, it's faster to only use # true CPUs
-    const int thread_count = omp_get_max_threads() >> 1;
-#else
-    const int thread_count = omp_get_max_threads();
-#endif
+    int thread_count = omp_get_max_threads();
+    if (exclude_reserved) {
+      if (reserve_cores_ >= thread_count) {
+        thread_count = 1;
+      } else {
+        thread_count -= reserve_cores_;
+      }
+    }
+    // Check that OMP doesn't suggest more than our 'omp_thread_max_' value
     if (!omp_thread_max_ || thread_count < omp_thread_max_) {
       return thread_count;
     }
@@ -78,6 +101,8 @@ int OpenMP::GetRecommendedOMPThreadCount() const {
 #endif
 }
 
+OpenMP *__init_omp__ = OpenMP::Get();
+
 }  // namespace engine
 }  // namespace mxnet
 
diff --git a/src/engine/openmp.h b/src/engine/openmp.h
index 8b995a6357..02e73c0955 100644
--- a/src/engine/openmp.h
+++ b/src/engine/openmp.h
@@ -36,7 +36,7 @@ class OpenMP {
    * \brief Get the recommended number of OMP threads to use given the current context
    * \return Recommended number of OMP threads to use in a parallel operation
    */
-  int GetRecommendedOMPThreadCount() const;
+  int GetRecommendedOMPThreadCount(bool exclude_reserved = true) const;
 
   /*!
    * \brief Set whether clients of this class receive pro-OMP behavior guidance
@@ -57,8 +57,19 @@ class OpenMP {
   int thread_max() const { return omp_thread_max_; }
 
   /*!
+   * \brief Reserve cores to be excluded from OMP regions
+   * \param cores Number of cores to be excluded from OMP region usage
+   */
+  void set_reserve_cores(int cores);
+  /*!
+   * \brief Get number of cores to be excluded from OMP regions
+   * \return Number of cores to be excluded from OMP regions
+   */
+  int reserve_cores() const { return reserve_cores_; }
+
+  /*!
    * \brief Get the OpenMP object's singleton pointer
-   * \return
+   * \return Singleton OpenMP object pointer
    */
   static OpenMP *Get();
 
@@ -73,6 +84,10 @@ class OpenMP {
    */
   volatile int omp_thread_max_ = 0;
   /*!
+   * \brief Number of cores to reserve for non-OMP regions
+   */
+  volatile int reserve_cores_ = 0;
+  /*!
    * \brief Whether OMP_NUM_THREADS was set in the environment.  If it is, we fall back to
    *        the OMP's implementation's handling of that environment variable
    */
diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc
index 60868797a6..e01dd4ed45 100644
--- a/src/engine/threaded_engine_perdevice.cc
+++ b/src/engine/threaded_engine_perdevice.cc
@@ -113,8 +113,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
         if (is_copy) {
           auto ptr =
           gpu_copy_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() {
-            // Signify to kernel that GPU is being used,  no Kernel Launch OMP (temporary behavior)
-            OpenMP::Get()->set_enabled(false);
+            // Signify to kernel that GPU is being used, so reserve cores as necessary
+            OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true));
             auto blk = new ThreadWorkerBlock<kCopyQueue>();
               blk->pool.reset(new ThreadPool(
                 nthread,
@@ -133,8 +133,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
           }
         } else {
           auto ptr = gpu_normal_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() {
-            // Signify to kernel that GPU is being used,  no Kernel Launch OMP (temporary behavior)
-            OpenMP::Get()->set_enabled(false);
+            // Signify to kernel that GPU is being used, so reserve cores as necessary
+            OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true));
               auto blk = new ThreadWorkerBlock<kWorkerQueue>();
               blk->pool.reset(new ThreadPool(
                 nthread,
@@ -234,6 +234,27 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
     }
   }
 
+  /*!
+   * \brief Get number of cores this engine should reserve for its own use
+   * \param using_gpu Whether there is GPU usage
+   * \return number of cores that this engine wishes to be reserved
+   * \note Testing found no degradation of performance using these values
+   *       running cifar10 with resnet50 on various GPU systems,
+   *       including AWS p2.16xlarge, which has 16 GPU's
+   */
+  int GetReserveCoreCount(const bool using_gpu) const {
+    int reserve = 0;
+    if (using_gpu) {
+      // Save at least one for GPU tasks
+      ++reserve;
+      // If we have 8 or more real cores, reserve another core for GPU tasks
+      if (OpenMP::Get()->GetRecommendedOMPThreadCount(true) >= 8) {
+        ++reserve;
+      }
+    }
+    return reserve;
+  }
+
   /*! \brief Signal a single queue for shutdown */
   template<typename Object>
   static inline void SignalQueueForKill(common::LazyAllocArray<Object> *array) {
diff --git a/src/initialize.cc b/src/initialize.cc
index ca78a76cee..a3cc1164fa 100644
--- a/src/initialize.cc
+++ b/src/initialize.cc
@@ -77,5 +77,10 @@ LibraryInitializer* LibraryInitializer::Get() {
   return &inst;
 }
 
+#ifdef __GNUC__
+// Don't print an unused variable message since this is intentional
+#pragma GCC diagnostic ignored "-Wunused-variable"
+#endif
+
 static LibraryInitializer* __library_init = LibraryInitializer::Get();
 }  // namespace mxnet
diff --git a/src/operator/dropout-inl.h b/src/operator/dropout-inl.h
index b2fb7823be..7fcd7adf86 100644
--- a/src/operator/dropout-inl.h
+++ b/src/operator/dropout-inl.h
@@ -35,6 +35,7 @@
 #include <algorithm>
 #include "./operator_common.h"
 #include "./mshadow_op.h"
+#include "../engine/openmp.h"
 
 #if defined(USE_MKL) && defined(_OPENMP)
 #include <omp.h>
@@ -55,8 +56,8 @@ namespace op {
 
 #if defined(USE_MKL) && defined(_OPENMP)
 static void bernoulli_generate(int n, double p, int* r) {
-  int seed = 17 + rand() % 4096;  // NOLINT(runtime/threadsafe_fn)
-  int nthr = omp_get_max_threads();
+  const int seed = 17 + rand() % 4096;  // NOLINT(runtime/threadsafe_fn)
+  const int nthr = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
 # pragma omp parallel num_threads(nthr)
   {
     const int ithr = omp_get_thread_num();
@@ -117,12 +118,13 @@ class DropoutOp : public Operator {
 #if !defined(__CUDACC__) && defined(USE_MKL) && defined(_OPENMP)
       DType* outptr = out.dptr_;
       DType* dataptr = data.dptr_;
-      int* maskptr = reinterpret_cast<int*>(mask.dptr_);
+      auto maskptr = reinterpret_cast<int*>(mask.dptr_);
       int count = mask.shape_[0]*mask.shape_[1];
       bernoulli_generate(count, this->pkeep_, maskptr);
-  #pragma omp parallel for
+      const float pk_1 = 1.0f / pkeep_;
+      #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
       for (int i = 0; i < count; ++i) {
-        outptr[i] = dataptr[i] * maskptr[i] * (1.0f / pkeep_);
+        outptr[i] = dataptr[i] * maskptr[i] * pk_1;
       }
 #else
       Random<xpu> *prnd = ctx.requested[dropout::kRandom].get_random<xpu, real_t>(s);
@@ -154,15 +156,15 @@ class DropoutOp : public Operator {
 #if !defined(__CUDACC__) && defined(USE_MKL) && defined(_OPENMP)
       DType* ingradptr = gdata.dptr_;
       DType* outgradptr = grad.dptr_;
-      int* maskptr = reinterpret_cast<int*>(mask.dptr_);
-
+      auto maskptr = reinterpret_cast<int*>(mask.dptr_);
       int count = mask.shape_[0]*mask.shape_[1];
-
-      #pragma omp parallel for
+      const float pk_1 = 1.0f / pkeep_;
+      #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
       for (int i = 0; i < count; ++i) {
-        ingradptr[i] = outgradptr[i] * maskptr[i] * (1.0f / pkeep_);
+        ingradptr[i] = outgradptr[i] * maskptr[i] * pk_1;
       }
 #else  // USE_MKL && _OPENMP
+      CHECK_EQ(grad.shape_.Size(), mask.shape_.Size());
       Assign(gdata, req[dropout::kData], grad * mask);
 #endif  // USE_MKL && _OPENMP
     } else {
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index c74b21605b..06e2393524 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -77,6 +77,7 @@ inline int get_num_threads<gpu>(const int N) {
   using namespace mshadow::cuda;
   return kBaseThreadNum * cuda_get_num_blocks(N);
 }
+
 #endif  // __CUDACC__
 
 template<>
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 875c79c2ae..d036355392 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -129,7 +129,7 @@ inline std::string shape_string(const TShape& x) {
   return os.str();
 }
 
-/*! \brief get string representation of shape */
+/*! \brief get string representation of data type */
 inline std::string type_string(const int& x) {
   switch (x) {
     case mshadow::kFloat32:
@@ -138,10 +138,14 @@ inline std::string type_string(const int& x) {
       return "float64";
     case mshadow::kFloat16:
       return "float16";
+    case mshadow::kInt8:
+      return "int8";
     case mshadow::kUint8:
       return "uint8";
     case mshadow::kInt32:
       return "int32";
+    case mshadow::kInt64:
+      return "int64";
   }
   return "unknown";
 }
diff --git a/src/operator/slice_channel-inl.h b/src/operator/slice_channel-inl.h
index a48c52f0b7..791b90e570 100644
--- a/src/operator/slice_channel-inl.h
+++ b/src/operator/slice_channel-inl.h
@@ -41,7 +41,6 @@ namespace op {
 
 namespace slice_enum {
 enum SliceChannelOpInputs {kData};
-enum SliceChannelOpOutputs {kOut0, kOut1, kOut2, kOut3, kOut4};
 }  // namespace slice_enum
 
 struct SliceChannelParam : public dmlc::Parameter<SliceChannelParam> {
diff --git a/tests/cpp/include/test_util.h b/tests/cpp/include/test_util.h
index 29260559ae..95ab141954 100644
--- a/tests/cpp/include/test_util.h
+++ b/tests/cpp/include/test_util.h
@@ -518,14 +518,17 @@ inline std::string demangle(const char *name) {
   return status ? name : res.get();
 }
 
+template<typename T>
+inline std::string type_name() { return demangle(typeid(T).name()); }
+
 #define PRINT_NDARRAYS(__ctx$, __var)  test::print(__ctx$, __FUNCTION__, #__var, __var)
 #define PRINT_OP_AND_ARRAYS(__ctx$, __op, __var)  test::print(__ctx$, __FUNCTION__, \
   static_cast<std::stringstream *>(&(std::stringstream() << #__var << \
-  "<" << test::demangle(typeid(__op).name()) << ">"))->str(), __var)
+  "<" << type_name<__op>() << ">"))->str(), __var)
 #define PRINT_OP2_AND_ARRAYS(__ctx$, __op1, __op2, __var)  test::print(__ctx$, __FUNCTION__, \
   static_cast<std::stringstream *>(&(std::stringstream() << #__var << \
-  "<" << test::demangle(typeid(__op1).name()) << ", " \
-  << test::demangle(typeid(__op2).name()) << ">"))->str(), __var)
+  "<" << type_name<__op1>().name()) << ", " \
+  << type_name<__op2>() << ">"))->str(), __var)
 
 /*! \brief Fill blob with some pattern defined by the getNextData() callback
  * Pattern fill in the defined order (important for analysis):
@@ -606,6 +609,19 @@ inline ScalarType rangedRand(const ScalarType min, const ScalarType max) {
   return static_cast<ScalarType>(x / bin_size + min);
 }
 
+inline std::string pretty_num(uint64_t val) {
+  std::string res, s = std::to_string(val);
+  size_t ctr = 0;
+  for (int i = static_cast<int>(s.size()) - 1; i >= 0; --i, ++ctr) {
+    if (ctr && (ctr % 3) == 0) {
+      res += ",";
+    }
+    res.push_back(s[i]);
+  }
+  std::reverse(res.begin(), res.end());
+  return res;
+}
+
 /*! \brief Change a value during the scope of this declaration */
 template<typename T>
 struct ScopeSet {
diff --git a/tests/cpp/misc/memory_test.cc b/tests/cpp/misc/memory_test.cc
index e39399d01a..a36f7f93ae 100644
--- a/tests/cpp/misc/memory_test.cc
+++ b/tests/cpp/misc/memory_test.cc
@@ -42,19 +42,6 @@ static typename Container::value_type average(const Container& cont) {
   return avg;
 }
 
-static std::string pretty_num(uint64_t val) {
-  std::string res, s = std::to_string(val);
-  size_t ctr = 0;
-  for (int i = static_cast<int>(s.size()) - 1; i >= 0; --i, ++ctr) {
-    if (ctr && (ctr % 3) == 0) {
-      res += ",";
-    }
-    res.push_back(s[i]);
-  }
-  std::reverse(res.begin(), res.end());
-  return res;
-}
-
 static int GetOMPThreadCount() {
   return omp_get_max_threads() >> 1;
 }
@@ -75,7 +62,7 @@ TEST(MEMORY_TEST, MemsetAndMemcopyPerformance) {
 
     const size_t test_size = 2 * base;
     std::cout << "====================================" << std::endl
-              << "Data size: " << pretty_num(test_size) << std::endl << std::flush;
+              << "Data size: " << test::pretty_num(test_size) << std::endl << std::flush;
 
     std::unique_ptr<uint8_t> buffer_1(new uint8_t[test_size]), buffer_2(new uint8_t[test_size]);
     uint8_t *src = buffer_1.get(), *dest = buffer_2.get();
@@ -117,19 +104,21 @@ TEST(MEMORY_TEST, MemsetAndMemcopyPerformance) {
       memcpy_times.push_back(memcpy_time);
       omp_copy_times.push_back(omp_copy_time);
 
-      std::cout << "memset time:   " << pretty_num(memcpy_time) << " ns" << std::endl
-                << "omp set time:  " << pretty_num(omp_set_time) << " ns" << std::endl
+      std::cout << "memset time:   " << test::pretty_num(memcpy_time) << " ns" << std::endl
+                << "omp set time:  " << test::pretty_num(omp_set_time) << " ns" << std::endl
                 << std::endl;
-      std::cout << "memcpy time:   " << pretty_num(memcpy_time) << " ns" << std::endl
-                << "omp copy time: " << pretty_num(omp_copy_time) << " ns" << std::endl
+      std::cout << "memcpy time:   " << test::pretty_num(memcpy_time) << " ns" << std::endl
+                << "omp copy time: " << test::pretty_num(omp_copy_time) << " ns" << std::endl
                 << std::endl;
     }
     std::cout << "------------------------------------" << std::endl;
     if (average(memset_times) > average(omp_set_times)) {
-      std::cout << "<< MEMSET SLOWER FOR " << pretty_num(test_size) << " items >>" << std::endl;
+      std::cout << "<< MEMSET SLOWER FOR " << test::pretty_num(test_size)
+                << " items >>" << std::endl;
     }
     if (average(memcpy_times) > average(omp_copy_times)) {
-      std::cout << "<< MEMCPY SLOWER FOR " << pretty_num(test_size) << " items >>" << std::endl;
+      std::cout << "<< MEMCPY SLOWER FOR " << test::pretty_num(test_size)
+                << " items >>" << std::endl;
     }
     if (!pass) {
       GTEST_ASSERT_LE(average(memset_times), average(omp_set_times));
diff --git a/tests/cpp/unittest.mk b/tests/cpp/unittest.mk
index 11ea6d141a..030b24026e 100644
--- a/tests/cpp/unittest.mk
+++ b/tests/cpp/unittest.mk
@@ -1,6 +1,6 @@
 TEST_SRC = $(shell find tests/cpp/ -name "*.cc")
 TEST_OBJ = $(patsubst %.cc, build/%.o, $(TEST_SRC))
-TEST = build/tests/cpp/mxnet_test
+TEST = build/tests/cpp/mxnet_unit_tests
 
 GTEST_LIB=$(GTEST_PATH)/lib/
 GTEST_INC=$(GTEST_PATH)/include/
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index d44055363b..2f0ebd812a 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -727,6 +727,19 @@ def check_binary_op_with_scalar(stype,
                                        force_overlap=force_overlap,
                                        verbose=False)
 
+        # minus_scalar
+        check_sparse_mathematical_core("minus_scalar", stype,
+                                       lambda x, y: x - y,
+                                       lambda x, y: x - y,
+                                       lambda input, rhs: 1,
+                                       rhs_arg=5.0,
+                                       data_init=2, grad_init=3,
+                                       output_grad_stype=output_grad_stype,
+                                       input_grad_stype=input_grad_stype,
+                                       density=density, ograd_density=ograd_density,
+                                       force_overlap=force_overlap,
+                                       verbose=False)
+
     # Check many basic unary operators
     def check_mathematical_core(stype, output_grad_stype=None,
                                 input_grad_stype=None, force_overlap=False,
@@ -1677,4 +1690,3 @@ def check_scatter_ops(name, shape, lhs_stype, rhs_stype, forward_mxnet_call, for
 if __name__ == '__main__':
     import nose
     nose.runmodule()
-


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services