You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/15 18:43:34 UTC

[GitHub] cjolivier01 closed pull request #8670: Refreshed branch bc_tune

cjolivier01 closed pull request #8670: Refreshed branch bc_tune
URL: https://github.com/apache/incubator-mxnet/pull/8670
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/CMakeLists.txt b/CMakeLists.txt
index af681d00aa..058fb6c1af 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -35,6 +35,7 @@ mxnet_option(USE_LAPACK           "Build with lapack support" ON IF NOT MSVC)
 mxnet_option(USE_MKL_IF_AVAILABLE "Use MKL if found" ON)
 mxnet_option(USE_MKLML_MKL        "Use MKLML variant of MKL (if MKL found)" ON IF USE_MKL_IF_AVAILABLE AND UNIX AND (NOT APPLE))
 mxnet_option(USE_MKL_EXPERIMENTAL "Use experimental MKL (if MKL enabled and found)" OFF)
+mxnet_option(USE_OPERATOR_TUNING  "Enable auto-tuning of operators" ON)
 mxnet_option(USE_GPERFTOOLS       "Build with GPerfTools support (if found)" ON)
 mxnet_option(USE_JEMALLOC         "Build with Jemalloc support"   ON)
 mxnet_option(USE_PROFILER         "Build with Profiler support"   OFF)
@@ -353,6 +354,10 @@ if(USE_PLUGINS_WARPCTC)
 	list(APPEND CUDA ${PLUGINS_CUSRC})
 endif()
 
+if(USE_OPERATOR_TUNING)
+  add_definitions(-DMXNET_USE_OPERATOR_TUNING=1)
+endif()
+
 if(USE_PLUGIN_CAFFE)
   if(NOT USE_CUDA)
     set(CPU_ONLY ON)
@@ -471,6 +476,11 @@ else()
   set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
 endif()
 
