You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cj...@apache.org on 2017/11/21 14:50:00 UTC

[incubator-mxnet] branch master updated: Kernel operator tuning (#8686)

This is an automated email from the ASF dual-hosted git repository.

cjolivier01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 068b589  Kernel operator tuning (#8686)
068b589 is described below

commit 068b589ad77afc125e9242e516dec1fde1e3f133
Author: Chris Olivier <cj...@gmail.com>
AuthorDate: Tue Nov 21 06:49:51 2017 -0800

    Kernel operator tuning (#8686)
    
    * Refreshed branch bc_tune
    
    * local-build openmp as static
    
    * trigger
    
    * Somehow broadcast found its way back in, removed again
    
    * Trigger rebuild
---
 CMakeLists.txt                                     |  24 +-
 Makefile                                           |   4 +
 make/config.mk                                     |   6 +
 src/operator/mshadow_op.h                          |  94 ++-
 src/operator/mxnet_op.h                            | 118 +++-
 src/operator/operator_tune-inl.h                   | 758 +++++++++++++++++++++
 src/operator/operator_tune.cc                      | 347 ++++++++++
 src/operator/operator_tune.h                       | 331 +++++++++
 src/operator/tensor/elemwise_binary_broadcast_op.h |  14 +-
 src/operator/tensor/init_op.h                      |   1 +
 tests/cpp/include/test_core_op.h                   |  12 +-
 tests/cpp/include/test_op_runner.h                 |  36 +-
 tests/cpp/include/test_tune.h                      | 333 +++++++++
 tests/cpp/include/test_util.h                      |  21 +-
 tests/cpp/operator/broadcast_perf.cc               | 101 ---
 tests/cpp/operator/tune/operator_tune_test.cc      | 173 +++++
 tests/cpp/test_main.cc                             |   9 +-
 17 files changed, 2212 insertions(+), 170 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index af681d0..2b7aba9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -35,6 +35,7 @@ mxnet_option(USE_LAPACK           "Build with lapack support" ON IF NOT MSVC)
 mxnet_option(USE_MKL_IF_AVAILABLE "Use MKL if found" ON)
 mxnet_option(USE_MKLML_MKL        "Use MKLML variant of MKL (if MKL found)" ON IF USE_MKL_IF_AVAILABLE AND UNIX AND (NOT APPLE))
 mxnet_option(USE_MKL_EXPERIMENTAL "Use experimental MKL (if MKL enabled and found)" OFF)
+mxnet_option(USE_OPERATOR_TUNING  "Enable auto-tuning of operators" ON AND NOT MSVC)
 mxnet_option(USE_GPERFTOOLS       "Build with GPerfTools support (if found)" ON)
 mxnet_option(USE_JEMALLOC         "Build with Jemalloc support"   ON)
 mxnet_option(USE_PROFILER         "Build with Profiler support"   OFF)
@@ -143,6 +144,8 @@ if(USE_MKL_IF_AVAILABLE)
     if(NOT MSVC)
       list(APPEND mxnet_LINKER_LIBS dl)
     endif()
+    # If using MKL, use the Intel OMP libraries
+    list(APPEND mxnet_LINKER_LIBS iomp5)
     if(USE_MKL_EXPERIMENTAL)
       add_definitions(-DMKL_EXPERIMENTAL=1)
     else()
@@ -260,11 +263,22 @@ endif()
 # ---[ OpenMP
 if(USE_OPENMP)
   find_package(OpenMP REQUIRED)
-  if(OPENMP_FOUND)
+  if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/openmp/CMakeLists.txt)
+    # Intel/llvm OpenMP: https://github.com/llvm-mirror/openmp
+    set(OPENMP_STANDALONE_BUILD TRUE)
+    set(LIBOMP_ENABLE_SHARED FALSE)
+    add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/openmp)
+    list(REMOVE_ITEM mxnet_LINKER_LIBS iomp5)
+    list(APPEND mxnet_LINKER_LIBS omp)
     set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
     set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
-    set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
-    set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
+  else()
+    if(OPENMP_FOUND)
+      set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
+      set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
+      set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
+      set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
+    endif()
   endif()
 elseif(UNIX)
   list(APPEND mxnet_LINKER_LIBS pthread)
@@ -353,6 +367,10 @@ if(USE_PLUGINS_WARPCTC)
 	list(APPEND CUDA ${PLUGINS_CUSRC})
 endif()
 
+if(USE_OPERATOR_TUNING)
+  add_definitions(-DMXNET_USE_OPERATOR_TUNING=1)
+endif()
+
 if(USE_PLUGIN_CAFFE)
   if(NOT USE_CUDA)
     set(CPU_ONLY ON)
diff --git a/Makefile b/Makefile
index 8c7ae6e..8659482 100644
--- a/Makefile
+++ b/Makefile
@@ -131,6 +131,10 @@ ifeq ($(USE_MKL2017), 1)
 	LDFLAGS +=  -liomp5
 endif
 
+ifeq ($(USE_OPERATOR_TUNING), 1)
+	CFLAGS += -DMXNET_USE_OPERATOR_TUNING=1
+endif
+
 # verify existence of separate lapack library when using blas/openblas/atlas
 # switch off lapack support in case it can't be found
 # issue covered with this
diff --git a/make/config.mk b/make/config.mk
index a4774f0..eeda36b 100644
--- a/make/config.mk
+++ b/make/config.mk
@@ -153,6 +153,12 @@ LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server
 # sudo apt-get install -y libcurl4-openssl-dev
 USE_S3 = 0
 
+#----------------------------
+# performance settings
+#----------------------------
+# Use operator tuning
+USE_OPERATOR_TUNING = 1
+
 # Use gperftools if found
 USE_GPERFTOOLS = 1
 
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index a34c117..10be627 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -30,6 +30,7 @@
 #include "math.h"
 #include "math_functions-inl.h"
 #include "special_functions-inl.h"
+#include "./mxnet_op.h"
 
 #ifdef __CUDACC__
 #include <cuda_fp16.h>
@@ -39,6 +40,24 @@ namespace mxnet {
 namespace op {
 namespace mshadow_op {
 
+/*!
+ * \brief Use the 'MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD' macro outside of the mshadow_op namespace
+ *        See mxnet_op.h for a description of 'MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD'
+ *
+ * \note An entry for the operator must also be added in operator_tune.cc, which will register it
+ *       for auto-tuning and also hold its workload weight
+ */
+#define MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(__op$) \
+  } MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow_op::__op$) namespace mshadow_op {  // NOLINT(*)
+/*!
+ * \brief Use the 'MXNET_TUNABLE_MSHADOW_OP_BACKWARD' macro outside of the mshadow_op namespace
+ *        See mxnet_op.h for a description of 'MXNET_TUNABLE_MSHADOW_OP_BACKWARD'
+ *
+ * \note An entry for the operator must also be added in operator_tune.cc, which will register it
+ *       for auto-tuning and also hold its workload weight
+ */
+#define MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(__op$) \
+  }  MXNET_TUNABLE_MSHADOW_OP_BACKWARD(mshadow_op::__op$) namespace mshadow_op {  // NOLINT(*)
 #ifdef __CUDA_ARCH__
 __constant__ const float PI = 3.14159265358979323846;
 #else
@@ -49,36 +68,41 @@ using std::enable_if;
 using std::is_unsigned;
 
 #define MXNET_UNARY_MATH_OP(name, expr) \
-struct name { \
-  template<typename DType> \
-  MSHADOW_XINLINE static DType Map(DType a) { \
-    return DType(expr); \
-  } \
-}
+  struct name { \
+    template<typename DType> \
+    MSHADOW_XINLINE static DType Map(DType a) { \
+      return DType(expr); \
+    } \
+  }; \
+  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
+
 
 #define MXNET_UNARY_MATH_OP_NC(name, expr) \
-struct name { \
-  template<typename DType> \
-  MSHADOW_XINLINE static DType Map(DType a) { \
-    return (expr); \
-  } \
-}
+  struct name { \
+    template<typename DType> \
+    MSHADOW_XINLINE static DType Map(DType a) { \
+      return (expr); \
+    } \
+  }; \
+  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
 
 #define MXNET_BINARY_MATH_OP(name, expr) \
-struct name { \
-  template<typename DType> \
-  MSHADOW_XINLINE static DType Map(DType a, DType b) { \
-    return DType(expr); \
-  } \
-}
+  struct name { \
+    template<typename DType> \
+    MSHADOW_XINLINE static DType Map(DType a, DType b) { \
+      return DType(expr); \
+    } \
+  }; \
+  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
 
 #define MXNET_BINARY_MATH_OP_NC(name, expr) \
-struct name { \
-  template<typename DType> \
-  MSHADOW_XINLINE static DType Map(DType a, DType b) { \
-    return (expr); \
-  } \
-}
+  struct name { \
+    template<typename DType> \
+    MSHADOW_XINLINE static DType Map(DType a, DType b) { \
+      return (expr); \
+    } \
+  }; \
+  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
 
 #define MXNET_SIMPLE_UNARY_MATH_OP(name) MXNET_UNARY_MATH_OP(name, math::name(a))
 
@@ -134,6 +158,7 @@ struct softrelu {
     }
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(softrelu)
 
 MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a));
 
@@ -154,6 +179,7 @@ struct log10_grad {
     return DType(0.4342944819f / static_cast<float>(a));
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(log10_grad)
 
 template<>
 MSHADOW_XINLINE double log10_grad::Map<double>(double a) {
@@ -169,6 +195,7 @@ struct log2_grad {
     return DType(1.442695041f / static_cast<float>(a));
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(log2_grad)
 
 template<>
 MSHADOW_XINLINE double log2_grad::Map<double>(double a) {
@@ -263,6 +290,7 @@ struct sign {
     return DType(0);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(sign)
 
 MXNET_UNARY_MATH_OP_NC(sign_grad, DType(0));
 
@@ -333,6 +361,7 @@ struct rint {
     return DType((af - floor) <= (ceil - af) ? floor : ceil);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(rint)
 
 /*! \brief used to round number to integer nearest to 0 */
 struct fix {
@@ -343,6 +372,7 @@ struct fix {
     return DType((floor > 0 ? floor : -floor) < (ceil > 0 ? ceil : -ceil) ? floor : ceil);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(fix)
 
 /*! \brief used for generate gradient of MAE loss*/
 MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1));
@@ -405,6 +435,7 @@ struct mod {
     }
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(mod)
 
 template<>
 MSHADOW_XINLINE mshadow::half::half2_t mod::Map<mshadow::half::half2_t>
@@ -419,6 +450,8 @@ struct mod_grad {
     return DType(0);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(mod_grad)
+
 template<>
 MSHADOW_XINLINE double mod_grad::Map<double>(double a, double b) {
   return 1.0;
@@ -454,6 +487,8 @@ struct mod_rgrad {
     return DType(0);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(mod_rgrad)
+
 template<>
 MSHADOW_XINLINE double mod_rgrad::Map<double>(double a, double b) {
   return -::floor(a/b);
@@ -517,6 +552,7 @@ struct rmod {
     }
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(rmod)
 
 template<>
 MSHADOW_XINLINE mshadow::half::half2_t rmod::Map<mshadow::half::half2_t>
@@ -531,6 +567,8 @@ struct rmod_grad {
     return DType(0);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(rmod_grad)
+
 template<>
 MSHADOW_XINLINE double rmod_grad::Map<double>(double a, double b) {
   return -::floor(b/a);
@@ -572,6 +610,7 @@ struct clip {
     }
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(clip)
 
 /***** gamma ******/
 
@@ -585,6 +624,7 @@ struct gamma_grad {
     return DType(math::tgamma(af) * special_functions::cephes::psi<float>(af));
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(gamma_grad)
 
 template<>
 MSHADOW_XINLINE double gamma_grad::Map<double>(double a) {
@@ -602,6 +642,7 @@ struct gammaln_grad {
     return DType(special_functions::cephes::psi<float>(a));
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(gammaln_grad)
 
 template<>
 MSHADOW_XINLINE double gammaln_grad::Map<double>(double a) {
@@ -633,6 +674,7 @@ struct smooth_l1_loss {
     }
   }
 };  // struct smooth_l1_loss
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(smooth_l1_loss)
 
 /* The derivative of smooth l1 loss is
  * f'(x) = sigma^2 * x, |x| < 1 / sigma^2
@@ -654,6 +696,7 @@ struct smooth_l1_gradient {
     }
   }
 };  // struct smooth_l1_derivative
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(smooth_l1_gradient)
 
 /*! \brief product reducer */
 struct product {
@@ -755,6 +798,7 @@ struct nansum_grad {
     return isnan_typed::IsNan(a) ? DType(0) : DType(1);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(nansum_grad)
 
 /*! \brief product reducer that ignores NaN values in the input */
 struct nanprod {
@@ -791,7 +835,7 @@ struct nanprod_grad {
     return isnan_typed::IsNan(a) ? DType(0) : b / a;
   }
 };
-
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(nanprod_grad)
 }  // namespace mshadow_op
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index c34d9c9..1d47943 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -31,6 +31,7 @@
 #include <mxnet/engine.h>
 #include <mxnet/op_attr_types.h>
 #include <algorithm>
+#include "./operator_tune.h"
 #include "../engine/openmp.h"
 
 #ifdef __CUDACC__
@@ -190,8 +191,9 @@ template<int ndim>
 MSHADOW_XINLINE int dot(const Shape<ndim>& coord, const Shape<ndim>& stride) {
   int ret = 0;
   #pragma unroll
-  for (int i = 0; i < ndim; ++i)
+  for (int i = 0; i < ndim; ++i) {
     ret += coord[i] * stride[i];
+  }
   return ret;
 }
 
@@ -346,15 +348,26 @@ struct op_with_req {
 template<typename OP, typename xpu>
 struct Kernel;
 
+/*!
+ * \brief CPU Kernel launcher
+ * \tparam OP Operator to launch
+ */
 template<typename OP>
 struct Kernel<OP, cpu> {
-  /*! \brief Launch CPU kernel */
+  /*!
+   * \brief Launch a generic CPU kernel.
+   * When using this for a new kernel op, add declaration and tuning objects to
+   * operator_tune.cc
+   * \tparam Args Varargs type to eventually pass to the OP::Map() functoion
+   * \param N Number of iterations
+   * \param dest Destination pointer (used to infer DType)
+   * \param args Varargs to eventually pass to the OP::Map() functoion
+   */
   template<typename ...Args>
   inline static void Launch(mshadow::Stream<cpu> *, const int N, Args... args) {
 #ifdef _OPENMP
     const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
     if (omp_threads < 2) {
-      // Zero means not to use OMP, but don't interfere with external OMP behavior
       for (int i = 0; i < N; ++i) {
         OP::Map(i, args...);
       }
@@ -371,14 +384,54 @@ struct Kernel<OP, cpu> {
 #endif
   }
 
+  /*!
+   * \brief Launch CPU kernel which has OMP tuning data available.
+   * When using this for a new kernel op, add declaration and tuning objects to
+   * operator_tune.cc
+   * \tparam PRIMITIVE_OP The primitive operation to use for tuning
+   * \tparam DType Data type
+   * \tparam Args Varargs type to eventually pass to the OP::Map() functoion
+   * \param N Number of iterations
+   * \param dest Destination pointer (used to infer DType)
+   * \param args Varargs to eventually pass to the OP::Map() functoion
+   */
+  template<typename PRIMITIVE_OP, typename DType, typename ...Args>
+  static void LaunchTuned(mshadow::Stream<cpu> *, const int N, Args... args) {
+#ifdef _OPENMP
+    const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+    if (omp_threads < 2 || !tuned_op<PRIMITIVE_OP, DType>::UseOMP(
+      static_cast<size_t>(N), static_cast<size_t>(omp_threads))) {
+      for (int i = 0; i < N; ++i) {
+        OP::Map(i, args...);
+      }
+    } else {
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        OP::Map(i, args...);
+      }
+    }
+#else
+    for (int i = 0; i < N; ++i) {
+      OP::Map(i, args...);
+    }
+#endif
+  }
+
+  /*!
+   * \brief Launch custom-tuned kernel where each thread is set to
+   *        operate on a contiguous partition
+   * \tparam Args Varargs type to eventually pass to the OP::Map() functoion
+   * \param N Number of iterations
+   * \param args Varargs to eventually pass to the UseOMP() and OP::Map() functions
+   */
   template<typename ...Args>
   inline static void LaunchEx(mshadow::Stream<cpu> *s, const int N, Args... args) {
 #ifdef _OPENMP
     const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
-    if (omp_threads <= 1) {
+    if (omp_threads < 2) {
       OP::Map(0, N, args...);
     } else {
-      int length = (N + omp_threads - 1) / omp_threads;
+      const int length = (N + omp_threads - 1) / omp_threads;
       #pragma omp parallel for num_threads(omp_threads)
       for (int i = 0; i < N; i += length) {
         OP::Map(i, i + length > N ? N - i : length, args...);
@@ -418,7 +471,7 @@ struct Kernel<OP, gpu> {
   }
 
   template<typename ...Args>
-  inline static void LaunchEx(mshadow::Stream<gpu> *s, int N, Args... args) {
+  inline static void LaunchEx(mshadow::Stream<gpu> *s, const int N, Args... args) {
     using namespace mshadow::cuda;
     int ngrid = std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum);
     mxnet_generic_kernel_ex<OP, Args...>
@@ -429,6 +482,43 @@ struct Kernel<OP, gpu> {
 #endif  // __CUDACC__
 
 /*!
+ * \brief Wrap Kernel<OP, xpu>::Launch* with some special-case helpers
+ */
+template<typename OP, typename xpu>
+struct KernelWrapper {
+  /*!
+   * \brief Launch 'mshadow_op-type' op (i.e. DType (*)( ... ) { return <operation> }
+   * \tparam Args Varargs type to eventually pass to the OP::Map() function
+   * \param s Stream object pointer (unused)
+   * \param N Number of iterations
+   * \param args Varargs to eventually pass to the OP::Map() functoion
+   */
+  template<typename DType, typename ...Args>
+  MSHADOW_CINLINE static void LaunchMShadowOpEx(mshadow::Stream<xpu> *s,
+                                                const int N,
+                                                DType *dest,
+                                                Args... args) {
+    mxnet::op::mxnet_op::Kernel<OP, xpu>::template LaunchTuned<
+      typename OP::Operation, DType>(s, N, dest, args...);
+  }
+
+  /*!
+   * \brief Launch 'mxnet_op-type' op (i.e. void (*)(int N, DType *out, ... )
+   * \tparam Args Varargs type to eventually pass to the OP::Map() function
+   * \param s Stream object pointer (unused)
+   * \param N Number of iterations
+   * \param args Varargs to eventually pass to the OP::Map() functoion
+   */
+  template<typename DType, typename ...Args>
+  MSHADOW_CINLINE static void LaunchMXNetOpEx(mshadow::Stream<xpu> *s,
+                                              const int N,
+                                              DType *dest,
+                                              Args... args) {
+    mxnet::op::mxnet_op::Kernel<OP, xpu>::template LaunchTuned<OP, DType>(s, N, dest, args...);
+  }
+};
+
+/*!
  * \brief Set to immediate scalar value kernel
  * \tparam val Scalar immediate
  */
@@ -450,7 +540,23 @@ struct set_to_int {
  */
 using set_zero = set_to_int<0>;
 using set_one  = set_to_int<1>;
+_MXNET_TUNABLE_MXNET_OP_FWD(set_zero);  // _ prefix denotes "already in mxnet_op namespace"
+_MXNET_TUNABLE_MXNET_OP_FWD(set_one);
 }  // namespace mxnet_op
+
+/*!
+ * \brief Tuning specializations for the simple ops in <mshadow/base.h>
+ *        Basically, this overrides mxnet::op::mxnet_op::Kernel<OP, cpu>::Launch() and
+ *        redirects to mxnet::op::mxnet_op::KernelWrapper<OP, cpu>::Launch????OpEx(),
+ *        which eventually leads back to mxnet::op::mxnet_op::Kernel<OP, cpu>::LaunchTuned()
+ */
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::identity)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::plus)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::minus)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::mul)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::div)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::right)
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/operator_tune-inl.h b/src/operator/operator_tune-inl.h
new file mode 100644
index 0000000..d0cf7e7
--- /dev/null
+++ b/src/operator/operator_tune-inl.h
@@ -0,0 +1,758 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef MXNET_OPERATOR_OPERATOR_TUNE_INL_H_
+#define MXNET_OPERATOR_OPERATOR_TUNE_INL_H_
+
+#include <dmlc/base.h>
+#include <dmlc/logging.h>
+#include <mshadow/base.h>
+#include <atomic>
+#include <cstdint>
+#include <chrono>
+#include <thread>
+#include <string>
+#include <vector>
+#include <algorithm>
+#include <list>
+#include <random>
+#include <unordered_set>
+#include "./mxnet_op.h"
+#include "./operator_tune.h"
+
+#if (__GNUC__ >= 4 || (__GNUC__ >= 3 && __GNUC_MINOR__ >= 4)) && !defined(__mips__)
+#  define HAS_CXA_DEMANGLE 1
+#else
+#  define HAS_CXA_DEMANGLE 0
+#endif
+
+#if HAS_CXA_DEMANGLE
+#include <cxxabi.h>
+#endif
+
+namespace mxnet {
+namespace op {
+
+#ifndef MXNET_NO_INLINE
+#ifdef _MSC_VER
+#define MXNET_NO_INLINE __declspec(noinline)
+#else
+#define MXNET_NO_INLINE __attribute__((noinline))
+#endif
+#endif  // MXNET_NO_INLINE
+
+#define OUTSIDE_COUNT_SHIFT    9
+
+namespace tune {
+
+/*!
+ * \brief Convert TuningMode value to a string representation
+ * \param tm  Scalar TuningMode value
+ * \return Character pointer to a string representing the TuningMode value
+ */
+inline const char *TuningModeToString(const TuningMode tm) {
+  switch (tm) {
+    case kAuto:
+      return "Auto";
+    case kNeverOMP:
+      return "NeverOMP";
+    case kAlwaysOMP:
+      return "AlwaysOMP";
+    default:
+      CHECK(false) << "Unknown TuningMode type: " << static_cast<int>(tm);
+      return "<unknown>";
+  }
+}
+}  // namespace tune
+
+/*!
+ * \brief Engine to tune kernel operations
+ * \tparam DType Data type to be used when tuning the kernel operations
+ * \remarks The basic concept here is that we time how long a trivial loop takes with and without
+ * OMP, subtracting the non-OMP run from the OMP run, which gives us the time
+ * that the OMP overhead takes.  Times were found to be relatively invariant with
+ * regard ot the number of threads/cores on a given machine.
+ * Secondly, supplied operators are run and timed (for each data type) in order to determine
+ * their individual time cost.
+ *
+ * Knowing the following items, we can determine how long the OMP and non-OMP run
+ * is expected to take:
+ *  1) OMP overhead time
+ *  2) Number of iterations required
+ *  3) Number of threads to be used if we choose the OMP method
+ *  4) The data type
+ *
+ * Therefore, at Kernel::Launch() time, we can estimate whether it is faster to use OMP or not
+ * for the given kernel operator.
+ *
+ * Results and efficiency of the tuning is tested in the gtest OMP_TUNING test suite
+ */
+template<typename DType>
+class OperatorTune : public OperatorTuneByType<DType> {
+ public:
+  using Tick = OperatorTuneBase::Tick;
+  using duration_t = OperatorTuneBase::duration_t;
+  using OperatorTuneByType<DType>::tuning_mode_;
+
+  /*!
+   * \brief Constructor
+   */
+  OperatorTune() {
+    TuneAll();
+  }
+
+  /*!
+   * \brief Initialize the OperatorTune object
+   * \return Whether the OperatorTune object was successfully initialized
+   */
+  static bool Initialize() {
+    if (!initialized_) {
+      initialized_ = true;
+      // Generate some random data for calling the operator kernels
+      data_set_.reserve(0x100);
+      std::random_device rd;
+      std::mt19937 gen(rd());
+      if (!std::is_integral<DType>::value) {
+        std::uniform_real_distribution<> dis(-1, 1);
+        for (int n = 0; n < 0x100; ++n) {
+          const auto val = static_cast<DType>(dis(gen));
+          // If too close to zero, try again
+          if (std::fabs(static_cast<double>(val)) < 1e-5) {
+            --n;
+            continue;
+          }
+          data_set_.emplace_back(val);
+        }
+      } else {
+        std::uniform_int_distribution<> dis(-128, 127);
+        for (int n = 0; n < 0x100; ++n) {
+          const auto val = static_cast<DType>(dis(gen));
+          // If zero, try again
+          if (!val) {
+            --n;
+            continue;
+          }
+          data_set_.emplace_back(val);
+        }
+      }
+      // Use this environment variable to generate new tuning statistics
+      // In order to avoid printing too many copies, only the float32 object prints
+      output_tuning_data_ = mshadow::DataType<DType>::kFlag == mshadow::kFloat32
+                            && dmlc::GetEnv("MXNET_OUTPUT_TUNING_DATA", false);
+      // If outputting tuning data, then also output verbose logging info
+      OperatorTuneBase::verbose_tuning_info_ = dmlc::GetEnv("MXNET_VERBOSE_TUNING_INFO", false);
+
+      OperatorTuneBase::tuning_weight_scale_ = dmlc::GetEnv("MXNET_TUNING_WEIGHT_SCALE", 0.0);
+
+      // This isn't actually supposed to be multithreaded init, but just to be sure the change is
+      // seen everywhere, using atomic bool.
+      if (!OperatorTuneBase::calculated_.load()) {
+        // Not especially concerned with a race condition, since this hsould
+        // run when only one thread is active (static init), just don't cache this variable
+        OperatorTuneBase::calculated_.store(true);
+        OperatorTuneBase::omp_overhead_ns_ = GetOMPLoopOverhead();
+        std::string config = dmlc::GetEnv("MXNET_USE_OPERATOR_TUNING", std::string());
+        ParseEnablerConfig(config);
+      }
+
+      if (OperatorTuneBase::verbose_tuning_info_) {
+        LOG(INFO) << "OMP overhead: " << OperatorTuneBase::omp_overhead_ns_ << " nanoseconds";
+      }
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Schedule a tuning run
+   * \tparam OP Operator to tune
+   * \param tune_func Function to call which tunes the operator
+   * \return true if the tune operation was scheduled
+   */
+  template<typename OP>
+  static bool ScheduleTune(void (*tune_func)()) {
+#ifdef MXNET_USE_OPERATOR_TUNING
+    if (tune_func) {
+      GetTuningList()->push_back(tune_func);
+      operator_names_.insert(demangle(typeid(OP).name()));
+      return true;
+    }
+    return false;
+#else
+    return true;
+#endif
+  }
+
+  /*!
+   * \brief Is the template parameter type a tuned kernel?
+   * \tparam OP kernel operator type
+   * \return true if the operator/kernel is tuned
+   */
+  template<typename OP>
+  static bool IsTuned() {
+    return operator_names_.find(demangle(typeid(OP).name())) != operator_names_.end();
+  }
+
+  /*!\
+   * \brief Tune all registered kernel operators that haven't already been tuned
+   */
+  static bool TuneAll() {
+    Initialize();
+    std::list<void (*)()> *tl = GetTuningList();
+    const size_t size_save = tl->size();  // For checking if anything asynchronous is
+    // adding or removing items, which is forbidden
+    if (output_tuning_data_ && !tl->empty()) {
+      // Only emit this once, use the most common case, 'float32'
+      if (mshadow::DataType<DType>::kFlag == mshadow::kFloat32) {
+        std::cout << "OperatorTuneBase::duration_t "
+                  << "OperatorTuneBase::omp_overhead_ns_ = " << OperatorTuneBase::omp_overhead_ns_
+                  << ";" << std::endl << std::flush;
+      }
+    }
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (auto i : *tl) {
+      (*i)();
+    }
+    if (OperatorTuneBase::verbose_tuning_info_) {
+      const duration_t duration = OperatorTune::GetDurationInNanoseconds(start);
+      LOG(INFO) << "Op Tuning  for " << type_name<DType>()
+                << " took " << (duration / 1000000) << " ms";
+    }
+    CHECK_EQ(size_save, tl->size()) << "Tuning list size should not have changed while tuning";
+    tl->clear();
+    return true;
+  }
+
+  /*!
+   * \brief Return set of operator names that were registered to be tuned. Does not imply
+   *        that the operator has been tuned.
+   * \return Set of operator/kernel names that were registered for tuning
+   */
+  static const std::unordered_set<std::string>& TunedOperatorNames() {
+    return operator_names_;
+  }
+
+ protected:
+  /*!
+   * \brief Get the list of tuning function calls for the operators
+   * \return Pointer to list of tuning function calls
+   */
+  static std::list<void (*)()> *GetTuningList();
+
+  /*!
+   * \brief Demangle typeid::name() in order to generate source macros
+   * \param name C++ Mangled name
+   * \return Demangled name as string
+   */
+  static inline std::string demangle(const char *name) {
+#if HAS_CXA_DEMANGLE
+    int status = -4;  // some arbitrary value to eliminate the compiler warning
+    std::unique_ptr<char, void (*)(void *)> res{
+      abi::__cxa_demangle(name, nullptr, nullptr, &status),
+      &std::free
+    };
+    return status ? name : res.get();
+#else
+    return name;
+#endif
+  }
+
+  /*!
+   * \brief Type name as string
+   * \tparam T Type
+   * \return std::string representing the human-readable demangled type name
+   */
+  template<typename T> static inline std::string type_name() {
+    return demangle(typeid(T).name());
+  }
+
+  /*! \brief Measure OMP overhead for a trivial OMP loop using all cores
+   * \param omp_thread_count - Number of OMP threads to use in the timing test
+   * \returns Duration in nanoseconds for the OMP overhead (time to initiate and close the
+   *          OMP session)
+   */
+  static duration_t GetOMPLoopOverhead(const size_t omp_thread_count) {
+    CHECK_GT(omp_thread_count, 1);  // Don't try to use OMP for one thread
+    int wl_count = OperatorTuneBase::WORKLOAD_COUNT;
+
+    Tick start = std::chrono::high_resolution_clock::now();
+    // Use two loops in order to simulate OMP outside timing
+    for (size_t i = 0; i < OUTSIDE_COUNT; ++i) {
+      for (int x = 0; x < wl_count; ++x) {
+        // trivial operation
+        volatile_int_ += x;
+      }
+    }
+    const OperatorTuneBase::duration_t no_omp_duration =
+      OperatorTuneBase::GetDurationInNanoseconds(start);
+
+    // Scale OMP iterations by type calculation complexity
+    double factor;
+
+    // if tuning_weight_scale_ is a number that looks valid, use it as the factor
+    if (OperatorTuneBase::tuning_weight_scale_ > 0.01) {
+      factor = OperatorTuneBase::tuning_weight_scale_;
+    } else {
+      // These are empirically-determined constants found by balancing between
+      // a desktop (8 & 12 cpu's) and large cloud instances (32 & 64 cpu's)
+      switch (mshadow::DataType<DType>::kFlag) {
+        case mshadow::kUint8:
+        case mshadow::kInt8:
+          factor = 8.5;
+          break;
+        case mshadow::kInt32:
+          factor = 4.5;
+          break;
+        case mshadow::kInt64:
+          factor = 2;
+          break;
+        case mshadow::kFloat64:
+          factor = 1.25;
+          break;
+        case mshadow::kFloat32:
+        default:
+          factor = 1.0;
+          break;
+      }
+    }
+
+    wl_count = static_cast<int>(factor * OperatorTuneBase::WORKLOAD_COUNT * omp_thread_count);
+    start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < OUTSIDE_COUNT; ++i) {
+      #pragma omp parallel for num_threads(omp_thread_count)
+      for (int x = 0; x < wl_count; ++x) {
+        // trivial operation
+        volatile_int_ += x;
+      }
+    }
+    const duration_t omp_duration = OperatorTuneBase::GetDurationInNanoseconds(start)
+                                    - no_omp_duration;
+    return omp_duration >> OUTSIDE_COUNT_SHIFT;
+  }
+
+  /*! \brief Measure OMP overhead for a trivial OMP loop using all cores
+   * \returns Time in nanoseconds to initialize/cleanup when excuting an OMP block
+   */
+  static duration_t GetOMPLoopOverhead() {
+    // It was found empirically that OMP times was not heavily tied to number of cores,
+    // so take an average across all core counts
+    const auto max_cores = static_cast<size_t>(omp_get_num_procs()) >> 1;
+    if (max_cores >= 2) {
+      std::vector<duration_t> core_times;
+      // Take care of any OMP lazy-init with a throwaway call
+      for (size_t omp_threads = 2; omp_threads <= max_cores; ++omp_threads) {
+        GetOMPLoopOverhead(omp_threads);
+      }
+      std::vector<duration_t> durations;
+      durations.reserve(max_cores - 1);
+      for (size_t omp_threads = 2; omp_threads <= max_cores; ++omp_threads) {
+        const duration_t duration = GetOMPLoopOverhead(omp_threads);
+        if (OperatorTuneBase::verbose_tuning_info_) {
+          LOG(INFO) << "OMP Thread Count: " << omp_threads << ", overhead: " << duration << " ns";
+        }
+        durations.emplace_back(duration);
+      }
+      // return median
+      std::sort(durations.begin(), durations.end());
+      return durations[durations.size() >> 1];
+    }
+    return INT_MAX;  // If only one core, then never use OMP (say the overhead is huge)
+  }
+
+  /*!
+   * \brief Some string utility functions that aren't specific to tuning
+   */
+  struct StringUtil {
+    /*!
+     * \brief Terim whitespace from beninning and end of string
+     * \param s String to trimp
+     * \return reference to the modified string. This is the same std::string object as what was
+     *         supplied in the parameters
+     */
+    static std::string &trim(std::string *s) {
+      s->erase(s->begin(), std::find_if(s->begin(), s->end(), [](int ch) {
+        return !std::isspace(ch);
+      }));
+      s->erase(std::find_if(s->rbegin(), s->rend(), [](int ch) {
+        return !std::isspace(ch);
+      }).base(), s->end());
+      return *s;
+    }
+
+    /*!
+     * \brief Tokenize a string into a list of tokens
+     * \param s String to tokenize
+     * \return std::list of tokens
+     */
+    static std::list<std::string> string2list(const std::string &s) {
+      std::list<std::string> res;
+      std::istringstream iss(s);
+      std::string token;
+      while (std::getline(iss, token, ',')) {
+        trim(&token);
+        if (!token.empty()) {
+          res.push_back(token);
+        }
+      }
+      return std::move(res);
+    }
+  };
+
+  /*!
+   * \brief Get data type from string representation
+   * \warning Do not call from a performance-sensitive area
+   */
+  static int type_from_string(const std::string& type_string) {
+    if (type_string == "float32")
+      return mshadow::kFloat32;
+    if (type_string == "float64")
+      return mshadow::kFloat64;
+    if (type_string == "float16")
+      return mshadow::kFloat16;
+    if (type_string == "int8")
+      return mshadow::kInt8;
+    if (type_string == "uint8")
+      return mshadow::kUint8;
+    if (type_string == "int32")
+      return mshadow::kInt32;
+    if (type_string == "int64")
+      return mshadow::kInt64;
+    return -1;  // invalid
+  }
+
+  /*!
+   * \brief Parse MXNET_ENABLE_OPERATOR_TUNING environment variable
+   * \param config String representation of MXNET_ENABLE_OPERATOR_TUNING environment variable
+   *        Values:
+   *            0=disable all
+   *            1=enable all
+   *            float32, float16, float32=list of types to enable, and disable those not listed
+   */
+  static void ParseEnablerConfig(std::string config) {
+    StringUtil::trim(&config);
+    if (!config.empty()) {
+      // First disable all
+      OperatorTuneByType<float>::set_tuning_mode(tune::kAlwaysOMP);
+      OperatorTuneByType<double>::set_tuning_mode(tune::kAlwaysOMP);
+      OperatorTuneByType<int8_t>::set_tuning_mode(tune::kAlwaysOMP);
+      OperatorTuneByType<uint8_t>::set_tuning_mode(tune::kAlwaysOMP);
+      OperatorTuneByType<int32_t>::set_tuning_mode(tune::kAlwaysOMP);
+      OperatorTuneByType<int64_t>::set_tuning_mode(tune::kAlwaysOMP);
+      // See if it's a non-number (ie type or list of types)
+      if (!::isdigit(config[0])) {
+        OperatorTuneByType<mshadow::half::half_t>::set_tuning_mode(tune::kAuto);
+        std::list<std::string> tokens = StringUtil::string2list(config);
+        for (const std::string& stype : tokens) {
+          // We don't have an enum for halt_t
+          const int typ = type_from_string(stype);
+          if (typ >= 0) {
+            switch (typ) {
+              case mshadow::kFloat32:
+                OperatorTuneByType<float>::set_tuning_mode(tune::kAuto);
+                break;
+              case mshadow::kFloat64:
+                OperatorTuneByType<double>::set_tuning_mode(tune::kAuto);
+                break;
+              case mshadow::kFloat16:
+                OperatorTuneByType<mshadow::half::half_t>::set_tuning_mode(tune::kAuto);
+                break;
+              case mshadow::kInt8:
+                OperatorTuneByType<int8_t>::set_tuning_mode(tune::kAuto);
+                break;
+              case mshadow::kUint8:
+                OperatorTuneByType<uint8_t>::set_tuning_mode(tune::kAuto);
+                break;
+              case mshadow::kInt32:
+                OperatorTuneByType<int32_t>::set_tuning_mode(tune::kAuto);
+                break;
+              case mshadow::kInt64:
+                OperatorTuneByType<int64_t>::set_tuning_mode(tune::kAuto);
+                break;
+              default:
+                CHECK(false) << "Unsupported tuning data type: " << stype;
+                break;
+            }
+          } else {
+            // -1 is error
+            LOG(WARNING) << "Unknown data type to be tuned: " << stype;
+          }
+        }
+      } else {
+        if (std::atoi(config.c_str()) > 0) {
+          OperatorTuneByType<float>::set_tuning_mode(tune::kAuto);
+          OperatorTuneByType<double>::set_tuning_mode(tune::kAuto);
+          OperatorTuneByType<int8_t>::set_tuning_mode(tune::kAuto);
+          OperatorTuneByType<uint8_t>::set_tuning_mode(tune::kAuto);
+          OperatorTuneByType<int32_t>::set_tuning_mode(tune::kAuto);
+          OperatorTuneByType<int64_t>::set_tuning_mode(tune::kAuto);
+          OperatorTuneByType<mshadow::half::half_t>::set_tuning_mode(tune::kAuto);
+        }
+      }
+    }
+  }
+
+  /*! \brief Whether this object has been initialized */
+  static bool initialized_;
+  /*! \brief Number of passes to obtain an average */
+  static constexpr duration_t OUTSIDE_COUNT = (1 << OUTSIDE_COUNT_SHIFT);
+  /*! \brief Random data for timing operator calls */
+  static std::vector<DType> data_set_;
+  /*! \brief Operators tuned */
+  static std::unordered_set<std::string> operator_names_;
+  /*! \brief Arbitary object to modify in OMP loop */
+  static volatile int volatile_int_;
+  /*! \brief Output insertable (into code) instantiation+default-value macros */
+  static bool output_tuning_data_;
+};
+
+/*!
+ * \brief Class that tunes unary operators
+ * \tparam DType Data type to be used when tuning the kernel operations
+ */
+template<typename DType>
+class UnaryOpTune : public OperatorTune<DType> {
+ protected:
+  typedef OperatorTune<DType> Super;
+  using duration_t = typename Super::duration_t;
+  using Tick = typename Super::Tick;
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for kernels that take no arguments (ie set_zero)
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static duration_t GetBlankWorkload() {
+    DType tmp;
+    volatile DType *res = &tmp;
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      // Use a logical AND instead of mod to avoid affecting the timing result with a slow divide
+      *res += OP::Map();
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for kernels that take one argument (ie sqrt())
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static duration_t GetUnaryWorkload() {
+    DType tmp;
+    volatile DType *res = &tmp;
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      // Use a logical AND instead of mod to avoid affecting the timing result with a slow divide
+      *res = OP::Map(Super::data_set_[i & 0xFF]);
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for kernels that take two arguments (ie elemwise_add())
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static inline duration_t GetBinaryWorkload() {
+    DType tmp;
+    volatile DType *res = &tmp;
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      // Use a logical AND instead of mod to avoid affecting the timing result with a slow divide
+      *res = OP::Map(Super::data_set_[i & 0xFF], Super::data_set_[(i + 1) & 0xFF]);
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for kernels that take three arguments (ie backwards_grad<elemwise_add>())
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static duration_t GetTertiaryWorkload() {
+    DType tmp;
+    volatile DType *res = &tmp;
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      // Use a logical AND instead of mod to avoid affecting the timing result with a slow divide
+      *res = OP::Map(Super::data_set_[i & 0xFF],
+                     Super::data_set_[(i + 1) & 0xFF],
+                     Super::data_set_[i & 0xFF]);
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for mxnet-like kernels that take no arguments)
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static duration_t GetBlankWorkloadEx() {
+    std::unique_ptr<DType> tmp(new DType[Super::WORKLOAD_COUNT]);
+    DType *tmp_ptr = tmp.get();
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      OP::Map(i, tmp_ptr);
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+ public:
+  /*!
+   * \brief Tune the specified kernel operator.  Optionally print out C++ macro that defines the
+   *        tuning data variable and the default tuned value
+   *        This function tunes an operator which takes no arguments
+   * \tparam OP The kernel operator to be tuned
+   */
+  template<typename OP>
+  static void TuneBlankOperator() {
+    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_ = GetBlankWorkload<OP>();
+    if (Super::output_tuning_data_) {
+      std::cout << "IMPLEMENT_UNARY_WORKLOAD_FWD("
+                << Super::template type_name<OP>()
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Tune the specified kernel operator.  Optionally print out C++ macro that defines the
+   *        tuning data variable and the default tuned value
+   *        This function tunes an operator which takes one argument
+   * \tparam OP The kernel operator to be tuned
+   */
+  template<typename OP>
+  static void TuneUnaryOperator() {
+    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_ = GetUnaryWorkload<OP>();
+    if (Super::output_tuning_data_) {
+      std::cout << "IMPLEMENT_UNARY_WORKLOAD_FWD("
+                << Super::template type_name<OP>()
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Tune the specified kernel operator.  Optionally print out C++ macro that defines the
+   *        tuning data variable and the default tuned value
+   *        This function tunes a backward operator which takes one argument
+   * \tparam OP The kernel operator to be tuned
+   */
+  template<typename OP>
+  static void TuneUnaryBackwardOperator() {
+    mxnet::op::mxnet_op::tuned_op<mxnet_op::backward_grad<OP>, DType>::workload_ =
+      GetBinaryWorkload<mxnet::op::mxnet_op::backward_grad<OP>>();
+    if (Super::output_tuning_data_) {
+      std::cout << "IMPLEMENT_UNARY_WORKLOAD_BWD("
+                << Super::template type_name<OP>()
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Tune the specified "mxnet_op-type" kernel operator.
+   *        Optionally print out C++ macro that defines the
+   *        tuning data variable and the default tuned value
+   *        This function tunes an operator which takes no arguments
+   * \tparam OP The kernel operator to be tuned
+   */
+  template<typename OP>
+  static void TuneBlankOperatorEx() {
+    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_ = GetBlankWorkloadEx<OP>();
+    if (Super::output_tuning_data_) {
+      std::cout << "IMPLEMENT_BLANK_WORKLOAD_FWD("
+                << Super::template type_name<OP>()
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Determine whether to use OMP based upon both timing and configuration using the
+   *        given (templated) operator's workload
+   * \tparam OP Operator whose workload to use (tuned_op::workload_)
+   * \param N Number of iterations desired
+   * \param thread_count Number of OMP threads available to perform the iterations
+   * \returns Whether it's faster to use OMP for these iterations
+   */
+  template<typename OP>
+  inline static bool UseOMP(size_t N, size_t thread_count) {
+      return OperatorTune<DType>::UseOMP(N,
+                                         thread_count,
+                                         static_cast<uint64_t>(N) * OP::workload_);
+  }
+};
+
+/*!
+ * \brief Class that tunes binary and unary operators
+ * \tparam DType Data type to be used when tuning the kernel operations
+ */
+template<typename DType>
+class BinaryOpTune : public UnaryOpTune<DType> {
+ protected:
+  typedef UnaryOpTune<DType> Super;
+
+ public:
+  /*!
+   * \brief Tune a generic binary operator
+   * @tparam OP - Operator type
+   */
+  template<typename OP>
+  static void TuneBinaryOperator() {
+    mxnet_op::tuned_op<OP, DType>::workload_ = Super::template GetBinaryWorkload<OP>();
+    if (Super::Super::output_tuning_data_) {
+      std::cout << "IMPLEMENT_BINARY_WORKLOAD_FWD("
+                << Super::template type_name<OP>()
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Tune binary backward operator
+   * \tparam OP - operator
+   */
+  template<typename OP>
+  static void TuneBinaryBackwardOperator() {
+    mxnet::op::mxnet_op::tuned_op<mxnet_op::backward_grad<OP>, DType>::workload_ =
+      Super::template GetTertiaryWorkload<mxnet::op::mxnet_op::backward_grad<OP>>();
+    if (Super::Super::output_tuning_data_) {
+      std::cout << "IMPLEMENT_BINARY_WORKLOAD_BWD("
+                << Super::template type_name<OP>()
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+};
+
+#undef OUTSIDE_COUNT_SHIFT
+#undef WORKLOAD_COUNT_SHIFT
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_OPERATOR_TUNE_INL_H_
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
new file mode 100644
index 0000000..525a66b
--- /dev/null
+++ b/src/operator/operator_tune.cc
@@ -0,0 +1,347 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <atomic>
+#include "./mxnet_op.h"
+#include "./mshadow_op.h"
+#include "./tensor/init_op.h"
+#include "./operator_tune-inl.h"
+#include "./tensor/elemwise_binary_broadcast_op.h"
+
+namespace mxnet {
+namespace op {
+
+/*!
+ * \brief Shared static variables for all OperatorTune data types
+ */
+std::atomic<bool> OperatorTuneBase::calculated_(false);
+bool OperatorTuneBase::verbose_tuning_info_ = false;
+double OperatorTuneBase::tuning_weight_scale_ = 0.0;
+
+/*!
+ * \brief Instantiate static variables for OperatorTune<DType>, where 'DType' is specified
+ */
+#define IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(__typ$) \
+  template<> bool OperatorTune<__typ$>::initialized_ = false; \
+  template<> std::vector<__typ$> OperatorTune<__typ$>::data_set_ = {}; \
+  template<> volatile tune::TuningMode OperatorTuneByType<__typ$>::tuning_mode_ = tune::kAuto; \
+  template<> volatile int OperatorTune<__typ$>::volatile_int_ = 9;  /* arbitrary number */ \
+  template<> std::unordered_set<std::string> OperatorTune<__typ$>::operator_names_ = {}; \
+  template<> bool OperatorTune<__typ$>::output_tuning_data_ = false; \
+  template<> std::list<void (*)()> *OperatorTune<__typ$>::GetTuningList() { \
+    static std::list<void (*)()> ll; \
+    return &ll; \
+  }
+
+/*!
+ * \brief Static variables for different types (ie OperatorTune<float>, OperatorTune<double>, etc.
+ */
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(float);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(double);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(mshadow::half::half_t);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int8_t);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(uint8_t);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int32_t);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int64_t);
+
+/*!
+ * \brief Init variable used to facilitate registering a tunable operator during
+ *        static initialization
+ * \tparam OP Operator type
+ * \tparam DType Data type
+ */
+template<typename OP, typename DType>
+struct static_init_var {
+  static bool init_;
+};
+
+/*!
+ * \brief Repeat the given macro and associated arguments for each data type,
+ *        appending the data type to the end of the arguments
+ */
+#define MSHADOW_MACRO_FOREACH_TYPE(__macro$, ...) \
+  __macro$(__VA_ARGS__, float); \
+  __macro$(__VA_ARGS__, double); \
+  __macro$(__VA_ARGS__, mshadow::half::half_t); \
+  __macro$(__VA_ARGS__, uint8_t); \
+  __macro$(__VA_ARGS__, int8_t); \
+  __macro$(__VA_ARGS__, int32_t); \
+  __macro$(__VA_ARGS__, int64_t);
+
+
+#define IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$) \
+  namespace mxnet_op { \
+  template<> size_t mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::workload_ = INT_MAX / 4; \
+  template<> std::vector<float> mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::workload_ex_ = {}; \
+  }  /* namespace mxnet_op */
+
+/*!
+ * \brief Implement tuning objects for a forward blank (no arguments) kernel operator
+ */
+#define _IMPLEMENT_BLANK_WORKLOAD_FWD(__op$, __typ$) \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$); \
+  namespace mxnet_op { \
+  template<> bool mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
+      N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<__op$, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      mxnet::op::UnaryOpTune<__typ$>::TuneBlankOperatorEx<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a forward unary kernel operator
+ */
+#define _IMPLEMENT_UNARY_WORKLOAD_FWD(__op$, __typ$) \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$); \
+  namespace mxnet_op { \
+  template<> bool mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
+      N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<__op$, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      mxnet::op::UnaryOpTune<__typ$>::TuneUnaryOperator<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a backward unary kernel operator
+ */
+#define _IMPLEMENT_UNARY_WORKLOAD_BWD(__op$, __typ$) \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(mxnet::op::mxnet_op::backward_grad<__op$>, __typ$); \
+  namespace mxnet_op { \
+  template<> \
+  bool mxnet::op::mxnet_op::tuned_op<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op< \
+      mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>>(N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      mxnet::op::UnaryOpTune<__typ$>::TuneUnaryBackwardOperator<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a forward binary kernel operator
+ */
+#define _IMPLEMENT_BINARY_WORKLOAD_FWD(__op$, __typ$) \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$); \
+  namespace mxnet_op { \
+  template<> bool mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::BinaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
+      N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<__op$, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      mxnet::op::BinaryOpTune<__typ$>::TuneBinaryOperator<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a backward binary kernel operator
+ */
+#define _IMPLEMENT_BINARY_WORKLOAD_BWD(__op$, __typ$) \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(mxnet::op::mxnet_op::backward_grad<__op$>, __typ$); \
+  namespace mxnet_op { \
+  template<> \
+    bool mxnet::op::mxnet_op::tuned_op<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::BinaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op< \
+      mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>>(N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>(  \
+      mxnet::op::BinaryOpTune<__typ$>::TuneBinaryBackwardOperator<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a custom forward kernel operator
+ */
+#define _IMPLEMENT_CUSTOM_WORKLOAD_FWD(__op$, __typ$) \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$<__typ$>, __typ$); \
+  template<> bool static_init_var<__op$<__typ$>, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$<__typ$>>(\
+      __op$<__typ$>::Tune)
+
+/*!
+ * \brief Macros for manually adding new blank, unary and binary operators to the tuning set
+ */
+#define IMPLEMENT_UNARY_WORKLOAD_FWD(__op$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_UNARY_WORKLOAD_FWD, __op$)
+
+#define IMPLEMENT_BLANK_WORKLOAD_FWD(__op$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BLANK_WORKLOAD_FWD, __op$)
+
+#define IMPLEMENT_UNARY_WORKLOAD_BWD(__op$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_UNARY_WORKLOAD_BWD, __op$)
+
+#define IMPLEMENT_BINARY_WORKLOAD_FWD(__op$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BINARY_WORKLOAD_FWD, __op$)
+
+#define IMPLEMENT_BINARY_WORKLOAD_BWD(__op$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BINARY_WORKLOAD_BWD, __op$)
+
+#define IMPLEMENT_CUSTOM_WORKLOAD_FWD(__op$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_CUSTOM_WORKLOAD_FWD, __op$)
+
+/*!
+ * \brief Tuning data and default weights in the case that MXNET_ENABLE_OPERATOR_AUTOTUNE is set
+ *        to zero (thus turning off auto-tuning)
+ * \note This code can be automatically generated
+ *       by setting the environment variable MXNET_OUTPUT_TUNING_DATA to a positive
+ *       integer value
+ */
+OperatorTuneBase::duration_t OperatorTuneBase::omp_overhead_ns_ = 5000;
+IMPLEMENT_UNARY_WORKLOAD_FWD(mshadow::op::identity);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::identity);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::identity_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::negation);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softrelu_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::exp);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::exp);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::expm1);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log1p);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log1p_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log2);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log2_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log10);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log10_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sin_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sinh);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sinh_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsin);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsin_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsinh);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsinh_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cos);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cos_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cosh);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cosh_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccos);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccos_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccosh);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccosh_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tan);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tan_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctan);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctanh);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctanh_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square_root);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_root_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_square_root);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_square_root_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cube_root);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cube_root_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_cube_root);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_cube_root_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::abs);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sign);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rint);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fix);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gamma_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gammaln);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gammaln_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ceil);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::degrees);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::degrees_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::radians);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::radians_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::plus);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::minus);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::mul);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::div);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::right);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rdiv);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_rgrad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_rgrad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mod);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_rgrad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rmod);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rmod_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::left);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::left);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::right);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::right);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::power);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rpower);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_left);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_left);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_right);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_right);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lt);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::lt);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::le);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::le);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gt);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gt);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ge);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ge);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient);  // NOLINT()
+IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>);  // NOLINT()
+IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<1>);  // NOLINT()
+IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel);  // NOLINT()
+/*!
+ * \brief Tuner objects, *not* automatically generated
+ */
+#ifdef MXNET_USE_OPERATOR_TUNING
+static BinaryOpTune<float>                  binaryOpTuneFloat;
+static BinaryOpTune<double>                 binaryOpTuneDouble;
+static BinaryOpTune<mshadow::half::half_t>  binaryOpTuneHalf;
+static BinaryOpTune<int8_t>                 binaryOpTuneInt8;
+static BinaryOpTune<uint8_t>                binaryOpTuneUInt8;
+static BinaryOpTune<int32_t>                binaryOpTuneInt32;
+static BinaryOpTune<int64_t>                binaryOpTuneInt64;
+#endif  // MXNET_USE_OPERATOR_TUNING
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/operator_tune.h b/src/operator/operator_tune.h
new file mode 100644
index 0000000..4f92c9d
--- /dev/null
+++ b/src/operator/operator_tune.h
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef MXNET_OPERATOR_OPERATOR_TUNE_H_
+#define MXNET_OPERATOR_OPERATOR_TUNE_H_
+
+#include <mshadow/base.h>
+#include <mshadow/tensor.h>
+#include <vector>
+#include <set>
+#include <atomic>
+
+namespace mxnet {
+namespace op {
+
+#define WORKLOAD_COUNT_SHIFT  11
+
+/*!
+ * \brief Shared data for all data types being tuned, acts as a base class for the higher-level
+ *        templated tunin classes
+ */
+class OperatorTuneBase {
+ public:
+  typedef int64_t duration_t;
+
+ protected:
+  /*! \brief Have calculated omp_overhead_ yet? */
+  static std::atomic<bool> calculated_;
+  /*! \brief Time in nanoseconds for OMP overhead */
+  static duration_t omp_overhead_ns_;
+  /*! \brief Print debug/trace output for tuning info */
+  static bool verbose_tuning_info_;
+  /*! \brief Tuning scale factor */
+  static double tuning_weight_scale_;
+
+ public:
+  typedef std::chrono::high_resolution_clock::time_point Tick;
+
+  /*!
+   * \brief Get timestamp for "now"
+   * \return Tick object representing the current itmestamp
+   */
+  static MSHADOW_CINLINE Tick Now() {
+    return std::move(std::chrono::high_resolution_clock::now());
+  }
+
+  /*!
+   * \brief Get duration in nanoseconds
+   * \param t1 Start time tick
+   * \param t2 End time tick
+   * \return duration in nanoseconds between t1 and t2
+   */
+  static MSHADOW_CINLINE duration_t GetDurationInNanoseconds(const Tick &t1, const Tick &t2) {
+    return static_cast<duration_t>(
+      std::chrono::duration_cast<std::chrono::nanoseconds>(t2 - t1).count());
+  }
+
+  /*!
+   * \brief Get duration in nanoseconds between the given 'since' value and now
+   * \param since Reference time which to calculate the duration
+   * \return Duration in nanoseconds between the given 'since' value and now
+   */
+  static MSHADOW_CINLINE duration_t GetDurationInNanoseconds(const Tick &since) {
+    return GetDurationInNanoseconds(since, Now());
+  }
+
+  /*! \brief Loop size to be timed (single op nanos may be too small to store accurately) */
+  static constexpr duration_t WORKLOAD_COUNT = (1 << WORKLOAD_COUNT_SHIFT);
+
+  /*!
+   * \brief Timer convenience class, sets start time as "now" in the constructor
+   */
+  struct Timer {
+    /*!
+     * \brief Constructor, sets start time
+     */
+    MSHADOW_CINLINE Timer()
+      : start_(OperatorTuneBase::Now()) {}
+    /*!
+     * \brief Get duration in nanoseconds since construction
+     * \return Duration in nanoseconds since construction
+     */
+    MSHADOW_CINLINE int64_t duration() const {
+      return OperatorTuneBase::GetDurationInNanoseconds(start_);
+    }
+
+    /*!
+     * \brief Reference start time, set in constructor
+     */
+    const OperatorTuneBase::Tick start_;
+  };
+
+  /*!
+   * \brief Estimate the time to compute with and without OMP, then return whether OMP is faster
+   * \param N - Number of iterations desired
+   * \param thread_count - Number of OMP threads available to perform the iterations
+   * \returns Whether it's faster to use OMP for these iterations
+   */
+  inline static bool IsOMPFaster(size_t N, size_t thread_count, const uint64_t serial_workload) {
+    if (thread_count >= 2) {
+      // Compute serial time required
+      const uint64_t total_serial_time_ns = serial_workload >> WORKLOAD_COUNT_SHIFT;
+
+      // Compute time required for OMP + # items per thread
+      const uint64_t omp_compute_time_ns = (serial_workload / thread_count) >> WORKLOAD_COUNT_SHIFT;
+      const uint64_t total_omp_time_ns = omp_overhead_ns_ + omp_compute_time_ns;
+
+      const bool rc = total_omp_time_ns < total_serial_time_ns;
+      return rc;
+    }
+    return false;
+  }
+};
+
+namespace tune {
+/*!
+ * \brief Tuning mode for registered kernel operators
+ */
+enum TuningMode {
+  kAuto,         // Based upon tuning data, choose whether to use OMP for kernel CPU Launch() loops
+  kNeverOMP,     // Don't use OMP for parallelism (legacy behavior for GPU builds)
+  kAlwaysOMP     // Don't use OMP for parallelism (legacy behavior for CPU builds)
+};
+}  // namespace tune
+
+template<typename DType>
+class OperatorTuneByType : public OperatorTuneBase {
+ public:
+  /*!
+   * \brief Set tuning mode
+   * \param tuning_mode The tune::TuningMode tuning mode value to set
+   */
+  static MSHADOW_CINLINE void set_tuning_mode(const tune::TuningMode tuning_mode) {
+    // Use const_cast to get past "assigning non-volatile to volatile warning
+    const_cast<tune::TuningMode &>(tuning_mode_) = tuning_mode;
+  }
+
+  /*!
+   * \brief Get the current tuning mode
+   * \return tune::TuningMode value for the current tuning mode
+   */
+  static MSHADOW_CINLINE volatile tune::TuningMode tuning_mode() {
+    return tuning_mode_;
+  }
+
+  /*!
+   * \brief Determine whether to use OMP based upon both timing and configuration
+   * \param N - Number of iterations desired
+   * \param thread_count - Number of OMP threads available to perform the iterations
+   * \returns Whether it's faster to use OMP for these iterations
+   */
+  inline static bool UseOMP(size_t N, size_t thread_count, const uint64_t serial_workload) {
+#ifdef MXNET_USE_OPERATOR_TUNING
+    switch (tuning_mode()) {
+      case tune::kAuto:
+        return OperatorTuneBase::IsOMPFaster(N, thread_count, serial_workload);
+      case tune::kNeverOMP:
+        return false;
+      case tune::kAlwaysOMP:
+      default:
+        return thread_count > 1;
+    }
+#else
+    return true;
+#endif
+  }
+
+ protected:
+  /*! \brief Tuning mode */
+  static volatile tune::TuningMode tuning_mode_;
+};
+
+namespace mxnet_op {
+/*!
+ * \brief Kernel operator wrapper used for tuning data
+ */
+template<typename Operation, typename DType>
+struct tuned_op : public Operation {
+  /*! \brief nanoseconds to perform WORKLOAD_COUNT operations
+   *  \note It is conceivable that a vector of values could be used for more complex tuning,
+   *        but the need hasn't yet arisen
+   *  \remarks This variable generally needs to be implemented somewhere.  Currently this is mostly
+   *           done via macros in operator_tune.cc.  If you get undefined reference errors when
+   *           linking, then try to use one of the macros in that file to instantiate the required
+   *           data/functions
+   */
+  static size_t workload_;
+
+  /*!
+   * \brief Extra workload-calculating information (ie times for sub-portions of the calculation)
+   */
+  static std::vector<float> workload_ex_;
+
+  /*!
+   * \brief Calls parent class (Operation)'s UseOMP
+   * \tparam Args Variable arguments passed
+   * \param N Number of iterations
+   * \param thread_count Number of threads available
+   * \param args Variable arguments passed
+   * \return true if OMP parallelism is recommended
+   */
+  template<typename ...Args>
+  static MSHADOW_CINLINE bool UseOMP(size_t N, size_t thread_count, Args... args) {
+    return Operation::UseOMP(N, thread_count, args...);
+  }
+
+  /*!
+   * \brief Call a standard UseOMP() implementation (if it exists). Currently, these
+   *        are implemented in operator_tune.cc for standard unary, binary,
+   *        and argumentless kernels (i.e. mshadow_op::sqrt)
+   * \tparam Args Variable arguments passed
+   * \param N Number of iterations
+   * \param thread_count Number of threads available
+   * \param args Variable arguments passed
+   * \return true if OMP parallelism is recommended
+   */
+  static bool UseOMP(size_t N, size_t thread_count);
+};
+}  // namespace mxnet_op
+
+/*!
+ * \brief Calculate workload for a given lambda function
+ * \tparam Function Lambda type to time for WORKLOAD_COUNT calls
+ * \param function Lambda to time for WORKLOAD_COUNT calls
+ * \return median workload for function call (nanoseconds for WORKLOAD_COUNT calls)
+ */
+template<typename Function>
+inline int64_t get_workload(Function function) {
+  std::multiset<int64_t> durations;
+  typename OperatorTuneBase::Timer timer;
+  for (int pass = 0; pass < 3; ++pass) {
+    for (int i = 0; i < OperatorTuneBase::WORKLOAD_COUNT; ++i) {
+      function();
+    }
+  }
+  const OperatorTuneBase::duration_t dd = timer.duration();
+  durations.insert(dd);
+  return *++durations.begin();  // return median value
+}
+
+/*!
+ * \brief Declare a template specialization for the Kernel::Launch call for the given OP
+ *        wrapped with mxnet_op::op_with_req, using the given OpReqType as the 'req'
+ *        template parameter for 'op_with_req'.  This is useful for the standard mshadow_op
+ *        operators which need to be wrapped with op_with_req in order to be used with the
+ *        Kernel::Launch command.
+ *
+ * \note Expects to be used within the mxnet::op namespace
+ *
+ * For example:
+ *
+ * namespace mxnet_op {
+ * template <>
+ * template <typename... Args>
+ * inline void Kernel<typename mxnet_op::op_with_req<mshadow::op::identity, kNullOp>, cpu>
+ *   ::Launch(mshadow::Stream<cpu>* s, const int N, Args... args) {
+ *   ::mxnet::op::mxnet_op::Kernel<typename mxnet_op::op_with_req<mshadow::op::identity, kNullOp>,
+ *     cpu>::LaunchMShadowOpEx(s, N, args...);
+ *   }
+ * }
+ *
+ */
+#define MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, __req$) \
+  namespace mxnet_op { \
+  template<> template<typename ...Args> \
+  inline void Kernel<typename mxnet_op::op_with_req<__op$, __req$>, ::mshadow::cpu>:: \
+    Launch(mshadow::Stream<::mshadow::cpu> *s, const int N, Args... args) { \
+      /* Launch via LaunchMShadowOpEx() */ \
+      KernelWrapper<typename mxnet_op::op_with_req<__op$, __req$>, ::mshadow::cpu>:: \
+        LaunchMShadowOpEx(s, N, args...); \
+  } \
+  }  /* namespace mxnet_op */
+
+/*!
+ * \brief Declare template specializations for the Kernel::Launch call for the given OP
+ *        wrapped with mxnet_op::op_with_req, using the all supported OpReqType as the 'req'
+ *        template parameter for 'op_with_req'.  This is useful for the standard mshadow_op
+ *        operators which need to be wrapped with op_with_req in order to be used with the
+ *        Kernel::Launch command.
+ * \note Expects to be used within the mxnet::op namespace
+ */
+#define MXNET_TUNABLE_MSHADOW_OP(__op$) \
+  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kNullOp); \
+  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kWriteTo); \
+  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kWriteInplace); \
+  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kAddTo);
+
+#define MXNET_TUNABLE_MSHADOW_OP_BACKWARD(__op$) \
+  MXNET_TUNABLE_MSHADOW_OP(mxnet::op::mxnet_op::backward_grad<__op$>)
+
+#define MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(__op$) \
+  MXNET_TUNABLE_MSHADOW_OP(__op$) \
+  MXNET_TUNABLE_MSHADOW_OP_BACKWARD(__op$)
+
+/*!
+ * \brief mxnet::op::mxnet_op format ops (work directly with Kernel<>::Launch()
+ *        Used from within mxnet::op::mxnet_op namespace
+ */
+#define _MXNET_TUNABLE_MXNET_OP_FWD(__op$) \
+  template<> template<typename ...Args> inline void Kernel<__op$, ::mshadow::cpu>::Launch( \
+    mshadow::Stream<::mshadow::cpu> *s, const int N, Args... args) { \
+      /* Launch via LaunchMXNetOpEx() */ \
+      KernelWrapper<__op$, ::mshadow::cpu>::LaunchMXNetOpEx(s, N, args...); \
+  }
+
+/*!
+ * \brief mxnet::op::mxnet_op format ops (work directly with Kernel<>::Launch()
+ *        Used from within mxnet::op
+ */
+#define MXNET_TUNABLE_MXNET_OP_FWD(__op$) \
+  namespace mxnet_op { _MXNET_TUNABLE_MXNET_OP_FWD(__op$) }  /* namespace mxnet_op */
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_OPERATOR_TUNE_H_
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 211b567..2317c98 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -141,16 +141,16 @@ struct binary_broadcast_kernel {
                                   const Shape<ndim>& lstride, const Shape<ndim>& rstride,
                                   const Shape<ndim>& oshape, DType* lhs, DType* rhs,
                                   DType* out, int lsize, int rsize) {
-    Shape<ndim> coord = unravel(base, oshape);
+      Shape <ndim> coord = unravel(base, oshape);
     index_t lidx = dot(coord, lstride);
     index_t ridx = dot(coord, rstride);
-    KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx]));
-    // starts from 1 to avoid extra inc at end of loop
-    for (int i = 1; i < length; ++i) {
-      inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
-      KERNEL_ASSIGN(out[base+i], req, OP::Map(lhs[lidx], rhs[ridx]));
+      KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx]));
+      // starts from 1 to avoid extra inc at end of loop
+      for (int i = 1; i < length; ++i) {
+        inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
+        KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx]));
+      }
     }
