You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/07 23:15:35 UTC

[GitHub] cjolivier01 closed pull request #8579: Automatic OMP operator tuning based upon kernel operation workload

cjolivier01 closed pull request #8579: Automatic OMP operator tuning based upon kernel operation workload
URL: https://github.com/apache/incubator-mxnet/pull/8579
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 539515b3a2..c79d5049e0 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -17,6 +17,7 @@ mxnet_option(USE_LAPACK           "Build with lapack support" ON IF NOT MSVC)
 mxnet_option(USE_MKL_IF_AVAILABLE "Use MKL if found" ON)
 mxnet_option(USE_MKLML_MKL        "Use MKLML variant of MKL (if MKL found)" ON IF USE_MKL_IF_AVAILABLE AND UNIX AND (NOT APPLE))
 mxnet_option(USE_MKL_EXPERIMENTAL "Use experimental MKL (if MKL enabled and found)" OFF)
+mxnet_option(USE_OPERATOR_TUNING  "Enable auto-tuning of operators" ON)
 mxnet_option(USE_GPERFTOOLS       "Build with GPerfTools support (if found)" ON)
 mxnet_option(USE_JEMALLOC         "Build with Jemalloc support"   ON)
 mxnet_option(USE_PROFILER         "Build with Profiler support"   OFF)
@@ -318,6 +319,10 @@ if(USE_PLUGINS_WARPCTC)
 	list(APPEND CUDA ${PLUGINS_CUSRC})
 endif()
 
+if(USE_OPERATOR_TUNING)
+  add_definitions(-DMXNET_USE_OPERATOR_TUNING=1)
+endif()
+
 if(USE_PLUGIN_CAFFE)
   if(NOT USE_CUDA)
     set(CPU_ONLY ON)
diff --git a/Makefile b/Makefile
index 8c7ae6e6fd..8659482f26 100644
--- a/Makefile
+++ b/Makefile
@@ -131,6 +131,10 @@ ifeq ($(USE_MKL2017), 1)
 	LDFLAGS +=  -liomp5
 endif
 
+ifeq ($(USE_OPERATOR_TUNING), 1)
+	CFLAGS += -DMXNET_USE_OPERATOR_TUNING=1
+endif
+
 # verify existence of separate lapack library when using blas/openblas/atlas
 # switch off lapack support in case it can't be found
 # issue covered with this
diff --git a/docs/faq/new_op.md b/docs/faq/new_op.md
index 55b7409ca2..994a2a6f82 100644
--- a/docs/faq/new_op.md
+++ b/docs/faq/new_op.md
@@ -339,7 +339,7 @@ NNVM_REGISTER_OP(_backward_abs)
 [](const NodeAttrs& attrs){
   return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
 })