+if(USE_OPERATOR_TUNING)
+  add_executable(mxnet-tblgen tools/tblgen/tblgen.cc)
+  target_link_libraries(mxnet-tblgen ${mxnet_LINKER_LIBS})
+endif()
+
 set(MXNET_INSTALL_TARGETS mxnet)
 if(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin" AND USE_MXNET_LIB_NAMING)
   add_library(mxnet MODULE ${SOURCE})
diff --git a/Makefile b/Makefile
index 8c7ae6e6fd..8659482f26 100644
--- a/Makefile
+++ b/Makefile
@@ -131,6 +131,10 @@ ifeq ($(USE_MKL2017), 1)
 	LDFLAGS +=  -liomp5
 endif
 
+ifeq ($(USE_OPERATOR_TUNING), 1)
+	CFLAGS += -DMXNET_USE_OPERATOR_TUNING=1
+endif
+
 # verify existence of separate lapack library when using blas/openblas/atlas
 # switch off lapack support in case it can't be found
 # issue covered with this
diff --git a/make/config.mk b/make/config.mk
index a4774f0da8..eeda36b365 100644
--- a/make/config.mk
+++ b/make/config.mk
@@ -153,6 +153,12 @@ LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server
 # sudo apt-get install -y libcurl4-openssl-dev
 USE_S3 = 0
 
+#----------------------------
+# performance settings
+#----------------------------
+# Use operator tuning
+USE_OPERATOR_TUNING = 1
+
 # Use gperftools if found
 USE_GPERFTOOLS = 1
 
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 04db326496..cc6a95d6eb 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -29,6 +29,7 @@
 #include "math.h"
 #include "math_functions-inl.h"
 #include "special_functions-inl.h"
+#include "./mxnet_op.h"
 
 #ifdef __CUDACC__
 #include <cuda_fp16.h>
@@ -38,6 +39,24 @@ namespace mxnet {
 namespace op {
 namespace mshadow_op {
 
+/*!
+ * \brief Use the 'MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD' macro outside of the mshadow_op namespace
+ *        See mxnet_op.h for a description of 'MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD'
+ *
+ * \note An entry for the operator must also be added in operator_tune.cc, which will register it
+ *       for auto-tuning and also hold its workload weight
+ */
+#define MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(__op$) \
+  } MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow_op::__op$) namespace mshadow_op {  // NOLINT(*)
+/*!
+ * \brief Use the 'MXNET_TUNABLE_MSHADOW_OP_BACKWARD' macro outside of the mshadow_op namespace
+ *        See mxnet_op.h for a description of 'MXNET_TUNABLE_MSHADOW_OP_BACKWARD'
+ *
+ * \note An entry for the operator must also be added in operator_tune.cc, which will register it
+ *       for auto-tuning and also hold its workload weight
+ */
+#define MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(__op$) \
+  }  MXNET_TUNABLE_MSHADOW_OP_BACKWARD(mshadow_op::__op$) namespace mshadow_op {  // NOLINT(*)
 #ifdef __CUDA_ARCH__
 __constant__ const float PI = 3.14159265358979323846;
 #else
@@ -48,36 +67,41 @@ using std::enable_if;
 using std::is_unsigned;
 
 #define MXNET_UNARY_MATH_OP(name, expr) \
-struct name { \
-  template<typename DType> \
-  MSHADOW_XINLINE static DType Map(DType a) { \
-    return DType(expr); \
-  } \
-}
+  struct name { \
+    template<typename DType> \
+    MSHADOW_XINLINE static DType Map(DType a) { \
+      return DType(expr); \
+    } \
+  }; \
+  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
+
 
 #define MXNET_UNARY_MATH_OP_NC(name, expr) \
-struct name { \
-  template<typename DType> \
-  MSHADOW_XINLINE static DType Map(DType a) { \
-    return (expr); \
-  } \
-}
+  struct name { \
+    template<typename DType> \
+    MSHADOW_XINLINE static DType Map(DType a) { \
+      return (expr); \
+    } \
+  }; \
+  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
 
 #define MXNET_BINARY_MATH_OP(name, expr) \
-struct name { \
-  template<typename DType> \
-  MSHADOW_XINLINE static DType Map(DType a, DType b) { \
-    return DType(expr); \
-  } \
-}
+  struct name { \
+    template<typename DType> \
+    MSHADOW_XINLINE static DType Map(DType a, DType b) { \
+      return DType(expr); \
+    } \
+  }; \
+  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
 
 #define MXNET_BINARY_MATH_OP_NC(name, expr) \
-struct name { \
-  template<typename DType> \
-  MSHADOW_XINLINE static DType Map(DType a, DType b) { \
-    return (expr); \
-  } \
-}
+  struct name { \
+    template<typename DType> \
+    MSHADOW_XINLINE static DType Map(DType a, DType b) { \
+      return (expr); \
+    } \
+  }; \
+  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
 
 #define MXNET_SIMPLE_UNARY_MATH_OP(name) MXNET_UNARY_MATH_OP(name, math::name(a))
 
@@ -133,6 +157,7 @@ struct softrelu {
     }
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(softrelu)
 
 MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a));
 
@@ -153,6 +178,7 @@ struct log10_grad {
     return DType(0.4342944819f / static_cast<float>(a));
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(log10_grad)
 
 template<>
 MSHADOW_XINLINE double log10_grad::Map<double>(double a) {
@@ -168,6 +194,7 @@ struct log2_grad {
     return DType(1.442695041f / static_cast<float>(a));
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(log2_grad)
 
 template<>
 MSHADOW_XINLINE double log2_grad::Map<double>(double a) {
@@ -262,6 +289,7 @@ struct sign {
     return DType(0);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(sign)
 
 MXNET_UNARY_MATH_OP_NC(sign_grad, DType(0));
 
@@ -332,6 +360,7 @@ struct rint {
     return DType((af - floor) <= (ceil - af) ? floor : ceil);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(rint)
 
 /*! \brief used to round number to integer nearest to 0 */
 struct fix {
@@ -342,6 +371,7 @@ struct fix {
     return DType((floor > 0 ? floor : -floor) < (ceil > 0 ? ceil : -ceil) ? floor : ceil);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(fix)
 
 /*! \brief used for generate gradient of MAE loss*/
 MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1));
@@ -404,6 +434,7 @@ struct mod {
     }
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(mod)
 
 template<>
 MSHADOW_XINLINE mshadow::half::half2_t mod::Map<mshadow::half::half2_t>
@@ -418,6 +449,8 @@ struct mod_grad {
     return DType(0);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(mod_grad)
+
 template<>
 MSHADOW_XINLINE double mod_grad::Map<double>(double a, double b) {
   return 1.0;
@@ -453,6 +486,8 @@ struct mod_rgrad {
     return DType(0);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(mod_rgrad)
+
 template<>
 MSHADOW_XINLINE double mod_rgrad::Map<double>(double a, double b) {
   return -::floor(a/b);
@@ -516,6 +551,7 @@ struct rmod {
     }
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(rmod)
 
 template<>
 MSHADOW_XINLINE mshadow::half::half2_t rmod::Map<mshadow::half::half2_t>
@@ -530,6 +566,8 @@ struct rmod_grad {
     return DType(0);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(rmod_grad)
+
 template<>
 MSHADOW_XINLINE double rmod_grad::Map<double>(double a, double b) {
   return -::floor(b/a);
@@ -571,6 +609,7 @@ struct clip {
     }
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(clip)
 
 /***** gamma ******/
 
@@ -584,6 +623,7 @@ struct gamma_grad {
     return DType(math::tgamma(af) * special_functions::cephes::psi<float>(af));
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(gamma_grad)
 
 template<>
 MSHADOW_XINLINE double gamma_grad::Map<double>(double a) {
@@ -601,6 +641,7 @@ struct gammaln_grad {
     return DType(special_functions::cephes::psi<float>(a));
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(gammaln_grad)
 
 template<>
 MSHADOW_XINLINE double gammaln_grad::Map<double>(double a) {
@@ -632,6 +673,7 @@ struct smooth_l1_loss {
     }
   }
 };  // struct smooth_l1_loss
+MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(smooth_l1_loss)
 
 /* The derivative of smooth l1 loss is
  * f'(x) = sigma^2 * x, |x| < 1 / sigma^2
@@ -653,6 +695,7 @@ struct smooth_l1_gradient {
     }
   }
 };  // struct smooth_l1_derivative
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(smooth_l1_gradient)
 
 /*! \brief product reducer */
 struct product {
@@ -754,6 +797,7 @@ struct nansum_grad {
     return isnan_typed::IsNan(a) ? DType(0) : DType(1);
   }
 };
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(nansum_grad)
 
 /*! \brief product reducer that ignores NaN values in the input */
 struct nanprod {
@@ -790,7 +834,7 @@ struct nanprod_grad {
     return isnan_typed::IsNan(a) ? DType(0) : b / a;
   }
 };
-
+MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(nanprod_grad)
 }  // namespace mshadow_op
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index e5c3b51410..aaf33f07ae 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -30,6 +30,7 @@
 #include <mxnet/engine.h>
 #include <mxnet/op_attr_types.h>
 #include <algorithm>
+#include "./operator_tune.h"
 #include "../engine/openmp.h"
 
 #ifdef __CUDACC__
@@ -189,8 +190,9 @@ template<int ndim>
 MSHADOW_XINLINE int dot(const Shape<ndim>& coord, const Shape<ndim>& stride) {
   int ret = 0;
   #pragma unroll
-  for (int i = 0; i < ndim; ++i)
+  for (int i = 0; i < ndim; ++i) {
     ret += coord[i] * stride[i];
+  }
   return ret;
 }
 
@@ -239,7 +241,7 @@ MSHADOW_XINLINE void inc(Shape<ndim>* coord, const Shape<ndim>& shape,
 
 /* Increment coordinates and modify index */
 template<int ndim>
-MSHADOW_XINLINE void inc(Shape<ndim>* coord, const Shape<ndim>& shape,
+/*MSHADOW_XINLINE*/ void inc(Shape<ndim>* coord, const Shape<ndim>& shape,
                          index_t* idx1, const Shape<ndim>& stride1,
                          index_t* idx2, const Shape<ndim>& stride2) {
   ++(*coord)[ndim-1];
@@ -345,15 +347,26 @@ struct op_with_req {
 template<typename OP, typename xpu>
 struct Kernel;
 
+/*!
+ * \brief CPU Kernel launcher
+ * \tparam OP Operator to launch
+ */
 template<typename OP>
 struct Kernel<OP, cpu> {
-  /*! \brief Launch CPU kernel */
+  /*!
+   * \brief Launch a generic CPU kernel.
+   * When using this for a new kernel op, add declaration and tuning objects to
+   * operator_tune.cc
+   * \tparam Args Varargs type to eventually pass to the OP::Map() functoion
+   * \param N Number of iterations
+   * \param dest Destination pointer (used to infer DType)
+   * \param args Varargs to eventually pass to the OP::Map() functoion
+   */
   template<typename ...Args>
   inline static void Launch(mshadow::Stream<cpu> *, const int N, Args... args) {
 #ifdef _OPENMP
     const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
     if (omp_threads < 2) {
-      // Zero means not to use OMP, but don't interfere with external OMP behavior
       for (int i = 0; i < N; ++i) {
         OP::Map(i, args...);
       }
@@ -370,14 +383,54 @@ struct Kernel<OP, cpu> {
 #endif
   }
 
+  /*!
+   * \brief Launch CPU kernel which has OMP tuning data available.
+   * When using this for a new kernel op, add declaration and tuning objects to
+   * operator_tune.cc
+   * \tparam PRIMITIVE_OP The primitive operation to use for tuning
+   * \tparam DType Data type
+   * \tparam Args Varargs type to eventually pass to the OP::Map() functoion
+   * \param N Number of iterations
+   * \param dest Destination pointer (used to infer DType)
+   * \param args Varargs to eventually pass to the OP::Map() functoion
+   */
+  template<typename PRIMITIVE_OP, typename DType, typename ...Args>
+  static void LaunchTuned(mshadow::Stream<cpu> *, const int N, Args... args) {
+#ifdef _OPENMP
+    const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+    if (omp_threads < 2 || !tuned_op<PRIMITIVE_OP, DType>::UseOMP(
+      static_cast<size_t>(N), static_cast<size_t>(omp_threads))) {
+      for (int i = 0; i < N; ++i) {
+        OP::Map(i, args...);
+      }
+    } else {
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        OP::Map(i, args...);
+      }
+    }
+#else
+    for (int i = 0; i < N; ++i) {
+      OP::Map(i, args...);
+    }
+#endif
+  }
+
+  /*!
+   * \brief Launch custom-tuned kernel where each thread is set to
+   *        operate on a contiguous partition
+   * \tparam Args Varargs type to eventually pass to the OP::Map() functoion
+   * \param N Number of iterations
+   * \param args Varargs to eventually pass to the UseOMP() and OP::Map() functions
+   */
   template<typename ...Args>
   inline static void LaunchEx(mshadow::Stream<cpu> *s, const int N, Args... args) {
 #ifdef _OPENMP
     const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
-    if (omp_threads <= 1) {
+    if (omp_threads < 2 || !tuned_op<OP, typename OP::DataType>::UseOMP(N, omp_threads, args...)) {
       OP::Map(0, N, args...);
     } else {
-      int length = (N + omp_threads - 1) / omp_threads;
+      const int length = (N + omp_threads - 1) / omp_threads;
       #pragma omp parallel for num_threads(omp_threads)
       for (int i = 0; i < N; i += length) {
         OP::Map(i, i + length > N ? N - i : length, args...);
@@ -417,7 +470,7 @@ struct Kernel<OP, gpu> {
   }
 
   template<typename ...Args>
-  inline static void LaunchEx(mshadow::Stream<gpu> *s, int N, Args... args) {
+  inline static void LaunchEx(mshadow::Stream<gpu> *s, const int N, Args... args) {
     using namespace mshadow::cuda;
     int ngrid = std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum);
     mxnet_generic_kernel_ex<OP, Args...>
@@ -428,6 +481,43 @@ struct Kernel<OP, gpu> {
 #endif  // __CUDACC__
 
 /*!
+ * \brief Wrap Kernel<OP, xpu>::Launch* with some special-case helpers
+ */
+template<typename OP, typename xpu>
+struct KernelWrapper {
+  /*!
+   * \brief Launch 'mshadow_op-type' op (i.e. DType (*)( ... ) { return <operation> }
+   * \tparam Args Varargs type to eventually pass to the OP::Map() function
+   * \param s Stream object pointer (unused)
+   * \param N Number of iterations
+   * \param args Varargs to eventually pass to the OP::Map() functoion
+   */
+  template<typename DType, typename ...Args>
+  MSHADOW_CINLINE static void LaunchMShadowOpEx(mshadow::Stream<xpu> *s,
+                                                const int N,
+                                                DType *dest,
+                                                Args... args) {
+    mxnet::op::mxnet_op::Kernel<OP, xpu>::template LaunchTuned<
+      typename OP::Operation, DType>(s, N, dest, args...);
+  }
+
+  /*!
+   * \brief Launch 'mxnet_op-type' op (i.e. void (*)(int N, DType *out, ... )
+   * \tparam Args Varargs type to eventually pass to the OP::Map() function
+   * \param s Stream object pointer (unused)
+   * \param N Number of iterations
+   * \param args Varargs to eventually pass to the OP::Map() functoion
+   */
+  template<typename DType, typename ...Args>
+  MSHADOW_CINLINE static void LaunchMXNetOpEx(mshadow::Stream<xpu> *s,
+                                              const int N,
+                                              DType *dest,
+                                              Args... args) {
+    mxnet::op::mxnet_op::Kernel<OP, xpu>::template LaunchTuned<OP, DType>(s, N, dest, args...);
+  }
+};
+
+/*!
  * \brief Set to immediate scalar value kernel
  * \tparam val Scalar immediate
  */
@@ -449,7 +539,23 @@ struct set_to_int {
  */
 using set_zero = set_to_int<0>;
 using set_one  = set_to_int<1>;
+_MXNET_TUNABLE_MXNET_OP_FWD(set_zero);  // _ prefix denotes "already in mxnet_op namespace"
+_MXNET_TUNABLE_MXNET_OP_FWD(set_one);
 }  // namespace mxnet_op
+
+/*!
+ * \brief Tuning specializations for the simple ops in <mshadow/base.h>
+ *        Basically, this overrides mxnet::op::mxnet_op::Kernel<OP, cpu>::Launch() and
+ *        redirects to mxnet::op::mxnet_op::KernelWrapper<OP, cpu>::Launch????OpEx(),
+ *        which eventually leads back to mxnet::op::mxnet_op::Kernel<OP, cpu>::LaunchTuned()
+ */
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::identity)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::plus)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::minus)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::mul)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::div)
+MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::right)
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/operator_tune-inl.h b/src/operator/operator_tune-inl.h
new file mode 100644
index 0000000000..d94fa08c8e
--- /dev/null
+++ b/src/operator/operator_tune-inl.h
@@ -0,0 +1,745 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef MXNET_OPERATOR_OPERATOR_TUNE_INL_H_
+#define MXNET_OPERATOR_OPERATOR_TUNE_INL_H_
+
+#include <dmlc/base.h>
+#include <dmlc/logging.h>
+#include <mshadow/base.h>
+#include <cxxabi.h>
+#include <atomic>
+#include <cstdint>
+#include <chrono>
+#include <thread>
+#include <string>
+#include <vector>
+#include <algorithm>
+#include <list>
+#include <random>
+#include <unordered_set>
+#include "./mxnet_op.h"
+#include "./operator_tune.h"
+
+namespace mxnet {
+namespace op {
+
+#ifndef MXNET_NO_INLINE
+#ifdef _MSC_VER
+#define MXNET_NO_INLINE __declspec(noinline)
+#else
+#define MXNET_NO_INLINE __attribute__((noinline))
+#endif
+#endif  // MXNET_NO_INLINE
+
+#define OUTSIDE_COUNT_SHIFT    9
+
+namespace tune {
+
+/*!
+ * \brief Convert TuningMode value to a string representation
+ * \param tm  Scalar TuningMode value
+ * \return Character pointer to a string representing the TuningMode value
+ */
+inline const char *TuningModeToString(const TuningMode tm) {
+  switch (tm) {
+    case Auto:
+      return "Auto";
+    case NeverOMP:
+      return "NeverOMP";
+    case AlwaysOMP:
+      return "AlwaysOMP";
+    default:
+      CHECK(false) << "Unknown TuningMode type: " << static_cast<int>(tm);
+      return "<unknown>";
+  }
+}
+}  // namespace tune
+
+/*!
+ * \brief Engine to tune kernel operations
+ * \tparam DType Data type to be used when tuning the kernel operations
+ * \remarks The basic concept here is that we time how long a trivial loop takes with and without
+ * OMP, subtracting the non-OMP run from the OMP run, which gives us the time
+ * that the OMP overhead takes.  Times were found to be relatively invariant with
+ * regard ot the number of threads/cores on a given machine.
+ * Secondly, supplied operators are run and timed (for each data type) in order to determine
+ * their individual time cost.
+ *
+ * Knowing the following items, we can determine how long the OMP and non-OMP run
+ * is expected to take:
+ *  1) OMP overhead time
+ *  2) Number of iterations required
+ *  3) Number of threads to be used if we choose the OMP method
+ *  4) The data type
+ *
+ * Therefore, at Kernel::Launch() time, we can estimate whether it is faster to use OMP or not
+ * for the given kernel operator.
+ *
+ * Results and efficiency of the tuning is tested in the gtest OMP_TUNING test suite
+ */
+template<typename DType>
+class OperatorTune : public OperatorTuneByType<DType> {
+ public:
+  using Tick = OperatorTuneBase::Tick;
+  using duration_t = OperatorTuneBase::duration_t;
+  using OperatorTuneByType<DType>::tuning_mode_;
+
+  /*!
+   * \brief Constructor
+   */
+  OperatorTune() {
+    TuneAll();
+  }
+
+  /*!
+   * \brief Initialize the OperatorTune object
+   * \return Whether the OperatorTune object was successfully initialized
+   */
+  static bool Initialize() {
+    if (!initialized_) {
+      initialized_ = true;
+      // Generate some random data for calling the operator kernels
+      data_set_.reserve(0x100);
+      std::random_device rd;
+      std::mt19937 gen(rd());
+      if (!std::is_integral<DType>::value) {
+        std::uniform_real_distribution<> dis(-1, 1);
+        for (int n = 0; n < 0x100; ++n) {
+          const auto val = static_cast<DType>(dis(gen));
+          // If too close to zero, try again
+          if (fabs(val) < 1e-5) {
+            --n;
+            continue;
+          }
+          data_set_.emplace_back(val);
+        }
+      } else {
+        std::uniform_int_distribution<> dis(-128, 127);
+        for (int n = 0; n < 0x100; ++n) {
+          const auto val = static_cast<DType>(dis(gen));
+          // If zero, try again
+          if (!val) {
+            --n;
+            continue;
+          }
+          data_set_.emplace_back(val);
+        }
+      }
+      // Use this environment variable to generate new tuning statistics
+      // In order to avoid printing too many copies, only the float32 object prints
+      output_tuning_data_ = mshadow::DataType<DType>::kFlag == mshadow::kFloat32
+                            && dmlc::GetEnv("MXNET_OUTPUT_TUNING_DATA", false);
+      // If outputting tuning data, then also output verbose logging info
+      OperatorTuneBase::verbose_tuning_info_ = dmlc::GetEnv("MXNET_VERBOSE_TUNING_INFO", false);
+
+      OperatorTuneBase::tuning_weight_scale_ = dmlc::GetEnv("MXNET_TUNING_WEIGHT_SCALE", 0.0);
+
+      // This isn't actually supposed to be multithreaded init, but just to be sure the change is
+      // seen everywhere, using atomic bool.
+      if (!OperatorTuneBase::calculated_.load()) {
+        // Not especially concerned with a race condition, since this hsould
+        // run when only one thread is active (static init), just don't cache this variable
+        OperatorTuneBase::calculated_.store(true);
+        OperatorTuneBase::omp_overhead_ns_ = GetOMPLoopOverhead();
+        std::string config = dmlc::GetEnv("MXNET_USE_OPERATOR_TUNING", std::string());
+        ParseEnablerConfig(config);
+      }
+
+      if (OperatorTuneBase::verbose_tuning_info_) {
+        LOG(INFO) << "OMP overhead: " << OperatorTuneBase::omp_overhead_ns_ << " nanoseconds";
+      }
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Schedule a tuning run
+   * \tparam OP Operator to tune
+   * \param tune_func Function to call which tunes the operator
+   * \return true if the tune operation was scheduled
+   */
+  template<typename OP>
+  static bool ScheduleTune(void (*tune_func)()) {
+#ifdef MXNET_USE_OPERATOR_TUNING
+    if (tune_func) {
+      GetTuningList()->push_back(tune_func);
+      operator_names_.insert(demangle(typeid(OP).name()));
+      return true;
+    }
+    return false;
+#else
+    return true;
+#endif
+  }
+
+  /*!
+   * \brief Is the template parameter type a tuned kernel?
+   * \tparam OP kernel operator type
+   * \return true if the operator/kernel is tuned
+   */
+  template<typename OP>
+  static bool IsTuned() {
+    return operator_names_.find(demangle(typeid(OP).name())) != operator_names_.end();
+  }
+
+  /*!\
+   * \brief Tune all registered kernel operators that haven't already been tuned
+   */
+  static bool TuneAll() {
+    Initialize();
+    std::list<void (*)()> *tl = GetTuningList();
+    const size_t size_save = tl->size();  // For checking if anything asynchronous is
+    // adding or removing items, which is forbidden
+    if (output_tuning_data_ && !tl->empty()) {
+      // Only emit this once, use the most common case, 'float32'
+      if (mshadow::DataType<DType>::kFlag == mshadow::kFloat32) {
+        std::cout << "OperatorTuneBase::duration_t "
+                  << "OperatorTuneBase::omp_overhead_ns_ = " << OperatorTuneBase::omp_overhead_ns_
+                  << ";" << std::endl << std::flush;
+      }
+    }
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (auto i : *tl) {
+      (*i)();
+    }
+    if (OperatorTuneBase::verbose_tuning_info_) {
+      const duration_t duration = OperatorTune::GetDurationInNanoseconds(start);
+      LOG(INFO) << "Op Tuning  for " << type_name<DType>()
+                << " took " << (duration / 1000000) << " ms";
+    }
+    CHECK_EQ(size_save, tl->size()) << "Tuning list size should not have changed while tuning";
+    tl->clear();
+    return true;
+  }
+
+  /*!
+   * \brief Return set of operator names that were registered to be tuned. Does not imply
+   *        that the operator has been tuned.
+   * \return Set of operator/kernel names that were registered for tuning
+   */
+  static const std::unordered_set<std::string>& TunedOperatorNames() {
+    return operator_names_;
+  }
+
+ protected:
+  /*!
+   * \brief Get the list of tuning function calls for the operators
+   * \return Pointer to list of tuning function calls
+   */
+  static std::list<void (*)()> *GetTuningList();
+
+  /*!
+   * \brief Demangle typeid::name() in order to generate source macros
+   * \param name C++ Mangled name
+   * \return Demangled name as string
+   */
+  static inline std::string demangle(const char *name) {
+    int status = -4;  // some arbitrary value to eliminate the compiler warning
+    std::unique_ptr<char, void (*)(void *)> res{
+      abi::__cxa_demangle(name, nullptr, nullptr, &status),
+      &std::free
+    };
+    return status ? name : res.get();
+  }
+
+  /*!
+   * \brief Type name as string
+   * \tparam T Type
+   * \return std::string representing the human-readable demangled type name
+   */
+  template<typename T> static inline std::string type_name() {
+    return demangle(typeid(T).name());
+  }
+
+  /*! \brief Measure OMP overhead for a trivial OMP loop using all cores
+   * \param omp_thread_count - Number of OMP threads to use in the timing test
+   * \returns Duration in nanoseconds for the OMP overhead (time to initiate and close the
+   *          OMP session)
+   */
+  static duration_t GetOMPLoopOverhead(const size_t omp_thread_count) {
+    CHECK_GT(omp_thread_count, 1);  // Don't try to use OMP for one thread
+    int wl_count = OperatorTuneBase::WORKLOAD_COUNT;
+
+    Tick start = std::chrono::high_resolution_clock::now();
+    // Use two loops in order to simulate OMP outside timing
+    for (size_t i = 0; i < OUTSIDE_COUNT; ++i) {
+      for (int x = 0; x < wl_count; ++x) {
+        // trivial operation
+        volatile_int_ += x;
+      }
+    }
+    const OperatorTuneBase::duration_t no_omp_duration =
+      OperatorTuneBase::GetDurationInNanoseconds(start);
+
+    // Scale OMP iterations by type calculation complexity
+    double factor;
+
+    // if tuning_weight_scale_ is a number that looks valid, use it as the factor
+    if (OperatorTuneBase::tuning_weight_scale_ > 0.01) {
+      factor = OperatorTuneBase::tuning_weight_scale_;
+    } else {
+      // These are empirically-determined constants found by balancing between
+      // a desktop (8 & 12 cpu's) and large cloud instances (32 & 64 cpu's)
+      switch (mshadow::DataType<DType>::kFlag) {
+        case mshadow::kUint8:
+        case mshadow::kInt8:
+          factor = 8.5;
+          break;
+        case mshadow::kInt32:
+          factor = 4.5;
+          break;
+        case mshadow::kInt64:
+          factor = 2;
+          break;
+        case mshadow::kFloat64:
+          factor = 1.25;
+          break;
+        case mshadow::kFloat32:
+        default:
+          factor = 1.0;
+          break;
+      }
+    }
+
+    wl_count = static_cast<int>(factor * OperatorTuneBase::WORKLOAD_COUNT * omp_thread_count);
+    start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < OUTSIDE_COUNT; ++i) {
+      #pragma omp parallel for num_threads(omp_thread_count)
+      for (int x = 0; x < wl_count; ++x) {
+        // trivial operation
+        volatile_int_ += x;
+      }
+    }
+    const duration_t omp_duration = OperatorTuneBase::GetDurationInNanoseconds(start)
+                                    - no_omp_duration;
+    return omp_duration >> OUTSIDE_COUNT_SHIFT;
+  }
+
+  /*! \brief Measure OMP overhead for a trivial OMP loop using all cores
+   * \returns Time in nanoseconds to initialize/cleanup when excuting an OMP block
+   */
+  static duration_t GetOMPLoopOverhead() {
+    // It was found empirically that OMP times was not heavily tied to number of cores,
+    // so take an average across all core counts
+    const auto max_cores = static_cast<size_t>(omp_get_num_procs()) >> 1;
+    if (max_cores >= 2) {
+      std::vector<duration_t> core_times;
+      // Take care of any OMP lazy-init with a throwaway call
+      for (size_t omp_threads = 2; omp_threads <= max_cores; ++omp_threads) {
+        GetOMPLoopOverhead(omp_threads);
+      }
+      std::vector<duration_t> durations;
+      durations.reserve(max_cores - 1);
+      for (size_t omp_threads = 2; omp_threads <= max_cores; ++omp_threads) {
+        const duration_t duration = GetOMPLoopOverhead(omp_threads);
+        if (OperatorTuneBase::verbose_tuning_info_) {
+          LOG(INFO) << "OMP Thread Count: " << omp_threads << ", overhead: " << duration << " ns";
+        }
+        durations.emplace_back(duration);
+      }
+      // return median
+      std::sort(durations.begin(), durations.end());
+      return durations[durations.size() >> 1];
+    }
+    return INT_MAX;  // If only one core, then never use OMP (say the overhead is huge)
+  }
+
+  /*!
+   * \brief Some string utility functions that aren't specific to tuning
+   */
+  struct StringUtil {
+    /*!
+     * \brief Terim whitespace from beninning and end of string
+     * \param s String to trimp
+     * \return reference to the modified string. This is the same std::string object as what was
+     *         supplied in the parameters
+     */
+    static std::string &trim(std::string *s) {
+      s->erase(s->begin(), std::find_if(s->begin(), s->end(), [](int ch) {
+        return !std::isspace(ch);
+      }));
+      s->erase(std::find_if(s->rbegin(), s->rend(), [](int ch) {
+        return !std::isspace(ch);
+      }).base(), s->end());
+      return *s;
+    }
+
+    /*!
+     * \brief Tokenize a string into a list of tokens
+     * \param s String to tokenize
+     * \return std::list of tokens
+     */
+    static std::list<std::string> string2list(const std::string &s) {
+      std::list<std::string> res;
+      std::istringstream iss(s);
+      std::string token;
+      while (std::getline(iss, token, ',')) {
+        trim(&token);
+        if (!token.empty()) {
+          res.push_back(token);
+        }
+      }
+      return std::move(res);
+    }
+  };
+
+  /*!
+   * \brief Get data type from string representation
+   * \warning Do not call from a performance-sensitive area
+   */
+  static int type_from_string(const std::string& type_string) {
+    if (type_string == "float32")
+      return mshadow::kFloat32;
+    if (type_string == "float64")
+      return mshadow::kFloat64;
+    if (type_string == "float16")
+      return mshadow::kFloat16;
+    if (type_string == "int8")
+      return mshadow::kInt8;
+    if (type_string == "uint8")
+      return mshadow::kUint8;
+    if (type_string == "int32")
+      return mshadow::kInt32;
+    if (type_string == "int64")
+      return mshadow::kInt64;
+    return -1;  // invalid
+  }
+
+  /*!
+   * \brief Parse MXNET_ENABLE_OPERATOR_TUNING environment variable
+   * \param config String representation of MXNET_ENABLE_OPERATOR_TUNING environment variable
+   *        Values:
+   *            0=disable all
+   *            1=enable all
+   *            float32, float16, float32=list of types to enable, and disable those not listed
+   */
+  static void ParseEnablerConfig(std::string config) {
+    StringUtil::trim(&config);
+    if (!config.empty()) {
+      // First disable all
+      OperatorTuneByType<float>::set_tuning_mode(tune::AlwaysOMP);
+      OperatorTuneByType<double>::set_tuning_mode(tune::AlwaysOMP);
+      OperatorTuneByType<int8_t>::set_tuning_mode(tune::AlwaysOMP);
+      OperatorTuneByType<uint8_t>::set_tuning_mode(tune::AlwaysOMP);
+      OperatorTuneByType<int32_t>::set_tuning_mode(tune::AlwaysOMP);
+      OperatorTuneByType<int64_t>::set_tuning_mode(tune::AlwaysOMP);
+      // See if it's a non-number (ie type or list of types)
+      if (!::isdigit(config[0])) {
+        OperatorTuneByType<mshadow::half::half_t>::set_tuning_mode(tune::Auto);
+        std::list<std::string> tokens = StringUtil::string2list(config);
+        for (const std::string& stype : tokens) {
+          // We don't have an enum for halt_t
+          const int typ = type_from_string(stype);
+          if (typ >= 0) {
+            switch (typ) {
+              case mshadow::kFloat32:
+                OperatorTuneByType<float>::set_tuning_mode(tune::Auto);
+                break;
+              case mshadow::kFloat64:
+                OperatorTuneByType<double>::set_tuning_mode(tune::Auto);
+                break;
+              case mshadow::kFloat16:
+                OperatorTuneByType<mshadow::half::half_t>::set_tuning_mode(tune::Auto);
+                break;
+              case mshadow::kInt8:
+                OperatorTuneByType<int8_t>::set_tuning_mode(tune::Auto);
+                break;
+              case mshadow::kUint8:
+                OperatorTuneByType<uint8_t>::set_tuning_mode(tune::Auto);
+                break;
+              case mshadow::kInt32:
+                OperatorTuneByType<int32_t>::set_tuning_mode(tune::Auto);
+                break;
+              case mshadow::kInt64:
+                OperatorTuneByType<int64_t>::set_tuning_mode(tune::Auto);
+                break;
+              default:
+                CHECK(false) << "Unsupported tuning data type: " << stype;
+                break;
+            }
+          } else {
+            // -1 is error
+            LOG(WARNING) << "Unknown data type to be tuned: " << stype;
+          }
+        }
+      } else {
+        if (std::atoi(config.c_str()) > 0) {
+          OperatorTuneByType<float>::set_tuning_mode(tune::Auto);
+          OperatorTuneByType<double>::set_tuning_mode(tune::Auto);
+          OperatorTuneByType<int8_t>::set_tuning_mode(tune::Auto);
+          OperatorTuneByType<uint8_t>::set_tuning_mode(tune::Auto);
+          OperatorTuneByType<int32_t>::set_tuning_mode(tune::Auto);
+          OperatorTuneByType<int64_t>::set_tuning_mode(tune::Auto);
+          OperatorTuneByType<mshadow::half::half_t>::set_tuning_mode(tune::Auto);
+        }
+      }
+    }
+  }
+
+  /*! \brief Whether this object has been initialized */
+  static bool initialized_;
+  /*! \brief Number of passes to obtain an average */
+  static constexpr duration_t OUTSIDE_COUNT = (1 << OUTSIDE_COUNT_SHIFT);
+  /*! \brief Random data for timing operator calls */
+  static std::vector<DType> data_set_;
+  /*! \brief Operators tuned */
+  static std::unordered_set<std::string> operator_names_;
+  /*! \brief Arbitary object to modify in OMP loop */
+  static volatile int volatile_int_;
+  /*! \brief Output insertable (into code) instantiation+default-value macros */
+  static bool output_tuning_data_;
+};
+
+/*!
+ * \brief Class that tunes unary operators
+ * \tparam DType Data type to be used when tuning the kernel operations
+ */
+template<typename DType>
+class UnaryOpTune : public OperatorTune<DType> {
+ protected:
+  typedef OperatorTune<DType> Super;
+  using duration_t = typename Super::duration_t;
+  using Tick = typename Super::Tick;
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for kernels that take no arguments (ie set_zero)
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static duration_t GetBlankWorkload() {
+    DType tmp;
+    volatile DType *res = &tmp;
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      // Use a logical AND instead of mod to avoid affecting the timing result with a slow divide
+      *res += OP::Map();
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for kernels that take one argument (ie sqrt())
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static duration_t GetUnaryWorkload() {
+    DType tmp;
+    volatile DType *res = &tmp;
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      // Use a logical AND instead of mod to avoid affecting the timing result with a slow divide
+      *res = OP::Map(Super::data_set_[i & 0xFF]);
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for kernels that take two arguments (ie elemwise_add())
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static inline duration_t GetBinaryWorkload() {
+    DType tmp;
+    volatile DType *res = &tmp;
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      // Use a logical AND instead of mod to avoid affecting the timing result with a slow divide
+      *res = OP::Map(Super::data_set_[i & 0xFF], Super::data_set_[(i + 1) & 0xFF]);
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for kernels that take three arguments (ie backwards_grad<elemwise_add>())
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static duration_t GetTertiaryWorkload() {
+    DType tmp;
+    volatile DType *res = &tmp;
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      // Use a logical AND instead of mod to avoid affecting the timing result with a slow divide
+      *res = OP::Map(Super::data_set_[i & 0xFF],
+                     Super::data_set_[(i + 1) & 0xFF],
+                     Super::data_set_[i & 0xFF]);
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+  /*!
+   * \brief Determine the time it takes a kernel operator to execute WORKLOAD_COUNT iterations
+   *        Used for mxnet-like kernels that take no arguments)
+   * \tparam OP Kernel operator
+   * \return Duration in nanoseconds for the 'WORKLOAD_COUNT' operations
+   */
+  template<typename OP>
+  static duration_t GetBlankWorkloadEx() {
+    std::unique_ptr<DType> tmp(new DType[Super::WORKLOAD_COUNT]);
+    DType *tmp_ptr = tmp.get();
+    const Tick start = std::chrono::high_resolution_clock::now();
+    for (size_t i = 0; i < Super::WORKLOAD_COUNT; ++i) {
+      OP::Map(i, tmp_ptr);
+    }
+    const duration_t omp_duration = Super::GetDurationInNanoseconds(start);
+    return omp_duration ? omp_duration : 1;
+  }
+
+ public:
+  /*!
+   * \brief Tune the specified kernel operator.  Optionally print out C++ macro that defines the
+   *        tuning data variable and the default tuned value
+   *        This function tunes an operator which takes no arguments
+   * \tparam OP The kernel operator to be tuned
+   */
+  template<typename OP>
+  static void TuneBlankOperator() {
+    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_ = GetBlankWorkload<OP>();
+    if (Super::output_tuning_data_) {
+      std::cout << "IMPLEMENT_UNARY_WORKLOAD_FWD("
+                << Super::template type_name<OP>()
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Tune the specified kernel operator.  Optionally print out C++ macro that defines the
+   *        tuning data variable and the default tuned value
+   *        This function tunes an operator which takes one argument
+   * \tparam OP The kernel operator to be tuned
+   */
+  template<typename OP>
+  static void TuneUnaryOperator() {
+    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_ = GetUnaryWorkload<OP>();
+    if (Super::output_tuning_data_) {
+      std::cout << "IMPLEMENT_UNARY_WORKLOAD_FWD("
+                << Super::template type_name<OP>()
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Tune the specified kernel operator.  Optionally print out C++ macro that defines the
+   *        tuning data variable and the default tuned value
+   *        This function tunes a backward operator which takes one argument
+   * \tparam OP The kernel operator to be tuned
+   */
+  template<typename OP>
+  static void TuneUnaryBackwardOperator() {
+    mxnet::op::mxnet_op::tuned_op<mxnet_op::backward_grad<OP>, DType>::workload_ =
+      GetBinaryWorkload<mxnet::op::mxnet_op::backward_grad<OP>>();
+    if (Super::output_tuning_data_) {
+      std::cout << "IMPLEMENT_UNARY_WORKLOAD_BWD("
+                << Super::template type_name<OP>()
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Tune the specified "mxnet_op-type" kernel operator.
+   *        Optionally print out C++ macro that defines the
+   *        tuning data variable and the default tuned value
+   *        This function tunes an operator which takes no arguments
+   * \tparam OP The kernel operator to be tuned
+   */
+  template<typename OP>
+  static void TuneBlankOperatorEx() {
+    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_ = GetBlankWorkloadEx<OP>();
+    if (Super::output_tuning_data_) {
+      std::cout << "IMPLEMENT_BLANK_WORKLOAD_FWD("
+                << Super::template type_name<OP>()
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Determine whether to use OMP based upon both timing and configuration using the
+   *        given (templated) operator's workload
+   * \tparam OP Operator whose workload to use (tuned_op::workload_)
+   * \param N Number of iterations desired
+   * \param thread_count Number of OMP threads available to perform the iterations
+   * \returns Whether it's faster to use OMP for these iterations
+   */
+  template<typename OP>
+  inline static bool UseOMP(size_t N, size_t thread_count) {
+      return OperatorTune<DType>::UseOMP(N,
+                                         thread_count,
+                                         static_cast<uint64_t>(N) * OP::workload_);
+  }
+};
+
+/*!
+ * \brief Class that tunes binary and unary operators
+ * \tparam DType Data type to be used when tuning the kernel operations
+ */
+template<typename DType>
+class BinaryOpTune : public UnaryOpTune<DType> {
+ protected:
+  typedef UnaryOpTune<DType> Super;
+
+ public:
+  /*!
+   * \brief Tune a generic binary operator
+   * @tparam OP - Operator type
+   */
+  template<typename OP>
+  static void TuneBinaryOperator() {
+    mxnet_op::tuned_op<OP, DType>::workload_ = Super::template GetBinaryWorkload<OP>();
+    if (Super::Super::output_tuning_data_) {
+      std::cout << "IMPLEMENT_BINARY_WORKLOAD_FWD("
+                << Super::template type_name<OP>()
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+
+  /*!
+   * \brief Tune binary backward operator
+   * \tparam OP - operator
+   */
+  template<typename OP>
+  static void TuneBinaryBackwardOperator() {
+    mxnet::op::mxnet_op::tuned_op<mxnet_op::backward_grad<OP>, DType>::workload_ =
+      Super::template GetTertiaryWorkload<mxnet::op::mxnet_op::backward_grad<OP>>();
+    if (Super::Super::output_tuning_data_) {
+      std::cout << "IMPLEMENT_BINARY_WORKLOAD_BWD("
+                << Super::template type_name<OP>()
+                << ");  // NOLINT()" << std::endl << std::flush;  // For long lines
+    }
+  }
+};
+
+#undef OUTSIDE_COUNT_SHIFT
+#undef WORKLOAD_COUNT_SHIFT
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_OPERATOR_TUNE_INL_H_
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
new file mode 100644
index 0000000000..a1e177b7a8
--- /dev/null
+++ b/src/operator/operator_tune.cc
@@ -0,0 +1,349 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <atomic>
+#include "./mxnet_op.h"
+#include "./mshadow_op.h"
+#include "./tensor/init_op.h"
+#include "./operator_tune-inl.h"
+#include "./tensor/elemwise_binary_broadcast_op.h"
+
+namespace mxnet {
+namespace op {
+
+/*!
+ * \brief Shared static variables for all OperatorTune data types
+ */
+std::atomic<bool> OperatorTuneBase::calculated_(false);
+bool OperatorTuneBase::verbose_tuning_info_ = false;
+double OperatorTuneBase::tuning_weight_scale_ = 0.0;
+
+/*!
+ * \brief Instantiate static variables for OperatorTune<DType>, where 'DType' is specified
+ */
+#define IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(__typ$) \
+  template<> bool OperatorTune<__typ$>::initialized_ = false; \
+  template<> std::vector<__typ$> OperatorTune<__typ$>::data_set_ = {}; \
+  template<> volatile tune::TuningMode OperatorTuneByType<__typ$>::tuning_mode_ = tune::Auto; \
+  template<> volatile int OperatorTune<__typ$>::volatile_int_ = 9;  /* arbitrary number */ \
+  template<> std::unordered_set<std::string> OperatorTune<__typ$>::operator_names_ = {}; \
+  template<> bool OperatorTune<__typ$>::output_tuning_data_ = false; \
+  template<> std::list<void (*)()> *OperatorTune<__typ$>::GetTuningList() { \
+    static std::list<void (*)()> ll; \
+    return &ll; \
+  }
+
+/*!
+ * \brief Static variables for different types (ie OperatorTune<float>, OperatorTune<double>, etc.
+ */
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(float);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(double);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(mshadow::half::half_t);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int8_t);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(uint8_t);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int32_t);
+IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int64_t);
+
+/*!
+ * \brief Init variable used to facilitate registering a tunable operator during
+ *        static initialization
+ * \tparam OP Operator type
+ * \tparam DType Data type
+ */
+template<typename OP, typename DType>
+struct static_init_var {
+  static bool init_;
+};
+
+/*!
+ * \brief Repeat the given macro and associated arguments for each data type,
+ *        appending the data type to the end of the arguments
+ */
+#define MSHADOW_MACRO_FOREACH_TYPE(__macro$, ...) \
+  __macro$(__VA_ARGS__, float); \
+  __macro$(__VA_ARGS__, double); \
+  __macro$(__VA_ARGS__, mshadow::half::half_t); \
+  __macro$(__VA_ARGS__, uint8_t); \
+  __macro$(__VA_ARGS__, int8_t); \
+  __macro$(__VA_ARGS__, int32_t); \
+  __macro$(__VA_ARGS__, int64_t);
+
+
+#define IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$) \
+  namespace mxnet_op { \
+  template<> size_t mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::workload_ = INT_MAX / 4; \
+  template<> std::vector<float> mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::workload_ex_; \
+  }  /* namespace mxnet_op */
+
+/*!
+ * \brief Implement tuning objects for a forward blank (no arguments) kernel operator
+ */
+#define _IMPLEMENT_BLANK_WORKLOAD_FWD(__op$, __typ$) \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$); \
+  namespace mxnet_op { \
+  template<> bool mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
+      N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<__op$, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      mxnet::op::UnaryOpTune<__typ$>::TuneBlankOperatorEx<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a forward unary kernel operator
+ */
+#define _IMPLEMENT_UNARY_WORKLOAD_FWD(__op$, __typ$) \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$); \
+  namespace mxnet_op { \
+  template<> bool mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
+      N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<__op$, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      mxnet::op::UnaryOpTune<__typ$>::TuneUnaryOperator<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a backward unary kernel operator
+ */
+#define _IMPLEMENT_UNARY_WORKLOAD_BWD(__op$, __typ$) \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(mxnet::op::mxnet_op::backward_grad<__op$>, __typ$); \
+  namespace mxnet_op { \
+  template<> \
+  bool mxnet::op::mxnet_op::tuned_op<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op< \
+      mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>>(N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      mxnet::op::UnaryOpTune<__typ$>::TuneUnaryBackwardOperator<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a forward binary kernel operator
+ */
+#define _IMPLEMENT_BINARY_WORKLOAD_FWD(__op$, __typ$) \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$); \
+  namespace mxnet_op { \
+  template<> bool mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::BinaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
+      N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<__op$, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      mxnet::op::BinaryOpTune<__typ$>::TuneBinaryOperator<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a backward binary kernel operator
+ */
+#define _IMPLEMENT_BINARY_WORKLOAD_BWD(__op$, __typ$) \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(mxnet::op::mxnet_op::backward_grad<__op$>, __typ$); \
+  namespace mxnet_op { \
+  template<> \
+    bool mxnet::op::mxnet_op::tuned_op<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::UseOMP( \
+    size_t N, size_t omp_threads) { \
+    return mxnet::op::BinaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op< \
+      mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>>(N, omp_threads); \
+  }}  /* namespace mxnet_op */ \
+  template<> bool static_init_var<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>(  \
+      mxnet::op::BinaryOpTune<__typ$>::TuneBinaryBackwardOperator<__op$>)
+
+/*!
+ * \brief Implement tuning objects for a custom forward kernel operator
+ */
+#define _IMPLEMENT_CUSTOM_WORKLOAD_FWD(__op$, __typ$) \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$<__typ$>, __typ$); \
+  template<> bool static_init_var<__op$<__typ$>, __typ$>::init_ = \
+    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$<__typ$>>(\
+      __op$<__typ$>::Tune)
+
+/*!
+ * \brief Macros for manually adding new blank, unary and binary operators to the tuning set
+ */
+#define IMPLEMENT_UNARY_WORKLOAD_FWD(__op$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_UNARY_WORKLOAD_FWD, __op$)
+
+#define IMPLEMENT_BLANK_WORKLOAD_FWD(__op$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BLANK_WORKLOAD_FWD, __op$)
+
+#define IMPLEMENT_UNARY_WORKLOAD_BWD(__op$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_UNARY_WORKLOAD_BWD, __op$)
+
+#define IMPLEMENT_BINARY_WORKLOAD_FWD(__op$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BINARY_WORKLOAD_FWD, __op$)
+
+#define IMPLEMENT_BINARY_WORKLOAD_BWD(__op$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_BINARY_WORKLOAD_BWD, __op$)
+
+#define IMPLEMENT_CUSTOM_WORKLOAD_FWD(__op$) \
+  MSHADOW_MACRO_FOREACH_TYPE(_IMPLEMENT_CUSTOM_WORKLOAD_FWD, __op$)
+
+IMPLEMENT_CUSTOM_WORKLOAD_FWD(mxnet::op::mxnet_op::tunable_binary_broadcast_kernel);  // NOLINT()
+
+/*!
+ * \brief Tuning data and default weights in the case that MXNET_ENABLE_OPERATOR_AUTOTUNE is set
+ *        to zero (thus turning off auto-tuning)
+ * \note This code can be automatically generated
+ *       by setting the environment variable MXNET_OUTPUT_TUNING_DATA to a positive
+ *       integer value
+ */
+OperatorTuneBase::duration_t OperatorTuneBase::omp_overhead_ns_ = 5000;
+IMPLEMENT_UNARY_WORKLOAD_FWD(mshadow::op::identity);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::identity);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::identity_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::negation);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softrelu_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::exp);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::exp);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::expm1);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log1p);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log1p_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log2);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log2_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log10);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log10_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sin_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sinh);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sinh_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsin);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsin_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arcsinh);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arcsinh_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cos);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cos_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cosh);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cosh_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccos);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccos_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arccosh);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arccosh_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tan);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tan_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctan);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctanh);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctanh_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::square_root);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::square_root_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_square_root);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_square_root_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::cube_root);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::cube_root_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal_cube_root);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_cube_root_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::abs);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sign);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rint);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fix);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gamma_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gammaln);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gammaln_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ceil);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::degrees);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::degrees_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::radians);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::radians_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::plus);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::minus);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::mul);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::div);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::right);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rdiv);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_rgrad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_rgrad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mod);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_rgrad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rmod);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rmod_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::left);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::left);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::right);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::right);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::power);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rpower);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_left);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_left);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_right);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_right);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lt);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::lt);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::le);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::le);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gt);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gt);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ge);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ge);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient);  // NOLINT()
+IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>);  // NOLINT()
+IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<1>);  // NOLINT()
+IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel);  // NOLINT()
+/*!
+ * \brief Tuner objects, *not* automatically generated
+ */
+#ifdef MXNET_USE_OPERATOR_TUNING
+static BinaryOpTune<float>                  binaryOpTuneFloat;
+static BinaryOpTune<double>                 binaryOpTuneDouble;
+static BinaryOpTune<mshadow::half::half_t>  binaryOpTuneHalf;
+static BinaryOpTune<int8_t>                 binaryOpTuneInt8;
+static BinaryOpTune<uint8_t>                binaryOpTuneUInt8;
+static BinaryOpTune<int32_t>                binaryOpTuneInt32;
+static BinaryOpTune<int64_t>                binaryOpTuneInt64;
+#endif  // MXNET_USE_OPERATOR_TUNING
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/operator_tune.h b/src/operator/operator_tune.h
new file mode 100644
index 0000000000..87d8435d8d
--- /dev/null
+++ b/src/operator/operator_tune.h
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef MXNET_OPERATOR_OPERATOR_TUNE_H_
+#define MXNET_OPERATOR_OPERATOR_TUNE_H_
+
+#include <mshadow/base.h>
+#include <mshadow/tensor.h>
+
+namespace mxnet {
+namespace op {
+
+#define WORKLOAD_COUNT_SHIFT  11
+
+/*!
+ * \brief Shared data for all data types being tuned, acts as a base class for the higher-level
+ *        templated tunin classes
+ */
+class OperatorTuneBase {
+ public:
+  typedef int64_t duration_t;
+
+ protected:
+  /*! \brief Have calculated omp_overhead_ yet? */
+  static std::atomic<bool> calculated_;
+  /*! \brief Time in nanoseconds for OMP overhead */
+  static duration_t omp_overhead_ns_;
+  /*! \brief Print debug/trace output for tuning info */
+  static bool verbose_tuning_info_;
+  /*! \brief Tuning scale factor */
+  static double tuning_weight_scale_;
+
+ public:
+  typedef std::chrono::high_resolution_clock::time_point Tick;
+
+  /*!
+   * \brief Get timestamp for "now"
+   * \return Tick object representing the current itmestamp
+   */
+  static MSHADOW_CINLINE Tick Now() {
+    return std::move(std::chrono::high_resolution_clock::now());
+  }
+
+  /*!
+   * \brief Get duration in nanoseconds
+   * \param t1 Start time tick
+   * \param t2 End time tick
+   * \return duration in nanoseconds between t1 and t2
+   */
+  static MSHADOW_CINLINE duration_t GetDurationInNanoseconds(const Tick &t1, const Tick &t2) {
+    return static_cast<duration_t>(
+      std::chrono::duration_cast<std::chrono::nanoseconds>(t2 - t1).count());
+  }
+
+  /*!
+   * \brief Get duration in nanoseconds between the given 'since' value and now
+   * \param since Reference time which to calculate the duration
+   * \return Duration in nanoseconds between the given 'since' value and now
+   */
+  static MSHADOW_CINLINE duration_t GetDurationInNanoseconds(const Tick &since) {
+    return GetDurationInNanoseconds(since, Now());
+  }
+
+  /*! \brief Loop size to be timed (single op nanos may be too small to store accurately) */
+  static constexpr duration_t WORKLOAD_COUNT = (1 << WORKLOAD_COUNT_SHIFT);
+
+  /*!
+   * \brief Timer convenience class, sets start time as "now" in the constructor
+   */
+  struct Timer {
+    /*!
+     * \brief Constructor, sets start time
+     */
+    MSHADOW_CINLINE Timer()
+      : start_(OperatorTuneBase::Now()) {}
+    /*!
+     * \brief Get duration in nanoseconds since construction
+     * \return Duration in nanoseconds since construction
+     */
+    MSHADOW_CINLINE int64_t duration() const {
+      return OperatorTuneBase::GetDurationInNanoseconds(start_);
+    }
+
+    /*!
+     * \brief Reference start time, set in constructor
+     */
+    const OperatorTuneBase::Tick start_;
+  };
+
+ protected:
+  /*!
+   * \brief Estimate the time to compute with and without OMP, then return whether OMP is faster
+   * \param N - Number of iterations desired
+   * \param thread_count - Number of OMP threads available to perform the iterations
+   * \returns Whether it's faster to use OMP for these iterations
+   */
+  inline static bool IsOMPFaster(size_t N, size_t thread_count, const uint64_t serial_workload) {
+    if (thread_count >= 2) {
+      // Compute serial time required
+      const uint64_t total_serial_time_ns = serial_workload >> WORKLOAD_COUNT_SHIFT;
+
+      // Compute time required for OMP + # items per thread
+      const uint64_t omp_compute_time_ns = (serial_workload / thread_count) >> WORKLOAD_COUNT_SHIFT;
+      const uint64_t total_omp_time_ns = omp_overhead_ns_ + omp_compute_time_ns;
+
+      const bool rc = total_omp_time_ns < total_serial_time_ns;
+      return rc;
+    }
+    return false;
+  }
+};
+
+namespace tune {
+/*!
+ * \brief Tuning mode for registered kernel operators
+ */
+enum TuningMode {
+  Auto,         // Based upon tuning data, choose whether to use OMP for kernel CPU Launch() loops
+  NeverOMP,     // Don't use OMP for parallelism (legacy behavior for GPU builds)
+  AlwaysOMP     // Don't use OMP for parallelism (legacy behavior for CPU builds)
+};
+}  // namespace tune
+
+template<typename DType>
+class OperatorTuneByType : public OperatorTuneBase {
+ public:
+  /*!
+   * \brief Set tuning mode
+   * \param tuning_mode The tune::TuningMode tuning mode value to set
+   */
+  static MSHADOW_CINLINE void set_tuning_mode(const tune::TuningMode tuning_mode) {
+    // Use const_cast to get past "assigning non-volatile to volatile warning
+    const_cast<tune::TuningMode &>(tuning_mode_) = tuning_mode;
+  }
+
+  /*!
+   * \brief Get the current tuning mode
+   * \return tune::TuningMode value for the current tuning mode
+   */
+  static MSHADOW_CINLINE volatile tune::TuningMode tuning_mode() {
+    return tuning_mode_;
+  }
+
+  /*!
+   * \brief Determine whether to use OMP based upon both timing and configuration
+   * \param N - Number of iterations desired
+   * \param thread_count - Number of OMP threads available to perform the iterations
+   * \returns Whether it's faster to use OMP for these iterations
+   */
+  inline static bool UseOMP(size_t N, size_t thread_count, const uint64_t serial_workload) {
+#ifdef MXNET_USE_OPERATOR_TUNING
+    switch (tuning_mode()) {
+      case tune::Auto:
+        return OperatorTuneBase::IsOMPFaster(N, thread_count, serial_workload);
+      case tune::NeverOMP:
+        return false;
+      case tune::AlwaysOMP:
+      default:
+        return thread_count > 1;
+    }
+#else
+    return true;
+#endif
+  }
+
+ protected:
+  /*! \brief Tuning mode */
+  static volatile tune::TuningMode tuning_mode_;
+};
+
+namespace mxnet_op {
+/*!
+ * \brief Kernel operator wrapper used for tuning data
+ */
+template<typename Operation, typename DType>
+struct tuned_op : public Operation {
+  /*! \brief nanoseconds to perform WORKLOAD_COUNT operations
+   *  \note It is conceivable that a vector of values could be used for more complex tuning,
+   *        but the need hasn't yet arisen
+   *  \remarks This variable generally needs to be implemented somewhere.  Currently this is mostly
+   *           done via macros in operator_tune.cc.  If you get undefined reference errors when
+   *           linking, then try to use one of the macros in that file to instantiate the required
+   *           data/functions
+   */
+  static size_t workload_;
+
+  /*!
+   * \brief Extra workload-calculating information (ie times for sub-portions of the calculation)
+   */
+  static std::vector<float> workload_ex_;
+
+  /*!
+   * \brief Calls parent class (Operation)'s UseOMP
+   * \tparam Args Variable arguments passed
+   * \param N Number of iterations
+   * \param thread_count Number of threads available
+   * \param args Variable arguments passed
+   * \return true if OMP parallelism is recommended
+   */
+  template<typename ...Args>
+  static MSHADOW_CINLINE bool UseOMP(size_t N, size_t thread_count, Args... args) {
+    return Operation::UseOMP(N, thread_count, args...);
+  }
+
+  /*!
+   * \brief Call a standard UseOMP() implementation (if it exists). Currently, these
+   *        are implemented in operator_tune.cc for standard unary, binary,
+   *        and argumentless kernels (i.e. mshadow_op::sqrt)
+   * \tparam Args Variable arguments passed
+   * \param N Number of iterations
+   * \param thread_count Number of threads available
+   * \param args Variable arguments passed
+   * \return true if OMP parallelism is recommended
+   */
+  static bool UseOMP(size_t N, size_t thread_count);
+};
+}  // namespace mxnet_op
+
+/*!
+ * \brief Declare a template specialization for the Kernel::Launch call for the given OP
+ *        wrapped with mxnet_op::op_with_req, using the given OpReqType as the 'req'
+ *        template parameter for 'op_with_req'.  This is useful for the standard mshadow_op
+ *        operators which need to be wrapped with op_with_req in order to be used with the
+ *        Kernel::Launch command.
+ *
+ * \note Expects to be used within the mxnet::op namespace
+ *
+ * For example:
+ *
+ * namespace mxnet_op {
+ * template <>
+ * template <typename... Args>
+ * inline void Kernel<typename mxnet_op::op_with_req<mshadow::op::identity, kNullOp>, cpu>
+ *   ::Launch(mshadow::Stream<cpu>* s, const int N, Args... args) {
+ *   ::mxnet::op::mxnet_op::Kernel<typename mxnet_op::op_with_req<mshadow::op::identity, kNullOp>,
+ *     cpu>::LaunchMShadowOpEx(s, N, args...);
+ *   }
+ * }
+ *
+ */
+#define MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, __req$) \
+  namespace mxnet_op { \
+  template<> template<typename ...Args> \
+  inline void Kernel<typename mxnet_op::op_with_req<__op$, __req$>, ::mshadow::cpu>:: \
+    Launch(mshadow::Stream<::mshadow::cpu> *s, const int N, Args... args) { \
+      /* Launch via LaunchMShadowOpEx() */ \
+      KernelWrapper<typename mxnet_op::op_with_req<__op$, __req$>, ::mshadow::cpu>:: \
+        LaunchMShadowOpEx(s, N, args...); \
+  } \
+  }  /* namespace mxnet_op */
+
+/*!
+ * \brief Declare template specializations for the Kernel::Launch call for the given OP
+ *        wrapped with mxnet_op::op_with_req, using the all supported OpReqType as the 'req'
+ *        template parameter for 'op_with_req'.  This is useful for the standard mshadow_op
+ *        operators which need to be wrapped with op_with_req in order to be used with the
+ *        Kernel::Launch command.
+ * \note Expects to be used within the mxnet::op namespace
+ */
+#define MXNET_TUNABLE_MSHADOW_OP(__op$) \
+  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kNullOp); \
+  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kWriteTo); \
+  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kWriteInplace); \
+  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kAddTo);
+
+#define MXNET_TUNABLE_MSHADOW_OP_BACKWARD(__op$) \
+  MXNET_TUNABLE_MSHADOW_OP(mxnet::op::mxnet_op::backward_grad<__op$>)
+
+#define MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(__op$) \
+  MXNET_TUNABLE_MSHADOW_OP(__op$) \
+  MXNET_TUNABLE_MSHADOW_OP_BACKWARD(__op$)
+
+/*!
+ * \brief mxnet::op::mxnet_op format ops (work directly with Kernel<>::Launch()
+ *        Used from within mxnet::op::mxnet_op namespace
+ */
+#define _MXNET_TUNABLE_MXNET_OP_FWD(__op$) \
+  template<> template<typename ...Args> inline void Kernel<__op$, ::mshadow::cpu>::Launch( \
+    mshadow::Stream<::mshadow::cpu> *s, const int N, Args... args) { \
+      /* Launch via LaunchMXNetOpEx() */ \
+      KernelWrapper<__op$, ::mshadow::cpu>::LaunchMXNetOpEx(s, N, args...); \
+  }
+
+/*!
+ * \brief mxnet::op::mxnet_op format ops (work directly with Kernel<>::Launch()
+ *        Used from within mxnet::op
+ */
+#define MXNET_TUNABLE_MXNET_OP_FWD(__op$) \
+  namespace mxnet_op { _MXNET_TUNABLE_MXNET_OP_FWD(__op$) }  /* namespace mxnet_op */
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_OPERATOR_TUNE_H_
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 1aab714625..ed1fba1fc3 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -29,11 +29,13 @@
 #include <vector>
 #include <string>
 #include <utility>