-  }
 };
 
 }  // namespace mxnet_op
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 1d30c88..95e8184 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -272,6 +272,7 @@ struct PopulateFullIdxRspKernel {
     KERNEL_ASSIGN(out[i], kWriteTo, i);
   }
 };
+MXNET_TUNABLE_MXNET_OP_FWD(PopulateFullIdxRspKernel);
 
 // Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray,
 // instead of the usual compact representation.
diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h
index 1bcd0e2..51cbcd7 100644
--- a/tests/cpp/include/test_core_op.h
+++ b/tests/cpp/include/test_core_op.h
@@ -36,15 +36,15 @@ namespace op {
 #define COREOP_BWD_OP_NAME_VALUE_NONE   "[none]"
 
 enum TimingDirection {
-  Forward,
-  Backward
+  kForward,
+  kBackward
 };
 
 inline const char *TimingDirectionAsString(const TimingDirection td) {
   switch (td) {
-    case Forward:
+    case kForward:
       return "Forward";
-    case Backward:
+    case kBackward:
       return "Backward";
     default:
       CHECK(false) << "Unknown timing direction: " << static_cast<int>(td);
@@ -426,7 +426,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
   inline bool initBackward(const OpProp &opProp, std::vector<int> *in_type) { return true; }
 
   inline void forward(const size_t count) {
-    perf::TimingItem timeF(&OperatorExecutorTiming::GetTiming(), Forward, "Forward", count);
+    perf::TimingItem timeF(&OperatorExecutorTiming::GetTiming(), kForward, "Forward", count);
     VTuneResume profile;
     for (size_t i = 0; i < count; ++i) {
       Execute();
@@ -435,7 +435,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
 
   inline void backward(const size_t count) {
     CHECK(HasBackward());
-    perf::TimingItem timeF(&OperatorExecutorTiming::GetTiming(), Backward, "Backward", count);
+    perf::TimingItem timeF(&OperatorExecutorTiming::GetTiming(), kBackward, "Backward", count);
     VTuneResume profile;
     for (size_t i = 0; i < count; ++i) {
       ExecuteBackward();
diff --git a/tests/cpp/include/test_op_runner.h b/tests/cpp/include/test_op_runner.h
index 3b06b1a..0992c41 100644
--- a/tests/cpp/include/test_op_runner.h
+++ b/tests/cpp/include/test_op_runner.h
@@ -44,6 +44,14 @@ class OperatorRunner {
  public:
   typedef typename OperatorExecutor::DataType    DType;
 
+  OperatorRunner() {
+#ifdef NDEBUG
+    total_iterations_ = 50;
+#else
+    total_iterations_ = 5;
+#endif
+  }
+
   /*!
    * \brief Test operator forward pass
    * \param isGPU Whether this test is for GPU
@@ -130,33 +138,34 @@ class OperatorRunner {
              int dim = 0,
              size_t count = 1,
              const std::vector<TShape>& timing_shapes = {}) {
-#ifdef NDEBUG
-    size_t COUNT = 50;
-#else
-    size_t COUNT = 5;
-#endif
     if (mxnet::test::quick_test) {
-      COUNT = 2;
+      total_iterations_ = 2;
       count = 1;
     }
 
     test::perf::TimingInstrument timing;
 
     std::stringstream ss;
-    ss << "Timing: " << COUNT << " iterations of " << count << " calls";
+    ss << "Timing: " << total_iterations_ << " iterations of " << count << " calls";
     if (timing_shapes[0].ndim()) {
+      size_t lhs_total = 0;
       ss << ", shape = ";
       for (size_t i = 0, n = timing_shapes.size(); i < n; ++i) {
         if (i) {
           ss << ", ";
         }
         ss << timing_shapes[i];
+        if (!i) {
+          lhs_total = timing_shapes[i].Size();
+        }
       }
-      ss << std::endl << std::flush;
+      ss << " = " << test::pretty_num(lhs_total) << " items " << std::endl << std::flush;
+    }
+    if (!mxnet::test::csv) {
+      std::cout << ss.str();
     }
-    std::cout << ss.str();
 
-    for (size_t i = 0; i < COUNT; ++i) {
+    for (size_t i = 0; i < total_iterations_; ++i) {
       index_t batchSize = 1;
       index_t channels = 1;
       index_t depth = 1;
@@ -223,16 +232,17 @@ class OperatorRunner {
       }
     }
 
-    if (verbose_) {
+    if (verbose_ && !mxnet::test::csv) {
       timing.print(&std::cout, label);
       std::cout << std::endl << std::flush;
     }
-
     return timing.data();
   }
 
   void set_verbose(bool verbose) { verbose_ = verbose; }
 
+  void set_total_iterations(size_t iterations) { total_iterations_ = iterations; }
+
  protected:
   static constexpr int TEST_BATCH_SIZE = 5;
   static constexpr int TEST_CHANNELS = 3;
@@ -247,6 +257,8 @@ class OperatorRunner {
   static constexpr int TIMING_DW = 64;
   /*! \brief verbose output */
   bool verbose_ = true;
+  /*! \brief Tital iterations */
+  size_t total_iterations_ = 10;
 };
 
 }  // namespace test
diff --git a/tests/cpp/include/test_tune.h b/tests/cpp/include/test_tune.h
new file mode 100644
index 0000000..725aa90
--- /dev/null
+++ b/tests/cpp/include/test_tune.h
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file test_tune.h
+ * \brief operator tuning tester
+ * \author Chris Olivier
+*/
+
+#ifndef TEST_TUNE_H_
+#define TEST_TUNE_H_
+
+#include <sys/time.h>
+#include <dmlc/logging.h>
+#include <iomanip>
+#include <iostream>
+#include <atomic>
+#include <unordered_set>
+#include <unordered_map>
+#include <mutex>
+#include <vector>
+#include <utility>
+#include <algorithm>
+#include <string>
+#include <map>
+#include "../../src/operator/operator_tune-inl.h"
+#include "./test_util.h"
+#include "./test_op.h"
+#include "./test_core_op.h"
+
+namespace mxnet {
+namespace test {
+namespace tune {
+
+/*!
+ * \brief Tuning tests, which whether the correct tuning mode is selected by Auto
+ * \note This class makes no attempt at being performant (i.e. it does all sorts of slow
+ *       deep copies and that sort of thing), so don't insert any of thios code in the main
+ *       trunk unless you've verified the performance characteristics for that chunk of code
+ * \tparam DType Data type to test
+ */
+template<typename DType>
+class TuningTester {
+ public:
+  using kwargs_t = test::op::kwargs_t;
+
+  using bool_mode_pair = std::pair<bool, ::mxnet::op::tune::TuningMode>;
+
+  using shape_vect = std::vector<TShape>;
+  using shape_vec_to_bool_map = std::map<shape_vect, bool_mode_pair, test::less_shapevect>;
+
+ private:
+  using ShapesToPerfTimingMap =
+  std::map<shape_vect, test::perf::timing_map_t, test::less_shapevect>;
+
+  /*!
+   * \brief Run timing test on various data shapes and sizes
+   * \param isGPU true if the GPU should be used for the timing test
+   * \param op_kwargs operator parameters
+   * \param op_name The operator's registered name (with nnvm)
+   * \param backward_op_name The backward operator's registered name (with nnvm)
+   * \return ShapesToPerfTimingMap map holsing timing data for shapes
+   */
+  ShapesToPerfTimingMap RunCoreOpTimingTest(const bool isGPU,
+                                            const kwargs_t &op_kwargs,
+                                            const std::vector<shape_vect>& shapes,
+                                            const char *op_name,
+                                            const char *backward_op_name = "") {
+    ShapesToPerfTimingMap res;
+    const kwargs_t kwargs = test::op::CoreOpExecutor<DType>::ArgsWithOpName(
+      op_kwargs, op_name, backward_op_name);
+
+    // prime code and cache before the performance runs
+    test::op::CoreOperatorRunner<DType> runner;
+    runner.set_total_iterations(total_iterations_);
+    runner.set_verbose(false);
+    runner.RunBidirectional(false, {{10, 3, 18, 128}}, kwargs, 1);
+
+    // Do the performance runs
+    const char *pu = isGPU ? "GPU" : "CPU";
+    for (const std::vector<TShape> &this_run_shapes : shapes) {
+      test::perf::timing_map_t tmap = runner.TimingTest(std::string(op_name) + " Operator " + pu,
+                                                        isGPU, false, kwargs,
+                                                        0, calls_per_iteration_,
+                                                        this_run_shapes);
+      CHECK(res.find(this_run_shapes) == res.end());
+      res[this_run_shapes] = tmap;
+    }
+    return std::move(res);
+  }
+
+  using tuned_timing_t = std::map<
+    shape_vect,
+    std::map<::mxnet::op::tune::TuningMode, test::perf::timing_map_t>, test::less_shapevect>;
+
+  using modesort_t = std::multimap<double, ::mxnet::op::tune::TuningMode>;
+
+  /*!
+   * \brief Check if the tuning succeeded
+   * \param mode_sort modesort_t structure produced by 'CalculateModeSort'
+   * \param closeness_factor fraction of largest standard time (omp, no omp) which is an acceptable
+   *        range
+   * \return a pair <bool, TuningMode> consisting of true or false signifying if the test appears to
+   *         have made the correct decision, and the TuningMode which was closest in timing to
+   *         the Auto mode.
+   */
+  static bool_mode_pair CheckCorrectTuning(const modesort_t &mode_sort,
+                                           const double closeness_factor = 0.25) {
+    CHECK_EQ(mode_sort.size(), 3U);
+
+    // Determine fastest normal mode
+    ::mxnet::op::tune::TuningMode fastest_standard_mode = ::mxnet::op::tune::kAuto;
+    for (auto i = mode_sort.begin(), e = mode_sort.end(); i != e; ++i) {
+      if (i->second != ::mxnet::op::tune::kAuto) {
+        fastest_standard_mode = i->second;
+        break;
+      }
+    }
+    CHECK_NE(fastest_standard_mode, ::mxnet::op::tune::kAuto);
+
+    // We should be closest to the faster of kNeverOMP and kAlwaysOMP
+    // Take into account some variance, especially if kNeverOMP and kAlwaysOMP are close together
+    std::map<::mxnet::op::tune::TuningMode, double> mode2time;
+    for (auto i = mode_sort.begin(), e = mode_sort.end(); i != e; ++i) {
+      mode2time[i->second] = i->first;
+    }
+    const double time_auto = mode2time[::mxnet::op::tune::kAuto];
+    const double time_no_omp = mode2time[::mxnet::op::tune::kNeverOMP];
+    const double time_omp = mode2time[::mxnet::op::tune::kAlwaysOMP];
+
+    // Figure out which one we are closest to and return that to help in the analysis
+    ::mxnet::op::tune::TuningMode closest_to;
+    if (fabs(time_auto - time_no_omp) < fabs(time_auto - time_omp)) {
+      closest_to = ::mxnet::op::tune::kNeverOMP;
+    } else {
+      closest_to = ::mxnet::op::tune::kAlwaysOMP;
+    }
+
+    // If difference between OMP and no OMP is < closeness_factor of largest of the two,
+    // then we just want to make sure we are close to both of these
+    const double fastest_standard_time = std::min(time_no_omp, time_omp);
+    const double allowed_difference = closeness_factor * fastest_standard_time;
+    const double mustbe_asfast = fastest_standard_time + allowed_difference;
+
+    return { time_auto <= mustbe_asfast || closest_to == fastest_standard_mode,
+             closest_to };
+  }
+
+ public:
+  /*!
+   * \brief Given timing statistics, determine if 'Auto' mode made the correct choice.
+   * \param direction Compute direction for which to check (Forward or Backward)
+   * \param verbose If true, print the statistical info
+   * \return A map of shape vectors to a pair <bool, TuningMode> consisting of true or false
+   *         signifying if the test appears to have made the correct decision, and the TuningMode
+   *         which was closest in timing to the Auto mode.
+   */
+  shape_vec_to_bool_map CalculateModeSort(const test::op::TimingDirection direction,
+                                          bool verbose = true) const {
+    if (test::csv) {
+      verbose = false;
+    }
+    shape_vec_to_bool_map results;
+    // Incredibly inefficient method of grouping the results
+    for (const auto &i : timing_) {
+      // print shapes
+      const shape_vect &shapes = i.first;
+      if (verbose || test::csv) {
+        if (!test::csv) {
+          for (size_t x = 0, n = shapes.size(); x < n; ++x) {
+            const TShape &shape = shapes[x];
+            if (x) {
+              std::cout << ", ";
+            }
+            std::cout << shape;
+          }
+          const TShape &lhs_shape = shapes[0];
+          std::cout << " lhs=" << test::pretty_num(lhs_shape.Size()) << " items";
+          std::cout << "\t(" << TimingDirectionAsString(direction) << ")" << std::endl;
+        } else {
+          std::cout << test::pretty_num(shapes[0].Size()) << ",";
+        }
+      }
+      const auto &mode2timing = i.second;
+      modesort_t mode_sort;
+      for (const auto &j : mode2timing) {
+        const ::mxnet::op::tune::TuningMode mode = j.first;
+        const test::perf::timing_map_t &tm = j.second;
+        if (tm.find(direction) != tm.end()) {
+          const test::perf::TimingInstrument::Info &info = tm.find(direction)->second;
+          double duration = info.TimeEach();
+          mode_sort.insert({duration, mode});
+          if (test::csv) {
+            std::cout << TimingDirectionAsString(direction) << ","
+                      << ::mxnet::op::tune::TuningModeToString(mode) << ","
+                      << duration << ",";
+          }
+        }
+      }
+      if (test::csv) {
+        std::cout << std::endl << std::flush;
+      }
+      if (!mode_sort.empty()) {
+        // Now we have modes sorted by performance, fastest to slowest
+        const bool_mode_pair result = CheckCorrectTuning(mode_sort);
+        if (verbose && !test::csv) {
+          for (const auto &k : mode_sort) {
+            std::cout << "\t" << ::mxnet::op::tune::TuningModeToString(k.second)
+                      << ": " << k.first << " ms";
+            if (k.second == ::mxnet::op::tune::kAuto) {
+              std::cout << " (" << ::mxnet::op::tune::TuningModeToString(result.second) << ")";
+            }
+            std::cout << std::endl;
+          }
+          std::cout << std::flush;
+          if (!result.first) {
+            std::cout << "*** WARNING: Wrong OMP state selected ***" << std::endl << std::flush;
+          }
+        }
+        CHECK(results.find(shapes) == results.end()) << "Duplicate entry for set of shapes";
+        results[shapes] = result;
+      }
+    }
+    return std::move(results);
+  }
+
+  /*!
+   * \brief Perform execution runs for a given forward (and optionally backward) operator
+   * \param kwargs Parameters for the operator
+   * \param op_name Name by which the operator is registered with nnvm
+   * \param backward_op_name Backward operator name
+   */
+  void TestTunedOperator(const kwargs_t &kwargs,
+                         const bool verbose,
+                         const std::vector<shape_vect>& shapevec_vectors,
+                         const char *op_name,
+                         const char *backward_op_name = COREOP_BWD_OP_NAME_VALUE_NONE) {
+    timing_.clear();
+    using namespace mxnet::op;
+    tuned_timing_t timing;
+    for (int x = 0; x < 1; ++x) {
+      for (auto mode : {::mxnet::op::tune::kNeverOMP,
+                        ::mxnet::op::tune::kAuto,
+                        ::mxnet::op::tune::kAlwaysOMP
+                        }) {
+        if (verbose && !test::csv) {
+          std::cout << std::endl << ::mxnet::op::tune::TuningModeToString(mode)
+                    << std::endl << std::flush;
+        }
+
+        mxnet::op::OperatorTune<DType>::set_tuning_mode(mode);
+        const ShapesToPerfTimingMap shapes2perfmap = RunCoreOpTimingTest(false,
+                                                                         kwargs,
+                                                                         shapevec_vectors,
+                                                                         op_name,
+                                                                         backward_op_name);
+        for (const auto &item : shapes2perfmap) {
+          const shape_vect &shapes = item.first;
+          const test::perf::timing_map_t &tm = item.second;
+          timing_[shapes][mode] = tm;
+        }
+      }
+    }
+  }
+
+  /*!
+   * \brief Calculate the success rate of the run based upon Auto being close to the faster
+   *        OMP/non-OMP attempt
+   * \param modes List of directions to use in calculation (Forward, Backward). Empty list means all
+   * \param verbose Whether to print info
+   * \return Success rate ratio (#success/#TOTAL) (0.0-1.0)
+   */
+  float CalculateSuccessRate(std::vector<test::op::TimingDirection> directions = {},
+                             bool verbose = true) const {
+    size_t count = 0, success = 0;
+    if (directions.empty()) {
+      directions = {test::op::kForward, test::op::kBackward};
+    }
+    for (const test::op::TimingDirection direction : directions) {
+      typename test::tune::TuningTester<DType>::shape_vec_to_bool_map res_fwd =
+        CalculateModeSort(direction, verbose);
+      for (auto iter = res_fwd.begin(), e = res_fwd.end(); iter != e; ++iter) {
+        ++count;
+        if (iter->second.first) {
+          ++success;
+        }
+      }
+    }
+    if (count) {
+      return static_cast<float>(success) / static_cast<float>(count);
+    }
+    return 1.0f;  // nothing ventured, nothing failed (glass-is-half-full angle)
+  }
+
+  void set_calls_per_iteration(size_t calls_per_iterations) {
+    calls_per_iteration_ = calls_per_iterations;
+  }
+  size_t calls_per_iteration(size_t calls_per_iterations) const {
+    return calls_per_iteration_;
+  }
+  void set_total_iterations(size_t iterations) { total_iterations_ = iterations; }
+  size_t total_iterations(size_t iterations) const { return total_iterations_; }
+
+ private:
+  /*! \brief Number of iterations */
+  size_t          total_iterations_ = 10;
+  /*! \brief Calls per iteration */
+  size_t          calls_per_iteration_ = 50;
+  /*! \brief Raw timing data */
+  tuned_timing_t  timing_;
+};
+
+}  // namespace tune
+}  // namespace test
+}  // namespace mxnet
+
+#endif  // TEST_TUNE_H_
diff --git a/tests/cpp/include/test_util.h b/tests/cpp/include/test_util.h
index edfa2d0..8347a8a 100644
--- a/tests/cpp/include/test_util.h
+++ b/tests/cpp/include/test_util.h
@@ -44,6 +44,7 @@ extern bool unitTestsWithCuda;
 extern bool debug_output;
 extern bool quick_test;
 extern bool performance_run;
+extern bool csv;
 
 /*! \brief Pause VTune analysis */
 struct VTunePause {
@@ -672,16 +673,20 @@ struct less_shapevect {
 };
 
 inline std::string pretty_num(uint64_t val) {
-  std::string res, s = std::to_string(val);
-  size_t ctr = 0;
-  for (int i = static_cast<int>(s.size()) - 1; i >= 0; --i, ++ctr) {
-    if (ctr && (ctr % 3) == 0) {
-      res += ",";
+  if (!test::csv) {
+    std::string res, s = std::to_string(val);
+    size_t ctr = 0;
+    for (int i = static_cast<int>(s.size()) - 1; i >= 0; --i, ++ctr) {
+      if (ctr && (ctr % 3) == 0) {
+        res += ",";
+      }
+      res.push_back(s[i]);
     }
-    res.push_back(s[i]);
+    std::reverse(res.begin(), res.end());
+    return res;
+  } else {
+    return std::to_string(val);
   }
-  std::reverse(res.begin(), res.end());
-  return res;
 }
 
 /*! \brief Change a value during the scope of this declaration */
diff --git a/tests/cpp/operator/broadcast_perf.cc b/tests/cpp/operator/broadcast_perf.cc
deleted file mode 100644
index 5edba0b..0000000
--- a/tests/cpp/operator/broadcast_perf.cc
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  \file broadcast_perf.cc
- *  \brief Perf/profile run of broadcast kernel
- *  \author Chris Olivier
- */
-#include <gtest/gtest.h>
-#include <mxnet/tensor_blob.h>
-#include "../include/test_op_runner.h"
-#include "../include/test_core_op.h"
-
-using namespace mxnet;
-
-using kwargs_t = test::op::kwargs_t;
-
-/*!
- * \brief Generic bidirectional sanity test
- */
-TEST(BROADCAST_PERF, ExecuteBidirectional) {
-  test::op::BasicRunCoreOpBidirectional(false, true, {},
-                                        { {2, 3}, {2, 1} },
-                                        "broadcast_add", "_backward_broadcast_add");
-}
-
-static const std::vector<std::vector<TShape>> broadcast_shapes() {
-  std::vector<std::vector<TShape>> shapes;
-  if (test::performance_run) {
-    shapes = {
-      { {28,  28},  {28, 1} },
-      { {64,  28},  {1, 28} },
-      { {28,  28, 28},  {28, 28, 1} },
-      { {128, 128}, {1, 128} },
-      { {1024, 12, 256}, {1024, 1, 1} },
-      { {2560, 1280}, {2560, 1} }
-    };
-  } else {
-    shapes = {
-      // Non-performance dataset acts as a sanity test
-      { {28,  28},  {28, 1} },
-      { {128, 128}, {128, 1} },
-      { {28,  28, 28},  {28, 28, 1} }
-    };
-  }
-  return std::move(shapes);
-}
-
-template<typename DType = float>
-static void RunCoreOpTimingTest(const bool isGPU,
-                                const kwargs_t& op_kwargs,
-                                const char *op_name,
-                                const char *backward_op_name = "") {
-  const kwargs_t kwargs = test::op::CoreOpExecutor<DType>::ArgsWithOpName(
-    op_kwargs, op_name, backward_op_name);
-
-  // prime code and cache before the performance runs
-  test::op::CoreOperatorRunner<DType> runner;
-  runner.RunBidirectional(false, { {2, 3}, {2, 1} }, kwargs, 1);
-
-  // Do the performance runs
-  std::vector<std::vector<TShape>> shapes = broadcast_shapes();
-  const char *pu = isGPU ? "GPU" : "CPU";
-  for (const std::vector<TShape> &shape : shapes) {
-    runner.TimingTest(std::string(op_name) + " Operator " + pu, isGPU, false, kwargs,
-                      2, 10, shape);
-  }
-}
-
-/*!
- * \brief ActivationOp timing test for CPU
- */
-TEST(BROADCAST_PERF, TimingCPU) {
-  RunCoreOpTimingTest(false, {}, "broadcast_add", "_backward_broadcast_add");
-}
-
-#if MXNET_USE_CUDA == 1
-/*!
- * \brief ActivationOp timing test for GPU
- */
-TEST(BROADCAST_PERF, TimingGPU) {
-  RunCoreOpTimingTest(true, {}, "broadcast_add", "_backward_broadcast_add");
-}
-#endif  // MXNET_USE_CUDA == 1
-
diff --git a/tests/cpp/operator/tune/operator_tune_test.cc b/tests/cpp/operator/tune/operator_tune_test.cc
new file mode 100644
index 0000000..5ecb03c
--- /dev/null
+++ b/tests/cpp/operator/tune/operator_tune_test.cc
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <mxnet/tensor_blob.h>
+#include "../../src/operator/activation-inl.h"
+#include "../../src/operator/operator_tune-inl.h"
+#include "../include/test_op_runner.h"
+#include "../include/test_core_op.h"
+#include "../include/test_tune.h"
+
+using namespace mxnet;
+
+/*!
+ * \brief ActivationOp timing test for CPU
+ */
+TEST(OMP_TUNING, ShowAllTunedOps) {
+  const std::unordered_set<std::string>& op_names = op::OperatorTune<float>::TunedOperatorNames();
+  for (auto iter = op_names.begin(), e_iter = op_names.end(); iter != e_iter; ++iter) {
+    std::cout << *iter << std::endl;
+  }
+}
+
+using kwargs_t = test::op::kwargs_t;
+
+static std::vector<std::vector<TShape>> tuning_shapes() {
+  std::vector<std::vector<TShape>> shapes;
+  if (test::performance_run || test::csv) {
+    shapes = {
+      {{1,  1, 28,  28}},
+      {{1,  3, 28,  28}},
+      {{50, 1, 18,  32}},
+      {{25, 3, 64,  64}},
+      {{10, 3, 128, 128}},
+      {{20, 3, 128, 128}},
+      {{30, 3, 128, 128}},
+      {{30, 3, 256, 128}},
+    };
+  } else {
+    shapes = {
+      // Non-performance dataset acts as a sanity test
+      {{1,  1, 28, 28}},
+      {{50, 3, 18, 32}}
+    };
+  }
+  return std::move(shapes);
+}
+
+/*!
+ * \brief Generic bidirectional sanity test
+ */
+TEST(OMP_TUNING, ExecuteBidirectional) {
+  test::op::BasicRunCoreOpBidirectional(false, true, {}, {tuning_shapes()[0]},
+                                        "elemwise_add", "_backward_add");
+}
+
+/* Some test results:
+ * AWS c4.8xlarge:
+  Success rate for type float: 0.90278
+  Success rate for type double: 0.88889
+  Success rate for type mshadow::half::half_t: 0.83333
+  Success rate for type unsigned char: 0.86111
+  Success rate for type int: 0.95833
+  Success rate for type long: 0.88889
+ * desktop: 12-core (6 real CPU cores + hyperthreading)
+  Success rate for type float: 0.78125
+  Success rate for type double: 0.85417
+  Success rate for type mshadow::half::half_t: 0.84375
+  Success rate for type unsigned char: 0.80208
+  Success rate for type int: 0.94444
+  Success rate for type long: 1.00000
+ */
+
+/*!
+ * \brief Rune a tuning evaluation
+ * \tparam DType Data type for which to evaluate tuning
+ */
+template<typename DType>
+static float EvaluateTune(const bool verbose = true) {
+  std::vector<std::pair<std::string, std::string>> binary_operators;
+  if (test::csv) {
+    binary_operators = {
+      {"elemwise_add", COREOP_BWD_OP_NAME_VALUE_NONE}
+    };
+  } else if (test::performance_run) {
+    binary_operators = {
+      {"relu",         ""},  // Code can figure out what the backward op is for some
+      {"sigmoid",      ""},
+      {"sqrt",         ""},
+      {"elemwise_add", "_backward_add"},
+      {"elemwise_mul", "_backward_mul"},
+      {"elemwise_div", "_backward_div"}
+    };
+  } else {
+    binary_operators = {
+      {"elemwise_add", "_backward_add"}
+    };
+  }
+  std::vector<float> rates;
+  for (size_t i = 0, n = binary_operators.size(); i < n; ++i) {
+    test::tune::TuningTester<DType> tuningTester;
+    tuningTester.set_calls_per_iteration(10);
+    tuningTester.set_total_iterations(5);
+    std::cout << "******************************" << std::endl;
+    std::cout << "Operators: " << binary_operators[i].first
+              << ", " << binary_operators[i].second
+              << " for type: " << test::type_name<DType>()
+              << std::endl;
+    std::cout << "******************************" << std::endl;
+
+    // Do the performance runs
+    std::vector<std::vector<TShape>> shapes = tuning_shapes();
+
+    tuningTester.TestTunedOperator({}, verbose, shapes,
+                                   binary_operators[i].first.c_str(),
+                                   binary_operators[i].second.c_str());
+    rates.push_back(tuningTester.CalculateSuccessRate());
+  }
+  return std::accumulate(rates.begin(), rates.end(), 0.0f) / rates.size();
+}
+
+/*! \brief ActivationOp timing test for CPU for float */
+TEST(OMP_TUNING, EvaluateTuneTestFloat) {
+  typedef float DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for double */
+TEST(OMP_TUNING, EvaluateTuneTestDouble) {
+  typedef double DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for float16 */
+TEST(OMP_TUNING, EvaluateTuneTestFloat16) {
+  typedef mshadow::half::half_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for int8_t */
+TEST(OMP_TUNING, EvaluateTuneTestInt8) {
+  typedef uint8_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for int32_t */
+TEST(OMP_TUNING, EvaluateTuneTestInt32) {
+  typedef int32_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for int64_t */
+TEST(OMP_TUNING, EvaluateTuneTestInt64) {
+  typedef int64_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+
diff --git a/tests/cpp/test_main.cc b/tests/cpp/test_main.cc
index fff1ca2..a882b0b 100644
--- a/tests/cpp/test_main.cc
+++ b/tests/cpp/test_main.cc
@@ -36,7 +36,8 @@ static bool dumpCallback(const google_breakpad::MinidumpDescriptor& descriptor,
 }
 #endif
 
-namespace mxnet { namespace test {
+namespace mxnet {
+namespace test {
 bool unitTestsWithCuda = false;
 #ifdef NDEBUG
 bool debug_output = false;
@@ -45,7 +46,9 @@ bool debug_output = false;
 #endif
 bool quick_test = false;
 bool performance_run = false;
-}}
+bool csv = false;
+}  // namespace test
+}  // namespace mxnet
 
 #if MXNET_USE_CUDA
 
@@ -90,6 +93,8 @@ int main(int argc, char ** argv) {
       mxnet::test::debug_output = true;
     } else if (!strcmp(argv[x], "--perf")) {
       mxnet::test::performance_run = true;
+    } else if (!strcmp(argv[x], "--csv")) {
+      mxnet::test::csv = true;
     } else if (!strcmp(argv[x], "--quick") || !strcmp(argv[x], "-q")) {
       mxnet::test::quick_test = true;
     }

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].