-.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, unary_bwd<mshadow_op::sign> >);
+.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, backward_grad<mshadow_op::sign> >);
 ```
 
 ### Legacy Operators
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index 4048d5a1a3..4c2314e176 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -267,11 +267,6 @@ class MXNET_API Engine {
     }
     read_vars->resize(rtop - read_vars->begin());
   }
-
-  /*! \brief Return the number of OMP threads that should be used per worker
-   * \return Number of OMP threads that should be used per worker
-   */
-  virtual int num_omp_threads_per_worker() const = 0;
 };  // class Engine
 #endif  // DMLC_USE_CXX11
 }  // namespace mxnet
diff --git a/make/config.mk b/make/config.mk
index d47d4d6931..dd5314f728 100644
--- a/make/config.mk
+++ b/make/config.mk
@@ -147,6 +147,12 @@ LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server
 # sudo apt-get install -y libcurl4-openssl-dev
 USE_S3 = 0
 
+#----------------------------
+# performance settings
+#----------------------------
+# Use operator tuning
+USE_OPERATOR_TUNING = 1
+
 # Use gperftools if found
 USE_GPERFTOOLS = 1
 
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 7e3554ab1c..4d63749f82 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -188,13 +188,6 @@ class NaiveEngine final : public Engine {
     shutdown_phase_.store(true);
   }
 
-  /*! \brief Return the number of OMP threads that should be used per worker
-   * \return Number of OMP threads that should be used per worker
-   */
-  int num_omp_threads_per_worker() const override {
-    return OpenMP::Get()->GetRecommendedOMPThreadCount();
-  }
-
  private:
   // callback to oncomplete
   static void OnComplete(Engine *engine, void *param) {
diff --git a/src/engine/openmp.cc b/src/engine/openmp.cc
index be7885ba75..ad0c5740ec 100644
--- a/src/engine/openmp.cc
+++ b/src/engine/openmp.cc
@@ -53,7 +53,7 @@ OpenMP::OpenMP()
       omp_set_num_threads(omp_thread_max_);
     } else {
       omp_thread_max_ = omp_get_max_threads();
-  }
+    }
   }
   omp_set_nested(dmlc::GetEnv("OMP_NESTED", false));
   omp_set_dynamic(dmlc::GetEnv("OMP_DYNAMIC", false));
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 3cf6653712..e000a22c22 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -297,25 +297,6 @@ class ThreadedEngine : public Engine {
     finished_cv_.notify_all();
   }
 
-  /*! \brief Return default OMP thread count. Currently, this is whatever OMP shows as number
-   * of procs
-   * \warning Do not call this in any performance-sensitive use-case since checking the environment
-   * is slow
-   */
-  static int DefaultOMPThreadsPerWorker() {
-#ifdef _OPENMP
-    // If OMP_NUM_THREADS is set, use omp_get_max_threads(), which will be the value
-    // interpreted by the implemetation from the OMP_NUM_THREADS environment variable.
-    // Otherwise, return the number of processors, not counting hyperthreading.
-    // Test for set OMP_NUM_THREADS by checking against some nonsensical value
-    const int max_threads = dmlc::GetEnv("OMP_NUM_THREADS", INT_MIN) == INT_MIN ?
-                            omp_get_num_procs() : omp_get_max_threads();
-    return max_threads;
-#else
-    return 1;
-#endif
-  }
-
  protected:
   /*!
    * \brief Push the opr block to execution queue to be executed.
@@ -383,13 +364,6 @@ class ThreadedEngine : public Engine {
     }
   }
 
-  /*! \brief Return the number of OMP threads that should be used per worker
-   * \return Number of OMP threads that should be used per worker
-   */
-  int num_omp_threads_per_worker() const override {
-    return OpenMP::Get()->GetRecommendedOMPThreadCount();
-  }
-
  private:
   /*!
    * \brief check if thee is duplication in const_vars and mutable_vars.
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 04db326496..cc6a95d6eb 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -29,6 +29,7 @@
 #include "math.h"
 #include "math_functions-inl.h"
 #include "special_functions-inl.h"
+#include "./mxnet_op.h"
 
 #ifdef __CUDACC__
 #include <cuda_fp16.h>
@@ -38,6 +39,24 @@ namespace mxnet {
 namespace op {
 namespace mshadow_op {
 
+/*!
+ * \brief Use the 'MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD' macro outside of the mshadow_op namespace
+ *        See mxnet_op.h for a description of 'MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD'
+ *
+ * \note An entry for the operator must also be added in operator_tune.cc, which will register it
+ *       for auto-tuning and also hold its workload weight
+ */
+#define MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(__op$) \
+  } MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow_op::__op$) namespace mshadow_op {  // NOLINT(*)
+/*!
+ * \brief Use the 'MXNET_TUNABLE_MSHADOW_OP_BACKWARD' macro outside of the mshadow_op namespace
+ *        See mxnet_op.h for a description of 'MXNET_TUNABLE_MSHADOW_OP_BACKWARD'
+ *
+ * \note An entry for the operator must also be added in operator_tune.cc, which will register it
+ *       for auto-tuning and also hold its workload weight
+ */
+#define MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(__op$) \
+  }  MXNET_TUNABLE_MSHADOW_OP_BACKWARD(mshadow_op::__op$) namespace mshadow_op {  // NOLINT(*)
 #ifdef __CUDA_ARCH__
 __constant__ const float PI = 3.14159265358979323846;
 #else
@@ -48,36 +67,41 @@ using std::enable_if;
 using std::is_unsigned;
 
 #define MXNET_UNARY_MATH_OP(name, expr) \
-struct name { \
-  template<typename DType> \
-  MSHADOW_XINLINE static DType Map(DType a) { \
-    return DType(expr); \
-  } \
-}
+  struct name { \
+    template<typename DType> \
+    MSHADOW_XINLINE static DType Map(DType a) { \
+      return DType(expr); \
+    } \
+  }; \
+  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
+
 
 #define MXNET_UNARY_MATH_OP_NC(name, expr) \
-struct name { \
-  template<typename DType> \
-  MSHADOW_XINLINE static DType Map(DType a) { \
-    return (expr); \
-  } \
-}
+  struct name { \
+    template<typename DType> \
+    MSHADOW_XINLINE static DType Map(DType a) { \
+      return (expr); \
+    } \
+  }; \
+  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
 
 #define MXNET_BINARY_MATH_OP(name, expr) \
-struct name { \
-  template<typename DType> \
-  MSHADOW_XINLINE static DType Map(DType a, DType b) { \
-    return DType(expr); \
-  } \
-}
+  struct name { \
+    template<typename DType> \
+    MSHADOW_XINLINE static DType Map(DType a, DType b) { \
+      return DType(expr); \
+    } \
+  }; \
+  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
 
 #define MXNET_BINARY_MATH_OP_NC(name, expr) \
-struct name { \
-  template<typename DType> \
-  MSHADOW_XINLINE static DType Map(DType a, DType b) { \
-    return (expr); \
-  } \
-}
+  struct name { \
+    template<typename DType> \
+    MSHADOW_XINLINE static DType Map(DType a, DType b) { \
+      return (expr); \
+    } \
+  }; \
+  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
 
 #define MXNET_SIMPLE_UNARY_MATH_OP(name) MXNET_UNARY_MATH_OP(name, math::name(a))
 
@@ -133,6 +157,7 @@ struct softrelu {
     }
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(softrelu)
 
 MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a));
 
@@ -153,6 +178,7 @@ struct log10_grad {
     return DType(0.4342944819f / static_cast<float>(a));
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(log10_grad)
 
 template<>
 MSHADOW_XINLINE double log10_grad::Map<double>(double a) {
@@ -168,6 +194,7 @@ struct log2_grad {
     return DType(1.442695041f / static_cast<float>(a));
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(log2_grad)
 
 template<>
 MSHADOW_XINLINE double log2_grad::Map<double>(double a) {
@@ -262,6 +289,7 @@ struct sign {
     return DType(0);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(sign)
 
 MXNET_UNARY_MATH_OP_NC(sign_grad, DType(0));
 
@@ -332,6 +360,7 @@ struct rint {
     return DType((af - floor) <= (ceil - af) ? floor : ceil);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(rint)
 
 /*! \brief used to round number to integer nearest to 0 */
 struct fix {
@@ -342,6 +371,7 @@ struct fix {
     return DType((floor > 0 ? floor : -floor) < (ceil > 0 ? ceil : -ceil) ? floor : ceil);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(fix)
 
 /*! \brief used for generate gradient of MAE loss*/
 MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1));
@@ -404,6 +434,7 @@ struct mod {
     }
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(mod)
 
 template<>
 MSHADOW_XINLINE mshadow::half::half2_t mod::Map<mshadow::half::half2_t>
@@ -418,6 +449,8 @@ struct mod_grad {
     return DType(0);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(mod_grad)
+
 template<>
 MSHADOW_XINLINE double mod_grad::Map<double>(double a, double b) {
   return 1.0;
@@ -453,6 +486,8 @@ struct mod_rgrad {
     return DType(0);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(mod_rgrad)
+
 template<>
 MSHADOW_XINLINE double mod_rgrad::Map<double>(double a, double b) {
   return -::floor(a/b);
@@ -516,6 +551,7 @@ struct rmod {
     }
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(rmod)
 
 template<>
 MSHADOW_XINLINE mshadow::half::half2_t rmod::Map<mshadow::half::half2_t>
@@ -530,6 +566,8 @@ struct rmod_grad {
     return DType(0);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(rmod_grad)
+
 template<>
 MSHADOW_XINLINE double rmod_grad::Map<double>(double a, double b) {
   return -::floor(b/a);
@@ -571,6 +609,7 @@ struct clip {
     }
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(clip)
 
 /***** gamma ******/
 
@@ -584,6 +623,7 @@ struct gamma_grad {
     return DType(math::tgamma(af) * special_functions::cephes::psi<float>(af));
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(gamma_grad)
 
 template<>
 MSHADOW_XINLINE double gamma_grad::Map<double>(double a) {
@@ -601,6 +641,7 @@ struct gammaln_grad {
     return DType(special_functions::cephes::psi<float>(a));
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(gammaln_grad)
 
 template<>
 MSHADOW_XINLINE double gammaln_grad::Map<double>(double a) {
@@ -632,6 +673,7 @@ struct smooth_l1_loss {
     }
   }
 };  // struct smooth_l1_loss
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(smooth_l1_loss)
 
 /* The derivative of smooth l1 loss is
  * f'(x) = sigma^2 * x, |x| < 1 / sigma^2
@@ -653,6 +695,7 @@ struct smooth_l1_gradient {
     }
   }
 };  // struct smooth_l1_derivative
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(smooth_l1_gradient)
 
 /*! \brief product reducer */
 struct product {
@@ -754,6 +797,7 @@ struct nansum_grad {
     return isnan_typed::IsNan(a) ? DType(0) : DType(1);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(nansum_grad)
 
 /*! \brief product reducer that ignores NaN values in the input */
 struct nanprod {
@@ -790,7 +834,7 @@ struct nanprod_grad {
     return isnan_typed::IsNan(a) ? DType(0) : b / a;
   }
 };
-
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(nanprod_grad)
 }  // namespace mshadow_op
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 06e2393524..4acbc49a11 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -25,11 +25,14 @@
 #ifndef MXNET_OPERATOR_MXNET_OP_H_
 #define MXNET_OPERATOR_MXNET_OP_H_
 
+#include <cxxabi.h>
 #include <dmlc/omp.h>
 #include <mxnet/base.h>
-#include <mxnet/engine.h>
 #include <mxnet/op_attr_types.h>
 #include <algorithm>
+#include <string>
+#include "./operator_tune.h"
+#include "../engine/openmp.h"
 
 #ifdef __CUDACC__
 #include "../common/cuda_utils.h"
@@ -288,55 +291,84 @@ struct op_with_req {
   }
 };
 
-/*!
- * \brief Set to immediate scalar value kernel
- * \tparam val Scalar immediate
- */
-template<int val>
-struct set_to_int {
-  // mxnet_op version (when used directly with Kernel<>::Launch()) */
-  template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out) {
-    out[i] = DType(val);
-  }
-  // mshadow_op version (when used with op_with_req<>)
-  MSHADOW_XINLINE static int Map() {
-    return val;
-  }
-};
 
-/*! \brief Special-case kernel shortcut for setting to zero */
-using set_zero = set_to_int<0>;
+/*! \brief Kernel operator wrapper used for tuning data */
+template<typename Operation, typename DType>
+struct tuned_op : public Operation {
+  static size_t workload_;       // nanos per operation * Tuner's WORKLOAD_COUNT
+  // the decision implementation
+  // TODO(cjolivier01): For more complex kernels, add a shape parameter version (diff LaunchEx)
+  static int UseOMP(size_t N, size_t thread_count);
+};
 
 template<typename OP, typename xpu>
 struct Kernel;
 
-
 template<typename OP>
 struct Kernel<OP, cpu> {
+  /*! \brief Launch CPU kernel */
   template<typename ...Args>
-  inline static void Launch(mshadow::Stream<cpu> *s, const int N, Args... args) {
+  inline static void Launch(mshadow::Stream<cpu> *, const int N, Args... args) {
 #ifdef _OPENMP
-    const int omp_cores = Engine::Get()->num_omp_threads_per_worker();
-    if (omp_cores <= 1) {
+    const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+    if (omp_threads < 2) {
       // Zero means not to use OMP, but don't interfere with external OMP behavior
       for (int i = 0; i < N; ++i) {
         OP::Map(i, args...);
       }
     } else {
-      #pragma omp parallel for num_threads(omp_cores)
+      #pragma omp parallel for num_threads(omp_threads)
       for (int i = 0; i < N; ++i) {
         OP::Map(i, args...);
       }
     }
 #else
     for (int i = 0; i < N; ++i) {
+      OP::Map(i, args...);
+    }
+#endif
+  }
+
+  /*! \brief Launch CPU kernel which has OMP tuning data available.
+   * When using this for a new kernel op, add declaration and tuning objects to
+   * operator_tune.cc
+   */
+  template<typename BasicOperation, typename DType, typename ...Args>
+  static void LaunchEx(mshadow::Stream<cpu> *, const int N, DType *dest, Args... args) {
+#ifdef _OPENMP
+    const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+    if (omp_threads < 2 || !tuned_op<BasicOperation, DType>::UseOMP(N, omp_threads)) {
+      // Zero means not to use OMP, but don't interfere with external OMP behavior
+      for (int i = 0; i < N; ++i) {
+        OP::Map(i, dest, args...);
+      }
+    } else {
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        OP::Map(i, dest, args...);
+      }
+    }
+#else
+    for (int i = 0; i < N; ++i) {
         OP::Map(i, args...);
     }
 #endif
   }
-};
 
+  /*! \brief Launch mshadow_op-type op (i.e. DType (*)( ... ) { return <operation> } */
+  template<typename ...Args>
+  MSHADOW_CINLINE static void LaunchMShadowOpEx(mshadow::Stream<cpu> *s,
+                                                const int N, Args... args) {
+    Kernel<OP, cpu>::template LaunchEx<typename OP::Operation>(s, N, args...);
+  }
+
+  /*! \brief Launch mxnet_op-type op (i.e. void (*)(int N, *out, ... ) */
+  template<typename ...Args>
+  MSHADOW_CINLINE static void LaunchMXNetOpEx(mshadow::Stream<cpu> *s,
+                                              const int N, Args... args) {
+    Kernel<OP, cpu>::template LaunchEx<OP>(s, N, args...);
+  }
+};
 
 #ifdef __CUDACC__
 template<typename OP, typename ...Args>
@@ -348,6 +380,7 @@ __global__ void mxnet_generic_kernel(int N, Args... args) {
 
 template<typename OP>
 struct Kernel<OP, gpu> {
+  /*! \brief Launch GPU kernel */
   template<typename ...Args>
   inline static void Launch(mshadow::Stream<gpu> *s, int N, Args... args) {
     using namespace mshadow::cuda;
@@ -356,10 +389,53 @@ struct Kernel<OP, gpu> {
       <<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
         N, args...);
   }
+  /*! \brief For GPU, LaunchEx redirects directly to the normal Launch */
+  template<typename ...Args>
+  MSHADOW_CINLINE static void LaunchEx(mshadow::Stream<gpu> *s, const int N, Args... args) {
+    Launch(s, N, args...);
+  }
 };
 #endif  // __CUDACC__
 
+/*!
+ * \brief Set to immediate scalar value kernel
+ * \tparam val Scalar immediate
+ */
+template<int val>
+struct set_to_int {
+  // mxnet_op version (when used directly with Kernel<>::Launch()) */
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType *out) {
+    out[i] = DType(val);
+  }
+  // mshadow_op version (when used with op_with_req<>)
+  MSHADOW_XINLINE static int Map() {
+    return val;
+  }
+};
+
+/*!
+ * \brief Special-case kernel shortcut for setting to zero and one
+ */
+using set_zero = set_to_int<0>;
+using set_one  = set_to_int<1>;
+_MXNET_TUNABLE_MXNET_OP_FWD(set_zero);
+_MXNET_TUNABLE_MXNET_OP_FWD(set_one);
+
 }  // namespace mxnet_op
+
+
+/*!
+ * \brief Tuning specializations for the simple ops in <mshadow/base.h>
+ */
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::identity)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::plus)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::minus)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::mul)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::div)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::right)
+
 }  // namespace op
 }  // namespace mxnet
+
 #endif  // MXNET_OPERATOR_MXNET_OP_H_
diff --git a/src/operator/operator_tune-inl.h b/src/operator/operator_tune-inl.h
new file mode 100644
index 0000000000..4a69b0b926
--- /dev/null
+++ b/src/operator/operator_tune-inl.h
@@ -0,0 +1,873 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef MXNET_OPERATOR_OPERATOR_TUNE_INL_H_
+#define MXNET_OPERATOR_OPERATOR_TUNE_INL_H_
+
+#include <dmlc/base.h>
+#include <dmlc/logging.h>
+#include <mshadow/base.h>
+#include <cxxabi.h>
+#include <atomic>
+#include <cstdint>
+#include <chrono>
+#include <thread>
+#include <string>
+#include <vector>
+#include <algorithm>
+#include <list>
+#include <random>
+#include <unordered_set>
+#include "./mxnet_op.h"
+
+namespace mxnet {
+namespace op {
+
+#ifndef MXNET_NO_INLINE
+#ifdef _MSC_VER
+#define MXNET_NO_INLINE __declspec(noinline)
+#else
+#define MXNET_NO_INLINE __attribute__((noinline))
+#endif
+#endif  // MXNET_NO_INLINE
+
+#define OUTSIDE_COUNT_SHIFT    9
+#define WORKLOAD_COUNT_SHIFT  11
+
+namespace tune {
+
+/*!
+ * \brief Tuning mode for registered kernel operators
+ */
+enum TuningMode {
+  Auto,         // Based upon tuning data, choose whether to use OMP for kernel CPU Launch() loops
+  NeverOMP,     // Don't use OMP for parallelism (legacy behavior for GPU builds)
+  AlwaysOMP     // Don't use OMP for parallelism (legacy behavior for CPU builds)
+};
+
+/*!
+ * \brief Convert TuningMode value to a string representation
+ * \param tm  Scalar TuningMode value
+ * \return Character pointer to a string representing the TuningMode value
+ */
+inline const char *TuningModeToString(const TuningMode tm) {
+  switch (tm) {
+    case Auto:
+      return "Auto";
+    case NeverOMP:
+      return "NeverOMP";
+    case AlwaysOMP:
+      return "AlwaysOMP";
+    default:
+      CHECK(false) << "Unknown TuningMode type: " << static_cast<int>(tm);
+      return "<unknown>";
+  }
+}
+}  // namespace tune
+
+/*!
+ * \brief Shared data for all data types being tuned, acts as a base class for the higher-level
+ *        templated tunin classes
+ */
+class OperatorTuneBase {
+ protected:
+  typedef unsigned int duration_t;
+  /*! \brief Have calculated omp_overhead_ yet? */
+  static std::atomic<bool> calculated_;
+  /*! \brief Time in nanoseconds for OMP overhead */
+  static duration_t omp_overhead_;
+  /*! \brief Output insertable (into code) instantiation+default-value macros */
+  static bool output_tuning_data_;
+  /*! \brief Print debug/trace output for tuning info */
+  static bool verbose_tuning_info_;
+  /*! \brief Tuning scale factor */
+  static double tuning_weight_scale_;
+  /*! \brief Enable auto-tune (ie retune at startup) rather than use stored data */
+  static bool enable_auto_tune_;
+};
+
+/*!
+ * \brief Engine to tune kernel operations
+ * \tparam DType Data type to be used when tuning the kernel operations
+ * \remarks The basic concept here is that we time how long a trivial loop takes with and without
+ * OMP, subtracting the non-OMP run from the OMP run, which gives us the time
+ * that the OMP overhead takes.  Times were found to be relatively invariant with
+ * regard ot the number of threads/cores on a given machine.
+ * Secondly, supplied operators are run and timed (for each data type) in order to determine
+ * their individual time cost.
+ *
+ * Knowing the following items, we can determine how long the OMP and non-OMP run
+ * is expected to take:
+ *  1) OMP overhead time
+ *  2) Number of iterations required
+ *  3) Number of threads to be used if we choose the OMP method
+ *  4) The data type
+ *
+ * Therefore, at Kernel::Launch() time, we can estimate whether it is faster to use OMP or not
+ * for the given kernel operator.
+ *
+ * Results and efficiency of the tuning is tested in the gtest OMP_TUNING test suite
+ */
+template<typename DType>
+class OperatorTune : public OperatorTuneBase {
+ public:
+  typedef std::chrono::high_resolution_clock::time_point Tick;
+
+  /*!
+   * \brief Constructor
+   */
+  OperatorTune() {
+    TuneAll();
+  }
+
+  /*!
+   * \brief Initialize the OperatorTune object
+   * \return Whether the OperatorTune object was successfully initialized
+   */
+  static bool Initialize() {
+    if (!initialized_) {
+      initialized_ = true;
+      // Generate some random data for calling the operator kernels
+      data_set_.reserve(0x100);
+      std::random_device rd;
+      std::mt19937 gen(rd());
+      if (!std::is_integral<DType>::value) {
+        std::uniform_real_distribution<> dis(-1, 1);
+        for (int n = 0; n < 0x100; ++n) {
+          const auto val = static_cast<DType>(dis(gen));
+          // If too close to zero, try again
+          if (fabs(val) < 1e-5) {
+            --n;
+            continue;
+          }
+          data_set_.emplace_back(val);
+        }
+      } else {
+        std::uniform_int_distribution<> dis(-128, 127);
+        for (int n = 0; n < 0x100; ++n) {
+          const auto val = static_cast<DType>(dis(gen));
+          // If zero, try again
+          if (!val) {
+            --n;
+            continue;
+          }
+          data_set_.emplace_back(val);
+        }
+      }
+      // Use this environment variable to generate new tuning statistics
+      output_tuning_data_ = dmlc::GetEnv("MXNET_OUTPUT_TUNING_DATA", false);
+      // If outputting tuning data, then also output verbose logging info
+      verbose_tuning_info_ = dmlc::GetEnv("MXNET_VERBOSE_TUNING_INFO", false);
+
+      tuning_weight_scale_ = dmlc::GetEnv("MXNET_TUNING_WEIGHT_SCALE", 0.0);
+
+      // This isn't actually supposed to be multithreaded init, but just to be sure the change is
+      // seen everywhere, using atomic bool.
+      if (!OperatorTuneBase::calculated_.load()) {
+        // Not especially concerned with a race condition, since this hsould
+        // run when only one thread is active (static init), just don't cache this variable
+        OperatorTuneBase::calculated_.store(true);
+        OperatorTuneBase::enable_auto_tune_ = dmlc::GetEnv("MXNET_ENABLE_OPERATOR_AUTOTUNE", true);
+        OperatorTuneBase::omp_overhead_ = GetOMPLoopOverhead();
+        std::string config = dmlc::GetEnv("MXNET_USE_OPERATOR_TUNING", std::string());
+        ParseEnablerConfig(config);
+      }
+
+      if (verbose_tuning_info_) {
+        LOG(INFO) << "OMP overhead: " << omp_overhead_ << " nanoseconds";
+      }
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Get duration in nanoseconds between the given 'since' value and now
+   * \param since Reference time which to calculate the duration
+   * \return Duration in nanoseconds between the given 'since' value and now
+   */
+  static MSHADOW_CINLINE duration_t GetDurationInNanoseconds(const Tick &since) {
+    return GetDurationInNanoseconds(since, std::chrono::high_resolution_clock::now());
+  }
+
+  /*!
+   * \brief Schedule a tuning run
+   * \tparam OP Operator to tune
+   * \param tune_func Function to call which tunes the operator
+   * \return true if the tune operation was scheduled
+   */
+  template<typename OP>
+  static bool ScheduleTune(void (*tune_func)()) {
+#ifdef MXNET_USE_OPERATOR_TUNING
+    if (tune_func) {
+      GetTuningList()->push_back(tune_func);
+      operator_names_.insert(demangle(typeid(OP).name()));
+      return true;
+    }
+    return false;
+#else
+    return true;
+#endif
+  }
+
+  /*!
+   * \brief Is the template parameter type a tuned kernel?
+   * \tparam OP kernel operator type
+   * \return true if the operator/kernel is tuned
+   */
+  template<typename OP>
+  static bool IsTuned() {
+    return operator_names_.find(demangle(typeid(OP).name())) != operator_names_.end();
+  }
+
+  /*!\
+   * \brief Tune all registered kernel operators that haven't already been tuned
+   */
+  static void TuneAll() {
+    Initialize();
+    if (OperatorTuneBase::enable_auto_tune_) {
+      std::list<void (*)()> *tl = GetTuningList();
+      const size_t size_save = tl->size();  // For checking if anything asynchronous is
+      // adding or removing items, which is forbidden
+      if (output_tuning_data_ && !tl->empty()) {
+        // Only emit this once, use the most common case, 'float32'
+        if (mshadow::DataType<DType>::kFlag == mshadow::kFloat32) {
+          std::cout << "OperatorTuneBase::duration_t "
+                    << "OperatorTuneBase::omp_overhead_ = " << omp_overhead_
+                    << ";" << std::endl << std::flush;
+        }
+      }
+      const Tick start = std::chrono::high_resolution_clock::now();
+      for (auto i : *tl) {
+        (*i)();
+      }
+      if (verbose_tuning_info_) {
+        const duration_t duration = OperatorTune::GetDurationInNanoseconds(start);
+        LOG(INFO) << "Op Tuning  for " << type_name<DType>()
+                  << " took " << (duration / 1000000) << " ms";
+      }
+      CHECK_EQ(size_save, tl->size()) << "Tuning list size should not have changed while tuning";
+      tl->clear();
+    }
+  }
+
+  /*!
+   * \brief Return set of operator names that were registered to be tuned. Does not imply
+   *        that the operator has been tuned.
+   * \return Set of operator/kernel names that were registered for tuning
+   */
+  static const std::unordered_set<std::string>& TunedOperatorNames() {
+    return operator_names_;
+  }
+
+  /*!
+   * \brief Set tuning mode
+   * \param tuning_mode The tune::TuningMode tuning mode value to set
+   */
+  static MSHADOW_CINLINE void set_tuning_mode(const tune::TuningMode tuning_mode) {
+    // Use const_cast to get past "assigning non-volatile to volatile warning */
+    const_cast<tune::TuningMode&>(tuning_mode_) = tuning_mode;
+  }
+
+  /*!
+   * \brief Get the current tuning mode
+   * \return tune::TuningMode value for the current tuning mode
+   */
+  static MSHADOW_CINLINE tune::TuningMode tuning_mode() {
+    return const_cast<tune::TuningMode&>(tuning_mode_);
+  }
+
+ protected:
+  /*!
+   * \brief Get the list of tuning function calls for the operators
+   * \return Pointer to list of tuning function calls
+   */
+  static std::list<void (*)()> *GetTuningList();
+
+  /*!
+   * \brief Get duration in nanoseconds
+   * \param t1 Start time tick
+   * \param t2 End time tick
+   * \return duration in nanoseconds between t1 and t2
+   */
+  static MSHADOW_CINLINE duration_t GetDurationInNanoseconds(const Tick &t1, const Tick &t2) {
+    return static_cast<duration_t>(
+      std::chrono::duration_cast<std::chrono::nanoseconds>(t2 - t1).count());
+  }
+
+  /*!
+   * \brief Demangle typeid::name() in order to generate source macros
+   * \param name C++ Mangled name
+   * \return Demangled name as string
+   */
+  static inline std::string demangle(const char *name) {
+    int status = -4;  // some arbitrary value to eliminate the compiler warning
+    std::unique_ptr<char, void (*)(void *)> res{
+      abi::__cxa_demangle(name, nullptr, nullptr, &status),
+      &std::free
+    };
+    return status ? name : res.get();
+  }
+
+  /*!
+   * \brief Type name as string
+   * \tparam T Type
+   * \return std::string representing the human-readable demangled type name
+   */
+  template<typename T> static inline std::string type_name() {
+    return demangle(typeid(T).name());
+  }
+
+  /*! \brief Measure OMP overhead for a trivial OMP loop using all cores
+   * \param omp_thread_count - Number of OMP threads to use in the timing test
+   * \returns Duration in nanoseconds for the OMP overhead (time to initiate and close the
+   *          OMP session)
+   */
+  static duration_t GetOMPLoopOverhead(const size_t omp_thread_count) {
+    CHECK_GT(omp_thread_count, 1);  // Don't try to use OMP for one thread
+    int wl_count = WORKLOAD_COUNT;
+
+    Tick start = std::chrono::high_resolution_clock::now();
+    // Use two loops in order to simulate OMP outside timing
+    for (size_t i = 0; i < OUTSIDE_COUNT; ++i) {
+      for (int x = 0; x < wl_count; ++x) {
+        // trivial operation
+        volatile_int_ += x;
+      }
+    }
+    const duration_t no_omp_duration = GetDurationInNanoseconds(start);
+
+    // Scale OMP iterations by type calculation complexity
+    double factor;
+
+    // if tuning_weight_scale_ is a number that looks valid, use it as the factor
+    if (tuning_weight_scale_ > 0.01) {
+      factor = tuning_weight_scale_;
+    } else {
+      // These are empirically-determined constants found by balancing between
+      // a desktop (8 & 12 cpu's) and large cloud instances (32 & 64 cpu's)
+      switch (mshadow::DataType<DType>::kFlag) {
+        case mshadow::kUint8:
+        case mshadow::kInt8:
+          factor = 8.5;
+          break;
+        case mshadow::kInt32:
+          factor = 4.5;
+          break;
+        case mshadow::kInt64:
+          factor = 2;
+          break;
+        case mshadow::kFloat64:
+          factor = 1.25;
+          break;
+        case mshadow::kFloat32:
+        default:
+          factor = 1.0;
+          break;
+      }
+    }
+
+    wl_count = static_cast<int>(factor * WORKLOAD_COUNT * omp_thread_count);
+    start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < OUTSIDE_COUNT; ++i) {
+      #pragma omp parallel for num_threads(omp_thread_count)
+      for (int x = 0; x < wl_count; ++x) {
+        // trivial operation
+        volatile_int_ += x;
+      }
+    }
+    const duration_t omp_duration = GetDurationInNanoseconds(start) - no_omp_duration;
+    return omp_duration >> OUTSIDE_COUNT_SHIFT;
+  }
+
+  /*! \brief Measure OMP overhead for a trivial OMP loop using all cores
+   * \returns Time in nanoseconds to initialize/cleanup when excuting an OMP block
+   */
+  static duration_t GetOMPLoopOverhead() {
+    // It was found empirically that OMP times was not heavily tied to number of cores,
+    // so take an average across all core counts
+    const auto max_cores = static_cast<size_t>(omp_get_num_procs()) >> 1;
+    if (max_cores >= 2) {
+      std::vector<duration_t> core_times;
+      // Take care of any OMP lazy-init with a throwaway call
+      for (size_t omp_threads = 2; omp_threads <= max_cores; ++omp_threads) {
+        GetOMPLoopOverhead(omp_threads);
+      }
+      std::vector<duration_t> durations;
+      durations.reserve(max_cores - 1);
+      for (size_t omp_threads = 2; omp_threads <= max_cores; ++omp_threads) {
+        const duration_t duration = GetOMPLoopOverhead(omp_threads);
+        if (verbose_tuning_info_) {
+          LOG(INFO) << "OMP Thread Count: " << omp_threads << ", overhead: " << duration << " ns";
+        }
+        durations.emplace_back(duration);
+      }
+      // return median
+      std::sort(durations.begin(), durations.end());
+      return durations[durations.size() >> 1];
+    }
+    return INT_MAX;  // If only one core, then never use OMP (say the overhead is huge)
+  }
+
+  /*!
+   * \brief Some string utility functions that aren't specific to tuning
+   */
+  struct StringUtil {
+    /*!
+     * \brief Terim whitespace from beninning and end of string
+     * \param s String to trimp
+     * \return reference to the modified string. This is the same std::string object as what was
+     *         supplied in the parameters
+     */
+    static std::string &trim(std::string *s) {
+      s->erase(s->begin(), std::find_if(s->begin(), s->end(), [](int ch) {
+        return !std::isspace(ch);
+      }));
+      s->erase(std::find_if(s->rbegin(), s->rend(), [](int ch) {
+        return !std::isspace(ch);
+      }).base(), s->end());
+      return *s;
+    }
+
+    /*!
+     * \brief Tokenize a string into a list of tokens
+     * \param s String to tokenize
+     * \return std::list of tokens
+     */
+    static std::list<std::string> string2list(const std::string &s) {
+      std::list<std::string> res;
+      std::istringstream iss(s);
+      std::string token;
+      while (std::getline(iss, token, ',')) {
+        trim(&token);
+        if (!token.empty()) {
+          res.push_back(token);
+        }
+      }
+      return std::move(res);
+    }
+  };
+
+  /*!
+   * \brief Get data type from string representation
+   * \warning Do not call from a performance-sensitive area
+   */
+  static int type_from_string(const std::string& type_string) {
+    if (type_string == "float32")
+      return mshadow::kFloat32;
+    if (type_string == "float64")
+      return mshadow::kFloat64;
+    if (type_string == "float16")
+      return mshadow::kFloat16;
+    if (type_string == "int8")
+      return mshadow::kInt8;
+    if (type_string == "uint8")
+      return mshadow::kUint8;
+    if (type_string == "int32")
+      return mshadow::kInt32;
+    if (type_string == "int64")
+      return mshadow::kInt64;
+    return -1;  // invalid
+  }
+
+  /*!
+   * \brief Parse MXNET_ENABLE_OPERATOR_TUNING environment variable
+   * \param config String representation of MXNET_ENABLE_OPERATOR_TUNING environment variable
+   *        Values:
+   *            0=disable all
+   *            1=enable all
+   *            float32, float16, float32=list of types to enable, and disable those not listed
+   */
+  static void ParseEnablerConfig(std::string config) {
+    StringUtil::trim(&config);
+    if (!config.empty()) {
+      // First disable all
+      OperatorTune<float>::set_tuning_mode(tune::AlwaysOMP);
+      OperatorTune<double>::set_tuning_mode(tune::AlwaysOMP);
+      OperatorTune<int8_t>::set_tuning_mode(tune::AlwaysOMP);
+      OperatorTune<uint8_t>::set_tuning_mode(tune::AlwaysOMP);
+      OperatorTune<int32_t>::set_tuning_mode(tune::AlwaysOMP);
+      OperatorTune<int64_t>::set_tuning_mode(tune::AlwaysOMP);
+      // See if it's a non-number (ie type or list of types)
+      if (!::isdigit(config[0])) {
+        OperatorTune<mshadow::half::half_t>::set_tuning_mode(tune::Auto);
+        std::list<std::string> tokens = StringUtil::string2list(config);
+        for (const std::string& stype : tokens) {
+          // We don't have an enum for halt_t
+          const int typ = type_from_string(stype);
+          if (typ >= 0) {
+            switch (typ) {
+              case mshadow::kFloat32:
+                OperatorTune<float>::set_tuning_mode(tune::Auto);
+                break;
+              case mshadow::kFloat64:
+                OperatorTune<double>::set_tuning_mode(tune::Auto);
+                break;
+              case mshadow::kFloat16:
+                OperatorTune<mshadow::half::half_t>::set_tuning_mode(tune::Auto);
+                break;
+              case mshadow::kInt8:
+                OperatorTune<int8_t>::set_tuning_mode(tune::Auto);
+                break;
+              case mshadow::kUint8:
+                OperatorTune<uint8_t>::set_tuning_mode(tune::Auto);
+                break;
+              case mshadow::kInt32:
+                OperatorTune<int32_t>::set_tuning_mode(tune::Auto);
+                break;
+              case mshadow::kInt64:
+                OperatorTune<int64_t>::set_tuning_mode(tune::Auto);
+                break;
+              default:
+                CHECK(false) << "Unsupported tuning data type: " << stype;
+                break;
+            }
+          } else {
+            // -1 is error
+            LOG(WARNING) << "Unknown data type to be tuned: " << stype;
+          }
+        }
+      } else {
+        if (std::atoi(config.c_str()) > 0) {
+          OperatorTune<float>::set_tuning_mode(tune::Auto);
+          OperatorTune<double>::set_tuning_mode(tune::Auto);
+          OperatorTune<int8_t>::set_tuning_mode(tune::Auto);
+          OperatorTune<uint8_t>::set_tuning_mode(tune::Auto);
+          OperatorTune<int32_t>::set_tuning_mode(tune::Auto);
+          OperatorTune<int64_t>::set_tuning_mode(tune::Auto);
+          OperatorTune<mshadow::half::half_t>::set_tuning_mode(tune::Auto);
+        }
+      }
+    }
+  }
+
+  /*! \brief Whether this object has been initialized */
+  static bool initialized_;
+  /*! \brief Number of passes to obtain an average */
+  static constexpr duration_t OUTSIDE_COUNT = (1 << OUTSIDE_COUNT_SHIFT);
+  /*! \brief Loop size to be timed (single op nanos may be too small to store accurately) */
+  static constexpr duration_t WORKLOAD_COUNT = (1 << WORKLOAD_COUNT_SHIFT);
+  /*! \brief Random data for timing operator calls */
+  static std::vector<DType> data_set_;
+  /*! \brief Operators tuned */
+  static std::unordered_set<std::string> operator_names_;
+  /*! \brief Tuning mode */
+  static volatile tune::TuningMode tuning_mode_;
+  /*! \brief Arbitary object to modify in OMP loop */
+  static volatile int volatile_int_;
+};
+
+/*!
+ * \brief Class that tunes unary operators
+ * \tparam DType Data type to be used when tuning the kernel operations
+ */
+template<typename DType>
+class UnaryOpTune : public OperatorTune<DType> {
+ protected:
+  typedef OperatorTune<DType> Super;
+  using duration_t = typename Super::duration_t;
+  using Tick = typename Super::Tick;
+
+  /*!
+   * \brief Some output type conversion to mxnet/mshadow types
+   * \param type string
+   * \return Possibly corrected type name
+   * \warning Do not call from within a performance-sensitive area
+   */
+  static std::string MakeOutputType(const std::string& typ) {
+    if (typ == "int") {
+      return "int32_t";
+    }
+    if (typ == "long") {
+      return "int64_t";
+    }
+    if (typ == "unsigned char") {
+      return "uint8_t";
+    }
+    if (typ == "char" || typ == "signed char") {
+      return "int8_t";
+    }
+    // Just return the default
+    return typ;
+  }
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for kernels that take no arguments (ie set_zero)
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static duration_t GetBlankWorkload() {
+    DType tmp;
+    volatile DType *res = &tmp;
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      // Use a logical AND instead of mod to avoid affecting the timing result with a slow divide
+      *res += OP::Map();
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for kernels that take one argument (ie sqrt())
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static duration_t GetUnaryWorkload() {
+    DType tmp;
+    volatile DType *res = &tmp;
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      // Use a logical AND instead of mod to avoid affecting the timing result with a slow divide
+      *res = OP::Map(Super::data_set_[i & 0xFF]);
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for kernels that take two arguments (ie elemwise_add())
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static inline duration_t GetBinaryWorkload() {
+    DType tmp;
+    volatile DType *res = &tmp;
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      // Use a logical AND instead of mod to avoid affecting the timing result with a slow divide
+      *res = OP::Map(Super::data_set_[i & 0xFF], Super::data_set_[(i + 1) & 0xFF]);
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for kernels that take three arguments (ie backwards_grad<elemwise_add>())
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static duration_t GetTertiaryWorkload() {
+    DType tmp;
+    volatile DType *res = &tmp;
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      // Use a logical AND instead of mod to avoid affecting the timing result with a slow divide
+      *res = OP::Map(Super::data_set_[i & 0xFF],
+                     Super::data_set_[(i + 1) & 0xFF],
+                     Super::data_set_[i & 0xFF]);
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for mxnet-like kernels that take no arguments)
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static duration_t GetBlankWorkloadEx() {
+    std::unique_ptr<DType> tmp(new DType[Super::WORKLOAD_COUNT]);
+    DType *tmp_ptr = tmp.get();
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      OP::Map(i, tmp_ptr);
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+ public:
+  /*!
+   * \brief Tune the specified kernel operator.  Optionally print out C++ macro that defines the
+   *        tuning data variable and the default tuned value
+   *        This function tunes an operator which takes no arguments
+   * \tparam OP The kernel operator to be tuned
+   */
+  template<typename OP>
+  static void TuneBlankOperator() {
+    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_ = GetBlankWorkload<OP>();
+    if (Super::output_tuning_data_) {
+      std::cout << "_IMPLEMENT_UNARY_WORKLOAD_FWD("
+                << Super::template type_name<OP>()
+                << ", " << mxnet_op::tuned_op<OP, DType>::workload_
+                << ", " << MakeOutputType(Super::template type_name<DType>())
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Tune the specified kernel operator.  Optionally print out C++ macro that defines the
+   *        tuning data variable and the default tuned value
+   *        This function tunes an operator which takes one argument
+   * \tparam OP The kernel operator to be tuned
+   */
+  template<typename OP>
+  static void TuneUnaryOperator() {
+    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_ = GetUnaryWorkload<OP>();
+    if (Super::output_tuning_data_) {
+      std::cout << "_IMPLEMENT_UNARY_WORKLOAD_FWD("
+                << Super::template type_name<OP>()
+                << ", " << mxnet_op::tuned_op<OP, DType>::workload_
+                << ", " << MakeOutputType(Super::template type_name<DType>())
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Tune the specified kernel operator.  Optionally print out C++ macro that defines the
+   *        tuning data variable and the default tuned value
+   *        This function tunes a backward operator which takes one argument
+   * \tparam OP The kernel operator to be tuned
+   */
+  template<typename OP>
+  static void TuneUnaryBackwardOperator() {
+    mxnet::op::mxnet_op::tuned_op<mxnet_op::backward_grad<OP>, DType>::workload_ =
+      GetBinaryWorkload<mxnet::op::mxnet_op::backward_grad<OP>>();
+    if (Super::output_tuning_data_) {
+      std::cout << "_IMPLEMENT_UNARY_WORKLOAD_BWD("
+                << Super::template type_name<OP>()
+                << ", "
+                << mxnet_op::tuned_op<mxnet::op::mxnet_op::backward_grad<OP>, DType>::workload_
+                << ", " << MakeOutputType(Super::template type_name<DType>())
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Tune the specified "mxnet_op-type" kernel operator.
+   *        Optionally print out C++ macro that defines the
+   *        tuning data variable and the default tuned value
+   *        This function tunes an operator which takes no arguments
+   * \tparam OP The kernel operator to be tuned
+   */
+  template<typename OP>
+  static void TuneBlankOperatorEx() {
+    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_ = GetBlankWorkloadEx<OP>();
+    if (Super::output_tuning_data_) {
+      std::cout << "_IMPLEMENT_BLANK_WORKLOAD_FWD("
+                << Super::template type_name<OP>()
+                << ", " << mxnet_op::tuned_op<OP, DType>::workload_
+                << ", " << MakeOutputType(Super::template type_name<DType>())
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Estimate the time to compute with and without OMP, then return whether to use OMP
+   * \param N - Number of iterations desired
+   * \param thread_count - Number of OMP threads available to perform the iterations
+   * \returns Whether it's faster to use OMP for these iterations
+   */
+  template<typename OP>
+  inline static bool UseOMP(size_t N, size_t thread_count) {
+#ifdef MXNET_USE_OPERATOR_TUNING
+    switch (Super::tuning_mode_) {
+      case tune::Auto:
+        if (thread_count >= 2) {
+          // Compute serial time required
+          const uint64_t total_serial_time =
+            (static_cast<uint64_t>(N) * OP::workload_) >> WORKLOAD_COUNT_SHIFT;
+
+          // Compute time required for OMP + # items per thread
+          const uint64_t omp_compute_time =
+            (static_cast<uint64_t>(N) * OP::workload_) / thread_count;
+          const uint64_t total_omp_time =
+            Super::omp_overhead_ + (omp_compute_time >> WORKLOAD_COUNT_SHIFT);
+
+          const bool res = total_omp_time < total_serial_time;
+          return res;
+        }
+        return false;
+      case tune::NeverOMP:
+        return false;
+      case tune::AlwaysOMP:
+      default:
+        return thread_count > 1;
+    }
+#else
+    return true;
+#endif
+  }
+};
+
+/*!
+ * \brief Class that tunes binary and unary operators
+ * \tparam DType Data type to be used when tuning the kernel operations
+ */
+template<typename DType>
+class BinaryOpTune : public UnaryOpTune<DType> {
+ protected:
+  typedef UnaryOpTune<DType> Super;
+
+ public:
+  explicit BinaryOpTune(op::tune::TuningMode mode) {
+    Super::set_tuning_mode(mode);
+  }
+
+  /*!
+   * \brief Tune a generic binary operator
+   * @tparam OP - Operator type
+   */
+  template<typename OP>
+  static void TuneBinaryOperator() {
+    mxnet_op::tuned_op<OP, DType>::workload_ = Super::template GetBinaryWorkload<OP>();
+    if (Super::Super::output_tuning_data_) {
+      std::cout << "_IMPLEMENT_BINARY_WORKLOAD_FWD("
+                << Super::template type_name<OP>()
+                << ", " << mxnet_op::tuned_op<OP, DType>::workload_
+                << ", " << Super::MakeOutputType(Super::template type_name<DType>())
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Tune binary backward operator
+   * \tparam OP - operator
+   */
+  template<typename OP>
+  static void TuneBinaryBackwardOperator() {
+    mxnet::op::mxnet_op::tuned_op<mxnet_op::backward_grad<OP>, DType>::workload_ =
+      Super::template GetTertiaryWorkload<mxnet::op::mxnet_op::backward_grad<OP>>();
+    if (Super::Super::output_tuning_data_) {
+      std::cout << "_IMPLEMENT_BINARY_WORKLOAD_BWD("
+                << Super::template type_name<OP>()
+                << ", "
+                << mxnet_op::tuned_op<mxnet::op::mxnet_op::backward_grad<OP>, DType>::workload_
+                << ", " << Super::MakeOutputType(Super::template type_name<DType>())
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+};
+
+#undef OUTSIDE_COUNT_SHIFT
+#undef WORKLOAD_COUNT_SHIFT
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_OPERATOR_TUNE_INL_H_
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
new file mode 100644
index 0000000000..eb0b2e9ffb
--- /dev/null
+++ b/src/operator/operator_tune.cc
@@ -0,0 +1,953 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <atomic>
+#include "./operator_tune.h"
+#include "./mshadow_op.h"
+#include "./tensor/init_op.h"
+#include "./operator_tune-inl.h"
+
+namespace mxnet {
+namespace op {
+
+/*!
+ * \brief Shared static variables for all OperatorTune data types
+ */
+std::atomic<bool> OperatorTuneBase::calculated_(false);
+bool OperatorTuneBase::output_tuning_data_ = false;
+bool OperatorTuneBase::verbose_tuning_info_ = false;
+double OperatorTuneBase::tuning_weight_scale_ = 0.0;
+bool OperatorTuneBase::enable_auto_tune_ = true;
+
+/*!
+ * \brief Instantiate static variables for OperatorTune<DType>, where 'DType' is specified
+ */
+#define IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(__typ$) \
+  template<> bool OperatorTune<__typ$>::initialized_ = false; \
+  template<> std::vector<__typ$> OperatorTune<__typ$>::data_set_ = {}; \
+  template<> volatile tune::TuningMode OperatorTune<__typ$>::tuning_mode_ = tune::Auto; \
+  template<> volatile int OperatorTune<__typ$>::volatile_int_ = 9; \
+  template<> std::unordered_set<std::string> OperatorTune<__typ$>::operator_names_ = {}; \
+  template<> std::list<void (*)()> *OperatorTune<__typ$>::GetTuningList() { \
+    static std::list<void (*)()> ll; \
+    return &ll; \
+  }
+
+/*!
+ * \brief Static variables for different types (ie OperatorTune<float>, OperatorTune<double>, etc.
+ */
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(float);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(double);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(mshadow::half::half_t);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int8_t);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(uint8_t);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int32_t);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int64_t);
+
+/*!
+ * \brief Tuning data and default weights in the case that MXNET_ENABLE_OPERATOR_AUTOTUNE is set
+ *        to zero (thus turning off auto-tuning)
+ * \note This code can be automatically generated
+ *       by setting the environment variable MXNET_OUTPUT_TUNING_DATA to a positive
+ *       integer value
+ */
+OperatorTuneBase::duration_t OperatorTuneBase::omp_overhead_ = 2557;
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mshadow::op::identity, 1222, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::identity, 887, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::identity_grad, 1080, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::negation, 1374, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal, 7634, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_grad, 7964, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid, 51917, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad, 5939, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu, 4174, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad, 4141, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh, 98111, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad, 3982, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu, 121839, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softrelu_grad, 73387, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::exp, 42970, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::exp, 44830, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::expm1, 68240, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log, 46303, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_grad, 7858, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log1p, 70883, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log1p_grad, 7871, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log2, 47596, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log2_grad, 7907, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log10, 53523, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log10_grad, 7855, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin, 35895, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sin_grad, 35643, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sinh, 95576, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sinh_grad, 68257, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsin, 44034, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsin_grad, 19501, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsinh, 103274, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsinh_grad, 31581, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cos, 31271, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cos_grad, 37167, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cosh, 64874, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cosh_grad, 94673, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccos, 52882, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccos_grad, 19477, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccosh, 17277, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccosh_grad, 29268, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tan, 66606, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tan_grad, 6263, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctan, 52749, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan_grad, 7640, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctanh, 97482, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctanh_grad, 8007, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square, 2811, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_grad, 3759, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square_root, 17254, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_root_grad, 7883, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_square_root, 17432, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_square_root_grad, 21084, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cube_root, 73535, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cube_root_grad, 7649, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_cube_root, 78676, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_cube_root_grad, 86259, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::abs, 1289, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sign, 4463, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign, 7284, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad, 3358, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round, 11673, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor, 6535, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc, 9170, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rint, 14273, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fix, 24176, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma, 188495, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gamma_grad, 389885, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gammaln, 177646, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gammaln_grad, 198518, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ceil, 5045, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::degrees, 3346, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::degrees_grad, 3365, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::radians, 3355, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::radians_grad, 3364, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::plus, 3562, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::minus, 3557, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::mul, 3574, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::div, 7503, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::right, 976, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus, 3576, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rdiv, 7574, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_grad, 7562, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_grad, 9446, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_rgrad, 7813, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_rgrad, 13537, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv_grad, 11364, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mod, 39168, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_grad, 1063, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_rgrad, 14988, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rmod, 38068, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rmod_grad, 15769, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::left, 1064, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::left, 3555, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::right, 1027, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::right, 2809, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::power, 137169, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rpower, 135519, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad, 140140, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad, 46200, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad, 178129, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum, 3562, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum, 3556, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot, 21679, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_left, 28271, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_left, 31139, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_right, 28172, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_right, 31137, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lt, 4202, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::lt, 7187, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::le, 4187, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::le, 7205, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gt, 5601, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gt, 8641, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ge, 8649, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ge, 5582, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne, 2001, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne, 3798, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq, 1991, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq, 3731, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss, 11752, float);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient, 10923, float);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>, 3659, float);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel, 1349, float);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mshadow::op::identity, 1104, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::identity, 1108, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::identity_grad, 1086, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::negation, 1240, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal, 9153, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_grad, 9680, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid, 53742, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad, 6089, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu, 4172, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad, 4159, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh, 95669, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad, 4006, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu, 113160, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softrelu_grad, 71708, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::exp, 44416, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::exp, 45434, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::expm1, 63875, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log, 44125, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_grad, 9581, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log1p, 59681, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log1p_grad, 9627, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log2, 47212, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log2_grad, 9816, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log10, 50422, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log10_grad, 9542, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin, 41887, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sin_grad, 52671, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sinh, 93011, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sinh_grad, 71570, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsin, 64841, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsin_grad, 24395, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsinh, 94516, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsinh_grad, 51546, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cos, 42030, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cos_grad, 45108, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cosh, 68272, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cosh_grad, 92688, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccos, 62425, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccos_grad, 24390, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccosh, 11958, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccosh_grad, 29153, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tan, 68871, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tan_grad, 6277, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctan, 63505, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan_grad, 9345, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctanh, 86646, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctanh_grad, 9528, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square, 2800, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_grad, 3768, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square_root, 16823, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_root_grad, 9635, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_square_root, 19024, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_square_root_grad, 19782, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cube_root, 82565, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cube_root_grad, 9360, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_cube_root, 88995, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_cube_root_grad, 95500, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::abs, 1239, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sign, 4623, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign, 6760, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad, 3363, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round, 16425, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor, 6796, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc, 9289, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rint, 14419, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fix, 25077, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma, 210978, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gamma_grad, 453107, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gammaln, 188625, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gammaln_grad, 237430, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ceil, 5061, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::degrees, 3413, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::degrees_grad, 3467, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::radians, 3406, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::radians_grad, 3493, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::plus, 3545, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::minus, 3613, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::mul, 3606, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::div, 9330, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::right, 988, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus, 3625, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rdiv, 9252, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_grad, 9306, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_grad, 11108, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_rgrad, 9616, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_rgrad, 15216, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv_grad, 13106, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mod, 34776, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_grad, 1076, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_rgrad, 16886, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rmod, 33631, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rmod_grad, 17668, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::left, 1168, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::left, 3554, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::right, 982, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::right, 2799, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::power, 106551, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rpower, 94821, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad, 97108, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad, 48308, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad, 136776, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum, 3559, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum, 3557, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot, 38473, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_left, 49383, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_left, 49771, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_right, 46794, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_right, 49956, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lt, 4190, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::lt, 7246, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::le, 4179, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::le, 7243, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gt, 5590, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gt, 8584, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ge, 8615, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ge, 5618, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne, 2017, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne, 3753, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq, 2012, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq, 3728, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss, 13354, double);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient, 10929, double);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>, 5213, double);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel, 734, double);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mshadow::op::identity, 613, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::identity, 603, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::identity_grad, 641, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::negation, 602, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal, 604, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_grad, 610, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid, 68999, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad, 612, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu, 612, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad, 608, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh, 606, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad, 610, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu, 79148, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softrelu_grad, 75681, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::exp, 29758, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::exp, 30936, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::expm1, 33922, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log, 33546, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_grad, 617, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log1p, 36328, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log1p_grad, 601, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log2, 37900, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log2_grad, 611, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log10, 37923, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log10_grad, 611, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin, 601, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sin_grad, 597, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sinh, 50732, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sinh_grad, 39194, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsin, 26381, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsin_grad, 9867, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsinh, 608, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsinh_grad, 17625, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cos, 613, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cos_grad, 606, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cosh, 38223, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cosh_grad, 51348, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccos, 29544, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccos_grad, 9886, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccosh, 18755, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccosh_grad, 20831, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tan, 612, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tan_grad, 611, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctan, 602, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan_grad, 607, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctanh, 50320, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctanh_grad, 615, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square, 620, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_grad, 611, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square_root, 22087, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_root_grad, 604, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_square_root, 21274, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_square_root_grad, 23637, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cube_root, 612, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cube_root_grad, 613, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_cube_root, 600, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_cube_root_grad, 603, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::abs, 597, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sign, 604, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign, 604, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad, 608, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round, 607, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor, 598, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc, 609, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rint, 604, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fix, 604, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma, 129925, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gamma_grad, 165461, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gammaln, 91331, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gammaln_grad, 65551, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ceil, 615, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::degrees, 610, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::degrees_grad, 598, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::radians, 608, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::radians_grad, 610, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::plus, 606, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::minus, 610, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::mul, 598, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::div, 605, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::right, 599, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus, 602, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rdiv, 603, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_grad, 604, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_grad, 607, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_rgrad, 604, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_rgrad, 605, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv_grad, 605, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mod, 52754, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_grad, 607, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_rgrad, 608, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rmod, 55365, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rmod_grad, 615, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::left, 610, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::left, 615, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::right, 617, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::right, 607, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::power, 107498, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rpower, 106419, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad, 143936, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad, 33041, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad, 124413, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum, 1533, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum, 1527, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot, 23558, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_left, 25021, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_left, 24313, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_right, 24582, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_right, 24371, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lt, 610, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::lt, 607, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::le, 610, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::le, 598, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gt, 605, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gt, 601, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ge, 617, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ge, 601, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne, 605, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne, 600, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq, 612, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq, 609, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss, 602, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient, 640, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>, 163, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel, 4118, mshadow::half::half_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mshadow::op::identity, 975, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::identity, 888, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::identity_grad, 886, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::negation, 1051, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal, 1995, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_grad, 2646, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid, 36944, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad, 2344, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu, 1172, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad, 1540, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh, 16315, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad, 2552, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu, 32325, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softrelu_grad, 23786, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::exp, 22048, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::exp, 23155, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::expm1, 21410, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log, 22525, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_grad, 2277, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log1p, 19499, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log1p_grad, 2627, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log2, 22440, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log2_grad, 2280, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log10, 25452, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log10_grad, 2281, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin, 20146, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sin_grad, 21881, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sinh, 33157, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sinh_grad, 24387, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsin, 12506, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsin_grad, 14324, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsinh, 37339, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsinh_grad, 11309, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cos, 20897, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cos_grad, 21417, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cosh, 22116, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cosh_grad, 33564, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccos, 12645, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccos_grad, 13699, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccosh, 28162, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccosh_grad, 9462, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tan, 48177, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tan_grad, 2342, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctan, 19239, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan_grad, 2935, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctanh, 11691, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctanh_grad, 3000, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square, 1719, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_grad, 1998, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square_root, 16500, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_root_grad, 2333, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_square_root, 15865, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_square_root_grad, 17088, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cube_root, 34890, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cube_root_grad, 2929, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_cube_root, 35550, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_cube_root_grad, 37298, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::abs, 1718, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sign, 1542, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign, 1728, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad, 624, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round, 4614, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor, 3955, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc, 5247, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rint, 6774, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fix, 7482, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma, 66818, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gamma_grad, 86241, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gammaln, 42016, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gammaln_grad, 22820, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ceil, 3937, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::degrees, 1717, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::degrees_grad, 1345, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::radians, 1718, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::radians_grad, 604, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::plus, 1172, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::minus, 1203, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::mul, 1169, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::div, 4504, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::right, 903, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus, 1372, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rdiv, 2549, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_grad, 1993, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_grad, 2267, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_rgrad, 2617, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_rgrad, 3195, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv_grad, 3193, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mod, 33600, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_grad, 678, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_rgrad, 676, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rmod, 34110, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rmod_grad, 674, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::left, 884, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::left, 1351, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::right, 890, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::right, 1165, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::power, 124079, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rpower, 117575, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad, 134697, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad, 24230, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad, 141691, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum, 1357, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum, 1409, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot, 10275, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_left, 11476, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_left, 11767, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_right, 15031, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_right, 21617, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lt, 1558, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::lt, 1815, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::le, 1601, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::le, 2198, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gt, 2996, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gt, 3094, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ge, 2322, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ge, 1350, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne, 1350, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne, 1530, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq, 1356, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq, 1755, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss, 7986, int8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient, 11963, int8_t);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>, 52, int8_t);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel, 752, int8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mshadow::op::identity, 940, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::identity, 938, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::identity_grad, 892, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::negation, 1053, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal, 1999, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_grad, 2645, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid, 29055, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad, 2329, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu, 898, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad, 1529, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh, 11208, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad, 2549, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu, 6136, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softrelu_grad, 11757, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::exp, 23505, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::exp, 24622, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::expm1, 19787, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log, 20534, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_grad, 2283, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log1p, 26721, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log1p_grad, 2700, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log2, 21167, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log2_grad, 2387, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log10, 31824, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log10_grad, 2279, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin, 21733, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sin_grad, 21572, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sinh, 27736, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sinh_grad, 24677, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsin, 12190, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsin_grad, 13820, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsinh, 37830, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsinh_grad, 11334, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cos, 20003, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cos_grad, 22090, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cosh, 23327, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cosh_grad, 28399, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccos, 13277, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccos_grad, 13781, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccosh, 37151, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccosh_grad, 9447, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tan, 106186, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tan_grad, 2340, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctan, 17939, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan_grad, 2930, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctanh, 12001, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctanh_grad, 3006, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square, 1718, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_grad, 2054, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square_root, 7237, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_root_grad, 2287, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_square_root, 9527, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_square_root_grad, 10555, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cube_root, 31904, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cube_root_grad, 2983, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_cube_root, 32557, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_cube_root_grad, 42138, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::abs, 1098, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sign, 1343, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign, 1616, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad, 916, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round, 6877, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor, 9913, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc, 8089, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rint, 8011, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fix, 10013, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma, 51033, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gamma_grad, 92099, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gammaln, 35324, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gammaln_grad, 36065, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ceil, 3413, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::degrees, 1708, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::degrees_grad, 1343, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::radians, 1716, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::radians_grad, 721, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::plus, 1166, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::minus, 1162, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::mul, 1385, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::div, 4531, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::right, 885, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus, 1342, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rdiv, 2556, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_grad, 1990, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_grad, 2323, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_rgrad, 2548, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_rgrad, 3133, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv_grad, 3111, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mod, 19504, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_grad, 676, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_rgrad, 607, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rmod, 18992, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rmod_grad, 684, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::left, 895, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::left, 1258, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::right, 888, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::right, 1019, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::power, 121274, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rpower, 116809, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad, 114964, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad, 22570, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad, 120983, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum, 1344, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum, 1530, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot, 10393, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_left, 11296, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_left, 11762, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_right, 11421, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_right, 11484, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lt, 1348, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::lt, 1803, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::le, 1530, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::le, 1592, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gt, 1575, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gt, 1546, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ge, 1761, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ge, 1350, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne, 1345, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne, 1585, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq, 1343, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq, 1534, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss, 3881, uint8_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient, 3141, uint8_t);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>, 54, uint8_t);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel, 764, uint8_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mshadow::op::identity, 921, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::identity, 949, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::identity_grad, 902, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::negation, 1174, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal, 1832, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_grad, 2626, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid, 41024, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad, 2317, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu, 1530, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad, 1714, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh, 15113, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad, 2314, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu, 30598, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softrelu_grad, 23928, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::exp, 20510, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::exp, 21841, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::expm1, 21920, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log, 21792, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_grad, 2316, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log1p, 19068, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log1p_grad, 2560, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log2, 21892, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log2_grad, 2306, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log10, 25817, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log10_grad, 2313, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin, 20020, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sin_grad, 21550, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sinh, 31030, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sinh_grad, 23083, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsin, 12072, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsin_grad, 13275, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsinh, 37593, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsinh_grad, 10694, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cos, 20674, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cos_grad, 21698, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cosh, 21881, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cosh_grad, 31297, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccos, 14416, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccos_grad, 13896, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccosh, 28558, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccosh_grad, 9497, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tan, 47822, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tan_grad, 2295, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctan, 18936, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan_grad, 2957, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctanh, 11983, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctanh_grad, 2968, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square, 1539, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_grad, 2031, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square_root, 14818, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_root_grad, 2311, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_square_root, 15903, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_square_root_grad, 15263, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cube_root, 34168, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cube_root_grad, 3012, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_cube_root, 34343, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_cube_root_grad, 37268, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::abs, 1575, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sign, 1750, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign, 1995, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad, 612, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round, 4559, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor, 3981, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc, 4934, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rint, 7355, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fix, 7943, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma, 60321, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gamma_grad, 79458, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gammaln, 42949, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gammaln_grad, 23218, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ceil, 3964, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::degrees, 1581, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::degrees_grad, 1158, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::radians, 1534, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::radians_grad, 608, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::plus, 1220, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::minus, 1177, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::mul, 1183, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::div, 4598, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::right, 911, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus, 1352, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rdiv, 2323, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_grad, 1879, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_grad, 2189, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_rgrad, 2624, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_rgrad, 2918, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv_grad, 2958, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mod, 40655, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_grad, 711, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_rgrad, 630, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rmod, 34719, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rmod_grad, 676, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::left, 905, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::left, 1208, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::right, 905, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::right, 1159, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::power, 122776, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rpower, 117304, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad, 119978, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad, 23683, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad, 142189, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum, 1348, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum, 1353, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot, 9679, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_left, 10834, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_left, 11415, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_right, 11096, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_right, 11569, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lt, 1594, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::lt, 1723, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::le, 1589, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::le, 1722, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gt, 1589, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gt, 1718, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ge, 1720, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ge, 1535, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne, 1534, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne, 1771, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq, 1537, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq, 1721, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss, 8250, int32_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient, 6789, int32_t);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>, 396, int32_t);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel, 498, int32_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mshadow::op::identity, 883, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::identity, 891, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::identity_grad, 898, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::negation, 1161, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal, 2088, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_grad, 3009, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid, 37587, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad, 2691, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu, 1442, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad, 1773, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh, 16847, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad, 2658, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu, 32984, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softrelu_grad, 25283, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::exp, 22831, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::exp, 23884, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::expm1, 22857, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log, 22334, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_grad, 2667, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log1p, 20560, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log1p_grad, 3075, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log2, 22791, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log2_grad, 2740, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log10, 26369, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log10_grad, 2626, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin, 22274, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sin_grad, 22129, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sinh, 34079, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sinh_grad, 24176, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsin, 12391, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsin_grad, 13878, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsinh, 39186, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsinh_grad, 11354, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cos, 20510, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cos_grad, 22682, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cosh, 22665, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cosh_grad, 33219, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccos, 13820, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccos_grad, 14336, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccosh, 28841, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccosh_grad, 8995, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tan, 49310, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tan_grad, 2688, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctan, 19799, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan_grad, 3370, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctanh, 12472, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctanh_grad, 3393, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square, 1906, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_grad, 2327, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square_root, 15647, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_root_grad, 2628, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_square_root, 15818, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_square_root_grad, 16739, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cube_root, 34639, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cube_root_grad, 3370, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_cube_root, 34997, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_cube_root_grad, 37961, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::abs, 1995, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sign, 1725, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign, 2003, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad, 706, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round, 4639, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor, 3996, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc, 5377, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rint, 7485, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fix, 7910, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma, 65634, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gamma_grad, 84586, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gammaln, 43821, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gammaln_grad, 23171, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ceil, 3449, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::degrees, 1892, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::degrees_grad, 1399, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::radians, 1896, int64_t);  // NOLINT()
+_IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::radians_grad, 613, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::plus, 1172, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::minus, 1193, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::mul, 1162, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::div, 15026, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::right, 883, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus, 1352, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rdiv, 2834, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_grad, 2092, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_grad, 2467, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_rgrad, 3112, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_rgrad, 3354, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv_grad, 3422, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mod, 33267, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_grad, 608, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_rgrad, 638, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rmod, 35780, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rmod_grad, 607, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::left, 890, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::left, 1164, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::right, 890, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::right, 1032, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::power, 118809, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rpower, 115566, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad, 121035, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad, 23749, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad, 139325, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum, 1353, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum, 1375, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot, 10542, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_left, 11615, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_left, 12057, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_right, 11644, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_right, 11931, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lt, 1540, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::lt, 1725, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::le, 1538, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::le, 1722, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gt, 1525, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gt, 1716, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ge, 1708, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ge, 1536, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne, 1451, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne, 1719, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq, 1539, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq, 1757, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss, 9184, int64_t);  // NOLINT()
+_IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient, 11223, int64_t);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>, 481, int64_t);  // NOLINT()
+_IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel, 506, int64_t);  // NOLINT()
+// /* END AUTOMATICALLY GENERATED DATA */
+
+/*!
+ * \brief Tuner objects, *not* automatically generated
+ */
+#ifdef MXNET_USE_OPERATOR_TUNING
+static BinaryOpTune<float>                  binaryOpTuneFloat(op::tune::Auto);
+static BinaryOpTune<double>                 binaryOpTuneDouble(op::tune::Auto);
+static BinaryOpTune<mshadow::half::half_t>  binaryOpTuneHalf(op::tune::Auto);
+static BinaryOpTune<int8_t>                 binaryOpTuneInt8(op::tune::AlwaysOMP);
+static BinaryOpTune<uint8_t>                binaryOpTuneUInt8(op::tune::AlwaysOMP);
+static BinaryOpTune<int32_t>                binaryOpTuneInt32(op::tune::Auto);
+static BinaryOpTune<int64_t>                binaryOpTuneInt64(op::tune::Auto);
+#endif  // MXNET_USE_OPERATOR_TUNING
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/operator_tune.h b/src/operator/operator_tune.h
new file mode 100644
index 0000000000..2e0a305cfd
--- /dev/null
+++ b/src/operator/operator_tune.h
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef MXNET_OPERATOR_OPERATOR_TUNE_H_
+#define MXNET_OPERATOR_OPERATOR_TUNE_H_
+
+#include <mshadow/base.h>
+
+namespace mxnet {
+namespace op {
+
+/*!
+ * \brief Declare a template specialization for the Kernel::Launch call for the given OP
+ *        wrapped with mxnet_op::op_with_req, using the given OpReqType as the 'req'
+ *        template parameter for 'op_with_req'.  This is useful for the standard mshadow_op
+ *        operators which need to be wrapped with op_with_req in order to be used with the
+ *        Kernel::Launch command.
+ *
+ * \note Expects to be used within the mxnet::op namespace
+ *
+ * For example:
+ *
+ * namespace mxnet_op {
+ * template <>
+ * template <typename... Args>
+ * inline void Kernel<typename mxnet_op::op_with_req<mshadow::op::identity, kNullOp>, cpu>
+ *   ::Launch(mshadow::Stream<cpu>* s, const int N, Args... args) {
+ *   ::mxnet::op::mxnet_op::Kernel<typename mxnet_op::op_with_req<mshadow::op::identity, kNullOp>,
+ *     cpu>::LaunchMShadowOpEx(s, N, args...);
+ *   }
+ * }
+ *
+ */
+#define MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, __req$) \
+  namespace mxnet_op { \
+  template<> template<typename ...Args> \
+  inline void Kernel<typename mxnet_op::op_with_req<__op$, __req$>, ::mshadow::cpu>:: \
+    Launch(mshadow::Stream<::mshadow::cpu> *s, const int N, Args... args) { \
+      /* Launch via LaunchMShadowOpEx() */ \
+      Kernel<typename mxnet_op::op_with_req<__op$, __req$>, ::mshadow::cpu>:: \
+        LaunchMShadowOpEx(s, N, args...); \
+  } \
+  }  /* namespace mxnet_op */
+
+/*!
+ * \brief Declare template specializations for the Kernel::Launch call for the given OP
+ *        wrapped with mxnet_op::op_with_req, using the all supported OpReqType as the 'req'
+ *        template parameter for 'op_with_req'.  This is useful for the standard mshadow_op
+ *        operators which need to be wrapped with op_with_req in order to be used with the
+ *        Kernel::Launch command.
+ * \note Expects to be used within the mxnet::op namespace
+ */
+#define MXNET_TUNABLE_MSHADOW_OP(__op$) \
+  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kNullOp); \
+  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kWriteTo); \
+  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kWriteInplace); \
+  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kAddTo);
+
+#define MXNET_TUNABLE_MSHADOW_OP_BACKWARD(__op$) \
+  MXNET_TUNABLE_MSHADOW_OP(mxnet::op::mxnet_op::backward_grad<__op$>)
+
+#define MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(__op$) \
+  MXNET_TUNABLE_MSHADOW_OP(__op$) \
+  MXNET_TUNABLE_MSHADOW_OP_BACKWARD(__op$)
+
+/*!
+ * \brief mxnet::op::mxnet_op format ops (work directly with Kernel<>::Launch()
+ *        Used from within mxnet::op::mxnet_op namespace
+ */
+#define _MXNET_TUNABLE_MXNET_OP_FWD(__op$) \
+  template<> template<typename ...Args> inline void Kernel<__op$, ::mshadow::cpu>::Launch( \
+    mshadow::Stream<::mshadow::cpu> *s, const int N, Args... args) { \
+      /* Launch via LaunchMXNetOpEx() */ \
+      Kernel<__op$, ::mshadow::cpu>::LaunchMXNetOpEx(s, N, args...); \
+  }
+
+/*!
+ * \brief mxnet::op::mxnet_op format ops (work directly with Kernel<>::Launch()
+ *        Used from within mxnet::op
+ */
+#define MXNET_TUNABLE_MXNET_OP_FWD(__op$) \
+  namespace mxnet_op { _MXNET_TUNABLE_MXNET_OP_FWD(__op$) }  /* namespace mxnet_op */
+
+
+/*!
+ * \brief Init variable used to facilitate registering a tunable operator during
+ *        static initialization
+ * \tparam OP Operator type
+ * \tparam DType Data type
+ */
+template<typename OP, typename DType>
+struct static_init_var {
+  static bool init_;
+};
+
+/*!
+ * \brief Repeat the given macro and associated arguments for each data type,
+ *        appending the data type to the end of the arguments
+ */
+#define MSHADOW_MACRO_FOREACH_TYPE(__macro$, ...) \
+  __macro$(__VA_ARGS__, float); \
+  __macro$(__VA_ARGS__, double); \
+  __macro$(__VA_ARGS__, mshadow::half::half_t); \
+  __macro$(__VA_ARGS__, uint8_t); \
+  __macro$(__VA_ARGS__, int8_t); \
+  __macro$(__VA_ARGS__, int32_t); \
+  __macro$(__VA_ARGS__, int64_t);
+
+
+#define IMPLEMENT_BASIC_WORKLOAD(__op$, __v1$, __typ$) \
+  namespace mxnet_op { \
+  template<> size_t mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::workload_ = (__v1$) == 0 ? \
+     (INT_MAX / 4) : (__v1$); }  /* namespace mxnet_op */
+
+/*!
+ * \brief Implement tuning objects for a forward blank (no arguments) kernel operator
+ */
+#define _IMPLEMENT_BLANK_WORKLOAD_FWD(__op$, __v1$, __typ$) \
+  IMPLEMENT_BASIC_WORKLOAD(__op$, __v1$, __typ$); \
+  namespace mxnet_op { \
+  template<> int mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
+      N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<__op$, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      mxnet::op::UnaryOpTune<__typ$>::TuneBlankOperatorEx<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a forward unary kernel operator
+ */
+#define _IMPLEMENT_UNARY_WORKLOAD_FWD(__op$, __v1$, __typ$) \
+  IMPLEMENT_BASIC_WORKLOAD(__op$, __v1$, __typ$); \
+  namespace mxnet_op { \
+  template<> int mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
+      N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<__op$, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      mxnet::op::UnaryOpTune<__typ$>::TuneUnaryOperator<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a backward unary kernel operator
+ */
+#define _IMPLEMENT_UNARY_WORKLOAD_BWD(__op$, __v1$, __typ$) \
+  IMPLEMENT_BASIC_WORKLOAD(mxnet::op::mxnet_op::backward_grad<__op$>, __v1$, __typ$); \
+  namespace mxnet_op { \
+  template<> \
+  int mxnet::op::mxnet_op::tuned_op<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op< \
+      mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>>(N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      mxnet::op::UnaryOpTune<__typ$>::TuneUnaryBackwardOperator<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a forward binary kernel operator
+ */
+#define _IMPLEMENT_BINARY_WORKLOAD_FWD(__op$, __v1$, __typ$) \
+  IMPLEMENT_BASIC_WORKLOAD(__op$, __v1$, __typ$); \
+  namespace mxnet_op { \
+  template<> int mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::BinaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
+      N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<__op$, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      mxnet::op::BinaryOpTune<__typ$>::TuneBinaryOperator<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a backward binary kernel operator
+ */
+#define _IMPLEMENT_BINARY_WORKLOAD_BWD(__op$, __v1$, __typ$) \
+  IMPLEMENT_BASIC_WORKLOAD(mxnet::op::mxnet_op::backward_grad<__op$>, __v1$, __typ$); \
+  namespace mxnet_op { \
+  template<> \
+    int mxnet::op::mxnet_op::tuned_op<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::BinaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op< \
+      mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>>(N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>(  \
+      mxnet::op::BinaryOpTune<__typ$>::TuneBinaryBackwardOperator<__op$>)
+
+/*!
+ * \brief Macros for manually adding new blank, unary and binary operators to the tuning set
+ */
+#define IMPLEMENT_UNARY_WORKLOAD_FWD(__op$, __v1$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_UNARY_WORKLOAD_FWD, __op$, __v1$)
+
+#define IMPLEMENT_BLANK_WORKLOAD_FWD(__op$, __v1$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BLANK_WORKLOAD_FWD, __op$, __v1$)
+
+#define IMPLEMENT_UNARY_WORKLOAD_BWD(__op$, __v1$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_UNARY_WORKLOAD_BWD, __op$, __v1$)
+
+#define IMPLEMENT_BINARY_WORKLOAD_FWD(__op$, __v1$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BINARY_WORKLOAD_FWD, __op$, __v1$)
+
+#define IMPLEMENT_BINARY_WORKLOAD_BWD(__op$, __v1$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BINARY_WORKLOAD_BWD, __op$, __v1$)
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_OPERATOR_TUNE_H_
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 1c5e1c62aa..61b97ba60d 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -264,7 +264,7 @@ inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs,
       grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
       static_cast<DType>(param.lr), static_cast<DType>(param.wd),
       static_cast<DType>(param.rescale_grad), req[0]);
-  });
+    });
 }
 
 template<int n_in, int n_out, int total_in>
diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h
index b8b5bd1390..9c8f180116 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -33,8 +33,10 @@
 #include <algorithm>
 #include "../mxnet_op.h"
 #include "../mshadow_op.h"
+#include "../../engine/openmp.h"
 #include "elemwise_unary_op.h"
 #include "../../common/utils.h"
+#include "./init_op.h"
 
 namespace mxnet {
 namespace op {
@@ -42,23 +44,6 @@ namespace op {
 /*! Gather binary operator functions into ElemwiseBinaryOp class */
 class ElemwiseBinaryOp : public OpBase {
  public:
-  template<typename OP, int Req>
-  struct BackwardUseNoneOp {
-    template<typename DType>
-    MSHADOW_XINLINE static void Map(int i, DType *igrad, const DType *ograd) {
-      KERNEL_ASSIGN(igrad[i], Req, OP::Map(ograd[i]));
-    }
-  };
-
-  template<typename OP, int Req>
-  struct BackwardUseInOp {
-    template<typename DType>
-    MSHADOW_XINLINE static void Map(int i, DType *igrad,
-                                    const DType *ograd, const DType *lhs, const DType *rhs) {
-      KERNEL_ASSIGN(igrad[i], Req, ograd[i] * OP::Map(lhs[i], rhs[i]));
-    }
-  };
-
   /*! \brief For sparse, assume missing rvalue is 0 */
   template<typename OP, int Req>
   struct MissingRValueOp {
@@ -89,25 +74,22 @@ class ElemwiseBinaryOp : public OpBase {
    * \brief Fill contiguous dense output rows with value computed from 0 lhs and 0 rhs input
    *        CPU-Only version
    */
-  template<typename DType, typename OP>
-  static inline size_t FillDense(mshadow::Stream<cpu> *s,
+  template<typename DType, typename OP, typename xpu>
+  static inline size_t FillDense(mshadow::Stream<xpu> *s,
                                  const size_t idx_l,
                                  const size_t idx_r,
                                  const OpReqType req,
-                                 mshadow::Tensor<cpu, 2, DType> *out,
+                                 mshadow::Tensor<xpu, 2, DType> *out,
                                  const size_t iter_out) {
-    const int index_out_min = std::min(idx_l, idx_r);
+    const int index_out_min = static_cast<int>(std::min(idx_l, idx_r));
     if (static_cast<size_t>(index_out_min) > iter_out) {
-      const size_t size = (*out)[iter_out].shape_.Size();
       const DType zero_input_val = OP::Map(DType(0), DType(0));
-      #pragma omp parallel for
-      for (int i = iter_out; i < index_out_min; ++i) {
-        MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-          SerialLaunchCPU<OpBase::set_to_scalar<Req>>(s, size, (*out)[i].dptr_, zero_input_val);
-        });
+      #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+      for (int i = static_cast<int>(iter_out); i < index_out_min; ++i) {
+        Fill<false>(s, (*out)[i], req, zero_input_val);
       }
     }
-    return index_out_min;
+    return static_cast<size_t>(index_out_min);  // MSVC wants OMP loops to always use 'int'
   }
 
   static inline bool IsSameArray(const NDArray& a1, const NDArray& a2) {
@@ -135,7 +117,7 @@ class ElemwiseBinaryOp : public OpBase {
     } else if (req[0] != kNullOp) {
       DType *lgrad_dptr = outputs[0].dptr<DType>();
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-        Kernel<BackwardUseNoneOp<LOP, Req>, xpu>::Launch(s, size, lgrad_dptr, ograd_dptr);
+        Kernel<mxnet_op::op_with_req<LOP, Req>, xpu>::Launch(s, size, lgrad_dptr, ograd_dptr);
       });
     }
     if (std::is_same<ROP, mshadow_op::identity>::value && req[1] == kWriteInplace) {
@@ -143,7 +125,7 @@ class ElemwiseBinaryOp : public OpBase {
     } else if (req[1] != kNullOp) {
       DType *rgrad_dptr = outputs[1].dptr<DType>();
       MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
-        Kernel<BackwardUseNoneOp<ROP, Req>, xpu>::Launch(s, size, rgrad_dptr, ograd_dptr);
+        Kernel<mxnet_op::op_with_req<ROP, Req>, xpu>::Launch(s, size, rgrad_dptr, ograd_dptr);
       });
     }
   }
@@ -165,14 +147,14 @@ class ElemwiseBinaryOp : public OpBase {
         (outputs[0].Size() + mxnet_op::DataType<DType>::kLanes - 1)
         / mxnet_op::DataType<DType>::kLanes);
       DType * lgrad_dptr = outputs[0].dptr<DType>();
-      mxnet_op::Kernel<BackwardUseInOp<LOP, Req>, xpu>::Launch(
+      mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad<LOP>, Req>, xpu>::Launch(
         s, size, lgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
     MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
       const int size = static_cast<int>(
         (outputs[1].Size() + mxnet_op::DataType<DType>::kLanes - 1)
         / mxnet_op::DataType<DType>::kLanes);
       DType * rgrad_dptr = outputs[1].dptr<DType>();
-      mxnet_op::Kernel<BackwardUseInOp<ROP, Req>, xpu>::Launch(
+      mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad<ROP>, Req>, xpu>::Launch(
         s, size, rgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
   }
 
@@ -503,10 +485,7 @@ class ElemwiseBinaryOp : public OpBase {
         CHECK_EQ(outputs[0].storage_type(), in_stype);
         // rsp -> rsp, _. op requires 0-input returns 0-output
         DCHECK_LT(fabs(static_cast<float>(LOP::Map(0))), 1e-5f);
-        MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-          UnaryOp::KernelComputeEx<xpu, BackwardUseNoneOp<LOP, Req>>(attrs, ctx, inputs,
-                                                                     req, {outputs[0]});
-        });
+        UnaryOp::ComputeEx<xpu, LOP>(attrs, ctx, inputs, req, {outputs[0]});
       } else {
         LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
       }
@@ -517,10 +496,7 @@ class ElemwiseBinaryOp : public OpBase {
         CHECK_EQ(outputs[0].storage_type(), in_stype);
         // rsp -> _, rsp. op requires 0-input returns 0-output
         DCHECK_LT(fabs(static_cast<float>(ROP::Map(0))), 1e-5f);
-        MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
-          UnaryOp::KernelComputeEx<xpu, BackwardUseNoneOp<ROP, Req>>(attrs, ctx, inputs,
-                                                                     req, {outputs[1]});
-        });
+        UnaryOp::ComputeEx<xpu, ROP>(attrs, ctx, inputs, req, {outputs[1]});
       } else {
         LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
       }
diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h
index b866a296f6..27d8ed343c 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op.h
+++ b/src/operator/tensor/elemwise_binary_scalar_op.h
@@ -66,7 +66,7 @@ class BinaryScalarOp : public UnaryOp {
         const int64_t dense_block_count = next_input_row - output_row;
         if (dense_block_count > 0) {
           MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-            mxnet_op::Kernel<OpBase::set_to_scalar<Req>, cpu>::Launch(
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, cpu>::Launch(
               stream,
               items_per_row * dense_block_count,
               output_data.dptr_ + items_per_row * output_row,
@@ -237,11 +237,8 @@ class BinaryScalarOp : public UnaryOp {
     const double alpha = nnvm::get<double>(attrs.parsed);
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-        mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s,
-                                                                      inputs[0].Size(),
-                                                                      outputs[0].dptr<DType>(),
-                                                                      inputs[0].dptr<DType>(),
-                                                                      DType(alpha));
+        mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
+          s, inputs[0].Size(), outputs[0].dptr<DType>(), inputs[0].dptr<DType>(), DType(alpha));
       });
     });
   }
@@ -286,10 +283,13 @@ class BinaryScalarOp : public UnaryOp {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const double alpha = nnvm::get<double>(attrs.parsed);
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-      Tensor<xpu, 1, DType> igrad = outputs[0].FlatTo1D<xpu, DType>(s);
-      Tensor<xpu, 1, DType> ograd = inputs[0].FlatTo1D<xpu, DType>(s);
-      Tensor<xpu, 1, DType> lhs = inputs[1].FlatTo1D<xpu, DType>(s);
-      ASSIGN_DISPATCH(igrad, req[0], ograd * F<OP>(lhs, scalar<DType>(DType(alpha))));
+      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+        mxnet::op::mxnet_op::Kernel<mxnet::op::mxnet_op::op_with_req<
+          mxnet::op::mxnet_op::backward_grad<OP>, Req>, xpu>::
+          Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
+                 inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
+                 DType(alpha));
+      });
     });
   }
 };
diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h
index d455b7e761..6fbde05c46 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -274,25 +274,6 @@ class UnaryOp : public OpBase {
   }
 
   template<typename xpu, typename op>
-  static void KernelCompute(const nnvm::NodeAttrs& attrs,
-                            const OpContext& ctx,
-                            const std::vector<TBlob>& inputs,
-                            const std::vector<OpReqType>& req,
-                            const std::vector<TBlob>& outputs) {
-    using namespace mshadow;
-    using namespace mxnet_op;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    CHECK_EQ(inputs.size(), 1U);
-    CHECK_EQ(outputs.size(), 1U);
-    if (req[0] != kNullOp) {
-      MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-        Kernel<op, xpu>::Launch(s, outputs[0].Size(),
-                                outputs[0].dptr<DType>(), inputs[0].dptr<DType>());
-      });
-    }
-  }
-
-  template<typename xpu, typename op>
   static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs,
                                const OpContext &ctx,
                                const std::vector<TBlob> &inputs,
@@ -309,25 +290,6 @@ class UnaryOp : public OpBase {
     });
   }
 
-  template<typename xpu, typename OP>
-  static void KernelComputeEx(const nnvm::NodeAttrs& attrs,
-                              const OpContext& ctx,
-                              const std::vector<NDArray>& inputs,
-                              const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& outputs) {
-    CHECK_EQ(inputs.size(), 1U);
-    CHECK_EQ(outputs.size(), 1U);
-    const auto in_stype = inputs[0].storage_type();
-    const auto out_stype = outputs[0].storage_type();
-    if (in_stype == out_stype && (in_stype == kRowSparseStorage || in_stype == kCSRStorage)) {
-      if (inputs[0].storage_shape().Size()) {
-        MapToFCompute<xpu>(attrs, ctx, inputs, req, outputs, KernelCompute<xpu, OP>);
-      }
-    } else {
-      LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
-    }
-  }
-
   template<typename xpu>
   static void IdentityCompute(const nnvm::NodeAttrs& attrs,
                               const OpContext& ctx,
@@ -395,13 +357,9 @@ class UnaryOp : public OpBase {
   }
 };
 
+/*! \brief Map legacy unary_bwd to backward_grad */
 template<typename GRAD_OP>
-struct unary_bwd {
-  template<typename DType>
-  MSHADOW_XINLINE static DType Map(DType a, DType b) {
-    return a * GRAD_OP::Map(b);
-  }
-};
+using unary_bwd = ::mxnet::op::mxnet_op::backward_grad<GRAD_OP>;
 
 struct CastParam : public dmlc::Parameter<CastParam> {
   // use int for enumeration
@@ -445,37 +403,6 @@ void CastCompute(const nnvm::NodeAttrs& attrs,
   });
 }
 
-namespace kernel_launch_op {
-/*! \brief sigmoid unit */
-struct sigmoid {
-  template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out,
-                                  const DType *in) {
-    out[i] = mshadow_op::sigmoid::Map<DType>(in[i]);
-  }
-};
-struct sigmoid_grad {
-  template<typename DType>
-  MSHADOW_XINLINE static DType Map(DType out_grad, DType in) {
-    return out_grad * mshadow_op::sigmoid_grad::Map<DType>(in);
-  }
-};
-/*! \brief Rectified Linear Operation */
-struct relu {
-  template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out,
-                                  const DType *in) {
-    out[i] = mshadow_op::relu::Map<DType>(in[i]);
-  }
-};
-struct relu_grad {
-  template<typename DType>
-  MSHADOW_XINLINE static DType Map(DType out_grad, DType in) {
-    return out_grad * mshadow_op::relu_grad::Map<DType>(in);
-  }
-};
-}  // namespace kernel_launch_op
-
 /*! \brief Unary compute */
 #define MXNET_OPERATOR_REGISTER_UNARY(__name$)                      \
   NNVM_REGISTER_OP(__name$)                                         \
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index c356c580bc..a69897ebb1 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -83,13 +83,12 @@ The storage type of ``relu`` output depends upon the input storage type:
 
 )code" ADD_FILELINE)
 .set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1, false, true, false>)
-.set_attr<FCompute>("FCompute<cpu>", UnaryOp::KernelCompute<
-  cpu, kernel_launch_op::relu>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::KernelComputeEx<
-  cpu, kernel_launch_op::relu>)
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::relu>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::ComputeEx<cpu, mshadow_op::relu>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_relu"});
 
-MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_relu, kernel_launch_op::relu_grad);
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_relu,
+                                               unary_bwd<mshadow_op::relu_grad>);
 
 // sigmoid
 MXNET_OPERATOR_REGISTER_UNARY(sigmoid)
@@ -102,11 +101,11 @@ MXNET_ADD_SPARSE_OP_ALIAS(sigmoid)
 The storage type of ``sigmoid`` output is always dense
 
 )code" ADD_FILELINE)
-.set_attr<FCompute>("FCompute<cpu>", UnaryOp::KernelCompute<
-  cpu, kernel_launch_op::sigmoid>)
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::sigmoid>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"});
 
-MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid, kernel_launch_op::sigmoid_grad);
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid,
+                                               unary_bwd<mshadow_op::sigmoid_grad>);
 
 // copy
 MXNET_OPERATOR_REGISTER_UNARY(_copy)
@@ -406,7 +405,8 @@ The storage type of ``sign`` output depends upon the input storage type:
 )code" ADD_FILELINE)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_sign"});
 
-MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sign, unary_bwd<mshadow_op::sign_grad>);
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sign,
+                                               unary_bwd<mshadow_op::sign_grad>);
 
 // round
 MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP(round, cpu, mshadow_op::round)
@@ -719,7 +719,6 @@ The storage type of ``expm1`` output depends upon the input storage type:
 
 MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_expm1, unary_bwd<mshadow_op::exp>);
 
-
 // gamma
 MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(gamma, cpu, mshadow_op::gamma)
 MXNET_ADD_SPARSE_OP_ALIAS(gamma)
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu
index 3f982a2c56..3ea4137fb8 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cu
+++ b/src/operator/tensor/elemwise_unary_op_basic.cu
@@ -26,18 +26,19 @@
 namespace mxnet {
 namespace op {
 NNVM_REGISTER_OP(relu)
-.set_attr<FCompute>("FCompute<gpu>", UnaryOp::KernelCompute<gpu, kernel_launch_op::relu>)
-.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::KernelComputeEx<gpu, kernel_launch_op::relu>);
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::relu>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::ComputeEx<gpu, mshadow_op::relu>);
 
 NNVM_REGISTER_OP(_backward_relu)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, kernel_launch_op::relu_grad>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
+  gpu, mxnet_op::backward_grad<mshadow_op::relu_grad>);
 
 NNVM_REGISTER_OP(sigmoid)
-.set_attr<FCompute>("FCompute<gpu>", UnaryOp::KernelCompute<gpu, kernel_launch_op::sigmoid>);
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::sigmoid>);
 
 NNVM_REGISTER_OP(_backward_sigmoid)
 .set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
-  gpu, kernel_launch_op::sigmoid_grad>);
+  gpu, mxnet_op::backward_grad<mshadow_op::sigmoid_grad>);
 
 // copy
 NNVM_REGISTER_OP(_copy)
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index bb6d3c129f..be57d12a9c 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -32,6 +32,7 @@
 #include <vector>
 #include <string>
 #include <limits>
+#include "../mshadow_op.h"
 #include "../elemwise_op_common.h"
 #include "../mxnet_op.h"
 #include "../mshadow_op.h"
@@ -225,7 +226,7 @@ void Fill(mshadow::Stream<xpu> *s, const TBlob& b, const OpReqType req, ValueTyp
       // Optimize common use-case of filling with ones
       MSHADOW_TYPE_SWITCH(b.type_flag_, DType, {
         MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-          mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::set_to_int<1>, Req>, xpu>::Launch(
+          mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::set_one, Req>, xpu>::Launch(
             s, b.Size(), b.dptr<DType>());
         });
       });
@@ -270,6 +271,7 @@ struct PopulateFullIdxRspKernel {
     KERNEL_ASSIGN(out[i], kWriteTo, i);
   }
 };
+MXNET_TUNABLE_MXNET_OP_FWD(PopulateFullIdxRspKernel);
 
 // Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray,
 // instead of the usual compact representation.
diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h
index 21d0776fca..c454c95847 100644
--- a/tests/cpp/include/test_core_op.h
+++ b/tests/cpp/include/test_core_op.h
@@ -33,7 +33,24 @@ namespace op {
 // Tried making this a struct w/constexpr, but getting undefined reference on gcc 5.4.1
 #define COREOP_FWD_OP_NAME_KEY          "fwd_op_name"
 #define COREOP_BWD_OP_NAME_KEY          "bwd_op_name"
-#define COREOP_BWD_OP_NAME_VALUE_NONE   "<none>"
+#define COREOP_BWD_OP_NAME_VALUE_NONE   "[none]"
+
+enum TimingDirection {
+  Forward,
+  Backward
+};
+
+inline const char *TimingDirectionAsString(const TimingDirection td) {
+  switch (td) {
+    case Forward:
+      return "Forward";
+    case Backward:
+      return "Backward";
+    default:
+      CHECK(false) << "Unknown timing direction: " << static_cast<int>(td);
+      return "<unknown>";
+  }
+}
 
 /*!
  * Low-noise operator executor
@@ -43,11 +60,6 @@ template<typename DType>
 class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
   , public test::op::OperatorExecutorTiming {
   /*! \brief Performance timing categories */
-  enum TimingId {
-    Forward,
-    Backward
-  };
-
   /*!
    * \brief Access data blob as if on the CPU via a callback
    * \tparam Type of callback Function to call with CPU-data NDArray
@@ -92,8 +104,8 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
     values.reserve(count);
     for (kwargs_t::const_iterator i_iter = args.begin(), e_iter = args.end();
          i_iter != e_iter; ++i_iter) {
-      keys.push_back(i_iter->first.c_str());
-      values.push_back(i_iter->second.c_str());
+      keys.emplace_back(i_iter->first.c_str());
+      values.emplace_back(i_iter->second.c_str());
     }
     return imperative::ParseAttrs(op, op->num_inputs, count, &keys[0], &values[0]);
   }
@@ -108,7 +120,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
                                                  std::vector<TBlob> *dest) {
     dest->reserve(dest->size() + src.size());
     for (size_t i = 0, n = src.size(); i < n; ++i) {
-      dest->push_back(src[i].data());
+      dest->emplace_back(src[i].data());
     }
     return *dest;
   }
@@ -194,9 +206,9 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
       for (const ResourceRequest& req : reqs) {
         if (req.type == ResourceRequest::kTempSpace) {
           Resource r = ResourceManager::Get()->Request(ctx->run_ctx.ctx, req);
-          requested.push_back(r);
+          requested.emplace_back(r);
         } else if (req.type == ResourceRequest::kRandom) {
-          requested.push_back(ResourceManager::Get()->Request(ctx->run_ctx.ctx, req));
+          requested.emplace_back(ResourceManager::Get()->Request(ctx->run_ctx.ctx, req));
         } else {
           LOG(FATAL) << "resource type not yet supported";
         }
@@ -216,7 +228,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
     new_args.reserve(args.size() + 1);
     for (const auto& a : args) {
       if (a.first != COREOP_FWD_OP_NAME_KEY && a.first != COREOP_BWD_OP_NAME_KEY) {
-        new_args.push_back(a);
+        new_args.emplace_back(a);
       }
     }
     new_args.push_back({ COREOP_FWD_OP_NAME_KEY, fwd_op_name});
@@ -241,7 +253,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
       } else if (a.first == COREOP_BWD_OP_NAME_KEY) {
         *bwd_op_name_ptr = a.second;
       } else {
-        new_args.push_back(a);
+        new_args.emplace_back(a);
       }
     }
     return new_args;
@@ -317,7 +329,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
       // operators such as dot
       std::vector<TShape> shapes;
       for (size_t i = 0, n = std::max(num_visible_outputs, num_inputs); i < n; ++i) {
-        shapes.push_back(i < input_shapes_.size() ? input_shapes_[i]
+        shapes.emplace_back(i < input_shapes_.size() ? input_shapes_[i]
                                                   : input_shapes_[input_shapes_.size() - 1]);
       }
       std::vector<NDArray *> inputs_p, outputs_p;
@@ -331,21 +343,21 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
       outputs_.reserve(num_visible_outputs);
       outputs_p.reserve(num_visible_outputs);
 
-      for (int i = 0; i < num_inputs; ++i) {
+      for (size_t i = 0; i < static_cast<size_t>(num_inputs); ++i) {
         CHECK_LT(i, static_cast<int>(shapes.size()));
-        inputs_.push_back(i < inputs.size() ? inputs[i] : CreateRandArray(shapes[i],
+        inputs_.emplace_back(i < inputs.size() ? inputs[i] : CreateRandArray(shapes[i],
                                                                           ctx_.run_ctx.ctx));
-        inputs_p.push_back(&*inputs_.rbegin());
+        inputs_p.emplace_back(&*inputs_.rbegin());
       }
 
-      for (int i = 0; i < num_visible_outputs; ++i) {
+      for (size_t i = 0; i < static_cast<size_t>(num_visible_outputs); ++i) {
         // If supplied and valid, pass from the supplied outputs vector
         // Otherwise use empty for forward pass, or zero-filled for backward pass
-        outputs_.push_back(i < outputs.size()
-                           ? outputs[i]
-                           : (backward_for_op ? CreateZeroArray(shapes[i], ctx_.run_ctx.ctx)
-                                              : NDArray()));
-        outputs_p.push_back(&*outputs_.rbegin());
+        outputs_.emplace_back(i < outputs.size()
+                              ? outputs[i]
+                              : (backward_for_op ? CreateZeroArray(shapes[i], ctx_.run_ctx.ctx)
+                                                 : NDArray()));
+        outputs_p.emplace_back(&*outputs_.rbegin());
       }
 
       if (!backward_for_op) {
@@ -396,7 +408,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
             << "Can't automatically determine backward op name. Please specify";
           for (std::pair<std::shared_ptr<CoreOpExecutor>, std::string> &bw_item : bwd) {
             bw_item.first->set_verbose(verbose_);
-            backward_.push_back(bw_item.first);
+            backward_.emplace_back(bw_item.first);
             bw_item.first->Init(ArgsWithOpName(args, bw_item.second), {}, {}, this);
           }
         }
diff --git a/tests/cpp/include/test_legacy_op.h b/tests/cpp/include/test_legacy_op.h
index 30bdf07b8b..6d326fc3c0 100644
--- a/tests/cpp/include/test_legacy_op.h
+++ b/tests/cpp/include/test_legacy_op.h
@@ -135,7 +135,7 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
         // Get the resource of temporal space
         std::vector<TShape> inputShapes;
         for (size_t x = 0, n = shape_input_vec_.size(); x < n; ++x) {
-          inputShapes.push_back(shape_input_vec_[x]);
+          inputShapes.emplace_back(shape_input_vec_[x]);
         }
         allocateResources(opProp.ForwardResource(inputShapes));
 
@@ -408,11 +408,11 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
 
     std::vector<std::vector<TBlob> *> all_blob_vects_;
     inline OpData() {
-      all_blob_vects_.push_back(&blob_input_vec_);
-      all_blob_vects_.push_back(&blob_output_vec_);
-      all_blob_vects_.push_back(&blob_aux_states_);
-      all_blob_vects_.push_back(&blob_in_grad_);
-      all_blob_vects_.push_back(&blob_out_grad_);  // Remaining err (loss) pushing back upstream
+      all_blob_vects_.emplace_back(&blob_input_vec_);
+      all_blob_vects_.emplace_back(&blob_output_vec_);
+      all_blob_vects_.emplace_back(&blob_aux_states_);
+      all_blob_vects_.emplace_back(&blob_in_grad_);
+      all_blob_vects_.emplace_back(&blob_out_grad_);  // Remaining err (loss) pushing back upstream
     }
     virtual ~OpData() {}
   };
@@ -495,14 +495,14 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
     for (const ResourceRequest& req : reqs) {
       if (req.type == ResourceRequest::kTempSpace) {
         if (cached_temp.count(ctx) != 0) {
-          opContext_.requested.push_back(cached_temp.at(ctx));
+          opContext_.requested.emplace_back(cached_temp.at(ctx));
         } else {
           Resource r = ResourceManager::Get()->Request(ctx, req);
-          opContext_.requested.push_back(r);
+          opContext_.requested.emplace_back(r);
           cached_temp[ctx] = r;
         }
       } else if (req.type == ResourceRequest::kRandom) {
-        opContext_.requested.push_back(ResourceManager::Get()->Request(ctx, req));
+        opContext_.requested.emplace_back(ResourceManager::Get()->Request(ctx, req));
       } else {
         LOG(FATAL) << "resource type not yet supported";
       }
@@ -517,8 +517,8 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
                              const int dtype) {
     test::StandaloneBlob *blob = new test::StandaloneBlob(shape, isGPU, dtype);
     CHECK_NE(blob, static_cast<TBlob *>(nullptr));
-    standalone_blobs->push_back(std::unique_ptr<test::StandaloneBlob>(blob));
-    (*dest).push_back(*blob);
+    standalone_blobs->emplace_back(std::unique_ptr<test::StandaloneBlob>(blob));
+    (*dest).emplace_back(*blob);
     return blob;
   }
 
diff --git a/tests/cpp/include/test_ndarray_utils.h b/tests/cpp/include/test_ndarray_utils.h
index bbc7c05bec..f5ab96794a 100644
--- a/tests/cpp/include/test_ndarray_utils.h
+++ b/tests/cpp/include/test_ndarray_utils.h
@@ -80,7 +80,7 @@ inline NDArray DnsND(const TShape shape, const Context ctx, std::vector<TEST_DTY
   // generate random values
   while (vs.size() < num_val) {
     auto v = RandFloat();
-    vs.push_back(v);
+    vs.emplace_back(v);
   }
   CHECK_EQ(vs.size(), nd.shape().Size());
   MSHADOW_TYPE_SWITCH(nd.dtype(), DType, {
diff --git a/tests/cpp/include/test_op.h b/tests/cpp/include/test_op.h
index 949f2ccdf4..cbafe14152 100644
--- a/tests/cpp/include/test_op.h
+++ b/tests/cpp/include/test_op.h
@@ -100,17 +100,27 @@ class OperatorDataInitializer {
    * \param blob Blob which to fill with random values
    */
   void FillRandom(const TBlob& blob) const {
-    std::uniform_real_distribution<DType> distribution(-1.0, 1.0);
-    test::patternFill<DType>(&blob, [this, &distribution]() -> DType {
-      return distribution(this->generator());
+    std::uniform_real_distribution<> dis_real(-5.0, 5.0);
+    std::uniform_int_distribution<> dis_int(-128, 127);
+    test::patternFill<DType>(&blob, [this, &dis_real, &dis_int]() -> DType {
+      if (!std::is_integral<DType>::value) {
+        DType val;
+        do {
+          val = static_cast<DType>(dis_real(this->generator()));
+        } while (fabs(val) < 1e-5);  // If too close to zero, try again
+        return val;
+      } else {
+        DType val;
+        do {
+          val = static_cast<DType>(dis_int(this->generator()));
+        } while (!val);  // If zero, try again
+        return val;
+      }
     });
   }
 
   void FillZero(const TBlob& blob) const {
-    std::uniform_real_distribution<DType> distribution(-1.0, 1.0);
-    test::patternFill<DType>(&blob, [this, &distribution]() -> DType {
-      return DType(0);
-    });
+    test::patternFill<DType>(&blob, []() -> DType { return DType(0); });
   }
 
  private:
@@ -271,7 +281,7 @@ inline std::vector<TShape> ShapesOf(const std::vector<NDArray>& arrays) {
   std::vector<TShape> res;
   res.reserve(arrays.size());
   for (const NDArray& ar : arrays) {
-    res.push_back(ar.shape());
+    res.emplace_back(ar.shape());
   }
   return std::move(res);
 }
diff --git a/tests/cpp/include/test_op_runner.h b/tests/cpp/include/test_op_runner.h
index 4c7cd1d774..eb259997cd 100644
--- a/tests/cpp/include/test_op_runner.h
+++ b/tests/cpp/include/test_op_runner.h
@@ -122,15 +122,14 @@ class OperatorRunner {
    * \param dim Data dimensions
    * \param count Number of times to run in each direction
    */
-  void TimingTest(const std::string& label,
-                  const bool isGPU,
-                  const bool stochastic,
-                  const test::op::kwargs_t& kwargs,
-                  int dim = 0,
-                  size_t count = 1,
-                  const std::vector<TShape>& timing_shapes = {}) {
-    std::cout << std::endl << std::flush;
-
+  std::unordered_map<int, perf::TimingInstrument::Info>
+  TimingTest(const std::string& label,
+             const bool isGPU,
+             const bool stochastic,
+             const test::op::kwargs_t& kwargs,
+             int dim = 0,
+             size_t count = 1,
+             const std::vector<TShape>& timing_shapes = {}) {
 #ifdef NDEBUG
     size_t COUNT = 50;
 #else
@@ -160,7 +159,7 @@ class OperatorRunner {
 
       if (timing_shapes.empty()) {
         do {
-          batchSize = stochastic ? test::rangedRand(1U, TES_BATCH_SIZE * 2U) : TIMING_BATCH_SIZE;
+          batchSize = stochastic ? test::rangedRand(1U, TEST_BATCH_SIZE * 2U) : TIMING_BATCH_SIZE;
           channels = stochastic ? test::rangedRand(1U, TEST_CHANNELS * 2U) : TIMING_CHANNELS;
           depth = stochastic ? test::rangedRand(1U, TEST_DEPTH * 2U) : TIMING_DEPTH;
           height = stochastic ? test::rangedRand(1U, TEST_DH * 2U) : TIMING_DH;
@@ -218,12 +217,18 @@ class OperatorRunner {
       }
     }
 
-    timing.print(&std::cout, label);
-    std::cout << std::endl << std::flush;
+    if (verbose_) {
+      timing.print(&std::cout, label);
+      std::cout << std::endl << std::flush;
+    }
+
+    return timing.data();
   }
 
+  void set_verbose(bool verbose) { verbose_ = verbose; }
+
  protected:
-  static constexpr int TES_BATCH_SIZE = 5;
+  static constexpr int TEST_BATCH_SIZE = 5;
   static constexpr int TEST_CHANNELS = 3;
   static constexpr int TEST_DEPTH = 2;
   static constexpr int TEST_DH = 2;
@@ -234,6 +239,8 @@ class OperatorRunner {
   static constexpr int TIMING_DEPTH = 2;
   static constexpr int TIMING_DH = 64;
   static constexpr int TIMING_DW = 64;
+  /*! \brief verbose output */
+  bool verbose_ = true;
 };
 
 }  // namespace test
diff --git a/tests/cpp/include/test_perf.h b/tests/cpp/include/test_perf.h
index b6f2145767..7971ed7985 100644
--- a/tests/cpp/include/test_perf.h
+++ b/tests/cpp/include/test_perf.h
@@ -45,7 +45,7 @@ namespace perf {
 inline uint64_t getMicroTickCount() {
 #ifndef _WIN32
   struct timeval tv;
-  gettimeofday(&tv, NULL);
+  gettimeofday(&tv, nullptr);
   return uint64_t(tv.tv_sec) * 1000000 + tv.tv_usec;
 #else
   LARGE_INTEGER CurrentTime;
@@ -79,11 +79,6 @@ inline uint64_t getNannoTickCount() {
 #endif
 }
 
-/*! \brief millisecond tick count */
-inline uint64_t getTickCount() {
-  return getMicroTickCount() / 1000;
-}
-
 #define MICRO2MS(__micro$)  (((__micro$) + 500)/1000)
 #define MICRO2MSF(__micro$) (static_cast<float>(__micro$)/1000)
 #define MICRO2MSF(__micro$) (static_cast<float>(__micro$)/1000)
@@ -100,7 +95,7 @@ class TimedScope {
   const size_t    count_;
 
  public:
-  explicit inline TimedScope(const char *msg = NULL, size_t count = 1, const bool start = true)
+  explicit inline TimedScope(const char *msg = nullptr, size_t count = 1, const bool start = true)
     : startTime_(start ? getMicroTickCount() : 0)
       , stopTime_(0)
       , count_(count) {
@@ -164,7 +159,7 @@ class TimingInstrument {
   }
   void startTiming(int id, const char *s) {
     std::unique_lock<std::recursive_mutex>  lk(mutex_);
-    std::unordered_map<int, Info>::iterator i = data_.find(id);
+    auto i = data_.find(id);
     if (i == data_.end()) {
       i = data_.emplace(std::make_pair(id, Info(s))).first;
     }
@@ -174,7 +169,7 @@ class TimingInstrument {
   }
   void stopTiming(int id, const size_t subIterationCount = 1) {
     std::unique_lock<std::recursive_mutex>  lk(mutex_);
-    std::unordered_map<int, Info>::iterator i = data_.find(id);
+    auto i = data_.find(id);
     CHECK_NE(i == data_.end(), true) << "Can't stop timing on an object that we don't know about";
     if (i != data_.end()) {
       CHECK_NE(i->second.nestingCount_, 0U) << "While stopping timing, invalid nesting count of 0";
@@ -188,7 +183,7 @@ class TimingInstrument {
   }
   uint64_t getDuration(int id) {
     std::unique_lock<std::recursive_mutex>  lk(mutex_);
-    std::unordered_map<int, Info>::iterator i = data_.find(id);
+    auto i = data_.find(id);
     if (i != data_.end()) {
       const Info&        info = i->second;
       const uint64_t duration = info.nestingCount_.load()
@@ -202,7 +197,7 @@ class TimingInstrument {
   bool isTiming(int id) {
     std::unordered_map<int, Info>::const_iterator i = data_.find(id);
     if (i != data_.end()) {
-      return !!i->second.nestingCount_.load();
+      return i->second.nestingCount_.load() != 0;
     }
     return false;
   }
@@ -216,7 +211,7 @@ class TimingInstrument {
         i != e; ++i) {
       const Info&        info = i->second;
       const uint64_t duration = getDuration(i->first);
-      *os << /*std::endl <<*/ label_ << ": " << name_ << " Timing [" << info.name_ << "] "
+      *os << label_ << ": " << name_ << " Timing [" << info.name_ << "] "
           << (info.nestingCount_.load() ? "*" : "")
           << MICRO2MSF(duration) << " ms";
         if (info.cycleCount_.load()) {
@@ -233,7 +228,7 @@ class TimingInstrument {
 
   void reset() {
     std::unique_lock<std::recursive_mutex>  lk(mutex_);
-    for (std::unordered_map<int, Info>::iterator i = data_.begin(), e = data_.end();
+    for (auto i = data_.begin(), e = data_.end();
         i != e; ++i) {
       const int id = i->first;
       const bool wasTiming = isTiming(id);
@@ -250,9 +245,9 @@ class TimingInstrument {
   }
 
   TimingInstrument& operator += (const TimingInstrument& o) {
-    for (std::unordered_map<int, Info>::const_iterator i = o.data_.begin(), e = o.data_.end();
+    for (auto i = o.data_.begin(), e = o.data_.end();
         i != e; ++i) {
-      std::unordered_map<int, Info>::iterator j = data_.find(i->first);
+      auto j = data_.find(i->first);
       if (j != data_.end())  {
         const Info &oInfo = i->second;
         CHECK_EQ(oInfo.nestingCount_, 0U);
@@ -265,7 +260,6 @@ class TimingInstrument {
     return *this;
   }
 
- private:
   struct Info {
     explicit inline Info(const char *s)
       : name_(s ? s : "")
@@ -273,6 +267,7 @@ class TimingInstrument {
         , nestingCount_(0)
         , cycleCount_(0)
         , duration_(0) {}
+
     inline Info(const Info& o)
       : name_(o.name_)
         , baseTime_(o.baseTime_.load())
@@ -281,17 +276,36 @@ class TimingInstrument {
         , duration_(o.duration_.load()) {
       CHECK_EQ(o.nestingCount_, 0U);
     }
+
+    /*!
+     * \brief Return time for each operation in milliseconds
+     * \return Time for each operation in milliseconds
+     */
+    inline double TimeEach() const {
+      return static_cast<double>(duration_) / cycleCount_.load() / 1000.0f;
+    }
+
     std::string           name_;
     std::atomic<uint64_t> baseTime_;
     std::atomic<uint64_t> nestingCount_;
     std::atomic<uint64_t> cycleCount_;  // Note that nesting may skew averages
     std::atomic<uint64_t> duration_;
   };
+
+  typedef std::unordered_map<int, TimingInstrument::Info> timing_map_t;
+
+  const timing_map_t& data() const {
+    return data_;
+  }
+
+ private:
   std::string                   name_;
   mutable std::recursive_mutex  mutex_;
   std::unordered_map<int, Info> data_;
 };
 
+using timing_map_t = TimingInstrument::timing_map_t;
+
 /*! \brief Accumulated scoped timing, indexed by ID */
 class TimingItem {
  public:
diff --git a/tests/cpp/include/test_util.h b/tests/cpp/include/test_util.h
index 95ab141954..33ca3c47d0 100644
--- a/tests/cpp/include/test_util.h
+++ b/tests/cpp/include/test_util.h
@@ -609,6 +609,67 @@ inline ScalarType rangedRand(const ScalarType min, const ScalarType max) {
   return static_cast<ScalarType>(x / bin_size + min);
 }
 
+/*!
+ * \brief Deterministically compare TShape objects as less-than,
+ *        for use in stl sorted key such as map and set
+ * \param s1 First shape
+ * \param s2 Second shape
+ * \return true if s1 is less than s2
+ */
+inline bool operator < (const nnvm::TShape &s1, const nnvm::TShape &s2) {
+  if (s1.Size() == s2.Size()) {
+    if (s1.ndim() == s2.ndim()) {
+      for (size_t i = 0, n = s1.ndim(); i < n; ++i) {
+        if (s1[i] == s2[i]) {
+          continue;
+        }
+        return s1[i] < s2[i];
+      }
+      return false;
+    }
+    return s1.ndim() < s2.ndim();
+  }
+  return s1.Size() < s2.Size();
+}
+
+/*!
+ * \brief Deterministically compare a vector of TShape objects as less-than,
+ *        for use in stl sorted key such as map and set
+ * \param v1 First vector of shapes
+ * \param v2 Second vector of shapes
+ * \return true if v1 is less than v2
+ */
+inline bool operator < (const std::vector<nnvm::TShape>& v1, const std::vector<nnvm::TShape>& v2) {
+  if (v1.size() == v2.size()) {
+    for (size_t i = 0, n = v1.size(); i < n; ++i) {
+      if (v1[i] == v2[i]) {
+        continue;
+      }
+      return v1[i] < v2[i];
+    }
+    return false;
+  }
+  return v1.size() < v2.size();
+}
+
+/*!
+ * \brief std::less compare structure for compating vectors of shapes for stl sorted containers
+ */
+struct less_shapevect {
+  bool operator()(const std::vector<nnvm::TShape>& v1, const std::vector<nnvm::TShape>& v2) const {
+    if (v1.size() == v2.size()) {
+      for (size_t i = 0, n = v1.size(); i < n; ++i) {
+        if (v1[i] == v2[i]) {
+          continue;
+        }
+        return v1[i] < v2[i];
+      }
+      return false;
+    }
+    return v1.size() < v2.size();
+  }
+};
+
 inline std::string pretty_num(uint64_t val) {
   std::string res, s = std::to_string(val);
   size_t ctr = 0;
diff --git a/tests/cpp/operator/core_op_runner_test.cc b/tests/cpp/operator/runner/core_op_runner_test.cc
similarity index 100%
rename from tests/cpp/operator/core_op_runner_test.cc
rename to tests/cpp/operator/runner/core_op_runner_test.cc
diff --git a/tests/cpp/operator/coreop_perf.cc b/tests/cpp/operator/sgd_mom_perf.cc
similarity index 100%
rename from tests/cpp/operator/coreop_perf.cc
rename to tests/cpp/operator/sgd_mom_perf.cc
diff --git a/tests/cpp/operator/slice_channel_perf.cc b/tests/cpp/operator/slice_channel_perf.cc
new file mode 100644
index 0000000000..dc42d2a5d4
--- /dev/null
+++ b/tests/cpp/operator/slice_channel_perf.cc
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  \file activation_perf.cc
+ *  \brief Perf/profile run of ActivationOp
+ *  \author Chris Olivier
+ */
+
+#include <gtest/gtest.h>
+#include <mxnet/tensor_blob.h>
+#include "../include/test_op_runner.h"
+#include "../include/test_legacy_op.h"
+#include "../../src/operator/slice_channel-inl.h"
+
+using namespace mxnet;
+
+typedef std::vector<std::pair<std::string, std::string> > kwargs_t;
+const kwargs_t basic_activation_args = { };
+
+/*!
+ * \brief Generic bidirectional sanity test
+ */
+TEST(SLICE_CHANNEL_PERF, ExecuteBidirectional) {
+  TShape shape({1, 160, 200});
+  kwargs_t kwargs = basic_activation_args;
+  kwargs.push_back({"num_outputs", "160"});
+  test::op::LegacyOpRunner<mxnet::op::SliceChannelProp, float, float> runner;
+  runner.RunBidirectional(false, { shape }, kwargs, 1);
+}
+
+/*!
+ * \brief ActivationOp timing test for CPU
+ */
+TEST(SLICE_CHANNEL_PERF, TimingCPU) {
+  kwargs_t kwargs = basic_activation_args;
+  // Which math function is arbitrary since it will have roughly constant timing among approaches
+  kwargs.push_back({"num_outputs", "160"});
+  test::op::LegacyOpRunner<mxnet::op::SliceChannelProp, float, float> runner;
+  runner.RunBidirectional(false,
+                          { TShape({1, 160, 200}) },
+                          kwargs, 1);  // prime code and cache
+  std::vector <TShape> shapes;
+  if (test::performance_run) {
+    shapes = {
+      {1, 160, 200},
+      {10, 160, 200},
+      {100, 160, 200},
+      {10, 160, 500},
+      {100, 160, 500}
+    };
+  } else {
+    shapes = {
+      {1, 160, 200},
+      {1, 160, 200}
+    };
+  }
+  for (const TShape &shape : shapes) {
+    runner.TimingTest("SliceChannel Operator CPU", false, false, kwargs, 2, 10, { shape });
+  }
+}
+
+#if MXNET_USE_CUDA == 1
+/*!
+ * \brief ActivationOp timing test for GPU
+ */
+TEST(SLICE_CHANNEL_PERF, TimingGPU) {
+  kwargs_t kwargs = basic_activation_args;
+  // Which math function is arbitrary since it will have roughly constant timing among approaches
+  kwargs.push_back({"num_outputs", "160"});
+  test::OperatorRunner<mxnet::op::SliceChannelProp,
+    test::op::LegacyOperatorExecutor<float, float>> runner;
+  runner.RunBidirectional(true,
+                          { TShape({1, 160, 200}) },
+                          kwargs, 1);  // prime code and cache
+  std::vector <TShape> shapes = {
+      {1, 160, 200},
+      {1, 160, 200},
+      {1, 160, 200},
+      {1, 160, 200},
+      {1, 160, 200}
+    };
+  for (const TShape &shape : shapes) {
+    runner.TimingTest("SliceChannel Operator GPU", true, false, kwargs, 2, 10, { shape });
+  }
+}
+#endif  // MXNET_USE_CUDA == 1
+
diff --git a/tests/cpp/operator/tune/operator_tune_test.cc b/tests/cpp/operator/tune/operator_tune_test.cc
new file mode 100644
index 0000000000..897b9979ac
--- /dev/null
+++ b/tests/cpp/operator/tune/operator_tune_test.cc
@@ -0,0 +1,413 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <mxnet/tensor_blob.h>
+#include "../../src/operator/activation-inl.h"
+#include "../../src/operator/operator_tune-inl.h"
+#include "../include/test_op_runner.h"
+#include "../include/test_core_op.h"
+
+using namespace mxnet;
+
+/*!
+ * \brief ActivationOp timing test for CPU
+ */
+TEST(OMP_TUNING, ShowAllTunedOps) {
+  const std::unordered_set<std::string>& op_names = op::OperatorTune<float>::TunedOperatorNames();
+  for (auto iter = op_names.begin(), e_iter = op_names.end(); iter != e_iter; ++iter) {
+    std::cout << *iter << std::endl;
+  }
+}
+
+using kwargs_t = test::op::kwargs_t;
+
+/*!
+ * \brief Rune a core op forward and backward
+ * \tparam DType Data type
+ * \param isGPU true if operation is to be run on the GPU
+ * \param op_kwargs Operator parameters
+ * \param op_name Operator name as registered with nnvm
+ * \param backward_op_name Backwards operator name as registered with nnvm
+ *        If blank, the runner will attempt to determine the backwards operator. If it fails,
+ *        an exception will be thrown.
+ *        If the string is [none], then no backward operator will be created or executed
+ */
+template<typename DType = float>
+static void RunCoreOpBidirectional(const bool isGPU,
+                                   const kwargs_t& op_kwargs,
+                                   const char *op_name,
+                                   const char *backward_op_name = "") {
+  const TShape shape({5, 5});
+  test::op::CoreOpExecutor<DType> op(isGPU, { shape });
+  op.set_verbose(false);
+
+  op.Init(op.ArgsWithOpName(op_kwargs, op_name, backward_op_name));
+
+  PRINT_NDARRAYS(op.ctx().run_ctx, op.inputs());
+  PRINT_NDARRAYS(op.ctx().run_ctx, op.outputs());
+  op.Execute();
+  PRINT_NDARRAYS(op.ctx().run_ctx, op.outputs());
+  if (op.HasBackward()) {
+    PRINT_NDARRAYS(op.ctx().run_ctx, op.bwd_inputs());
+    PRINT_NDARRAYS(op.ctx().run_ctx, op.bwd_outputs());
+    op.ExecuteBackward();
+    PRINT_NDARRAYS(op.ctx().run_ctx, op.bwd_outputs());
+  }
+}
+
+/*!
+ * \brief Generic bidirectional sanity test
+ */
+TEST(OMP_TUNING, ExecuteBidirectional) {
+  RunCoreOpBidirectional(false, {}, "elemwise_add", "_backward_add");
+}
+
+template<typename DType>
+class TuningTester {
+ public:
+  using bool_mode_pair = std::pair<bool, op::tune::TuningMode>;
+
+  using shape_vect = std::vector<TShape>;
+  using shape_vec_to_bool_map = std::map<shape_vect, bool_mode_pair, test::less_shapevect>;
+
+ private:
+  using ShapesToPerfTimingMap =
+  std::map<shape_vect, test::perf::timing_map_t, test::less_shapevect>;
+
+  /*!
+   * \brief Run timing test on various data shapes and sizes
+   * \param isGPU true if the GPU should be used for the timing test
+   * \param op_kwargs operator parameters
+   * \param op_name The operator's registered name (with nnvm)
+   * \param backward_op_name The backward operator's registered name (with nnvm)
+   * \return ShapesToPerfTimingMap map holsing timing data for shapes
+   */
+  static ShapesToPerfTimingMap RunCoreOpTimingTest(const bool isGPU,
+                                                   const kwargs_t &op_kwargs,
+                                                   const char *op_name,
+                                                   const char *backward_op_name = "") {
+    ShapesToPerfTimingMap res;
+    const kwargs_t kwargs = test::op::CoreOpExecutor<DType>::ArgsWithOpName(
+      op_kwargs, op_name, backward_op_name);
+
+    // prime code and cache before the performance runs
+    test::op::CoreOperatorRunner<DType> runner;
+    runner.set_verbose(false);
+    runner.RunBidirectional(false, {{10, 3, 18, 128}}, kwargs, 1);
+
+    // Do the performance runs
+    shape_vect shapes;
+    if (test::performance_run) {
+      shapes = {
+        {1,  1, 28,  28},
+        {1,  3, 28,  28},
+        {50, 1, 18,  32},
+        {25, 3, 64,  64},
+        {10, 3, 128, 128},
+        {20, 3, 256, 256}
+      };
+    } else {
+      shapes = {
+        // Non-performance dataset acts as a sanity test
+        {1,  1, 28, 28},
+        {50, 3, 18, 32}
+      };
+    }
+    const char *pu = isGPU ? "GPU" : "CPU";
+    for (const TShape &shape : shapes) {
+      const shape_vect this_run_shapes = {shape};
+      test::perf::timing_map_t tmap = runner.TimingTest(std::string(op_name) + " Operator " + pu,
+                                                        isGPU, false, kwargs, 2, 10,
+                                                        this_run_shapes);
+      CHECK(res.find(this_run_shapes) == res.end());
+      res[this_run_shapes] = tmap;
+    }
+    return std::move(res);
+  }
+
+  using tuned_timing_t = std::map<
+    shape_vect,
+    std::map<op::tune::TuningMode, test::perf::timing_map_t>, test::less_shapevect>;
+
+  using modesort_t = std::multimap<double, op::tune::TuningMode>;
+
+  /*!
+   * \brief Check if the tuning succeeded
+   * \param mode_sort modesort_t structure produced by 'CalculateModeSort'
+   * \param closeness_factor fraction of largest standard time (omp, no omp) which is an acceptable
+   *        range
+   * \return a pair <bool, TuningMode> consisting of true or false signifying if the test appears to
+   *         have made the correct decision, and the TuningMode which was closest in timing to
+   *         the Auto mode.
+   */
+  static bool_mode_pair
+  CheckCorrectTuning(const modesort_t &mode_sort,
+                     const double closeness_factor = 0.25) {
+    CHECK_EQ(mode_sort.size(), 3U);
+
+    // Determine fastest normal mode
+    op::tune::TuningMode fastest_standard_mode = op::tune::Auto;
+    for (auto i = mode_sort.begin(), e = mode_sort.end(); i != e; ++i) {
+      if (i->second != op::tune::Auto) {
+        fastest_standard_mode = i->second;
+        break;
+      }
+    }
+    CHECK_NE(fastest_standard_mode, op::tune::Auto);
+
+    // We should be closest to the faster of NeverOMP and AlwaysOMP
+    // Take into account some variance, especially if NeverOMP and AlwaysOMP are close together
+    std::map<op::tune::TuningMode, double> mode2time;
+    for (auto i = mode_sort.begin(), e = mode_sort.end(); i != e; ++i) {
+      mode2time[i->second] = i->first;
+    }
+    const double time_auto = mode2time[op::tune::Auto];
+    const double time_no_omp = mode2time[op::tune::NeverOMP];
+    const double time_omp = mode2time[op::tune::AlwaysOMP];
+
+    // If difference between OMP and no OMP is < closeness_factor of largest of the two,
+    // then we just want to make sure we are close to both of these
+    const double fastest_standard_time = std::min(time_no_omp, time_omp);
+    const double allowed_difference = closeness_factor * fastest_standard_time;
+    const double mustbe_asfast = fastest_standard_time + allowed_difference;
+
+    // Figure out which one we are closest to and return that to help in the analysis
+    op::tune::TuningMode closest_to;
+    if (fabs(time_auto - time_no_omp) < fabs(time_auto - time_omp)) {
+      closest_to = op::tune::NeverOMP;
+    } else {
+      closest_to = op::tune::AlwaysOMP;
+    }
+
+    if (time_auto <= mustbe_asfast || closest_to == fastest_standard_mode) {
+      return { true, closest_to };
+    }
+    return { false, closest_to };
+  }
+
+ public:
+  /*!
+   * \brief Given timing statistics, determine if 'Auto' mode made the correct choice.
+   * \param direction Compute direction for which to check (Forward or Backward)
+   * \param verbose If true, print the statistical info
+   * \return A map of shape vectors to a pair <bool, TuningMode> consisting of true or false
+   *         signifying if the test appears to have made the correct decision, and the TuningMode
+   *         which was closest in timing to the Auto mode.
+   */
+  shape_vec_to_bool_map CalculateModeSort(const test::op::TimingDirection direction,
+                                          bool verbose = true) {
+    shape_vec_to_bool_map results;
+    // Incredibly inefficient method of grouping the results
+    for (const auto &i : timing_) {
+      // print shapes
+      const shape_vect &shapes = i.first;
+      if (verbose) {
+        for (size_t x = 0, n = shapes.size(); x < n; ++x) {
+          const TShape &shape = shapes[x];
+          if (x) {
+            std::cout << ", ";
+          }
+          std::cout << shape;
+        }
+        const TShape& lhs_shape = shapes[0];
+        std::cout << " lhs=" << test::pretty_num(lhs_shape.Size()) << " items";
+        std::cout << "\t(" << TimingDirectionAsString(direction) << ")" << std::endl;
+      }
+      const auto &mode2timing = i.second;
+      modesort_t mode_sort;
+      for (const auto &j : mode2timing) {
+        const op::tune::TuningMode mode = j.first;
+        const test::perf::timing_map_t &tm = j.second;
+        if (tm.find(direction) != tm.end()) {
+          const test::perf::TimingInstrument::Info &info = tm.find(direction)->second;
+          double duration = info.TimeEach();
+          mode_sort.insert({duration, mode});
+        }
+      }
+      if (!mode_sort.empty()) {
+        // Now we have modes sorted by performance, fastest to slowest
+        const bool_mode_pair result = CheckCorrectTuning(mode_sort);
+        if (verbose) {
+          for (const auto &k : mode_sort) {
+            std::cout << "\t" << op::tune::TuningModeToString(k.second)
+                      << ": " << k.first << " ms";
+            if (k.second == op::tune::Auto) {
+              std::cout << " (" << op::tune::TuningModeToString(result.second) << ")";
+            }
+            std::cout << std::endl;
+          }
+        }
+        std::cout << std::flush;
+        if (!result.first && verbose) {
+          std::cout << "*** WARNING: Wrong OMP state selected ***" << std::endl << std::flush;
+        }
+        if (verbose) {
+          std::cout << std::endl << std::flush;
+        }
+        CHECK(results.find(shapes) == results.end());
+        results[shapes] = result;
+      }
+    }
+    return std::move(results);
+  }
+
+  /*!
+   * \brief Perform execution runs for a given forward (and optionally backward) operator
+   * \param kwargs Parameters for the operator
+   * \param op_name Name by which the operator is registered with nnvm
+   * \param backward_op_name Backward operator name
+   */
+  void TestTunedOperator(const kwargs_t &kwargs,
+                         const char *op_name,
+                         const char *backward_op_name = COREOP_BWD_OP_NAME_VALUE_NONE) {
+    timing_.clear();
+    using namespace mxnet::op;
+    tuned_timing_t timing;
+    for (int x = 0; x < 1; ++x) {
+      for (auto mode : {op::tune::AlwaysOMP,
+                        op::tune::Auto,
+                        op::tune::NeverOMP}) {
+        std::cout << std::endl << op::tune::TuningModeToString(mode) << std::endl << std::flush;
+        mxnet::op::OperatorTune<DType>::set_tuning_mode(mode);
+        const ShapesToPerfTimingMap shapes2perfmap = RunCoreOpTimingTest(false,
+                                                                         kwargs,
+                                                                         op_name,
+                                                                         backward_op_name);
+        for (const auto &item : shapes2perfmap) {
+          const shape_vect &shapes = item.first;
+          const test::perf::timing_map_t &tm = item.second;
+          timing_[shapes][mode] = tm;
+        }
+      }
+    }
+  }
+
+ private:
+  tuned_timing_t  timing_;
+};
+
+/* Some test results:
+ * AWS c4.8xlarge:
+  Success rate for type float: 0.90278
+  Success rate for type double: 0.88889
+  Success rate for type mshadow::half::half_t: 0.83333
+  Success rate for type unsigned char: 0.86111
+  Success rate for type int: 0.95833
+  Success rate for type long: 0.88889
+ * desktop: 12-core (6 real CPU cores + hyperthreading)
+  Success rate for type float: 0.79167
+  Success rate for type double: 0.75000
+  Success rate for type unsigned char: 0.72222
+  Success rate for type int: 0.94444
+  Success rate for type long: 1.00000
+ *
+ */
+
+/*!
+ * \brief Rune a tuning evaluation
+ * \tparam DType Data type for which to evaluate tuning
+ */
+template<typename DType>
+static float EvaluateTune() {
+  std::vector<std::pair<std::string, std::string>> binary_operators;
+  if (test::performance_run) {
+    binary_operators = {
+      { "relu",    "" },  // Code can figure out what the backward op is for some
+      { "sigmoid", "" },
+      { "sqrt",    "" },
+      { "elemwise_add", "_backward_add" },
+      { "elemwise_mul", "_backward_mul" },
+      { "elemwise_div", "_backward_div" }
+    };
+  } else {
+    binary_operators = {
+      { "elemwise_add", "_backward_add" }
+    };
+  }
+  size_t count = 0, success = 0;
+  for (size_t i = 0, n = binary_operators.size(); i < n; ++i) {
+    TuningTester<DType> tuningTester;
+    std::cout << "******************************" << std::endl;
+    std::cout << "Operators: " << binary_operators[i].first
+              << ", " << binary_operators[i].second
+              << " for type: " << test::type_name<DType>()
+              << std::endl;
+    std::cout << "******************************" << std::endl;
+    tuningTester.TestTunedOperator({},
+                                   binary_operators[i].first.c_str(),
+                                   binary_operators[i].second.c_str());
+    typename TuningTester<DType>::shape_vec_to_bool_map res_fwd =
+      tuningTester.CalculateModeSort(test::op::Forward);
+    for (auto iter = res_fwd.begin(), e = res_fwd.end(); iter != e; ++iter) {
+      ++count;
+      if (iter->second.first) {
+        ++success;
+      }
+    }
+    typename TuningTester<DType>::shape_vec_to_bool_map res_bwd =
+      tuningTester.CalculateModeSort(test::op::Backward);
+    for (auto iter = res_bwd.begin(), e = res_bwd.end(); iter != e; ++iter) {
+      ++count;
+      if (iter->second.first) {
+        ++success;
+      }
+    }
+  }
+  if (count) {
+    return static_cast<float>(success) / static_cast<float>(count);
+  }
+  return 1.0f;  // nothing ventured, nothing failed (glass-is-half-full approach)
+}
+
+/*! \brief ActivationOp timing test for CPU for float */
+TEST(OMP_TUNING, EvaluateTuneTestFloat) {
+  typedef float DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for double */
+TEST(OMP_TUNING, EvaluateTuneTestDouble) {
+  typedef double DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+TEST(OMP_TUNING, EvaluateTuneTestFloat16) {
+  typedef mshadow::half::half_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for int8_t */
+TEST(OMP_TUNING, EvaluateTuneTestInt8) {
+  typedef uint8_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for int32_t */
+TEST(OMP_TUNING, EvaluateTuneTestInt32) {
+  typedef int32_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for int64_t */
+TEST(OMP_TUNING, EvaluateTuneTestInt64) {
+  typedef int64_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services