+#include <set>
 #include "../mshadow_op.h"
 #include "../elemwise_op_common.h"
 #include "./elemwise_binary_op.h"
 #include "../operator_common.h"
 #include "broadcast_reduce-inl.h"
+#include "../mxnet_op.h"
 
 namespace mxnet {
 namespace op {
@@ -134,24 +136,138 @@ inline int BinaryBroadcastShapeCompact(const TShape& lshape, const TShape& rshap
 }
 
 namespace mxnet_op {
-template<int ndim, typename DType, typename OP>
-struct binary_broadcast_kernel {
-  MSHADOW_XINLINE static void Map(int base, int length, OpReqType req,
+
+/*!
+ * \brief Type-level specialization base class for binary_broadcast_kernel
+ * \tparam DType
+ */
+
+template <typename DType> struct tunable_binary_broadcast_kernel;
+
+template<int ndim, typename DType, typename OP, bool is_tuning = false>
+struct binary_broadcast_kernel /*: public tunable_binary_broadcast_kernel<DType>*/ {
+  /*! \brief Map function for binary_broadcast_kernel */
+  /*MSHADOW_XINLINE*/ static void Map(int base, int length, OpReqType req,
                                   const Shape<ndim>& lstride, const Shape<ndim>& rstride,
                                   const Shape<ndim>& oshape, DType* lhs, DType* rhs,
-                                  DType* out, int lsize, int rsize) {
-    Shape<ndim> coord = unravel(base, oshape);
-    index_t lidx = dot(coord, lstride);
-    index_t ridx = dot(coord, rstride);
-    KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx]));
-    // starts from 1 to avoid extra inc at end of loop
-    for (int i = 1; i < length; ++i) {
-      inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
-      KERNEL_ASSIGN(out[base+i], req, OP::Map(lhs[lidx], rhs[ridx]));
+                                  DType* out) {
+    if(req != kNullOp) {
+      Shape <ndim> coord = unravel(base, oshape);
+      index_t lidx = static_cast<index_t>(dot(coord, lstride));
+      index_t ridx = static_cast<index_t>(dot(coord, rstride));
+      if (!is_tuning) {
+        KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx]));
+      }
+      // starts from 1 to avoid extra inc at end of loop
+      for (int i = 1; i < length; ++i) {
+        inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
+        // When tuning, don't actually run the op, since it's not going
+        // to be the op we eventually are using
+        if (!is_tuning) {
+          KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx]));
+        }
+      }
     }
   }
+
+  /*!
+   * \brief Decide whether to use OpenMP parallelization
+   * \tparam Args Variable number and types of arguments (passed same args as Map())
+   * \param N Number of iterations
+   * \param thread_count Number of OMP threads available
+   * \param req Req type (i.e. kNullOp, kWriteTo, kAddTo, etc)
+   * \param args remaining (unused) arguments
+   * \return true if OMP parallelization should be used for the N iterations
+   */
+  template<typename ...Args>
+  static bool UseOMP(const size_t N, const size_t thread_count, OpReqType req, Args... args) {
+    if (req != kNullOp) {
+      CHECK_GT(thread_count, 0) << "Invalid thread count: " << thread_count;
+      const uint64_t length = (N + thread_count - 1) / thread_count;
+
+      float wl = tuned_op<tunable_binary_broadcast_kernel, DType>::workload_ex_[0]
+                 + tuned_op<tunable_binary_broadcast_kernel, DType>::workload_ex_[1]
+                   * length;
+      // OP::Map() is called 'length' times for each map call
+      int64_t subop_actual_workload =
+        tuned_op<OP, DType>::workload_ - tuned_op<mxnet_op::set_zero, DType>::workload_;
+      if(subop_actual_workload < 0) {
+        subop_actual_workload = 1;
+      }
+      wl += subop_actual_workload * length;
+      return OperatorTuneByType<DType>::UseOMP(N, thread_count, static_cast<uint64_t>(wl));
+    }
+    return false;
+  }
 };
 
+/*!
+ * \brief Calculate workload for a given lambda function
+ * \tparam Function Lambda type to time for WORKLOAD_COUNT calls
+ * \param function Lambda to time for WORKLOAD_COUNT calls
+ * \return median workload for function call (nanoseconds for WORKLOAD_COUNT calls)
+ */
+template<typename Function>
+inline int64_t get_workload(Function function) {
+  std::multiset<int64_t> durations;
+  typename OperatorTuneBase::Timer timer;
+  for (int pass = 0; pass < 3; ++pass) {
+    for (int i = 0; i < OperatorTuneBase::WORKLOAD_COUNT; ++i) {
+      function();
+    }
+  }
+  const OperatorTuneBase::duration_t dd = timer.duration();
+  durations.insert(dd);
+  return *++durations.begin();  // return median value
+}
+
+/*!
+ * \brief Type-specific tuning
+ * \tparam DType
+ */
+template<typename DType>
+struct tunable_binary_broadcast_kernel {
+  /*! \brief Allows LaunchEx to know the data type for tuning selection */
+  typedef DType DataType;
+  /*!
+   * \brief Run-time tuning of sub-op-independent binary_broadcast_kernel
+   */
+  static void Tune() {
+    constexpr int dim = 2;  // Have to pick one to represent all
+    Shape<dim> oshape, lstride, rstride;
+    for (index_t i = 0; i < dim; ++i) {
+      oshape[i]  = 28U * i;
+      lstride[i] = 2U * i;
+      rstride[i] = 3U * i;
+    }
+    const int base = 2;
+    const int length = 10;  // length factor
+    const size_t data_size = lstride.Size() * rstride.Size() * oshape.Size();
+    const size_t out_size = base + length;
+    std::unique_ptr<DType> data(new DType[data_size]);
+    //std::unique_ptr<DType> out(new DType[out_size]);
+    memset(data.get(), 0, data_size);  // get into cache
+    //memset(out.get(), 0, out_size); // get into cache
+    tuned_op<tunable_binary_broadcast_kernel, DType>::workload_ex_.push_back(
+      get_workload([&]() { binary_broadcast_kernel<2, DType, mshadow_op::left, true>::Map(
+        base, 1, kWriteTo, lstride, rstride, oshape, data.get(), data.get(), nullptr);
+      });
+    );
+    tuned_op<tunable_binary_broadcast_kernel, DType>::workload_ex_.push_back(
+      get_workload([&]() { binary_broadcast_kernel<2, DType, mshadow_op::left, true>::Map(
+        base, 1000, kWriteTo, lstride, rstride, oshape, data.get(), data.get(), nullptr);
+      });
+    );
+    // Record base time for function
+    tuned_op<tunable_binary_broadcast_kernel, DType>::workload_ex_[1] -=
+      tuned_op<tunable_binary_broadcast_kernel, DType>::workload_ex_[0];
+    // Record per-length item adder
+    tuned_op<tunable_binary_broadcast_kernel, DType>::workload_ex_[1] /= 1000;
+    if(tuned_op<tunable_binary_broadcast_kernel, DType>::workload_ex_[1] <= 0) {
+      tuned_op<tunable_binary_broadcast_kernel, DType>::workload_ex_[1] = 1;
+    }
+  }
+};
 }  // namespace mxnet_op
 
 template<typename xpu, typename OP>
@@ -160,7 +276,6 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
                             const std::vector<TBlob>& inputs,
                             const std::vector<OpReqType>& req,
                             const std::vector<TBlob>& outputs) {
-  using namespace mxnet_op;
   TShape new_lshape, new_rshape, new_oshape;
   int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_,
                                          &new_lshape, &new_rshape, &new_oshape);
@@ -170,13 +285,12 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       BROADCAST_NDIM_SWITCH(ndim, NDim, {
-        Shape<NDim> oshape = new_oshape.get<NDim>();
-        Shape<NDim> lstride = calc_stride(new_lshape.get<NDim>());
-        Shape<NDim> rstride = calc_stride(new_rshape.get<NDim>());
-        Kernel<binary_broadcast_kernel<NDim, DType, OP>, xpu>::LaunchEx(
-            s, new_oshape.Size(), req[0], lstride, rstride, oshape,
-            inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), outputs[0].dptr<DType>(),
-            inputs[0].Size(), inputs[1].Size());
+        mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
+        mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
+        mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
+        mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType, OP>, xpu>::
+        template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape,
+        inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), outputs[0].dptr<DType>());
       });
     });
   }
@@ -249,8 +363,9 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs,
                                   const std::vector<OpReqType>& req,
                                   const std::vector<TBlob>& outputs) {
   TShape new_lshape, new_rshape, new_oshape;
-  bool need_bc = BinaryBroadcastShapeCompact(outputs[0].shape_, outputs[1].shape_, inputs[0].shape_,
-                                             &new_lshape, &new_rshape, &new_oshape);
+  const bool need_bc = BinaryBroadcastShapeCompact(outputs[0].shape_,
+                                                   outputs[1].shape_, inputs[0].shape_,
+                                                   &new_lshape, &new_rshape, &new_oshape) != 0;
   if (!need_bc) {
     ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
   } else {
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
index 8c97849e20..991c10a425 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
@@ -27,6 +27,7 @@
 
 namespace mxnet {
 namespace op {
+
 MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_add)
 .add_alias("broadcast_plus")
 .describe(R"code(Returns element-wise sum of the input arrays with broadcasting.
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index c621f6e9a6..be57d12a9c 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -271,6 +271,7 @@ struct PopulateFullIdxRspKernel {
     KERNEL_ASSIGN(out[i], kWriteTo, i);
   }
 };
+MXNET_TUNABLE_MXNET_OP_FWD(PopulateFullIdxRspKernel);
 
 // Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray,
 // instead of the usual compact representation.
diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h
index c454c95847..1bcd0e2df4 100644
--- a/tests/cpp/include/test_core_op.h
+++ b/tests/cpp/include/test_core_op.h
@@ -611,6 +611,50 @@ class CoreOpProp {
 template<typename DType>
 using CoreOperatorRunner = test::OperatorRunner<CoreOpProp, CoreOpExecutor<DType>>;
 
+
+/*!
+ * \brief Rune a core op forward and backward
+ * \tparam DType Data type
+ * \param isGPU true if operation is to be run on the GPU
+ * \param op_kwargs Operator parameters
+ * \param op_name Operator name as registered with nnvm
+ * \param backward_op_name Backwards operator name as registered with nnvm
+ *        If blank, the runner will attempt to determine the backwards operator. If it fails,
+ *        an exception will be thrown.
+ *        If the string is [none], then no backward operator will be created or executed
+ */
+template<typename DType = float>
+inline void BasicRunCoreOpBidirectional(const bool isGPU,
+                                        bool verbose,
+                                        const kwargs_t& op_kwargs,
+                                        const std::vector<TShape>& shapes,
+                                        const char *op_name,
+                                        const char *backward_op_name = "") {
+  test::op::CoreOpExecutor<DType> op(isGPU, shapes);
+  op.set_verbose(false);
+
+  op.Init(op.ArgsWithOpName(op_kwargs, op_name, backward_op_name));
+
+  if (verbose) {
+    PRINT_NDARRAYS(op.ctx().run_ctx, op.inputs());
+    PRINT_NDARRAYS(op.ctx().run_ctx, op.outputs());
+  }
+  op.Execute();
+  if (verbose) {
+    PRINT_NDARRAYS(op.ctx().run_ctx, op.outputs());
+  }
+  if (op.HasBackward()) {
+    if (verbose) {
+      PRINT_NDARRAYS(op.ctx().run_ctx, op.bwd_inputs());
+      PRINT_NDARRAYS(op.ctx().run_ctx, op.bwd_outputs());
+    }
+    op.ExecuteBackward();
+    if (verbose) {
+      PRINT_NDARRAYS(op.ctx().run_ctx, op.bwd_outputs());
+    }
+  }
+}
+
 }  // namespace op
 }  // namespace test
 }  // namespace mxnet
diff --git a/tests/cpp/include/test_op_runner.h b/tests/cpp/include/test_op_runner.h
index eb259997cd..25f86364db 100644
--- a/tests/cpp/include/test_op_runner.h
+++ b/tests/cpp/include/test_op_runner.h
@@ -145,10 +145,22 @@ class OperatorRunner {
     std::stringstream ss;
     ss << "Timing: " << COUNT << " iterations of " << count << " calls";
     if (timing_shapes[0].ndim()) {
-      // TODO(cjolivier01): Print all shapes (if they differ)
-      ss << ", shape = " << timing_shapes[0] << std::endl << std::flush;
+      size_t lhs_total = 0;
+      ss << ", shape = ";
+      for (size_t i = 0, n = timing_shapes.size(); i < n; ++i) {
+        if (i) {
+          ss << ", ";
+        }
+        ss << timing_shapes[i];
+        if (!i) {
+          lhs_total = timing_shapes[i].Size();
+        }
+      }
+      ss << " = " << test::pretty_num(lhs_total) << " items " << std::endl << std::flush;
+    }
+    if (!mxnet::test::csv) {
+      std::cout << ss.str();
     }
-    std::cout << ss.str();
 
     for (size_t i = 0; i < COUNT; ++i) {
       index_t batchSize = 1;
@@ -217,11 +229,10 @@ class OperatorRunner {
       }
     }
 
-    if (verbose_) {
+    if (verbose_ && !mxnet::test::csv) {
       timing.print(&std::cout, label);
       std::cout << std::endl << std::flush;
     }
-
     return timing.data();
   }
 
diff --git a/tests/cpp/include/test_tune.h b/tests/cpp/include/test_tune.h
new file mode 100644
index 0000000000..5f72ca1bd9
--- /dev/null
+++ b/tests/cpp/include/test_tune.h
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file test_tune.h
+ * \brief operator tuning tester
+ * \author Chris Olivier
+*/
+
+#ifndef TEST_TUNE_H_
+#define TEST_TUNE_H_
+
+#include <sys/time.h>
+#include <dmlc/logging.h>
+#include <iomanip>
+#include <iostream>
+#include <atomic>
+#include <unordered_set>
+#include <unordered_map>
+#include <mutex>
+#include <vector>
+#include <utility>
+#include <algorithm>
+#include <string>
+#include <map>
+#include "../../src/operator/operator_tune-inl.h"
+#include "./test_util.h"
+#include "./test_op.h"
+#include "./test_core_op.h"
+
+namespace mxnet {
+namespace test {
+namespace tune {
+
+/*!
+ * \brief Tuning tests, which whether the correct tuning mode is selected by Auto
+ * \note This class makes no attempt at being performant (i.e. it does all sorts of slow
+ *       deep copies and that sort of thing), so don't insert any of thios code in the main
+ *       trunk unless you've verified the performance characteristics for that chunk of code
+ * \tparam DType Data type to test
+ */
+template<typename DType>
+class TuningTester {
+ public:
+  using kwargs_t = test::op::kwargs_t;
+
+  using bool_mode_pair = std::pair<bool, ::mxnet::op::tune::TuningMode>;
+
+  using shape_vect = std::vector<TShape>;
+  using shape_vec_to_bool_map = std::map<shape_vect, bool_mode_pair, test::less_shapevect>;
+
+ private:
+  using ShapesToPerfTimingMap =
+  std::map<shape_vect, test::perf::timing_map_t, test::less_shapevect>;
+
+  /*!
+   * \brief Run timing test on various data shapes and sizes
+   * \param isGPU true if the GPU should be used for the timing test
+   * \param op_kwargs operator parameters
+   * \param op_name The operator's registered name (with nnvm)
+   * \param backward_op_name The backward operator's registered name (with nnvm)
+   * \return ShapesToPerfTimingMap map holsing timing data for shapes
+   */
+  static ShapesToPerfTimingMap RunCoreOpTimingTest(const bool isGPU,
+                                                   const kwargs_t &op_kwargs,
+                                                   const std::vector<shape_vect>& shapes,
+                                                   const char *op_name,
+                                                   const char *backward_op_name = "") {
+    ShapesToPerfTimingMap res;
+    const kwargs_t kwargs = test::op::CoreOpExecutor<DType>::ArgsWithOpName(
+      op_kwargs, op_name, backward_op_name);
+
+    // prime code and cache before the performance runs
+    test::op::CoreOperatorRunner<DType> runner;
+    runner.set_verbose(false);
+    runner.RunBidirectional(false, {{10, 3, 18, 128}}, kwargs, 1);
+
+    // Do the performance runs
+    const char *pu = isGPU ? "GPU" : "CPU";
+    for (const std::vector<TShape> &this_run_shapes : shapes) {
+      test::perf::timing_map_t tmap = runner.TimingTest(std::string(op_name) + " Operator " + pu,
+                                                        isGPU, false, kwargs, 2, 10,
+                                                        this_run_shapes);
+      CHECK(res.find(this_run_shapes) == res.end());
+      res[this_run_shapes] = tmap;
+    }
+    return std::move(res);
+  }
+
+  using tuned_timing_t = std::map<
+    shape_vect,
+    std::map<::mxnet::op::tune::TuningMode, test::perf::timing_map_t>, test::less_shapevect>;
+
+  using modesort_t = std::multimap<double, ::mxnet::op::tune::TuningMode>;
+
+  /*!
+   * \brief Check if the tuning succeeded
+   * \param mode_sort modesort_t structure produced by 'CalculateModeSort'
+   * \param closeness_factor fraction of largest standard time (omp, no omp) which is an acceptable
+   *        range
+   * \return a pair <bool, TuningMode> consisting of true or false signifying if the test appears to
+   *         have made the correct decision, and the TuningMode which was closest in timing to
+   *         the Auto mode.
+   */
+  static bool_mode_pair CheckCorrectTuning(const modesort_t &mode_sort,
+                                           const double closeness_factor = 0.25) {
+    CHECK_EQ(mode_sort.size(), 3U);
+
+    // Determine fastest normal mode
+    ::mxnet::op::tune::TuningMode fastest_standard_mode = ::mxnet::op::tune::Auto;
+    for (auto i = mode_sort.begin(), e = mode_sort.end(); i != e; ++i) {
+      if (i->second != ::mxnet::op::tune::Auto) {
+        fastest_standard_mode = i->second;
+        break;
+      }
+    }
+    CHECK_NE(fastest_standard_mode, ::mxnet::op::tune::Auto);
+
+    // We should be closest to the faster of NeverOMP and AlwaysOMP
+    // Take into account some variance, especially if NeverOMP and AlwaysOMP are close together
+    std::map<::mxnet::op::tune::TuningMode, double> mode2time;
+    for (auto i = mode_sort.begin(), e = mode_sort.end(); i != e; ++i) {
+      mode2time[i->second] = i->first;
+    }
+    const double time_auto = mode2time[::mxnet::op::tune::Auto];
+    const double time_no_omp = mode2time[::mxnet::op::tune::NeverOMP];
+    const double time_omp = mode2time[::mxnet::op::tune::AlwaysOMP];
+
+    // If difference between OMP and no OMP is < closeness_factor of largest of the two,
+    // then we just want to make sure we are close to both of these
+    const double fastest_standard_time = std::min(time_no_omp, time_omp);
+    const double allowed_difference = closeness_factor * fastest_standard_time;
+    const double mustbe_asfast = fastest_standard_time + allowed_difference;
+
+    // Figure out which one we are closest to and return that to help in the analysis
+    ::mxnet::op::tune::TuningMode closest_to;
+    if (fabs(time_auto - time_no_omp) < fabs(time_auto - time_omp)) {
+      closest_to = ::mxnet::op::tune::NeverOMP;
+    } else {
+      closest_to = ::mxnet::op::tune::AlwaysOMP;
+    }
+
+    if (time_auto <= mustbe_asfast || closest_to == fastest_standard_mode) {
+      return { true, closest_to };
+    }
+    return { false, closest_to };
+  }
+
+ public:
+  /*!
+   * \brief Given timing statistics, determine if 'Auto' mode made the correct choice.
+   * \param direction Compute direction for which to check (Forward or Backward)
+   * \param verbose If true, print the statistical info
+   * \return A map of shape vectors to a pair <bool, TuningMode> consisting of true or false
+   *         signifying if the test appears to have made the correct decision, and the TuningMode
+   *         which was closest in timing to the Auto mode.
+   */
+  shape_vec_to_bool_map CalculateModeSort(const test::op::TimingDirection direction,
+                                          bool verbose = true) const {
+    if (test::csv) {
+      verbose = false;
+    }
+    shape_vec_to_bool_map results;
+    // Incredibly inefficient method of grouping the results
+    for (const auto &i : timing_) {
+      // print shapes
+      const shape_vect &shapes = i.first;
+      if (verbose || test::csv) {
+        if (!test::csv) {
+          for (size_t x = 0, n = shapes.size(); x < n; ++x) {
+            const TShape &shape = shapes[x];
+            if (x) {
+              std::cout << ", ";
+            }
+            std::cout << shape;
+          }
+          const TShape &lhs_shape = shapes[0];
+          std::cout << " lhs=" << test::pretty_num(lhs_shape.Size()) << " items";
+          std::cout << "\t(" << TimingDirectionAsString(direction) << ")" << std::endl;
+        } else {
+          std::cout << test::pretty_num(shapes[0].Size()) << ",";
+        }
+      }
+      const auto &mode2timing = i.second;
+      modesort_t mode_sort;
+      for (const auto &j : mode2timing) {
+        const ::mxnet::op::tune::TuningMode mode = j.first;
+        const test::perf::timing_map_t &tm = j.second;
+        if (tm.find(direction) != tm.end()) {
+          const test::perf::TimingInstrument::Info &info = tm.find(direction)->second;
+          double duration = info.TimeEach();
+          mode_sort.insert({duration, mode});
+          if (test::csv) {
+            std::cout << TimingDirectionAsString(direction) << ","
+                      << ::mxnet::op::tune::TuningModeToString(mode) << ","
+                      << duration << ",";
+          }
+        }
+      }
+      if (test::csv) {
+        std::cout << std::endl << std::flush;
+      }
+      if (!mode_sort.empty()) {
+        // Now we have modes sorted by performance, fastest to slowest
+        const bool_mode_pair result = CheckCorrectTuning(mode_sort);
+        if (verbose && !test::csv) {
+          for (const auto &k : mode_sort) {
+            std::cout << "\t" << ::mxnet::op::tune::TuningModeToString(k.second)
+                      << ": " << k.first << " ms";
+            if (k.second == ::mxnet::op::tune::Auto) {
+              std::cout << " (" << ::mxnet::op::tune::TuningModeToString(result.second) << ")";
+            }
+            std::cout << std::endl;
+          }
+          std::cout << std::flush;
+          if (!result.first) {
+            std::cout << "*** WARNING: Wrong OMP state selected ***" << std::endl << std::flush;
+          }
+        }
+        CHECK(results.find(shapes) == results.end()) << "Duplicate entry for set of shapes";
+        results[shapes] = result;
+      }
+    }
+    return std::move(results);
+  }
+
+  /*!
+   * \brief Perform execution runs for a given forward (and optionally backward) operator
+   * \param kwargs Parameters for the operator
+   * \param op_name Name by which the operator is registered with nnvm
+   * \param backward_op_name Backward operator name
+   */
+  void TestTunedOperator(const kwargs_t &kwargs,
+                         const bool verbose,
+                         const std::vector<shape_vect>& shapevec_vectors,
+                         const char *op_name,
+                         const char *backward_op_name = COREOP_BWD_OP_NAME_VALUE_NONE) {
+    timing_.clear();
+    using namespace mxnet::op;
+    tuned_timing_t timing;
+    for (int x = 0; x < 1; ++x) {
+      for (auto mode : {::mxnet::op::tune::AlwaysOMP,
+                        ::mxnet::op::tune::Auto,
+                        ::mxnet::op::tune::NeverOMP}) {
+        if (verbose && !test::csv) {
+          std::cout << std::endl << ::mxnet::op::tune::TuningModeToString(mode)
+                    << std::endl << std::flush;
+        }
+
+        mxnet::op::OperatorTune<DType>::set_tuning_mode(mode);
+        const ShapesToPerfTimingMap shapes2perfmap = RunCoreOpTimingTest(false,
+                                                                         kwargs,
+                                                                         shapevec_vectors,
+                                                                         op_name,
+                                                                         backward_op_name);
+        for (const auto &item : shapes2perfmap) {
+          const shape_vect &shapes = item.first;
+          const test::perf::timing_map_t &tm = item.second;
+          timing_[shapes][mode] = tm;
+        }
+      }
+    }
+  }
+
+  /*!
+   * \brief Calculate the success rate of the run based upon Auto being close to the faster
+   *        OMP/non-OMP attempt
+   * \param modes List of directions to use in calculation (Forward, Backward). Empty list means all
+   * \param verbose Whether to print info
+   * \return Success rate ratio (#success/#TOTAL) (0.0-1.0)
+   */
+  float CalculateSuccessRate(std::vector<test::op::TimingDirection> directions = {},
+                             bool verbose = true) const {
+    size_t count = 0, success = 0;
+    if (directions.empty()) {
+      directions = {test::op::Forward, test::op::Backward};
+    }
+    for (const test::op::TimingDirection direction : directions) {
+      typename test::tune::TuningTester<DType>::shape_vec_to_bool_map res_fwd =
+        CalculateModeSort(direction, verbose);
+      for (auto iter = res_fwd.begin(), e = res_fwd.end(); iter != e; ++iter) {
+        ++count;
+        if (iter->second.first) {
+          ++success;
+        }
+      }
+    }
+    if (count) {
+      return static_cast<float>(success) / static_cast<float>(count);
+    }
+    return 1.0f;  // nothing ventured, nothing failed (glass-is-half-full angle)
+  }
+
+ private:
+  tuned_timing_t  timing_;
+};
+
+}  // namespace tune
+}  // namespace test
+}  // namespace mxnet
+
+#endif  // TEST_TUNE_H_
diff --git a/tests/cpp/include/test_util.h b/tests/cpp/include/test_util.h
index 33ca3c47d0..7b75691997 100644
--- a/tests/cpp/include/test_util.h
+++ b/tests/cpp/include/test_util.h
@@ -43,6 +43,7 @@ extern bool unitTestsWithCuda;
 extern bool debug_output;
 extern bool quick_test;
 extern bool performance_run;
+extern bool csv;
 
 /*! \brief Pause VTune analysis */
 struct VTunePause {
@@ -671,16 +672,20 @@ struct less_shapevect {
 };
 
 inline std::string pretty_num(uint64_t val) {
-  std::string res, s = std::to_string(val);
-  size_t ctr = 0;
-  for (int i = static_cast<int>(s.size()) - 1; i >= 0; --i, ++ctr) {
-    if (ctr && (ctr % 3) == 0) {
-      res += ",";
+  if (!test::csv) {
+    std::string res, s = std::to_string(val);
+    size_t ctr = 0;
+    for (int i = static_cast<int>(s.size()) - 1; i >= 0; --i, ++ctr) {
+      if (ctr && (ctr % 3) == 0) {
+        res += ",";
+      }
+      res.push_back(s[i]);
     }
-    res.push_back(s[i]);
+    std::reverse(res.begin(), res.end());
+    return res;
+  } else {
+    return std::to_string(val);
   }
-  std::reverse(res.begin(), res.end());
-  return res;
 }
 
 /*! \brief Change a value during the scope of this declaration */
diff --git a/tests/cpp/misc/memory_test.cc b/tests/cpp/misc/memory_test.cc
index a36f7f93ae..8f4e8c25e8 100644
--- a/tests/cpp/misc/memory_test.cc
+++ b/tests/cpp/misc/memory_test.cc
@@ -79,7 +79,7 @@ TEST(MEMORY_TEST, MemsetAndMemcopyPerformance) {
 
       start = test::perf::getNannoTickCount();
       #pragma omp parallel for num_threads(GetOMPThreadCount())
-      for (int i = 0; i < test_size; ++i) {
+      for (int i = 0; i < static_cast<int>(test_size); ++i) {
         src[i] = 42;
       }
       const uint64_t omp_set_time = test::perf::getNannoTickCount() - start;
@@ -94,7 +94,7 @@ TEST(MEMORY_TEST, MemsetAndMemcopyPerformance) {
 
       start = test::perf::getNannoTickCount();
       #pragma omp parallel for num_threads(GetOMPThreadCount())
-      for (int i = 0; i < test_size; ++i) {
+      for (int i = 0; i < static_cast<int>(test_size); ++i) {
         dest[i] = src[i];
       }
       const uint64_t omp_copy_time = test::perf::getNannoTickCount() - start;
diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc
index 24b5600a71..80efbd51a9 100644
--- a/tests/cpp/operator/batchnorm_test.cc
+++ b/tests/cpp/operator/batchnorm_test.cc
@@ -1424,7 +1424,7 @@ static TShape MakeShape(const std::vector<index_t>& shape,
   CHECK_LT(channelAxis, shape.size() + 1);
   const index_t dim = index_t(shape.size()) + 1;
   TShape newShape(dim);
-  for (size_t x = 0; x < channelAxis; ++x) {
+  for (size_t x = 0; x < static_cast<size_t>(channelAxis); ++x) {
     newShape[x] = index_t(shape[x]);
   }
   newShape[channelAxis] = index_t(channelCount);
diff --git a/tests/cpp/operator/broadcast_perf.cc b/tests/cpp/operator/broadcast_perf.cc
index 6986c4d27e..f4ac06228e 100644
--- a/tests/cpp/operator/broadcast_perf.cc
+++ b/tests/cpp/operator/broadcast_perf.cc
@@ -26,39 +26,42 @@
 #include <mxnet/tensor_blob.h>
 #include "../include/test_op_runner.h"
 #include "../include/test_core_op.h"
+#include "../include/test_tune.h"
 
 using namespace mxnet;
 
 using kwargs_t = test::op::kwargs_t;
 
-template<typename DType = float>
-static void RunCoreOpBidirectional(const bool isGPU,
-                                   const kwargs_t& op_kwargs,
-                                   const char *op_name,
-                                   const char *backward_op_name = "") {
-  const std::vector<TShape> shapes = { {2, 3}, {2, 1} };
-  test::op::CoreOpExecutor<DType> op(isGPU, shapes);
-  op.set_verbose(false);
-
-  op.Init(op.ArgsWithOpName(op_kwargs, op_name, backward_op_name));
-
-  PRINT_NDARRAYS(op.ctx().run_ctx, op.inputs());
-  PRINT_NDARRAYS(op.ctx().run_ctx, op.outputs());
-  op.Execute();
-  PRINT_NDARRAYS(op.ctx().run_ctx, op.outputs());
-  if (op.HasBackward()) {
-    PRINT_NDARRAYS(op.ctx().run_ctx, op.bwd_inputs());
-    PRINT_NDARRAYS(op.ctx().run_ctx, op.bwd_outputs());
-    op.ExecuteBackward();
-    PRINT_NDARRAYS(op.ctx().run_ctx, op.bwd_outputs());
+static const std::vector<std::vector<TShape>> broadcast_shapes() {
+  std::vector<std::vector<TShape>> shapes;
+  if (test::performance_run) {
+    shapes = {
+      { {28,  28},  {28, 1} },
+      { {64,  28},  {1, 28} },
+      { {28,  28, 28},  {28, 28, 1} },
+      { {128, 128}, {1, 128} },
+      { {1024, 64}, {1, 64} },
+      { {1024, 12, 256}, {1024, 1, 1} },
+      { {2560, 1280}, {2560, 1} }
+    };
+  } else {
+    shapes = {
+      // Non-performance dataset acts as a sanity test
+      { {28,  28},  {28, 1} },
+      { {128, 128}, {128, 1} },
+      { {28,  28, 28},  {28, 28, 1} }
+    };
   }
+  return std::move(shapes);
 }
 
 /*!
  * \brief Generic bidirectional sanity test
  */
 TEST(BROADCAST_PERF, ExecuteBidirectional) {
-  RunCoreOpBidirectional(false, {}, "broadcast_add", "_backward_broadcast_add");
+  test::op::BasicRunCoreOpBidirectional(false, true, {},
+                                        { broadcast_shapes()[0] },
+                                        "broadcast_add", "_backward_broadcast_add");
 }
 
 template<typename DType = float>
@@ -71,23 +74,10 @@ static void RunCoreOpTimingTest(const bool isGPU,
 
   // prime code and cache before the performance runs
   test::op::CoreOperatorRunner<DType> runner;
-  runner.RunBidirectional(false, { {2, 3}, {2, 1} }, kwargs, 1);
+  runner.RunBidirectional(false, { broadcast_shapes()[0] }, kwargs, 1);
 
   // Do the performance runs
-  std::vector<std::vector<TShape>> shapes;
-  if (test::performance_run) {
-    shapes = {
-      { {28,  28},  {28, 1} },
-      { {18,  32} , {18, 1} },
-      { {128, 128}, {128, 1} },
-      { {2560, 1280}, {2560, 1} }
-    };
-  } else {
-    shapes = {
-      { {28,  28},  {28, 1} },
-      { {128, 128}, {128, 1} }
-    };
-  }
+  std::vector<std::vector<TShape>> shapes = broadcast_shapes();
   const char *pu = isGPU ? "GPU" : "CPU";
   for (const std::vector<TShape> &shape : shapes) {
     runner.TimingTest(std::string(op_name) + " Operator " + pu, isGPU, false, kwargs,
@@ -99,7 +89,11 @@ static void RunCoreOpTimingTest(const bool isGPU,
  * \brief ActivationOp timing test for CPU
  */
 TEST(BROADCAST_PERF, TimingCPU) {
-  RunCoreOpTimingTest(false, {}, "broadcast_add", "_backward_broadcast_add");
+  if (!test::csv) {
+    RunCoreOpTimingTest(false, {}, "broadcast_add", "_backward_broadcast_add");
+  } else {
+    RunCoreOpTimingTest(false, {}, "broadcast_add", COREOP_BWD_OP_NAME_VALUE_NONE);
+  }
 }
 
 #if MXNET_USE_CUDA == 1
@@ -111,3 +105,84 @@ TEST(BROADCAST_PERF, TimingGPU) {
 }
 #endif  // MXNET_USE_CUDA == 1
 
+/*!
+ * \brief Rune a tuning evaluation
+ * \tparam DType Data type for which to evaluate tuning
+ */
+template<typename DType>
+static float EvaluateTune(bool verbose = true) {
+  std::vector<std::pair<std::string, std::string>> binary_operators;
+  if (test::performance_run) {
+    binary_operators = {
+      {"broadcast_add", COREOP_BWD_OP_NAME_VALUE_NONE /*"_backward_broadcast_add"*/},
+      {"broadcast_mul", COREOP_BWD_OP_NAME_VALUE_NONE /*"_backward_broadcast_mul"*/},
+      {"broadcast_div", COREOP_BWD_OP_NAME_VALUE_NONE /*"_backward_broadcast_div"*/}
+    };
+  } else {
+    binary_operators = {
+      {"broadcast_add", COREOP_BWD_OP_NAME_VALUE_NONE /*"_backward_broadcast_add"*/}
+    };
+  }
+  std::vector<float> rates;
+  for (size_t i = 0, n = binary_operators.size(); i < n; ++i) {
+    test::tune::TuningTester<DType> tuningTester;
+    std::cout << "******************************" << std::endl;
+    std::cout << "Operators: " << binary_operators[i].first
+              << ", " << binary_operators[i].second
+              << " for type: " << test::type_name<DType>()
+              << std::endl;
+    std::cout << "******************************" << std::endl;
+
+    // Prime code and cache
+    test::op::BasicRunCoreOpBidirectional(false, false, {},
+                                          { broadcast_shapes()[0] },
+                                          binary_operators[i].first.c_str(),
+                                          binary_operators[i].second.c_str());
+
+    // Do the performance runs
+    std::vector<std::vector<TShape>> shapes = broadcast_shapes();
+
+    tuningTester.TestTunedOperator({}, true, shapes,
+                                   binary_operators[i].first.c_str(),
+                                   binary_operators[i].second.c_str());
+    rates.push_back(tuningTester.CalculateSuccessRate({}, verbose));
+  }
+  return std::accumulate(rates.begin(), rates.end(), 0.0f) / rates.size();
+}
+
+/*! \brief ActivationOp timing test for CPU for float */
+TEST(BROADCAST_PERF, EvaluateTuneTestFloat) {
+  typedef float DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for double */
+TEST(BROADCAST_PERF, EvaluateTuneTestDouble) {
+  typedef double DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+TEST(BROADCAST_PERF, EvaluateTuneTestFloat16) {
+  typedef mshadow::half::half_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for int8_t */
+TEST(BROADCAST_PERF, EvaluateTuneTestInt8) {
+  typedef uint8_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for int32_t */
+TEST(BROADCAST_PERF, EvaluateTuneTestInt32) {
+  typedef int32_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for int64_t */
+TEST(BROADCAST_PERF, EvaluateTuneTestInt64) {
+  typedef int64_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+
diff --git a/tests/cpp/operator/tune/operator_tune_test.cc b/tests/cpp/operator/tune/operator_tune_test.cc
new file mode 100644
index 0000000000..68444ec3f7
--- /dev/null
+++ b/tests/cpp/operator/tune/operator_tune_test.cc
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <mxnet/tensor_blob.h>
+#include "../../src/operator/activation-inl.h"
+#include "../../src/operator/operator_tune-inl.h"
+#include "../include/test_op_runner.h"
+#include "../include/test_core_op.h"
+#include "../include/test_tune.h"
+
+using namespace mxnet;
+
+/*!
+ * \brief ActivationOp timing test for CPU
+ */
+TEST(OMP_TUNING, ShowAllTunedOps) {
+  const std::unordered_set<std::string>& op_names = op::OperatorTune<float>::TunedOperatorNames();
+  for (auto iter = op_names.begin(), e_iter = op_names.end(); iter != e_iter; ++iter) {
+    std::cout << *iter << std::endl;
+  }
+}
+
+using kwargs_t = test::op::kwargs_t;
+
+/*!
+ * \brief Generic bidirectional sanity test
+ */
+TEST(OMP_TUNING, ExecuteBidirectional) {
+  test::op::BasicRunCoreOpBidirectional(false, true, {}, {{5, 5}}, "elemwise_add", "_backward_add");
+}
+
+/* Some test results:
+ * AWS c4.8xlarge:
+  Success rate for type float: 0.90278
+  Success rate for type double: 0.88889
+  Success rate for type mshadow::half::half_t: 0.83333
+  Success rate for type unsigned char: 0.86111
+  Success rate for type int: 0.95833
+  Success rate for type long: 0.88889
+ * desktop: 12-core (6 real CPU cores + hyperthreading)
+  Success rate for type float: 0.78125
+  Success rate for type double: 0.85417
+  Success rate for type mshadow::half::half_t: 0.84375
+  Success rate for type unsigned char: 0.80208
+  Success rate for type int: 0.94444
+  Success rate for type long: 1.00000
+ */
+
+/*!
+ * \brief Rune a tuning evaluation
+ * \tparam DType Data type for which to evaluate tuning
+ */
+template<typename DType>
+static float EvaluateTune(const bool verbose = true) {
+  std::vector<std::pair<std::string, std::string>> binary_operators;
+  if (test::csv) {
+    binary_operators = {
+      {"elemwise_add", COREOP_BWD_OP_NAME_VALUE_NONE}
+    };
+  } else if (test::performance_run) {
+    binary_operators = {
+      {"relu",         ""},  // Code can figure out what the backward op is for some
+      {"sigmoid",      ""},
+      {"sqrt",         ""},
+      {"elemwise_add", "_backward_add"},
+      {"elemwise_mul", "_backward_mul"},
+      {"elemwise_div", "_backward_div"}
+    };
+  } else {
+    binary_operators = {
+      {"elemwise_add", "_backward_add"}
+    };
+  }
+  std::vector<float> rates;
+  for (size_t i = 0, n = binary_operators.size(); i < n; ++i) {
+    test::tune::TuningTester<DType> tuningTester;
+    std::cout << "******************************" << std::endl;
+    std::cout << "Operators: " << binary_operators[i].first
+              << ", " << binary_operators[i].second
+              << " for type: " << test::type_name<DType>()
+              << std::endl;
+    std::cout << "******************************" << std::endl;
+
+    // Do the performance runs
+    std::vector<std::vector<TShape>> shapes;
+    if (test::performance_run || test::csv) {
+      shapes = {
+        {{1,  1, 28,  28}},
+        {{1,  3, 28,  28}},
+        {{50, 1, 18,  32}},
+        {{25, 3, 64,  64}},
+        {{10, 3, 128, 128}},
+        {{20, 3, 128, 128}},
+        {{30, 3, 128, 128}},
+        {{30, 3, 256, 128}},
+      };
+    } else {
+      shapes = {
+        // Non-performance dataset acts as a sanity test
+        {{1,  1, 28, 28}},
+        {{50, 3, 18, 32}}
+      };
+    }
+
+    tuningTester.TestTunedOperator({}, verbose, shapes,
+                                   binary_operators[i].first.c_str(),
+                                   binary_operators[i].second.c_str());
+    rates.push_back(tuningTester.CalculateSuccessRate());
+  }
+  return std::accumulate(rates.begin(), rates.end(), 0.0f) / rates.size();
+}
+
+/*! \brief ActivationOp timing test for CPU for float */
+TEST(OMP_TUNING, EvaluateTuneTestFloat) {
+  typedef float DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for double */
+TEST(OMP_TUNING, EvaluateTuneTestDouble) {
+  typedef double DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for float16 */
+TEST(OMP_TUNING, EvaluateTuneTestFloat16) {
+  typedef mshadow::half::half_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for int8_t */
+TEST(OMP_TUNING, EvaluateTuneTestInt8) {
+  typedef uint8_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for int32_t */
+TEST(OMP_TUNING, EvaluateTuneTestInt32) {
+  typedef int32_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+/*! \brief ActivationOp timing test for CPU for int64_t */
+TEST(OMP_TUNING, EvaluateTuneTestInt64) {
+  typedef int64_t DType;
+  const float result = EvaluateTune<DType>();
+  std::cout << "Success rate for type " << test::type_name<DType>() << ": " << result << std::endl;
+}
+
diff --git a/tests/cpp/test_main.cc b/tests/cpp/test_main.cc
index eaf9e3c219..93e646f682 100644
--- a/tests/cpp/test_main.cc
+++ b/tests/cpp/test_main.cc
@@ -35,7 +35,8 @@ static bool dumpCallback(const google_breakpad::MinidumpDescriptor& descriptor,
 }
 #endif
 
-namespace mxnet { namespace test {
+namespace mxnet {
+namespace test {
 bool unitTestsWithCuda = false;
 #ifdef NDEBUG
 bool debug_output = false;
@@ -44,7 +45,9 @@ bool debug_output = false;
 #endif
 bool quick_test = false;
 bool performance_run = false;
-}}
+bool csv = false;
+}  // namespace test
+}  // namespace mxnet
 
 #if MXNET_USE_CUDA
 
@@ -89,6 +92,8 @@ int main(int argc, char ** argv) {
       mxnet::test::debug_output = true;
     } else if (!strcmp(argv[x], "--perf")) {
       mxnet::test::performance_run = true;
+    } else if (!strcmp(argv[x], "--csv")) {
+      mxnet::test::csv = true;
     } else if (!strcmp(argv[x], "--quick") || !strcmp(argv[x], "-q")) {
       mxnet::test::quick_test = true;
     }
diff --git a/tools/tblgen/tblgen.cc b/tools/tblgen/tblgen.cc
new file mode 100644
index 0000000000..a70a2cb1a1
--- /dev/null
+++ b/tools/tblgen/tblgen.cc
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <iostream>
+#include <fstream>
+#include <algorithm>
+#include <cstring>
+#include <sstream>
+#include <list>
+
+/*!
+ * \brief Some string utility functions
+ *        These are meant to be functional and simple rather than performant
+ */
+struct StringUtil {
+  /*!
+   * \brief Terim whitespace from beninning and end of string
+   * \param s String to trimp
+   * \return reference to the modified string. This is the same std::string object as what was
+   *         supplied in the parameters
+   */
+  static std::string &trim(std::string *s) {
+    s->erase(s->begin(), std::find_if(s->begin(), s->end(), [](int ch) {
+      return !std::isspace(ch);
+    }));
+    s->erase(std::find_if(s->rbegin(), s->rend(), [](int ch) {
+      return !std::isspace(ch);
+    }).base(), s->end());
+    return *s;
+  }
+  /*!
+   * \brief Tokenize a string into a list of tokens
+   * \param s String to tokenize
+   * \return std::list of tokens
+   */
+  static std::list<std::string> string2list(const std::string &s) {
+    std::list<std::string> res;
+    std::istringstream iss(s);
+    std::string token;
+    while (std::getline(iss, token, ',')) {
+      trim(&token);
+      if (!token.empty()) {
+        res.push_back(token);
+      }
+    }
+    return std::move(res);
+  }
+  /*!
+   * \brief Replace all occurances, in a string, of a substring with another substring
+   * \param str String to modify
+   * \param from Substring to replace
+   * \param to New substring to appear in modified string
+   */
+  static std::string& replaceAll(std::string& str, const std::string& from, const std::string& to) {
+    if(!from.empty()) {
+      size_t start_pos = 0;
+      while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
+        str.replace(start_pos, from.length(), to);
+        start_pos += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx'
+      }
+    }
+    return str;
+  }
+};
+
+#define REPLACE_TOKEN_PRE   "%ITEM_"
+#define REPLACE_TOKEN_POST  "%"
+
+static std::string load_template_file(const std::string& file_name) {
+  std::ifstream tmpl;
+  tmpl.open(file_name);
+  if(tmpl.fail()) {
+    std::cerr << "Error opening file: \"" << file_name << "\": " << strerror(errno) << std::endl;
+    return "";
+  }
+  std::stringstream tmpl_buff;
+  while(!tmpl.eof()) {
+    std::string line;
+    tmpl >> line;
+    tmpl_buff << line;
+  }
+  tmpl.close();
+  std::string tmpl_string = tmpl_buff.str();
+  StringUtil::trim(&tmpl_string);
+  return std::move(tmpl_string);
+}
+
+int main(int argc, char **argv) {
+  int rc = 0;
+  if(argc != 4) {
+    std::cerr << "Usage:" << std::endl
+              << argv[0]
+              << " <input file> <output file> [template file]" << std::endl;
+  }
+  std::ifstream input, tmpl;
+  input.open(argv[1]);
+  if(input.fail()) {
+    std::cerr << "Error opening file: \"" << argv[1] << "\": " << strerror(errno) << std::endl;
+    return errno;
+  }
+
+  std::ofstream output;
+  output.open(argv[2], std::ios_base::trunc|std::ios_base::out);
+  if(output.fail()) {
+    std::cerr << "Error opening file: \"" << argv[3] << "\": " << strerror(errno) << std::endl;
+    return errno;
+  }
+
+  std::string tmpl_string;
+  if(argc > 2) {
+    tmpl_string = load_template_file(argv[3]);
+    if (tmpl_string.empty()) {
+      return errno ? errno : -1;
+    }
+  }
+
+  output << "/* This file is automatically generated. DO NOT EDIT! */" << std::endl << std::endl;
+
+  while(!input.eof()) {
+    std::string token_line;
+    input >> token_line;
+    StringUtil::trim(&token_line);
+    if(!token_line.empty() && token_line[0] != '#') {
+
+      // Look for new template directive
+      if(token_line[0] == '>') {
+        std::string new_template_file = &token_line[1];
+        StringUtil::trim(&new_template_file);
+        tmpl_string = load_template_file(new_template_file);
+        continue;
+      }
+      if(tmpl_string.empty()) {
+        std::cerr << "No valid template" << std::endl;
+        rc = -1;
+        break;
+      }
+
+      std::list<std::string> tokens = StringUtil::string2list(token_line);
+      std::string replaced = tmpl_string;
+      size_t item = 0;
+      for (auto iter = tokens.begin(), e_iter = tokens.end(); iter != e_iter; ++iter, ++item) {
+        std::string token = *iter;
+        StringUtil::trim(&token);
+        std::string to_replace = REPLACE_TOKEN_PRE;
+        to_replace += std::to_string(item);
+        to_replace += REPLACE_TOKEN_POST;
+        StringUtil::replaceAll(replaced, to_replace, token);
+      }
+      output << replaced << std::endl;
+      if (output.fail()) {
+        std::cerr << "Error writing file: \"" << argv[3] << "\": " << strerror(errno) << std::endl;
+        return errno;
+      }
+    }
+  }
+  output.close();
+  input.close();
+
+  return rc;
+}
+


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services