You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/25 22:00:43 UTC

[GitHub] piiswrong closed pull request #8719: Tune without Launch specialization macros

piiswrong closed pull request #8719: Tune without Launch specialization macros
URL: https://github.com/apache/incubator-mxnet/pull/8719
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index ec0b9c2530..5717327f87 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -154,7 +154,7 @@ std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
     auto nop = Node::Create();
     nop->attrs.op = _NoGrad;
     nop->attrs.name = "NoGradient";
-    uint32_t j = 0, k = 0;
+    uint32_t k = 0;
     for (const auto& i : fwd_graph_.indexed_graph().input_nodes()) {
       if (auxs.count(i)) {
         ret.emplace_back(NodeEntry{nop, 0, 0});
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index de94c8669a..49aa001910 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -301,7 +301,7 @@ class KVStoreDistServer {
           using namespace mshadow;
           Engine::Get()->PushAsync(
             [recved, merged, out](RunContext ctx, Engine::CallbackOnComplete on_complete) {
-              op::ElemwiseBinaryOp::ComputeEx<cpu, mshadow::op::plus>(
+              op::ElemwiseBinaryOp::ComputeEx<cpu, op::mshadow_op::plus>(
                 {}, {}, {recved, merged.array}, {kWriteTo}, {out});
               on_complete();
             }, recved.ctx(), const_vars, {out.var()},
diff --git a/src/ndarray/ndarray_function-inl.h b/src/ndarray/ndarray_function-inl.h
index 821ef2c129..a80d9db363 100644
--- a/src/ndarray/ndarray_function-inl.h
+++ b/src/ndarray/ndarray_function-inl.h
@@ -414,7 +414,7 @@ void ElementwiseSum<DEVICE>(const std::vector<TBlob> source,
       }
       default: {
         Tensor<xpu, 2, DType> in_0 = source[0].FlatTo2D<xpu, DType>(s);
-        out = F<mshadow::op::identity>(in_0);
+        out = F<op::mshadow_op::identity>(in_0);
         for (size_t i = 1; i < source.size(); ++i) {
           out += source[i].FlatTo2D<xpu, DType>(s);
         }
diff --git a/src/ndarray/ndarray_function.h b/src/ndarray/ndarray_function.h
index 98ad3e9257..6e6df3954c 100644
--- a/src/ndarray/ndarray_function.h
+++ b/src/ndarray/ndarray_function.h
@@ -46,19 +46,19 @@ struct BinaryBase {
 
 // operators
 struct Plus : public BinaryBase {
-  typedef mshadow::op::plus mshadow_op;
+  typedef op::mshadow_op::plus mshadow_op;
 };
 
 struct Minus : public BinaryBase {
-  typedef mshadow::op::minus mshadow_op;
+  typedef op::mshadow_op::minus mshadow_op;
 };
 
 struct Mul : public BinaryBase {
-  typedef mshadow::op::mul mshadow_op;
+  typedef op::mshadow_op::mul mshadow_op;
 };
 
 struct Div : public BinaryBase {
-  typedef mshadow::op::div mshadow_op;
+  typedef op::mshadow_op::div mshadow_op;
 };
 
 struct Mod : public BinaryBase {
diff --git a/src/operator/activation-inl.h b/src/operator/activation-inl.h
index a39fe9ab11..82800793c0 100644
--- a/src/operator/activation-inl.h
+++ b/src/operator/activation-inl.h
@@ -110,7 +110,7 @@ class ActivationOp : public Operator {
     if (sz) {
       MXNET_ASSIGN_REQ_SWITCH(req[activation::kData], Req, {
         mxnet_op::Kernel<mxnet_op::op_with_req<
-          mxnet::op::mxnet_op::backward_grad<BackwardOp>, Req>, xpu>::Launch(
+          mxnet::op::mxnet_op::backward_grad_tuned<BackwardOp>, Req>, xpu>::Launch(
           s, sz,
           m_in_grad.dptr<DType>(),
           m_out_grad.dptr<DType>(),
diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h
index ad1d1ec91f..ef58c519aa 100644
--- a/src/operator/contrib/ctc_loss-inl.h
+++ b/src/operator/contrib/ctc_loss-inl.h
@@ -426,7 +426,7 @@ class CTCLossOp : public Operator {
                             workspace_bytes));
 
     if (req_grad) {
-      mxnet_op::SoftmaxGrad<mshadow::op::mul, mxnet_op::softmax_bwd>(s,
+      mxnet_op::SoftmaxGrad<mshadow_op::mul, mxnet_op::softmax_bwd>(s,
           prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2);
       Assign(grad, mxnet::kWriteInplace, grad * alphabet_size);
     }
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 10be627ee7..af7ef513f1 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -30,7 +30,7 @@
 #include "math.h"
 #include "math_functions-inl.h"
 #include "special_functions-inl.h"
-#include "./mxnet_op.h"
+#include "./operator_tune.h"
 
 #ifdef __CUDACC__
 #include <cuda_fp16.h>
@@ -40,24 +40,6 @@ namespace mxnet {
 namespace op {
 namespace mshadow_op {
 
-/*!
- * \brief Use the 'MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD' macro outside of the mshadow_op namespace
- *        See mxnet_op.h for a description of 'MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD'
- *
- * \note An entry for the operator must also be added in operator_tune.cc, which will register it
- *       for auto-tuning and also hold its workload weight
- */
-#define MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(__op$) \
-  } MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow_op::__op$) namespace mshadow_op {  // NOLINT(*)
-/*!
- * \brief Use the 'MXNET_TUNABLE_MSHADOW_OP_BACKWARD' macro outside of the mshadow_op namespace
- *        See mxnet_op.h for a description of 'MXNET_TUNABLE_MSHADOW_OP_BACKWARD'
- *
- * \note An entry for the operator must also be added in operator_tune.cc, which will register it
- *       for auto-tuning and also hold its workload weight
- */
-#define MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(__op$) \
-  }  MXNET_TUNABLE_MSHADOW_OP_BACKWARD(mshadow_op::__op$) namespace mshadow_op {  // NOLINT(*)
 #ifdef __CUDA_ARCH__
 __constant__ const float PI = 3.14159265358979323846;
 #else
@@ -68,41 +50,36 @@ using std::enable_if;
 using std::is_unsigned;
 
 #define MXNET_UNARY_MATH_OP(name, expr) \
-  struct name { \
+  struct name : public mxnet_op::tunable { \
     template<typename DType> \
     MSHADOW_XINLINE static DType Map(DType a) { \
       return DType(expr); \
     } \
-  }; \
-  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
-
+  }
 
 #define MXNET_UNARY_MATH_OP_NC(name, expr) \
-  struct name { \
+  struct name : public mxnet_op::tunable { \
     template<typename DType> \
     MSHADOW_XINLINE static DType Map(DType a) { \
       return (expr); \
     } \
-  }; \
-  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
+  }
 
 #define MXNET_BINARY_MATH_OP(name, expr) \
-  struct name { \
+  struct name : public mxnet_op::tunable { \
     template<typename DType> \
     MSHADOW_XINLINE static DType Map(DType a, DType b) { \
       return DType(expr); \
     } \
-  }; \
-  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
+  }
 
 #define MXNET_BINARY_MATH_OP_NC(name, expr) \
-  struct name { \
+  struct name : public mxnet_op::tunable  { \
     template<typename DType> \
     MSHADOW_XINLINE static DType Map(DType a, DType b) { \
       return (expr); \
     } \
-  }; \
-  MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(name)
+  }
 
 #define MXNET_SIMPLE_UNARY_MATH_OP(name) MXNET_UNARY_MATH_OP(name, math::name(a))
 
@@ -116,6 +93,14 @@ MXNET_BINARY_MATH_OP_NC(left, a);
 
 MXNET_BINARY_MATH_OP_NC(right, b);
 
+MXNET_BINARY_MATH_OP_NC(mul, a * b);
+
+MXNET_BINARY_MATH_OP_NC(div, a / b);
+
+MXNET_BINARY_MATH_OP_NC(plus, a + b);
+
+MXNET_BINARY_MATH_OP_NC(minus, a - b);
+
 MXNET_UNARY_MATH_OP(negation, -a);
 
 MXNET_UNARY_MATH_OP(reciprocal, 1.0f / math::id(a));
@@ -145,7 +130,7 @@ MXNET_SIMPLE_UNARY_MATH_OP(tanh);
 MXNET_UNARY_MATH_OP(tanh_grad, 1.0f - math::sqr(a));
 
 /*! \brief SoftReLU, also known as softplus activation */
-struct softrelu {
+struct softrelu : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a) {
     // Avoid overflow of exp for large inputs.
@@ -158,7 +143,6 @@ struct softrelu {
     }
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(softrelu)
 
 MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a));
 
@@ -173,13 +157,12 @@ MXNET_UNARY_MATH_OP(log_grad, 1.0f / math::id(a));
 MXNET_SIMPLE_UNARY_MATH_OP(log10);
 
 // Constant is 1 / log(10)
-struct log10_grad {
+struct log10_grad : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a) {
     return DType(0.4342944819f / static_cast<float>(a));
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(log10_grad)
 
 template<>
 MSHADOW_XINLINE double log10_grad::Map<double>(double a) {
@@ -189,13 +172,12 @@ MSHADOW_XINLINE double log10_grad::Map<double>(double a) {
 MXNET_SIMPLE_UNARY_MATH_OP(log2);
 
 // Constant is 1 / log(2)
-struct log2_grad {
+struct log2_grad : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a) {
     return DType(1.442695041f / static_cast<float>(a));
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(log2_grad)
 
 template<>
 MSHADOW_XINLINE double log2_grad::Map<double>(double a) {
@@ -275,7 +257,7 @@ MXNET_BINARY_MATH_OP_NC(threshold, a < b ? DType(1) : DType(0));
 MXNET_UNARY_MATH_OP(abs, math::fabs(a)); // NOLINT(*)
 
 /*! \brief used for generate element of sign */
-struct sign {
+struct sign : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
   Map(DType a) {
@@ -290,7 +272,6 @@ struct sign {
     return DType(0);
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(sign)
 
 MXNET_UNARY_MATH_OP_NC(sign_grad, DType(0));
 
@@ -352,7 +333,7 @@ MXNET_SIMPLE_UNARY_MATH_OP(floor);
 MXNET_SIMPLE_UNARY_MATH_OP(trunc);
 
 /*! \brief used to round number to nearest integer */
-struct rint {
+struct rint : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a) {
     auto floor = math::floor(a);
@@ -361,10 +342,9 @@ struct rint {
     return DType((af - floor) <= (ceil - af) ? floor : ceil);
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(rint)
 
 /*! \brief used to round number to integer nearest to 0 */
-struct fix {
+struct fix : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a) {
     auto floor = math::floor(a);
@@ -372,7 +352,6 @@ struct fix {
     return DType((floor > 0 ? floor : -floor) < (ceil > 0 ? ceil : -ceil) ? floor : ceil);
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(fix)
 
 /*! \brief used for generate gradient of MAE loss*/
 MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1));
@@ -401,7 +380,7 @@ MXNET_BINARY_MATH_OP(rdiv, math::id(b) / math::id(a));
 
 MXNET_BINARY_MATH_OP(rdiv_grad, -math::id(b) / math::sqr(a));
 
-struct mod {
+struct mod : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
   Map(DType a, DType b) {
@@ -435,7 +414,6 @@ struct mod {
     }
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(mod)
 
 template<>
 MSHADOW_XINLINE mshadow::half::half2_t mod::Map<mshadow::half::half2_t>
@@ -444,14 +422,12 @@ MSHADOW_XINLINE mshadow::half::half2_t mod::Map<mshadow::half::half2_t>
   return a%b;
 }
 
-struct mod_grad {
+struct mod_grad : public mxnet_op::tunable  {
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a, DType b) {
     return DType(0);
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(mod_grad)
-
 template<>
 MSHADOW_XINLINE double mod_grad::Map<double>(double a, double b) {
   return 1.0;
@@ -481,14 +457,12 @@ MSHADOW_XINLINE mshadow::half::half2_t mod_grad::Map<mshadow::half::half2_t>
   return result;
 }
 
-struct mod_rgrad {
+struct mod_rgrad : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a, DType b) {
     return DType(0);
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(mod_rgrad)
-
 template<>
 MSHADOW_XINLINE double mod_rgrad::Map<double>(double a, double b) {
   return -::floor(a/b);
@@ -518,7 +492,7 @@ MSHADOW_XINLINE mshadow::half::half2_t mod_rgrad::Map<mshadow::half::half2_t>
 #endif
 }
 
-struct rmod {
+struct rmod : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
   Map(DType a, DType b) {
@@ -552,7 +526,6 @@ struct rmod {
     }
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(rmod)
 
 template<>
 MSHADOW_XINLINE mshadow::half::half2_t rmod::Map<mshadow::half::half2_t>
@@ -567,8 +540,6 @@ struct rmod_grad {
     return DType(0);
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(rmod_grad)
-
 template<>
 MSHADOW_XINLINE double rmod_grad::Map<double>(double a, double b) {
   return -::floor(b/a);
@@ -598,7 +569,7 @@ MSHADOW_XINLINE mshadow::half::half2_t rmod_grad::Map<mshadow::half::half2_t>
 #endif
 }
 
-struct clip {
+struct clip : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType x, DType bound) {
     if (x > bound) {
@@ -610,13 +581,12 @@ struct clip {
     }
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(clip)
 
 /***** gamma ******/
 
 MXNET_UNARY_MATH_OP(gamma, math::tgamma(a));
 
-struct gamma_grad {
+struct gamma_grad : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a) {
     // default implementation using floating precision
@@ -624,7 +594,6 @@ struct gamma_grad {
     return DType(math::tgamma(af) * special_functions::cephes::psi<float>(af));
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(gamma_grad)
 
 template<>
 MSHADOW_XINLINE double gamma_grad::Map<double>(double a) {
@@ -635,14 +604,13 @@ MSHADOW_XINLINE double gamma_grad::Map<double>(double a) {
 
 MXNET_UNARY_MATH_OP(gammaln, math::lgamma(a));
 
-struct gammaln_grad {
+struct gammaln_grad : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a) {
     // default implementation using floating precision
     return DType(special_functions::cephes::psi<float>(a));
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(gammaln_grad)
 
 template<>
 MSHADOW_XINLINE double gammaln_grad::Map<double>(double a) {
@@ -658,7 +626,7 @@ MSHADOW_XINLINE double gammaln_grad::Map<double>(double a) {
  * smooth_l1_loss = w_out * f(w_in * x)
  * with w_in, w_out provided by input_data.
  */
-struct smooth_l1_loss {
+struct smooth_l1_loss : public mxnet_op::tunable {
   // a is x, b is sigma
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a, DType b) {
@@ -674,13 +642,12 @@ struct smooth_l1_loss {
     }
   }
 };  // struct smooth_l1_loss
-MSHADOW_OP_DECLARE_TUNABLE_FWD_AND_BWD(smooth_l1_loss)
 
 /* The derivative of smooth l1 loss is
  * f'(x) = sigma^2 * x, |x| < 1 / sigma^2
  *       = sign(x),     otherwise
  */
-struct smooth_l1_gradient {
+struct smooth_l1_gradient : public mxnet_op::tunable {
   // a is x, b is sigma2
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a, DType b) {
@@ -696,7 +663,6 @@ struct smooth_l1_gradient {
     }
   }
 };  // struct smooth_l1_derivative
-MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(smooth_l1_gradient)
 
 /*! \brief product reducer */
 struct product {
@@ -792,13 +758,12 @@ struct nansum {
   }
 };
 
-struct nansum_grad {
+struct nansum_grad : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a, DType b) {
     return isnan_typed::IsNan(a) ? DType(0) : DType(1);
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(nansum_grad)
 
 /*! \brief product reducer that ignores NaN values in the input */
 struct nanprod {
@@ -829,13 +794,13 @@ struct nanprod {
   }
 };
 
-struct nanprod_grad {
+struct nanprod_grad : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a, DType b) {
     return isnan_typed::IsNan(a) ? DType(0) : b / a;
   }
 };
-MSHADOW_OP_DECLARE_TUNABLE_BACKWARD(nanprod_grad)
+
 }  // namespace mshadow_op
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 1d47943082..15ad59f552 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -85,7 +85,7 @@ inline int get_num_threads<gpu>(const int N) {
 
 template<>
 inline int get_num_threads<cpu>(const int N) {
-  return omp_get_max_threads();
+  return engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
 }
 
 /*! \brief operator request type switch */
@@ -291,6 +291,12 @@ struct backward_grad {
   }
 };
 
+/*! \brief Binary op backward gradient OP wrapper (tuned) */
+template<typename GRAD_OP>
+struct backward_grad_tuned : public backward_grad<GRAD_OP>, public tunable {
+  using backward_grad<GRAD_OP>::Map;
+};
+
 /*! \brief Select assignment operation based upon the req value
  * Also useful for mapping mshadow Compute (F<OP>) to Kernel<OP>::Launch
  */
@@ -360,11 +366,10 @@ struct Kernel<OP, cpu> {
    * operator_tune.cc
    * \tparam Args Varargs type to eventually pass to the OP::Map() functoion
    * \param N Number of iterations
-   * \param dest Destination pointer (used to infer DType)
    * \param args Varargs to eventually pass to the OP::Map() functoion
    */
   template<typename ...Args>
-  inline static void Launch(mshadow::Stream<cpu> *, const int N, Args... args) {
+  inline static bool Launch(mshadow::Stream<cpu> *, const int N, Args... args) {
 #ifdef _OPENMP
     const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
     if (omp_threads < 2) {
@@ -382,6 +387,7 @@ struct Kernel<OP, cpu> {
       OP::Map(i, args...);
     }
 #endif
+    return true;
   }
 
   /*!
@@ -441,8 +447,46 @@ struct Kernel<OP, cpu> {
     OP::Map(0, N, args...);
 #endif
   }
+
+  /*!
+   * \brief Launch a tunable OP with implicitly-supplied data type
+   * \tparam DType Data type
+   * \tparam T OP type
+   * \tparam Args Varargs type to eventually pass to the OP::Map() functoion
+   * \param s Stream (usually null for CPU)
+   * \param N Number of iterations
+   * \param args Varargs to eventually pass to the OP::Map() functoion
+   * \return Always true
+   */
+  template<typename DType, typename T = OP, typename ...Args>
+  static MSHADOW_CINLINE
+  typename std::enable_if<std::is_base_of<tunable, T>::value, bool>::type
+  Launch(mshadow::Stream<cpu> *s, const int N, DType *dest, Args... args) {
+    LaunchTuned<T, DType>(s, N, dest, args...);
+    return true;
+  }
+
+  /*!
+   * \brief Launch a tunable OP wrapper with explicitly-supplied data type (ie op_with_req)
+   * \tparam DType Data type
+   * \tparam T Wrapper type
+   * \tparam Args Varargs type to eventually pass to the OP::Map() functoion
+   * \param s Stream (usually null for CPU)
+   * \param N Number of iterations
+   * \param args Varargs to eventually pass to the OP::Map() functoion
+   * \return Always true
+   */
+  template<typename DType, typename T = OP, typename ...Args>
+  static MSHADOW_CINLINE
+  typename std::enable_if<std::is_base_of<tunable, typename T::Operation>::value, bool>::type
+  Launch(mshadow::Stream<cpu> *s, const int N, DType *dest, Args... args) {
+    LaunchTuned<typename T::Operation, DType>(s, N, dest, args...);
+    return true;
+  }
 };
 
+
+
 #ifdef __CUDACC__
 template<typename OP, typename ...Args>
 __global__ void mxnet_generic_kernel(int N, Args... args) {
@@ -481,49 +525,12 @@ struct Kernel<OP, gpu> {
 };
 #endif  // __CUDACC__
 
-/*!
- * \brief Wrap Kernel<OP, xpu>::Launch* with some special-case helpers
- */
-template<typename OP, typename xpu>
-struct KernelWrapper {
-  /*!
-   * \brief Launch 'mshadow_op-type' op (i.e. DType (*)( ... ) { return <operation> }
-   * \tparam Args Varargs type to eventually pass to the OP::Map() function
-   * \param s Stream object pointer (unused)
-   * \param N Number of iterations
-   * \param args Varargs to eventually pass to the OP::Map() functoion
-   */
-  template<typename DType, typename ...Args>
-  MSHADOW_CINLINE static void LaunchMShadowOpEx(mshadow::Stream<xpu> *s,
-                                                const int N,
-                                                DType *dest,
-                                                Args... args) {
-    mxnet::op::mxnet_op::Kernel<OP, xpu>::template LaunchTuned<
-      typename OP::Operation, DType>(s, N, dest, args...);
-  }
-
-  /*!
-   * \brief Launch 'mxnet_op-type' op (i.e. void (*)(int N, DType *out, ... )
-   * \tparam Args Varargs type to eventually pass to the OP::Map() function
-   * \param s Stream object pointer (unused)
-   * \param N Number of iterations
-   * \param args Varargs to eventually pass to the OP::Map() functoion
-   */
-  template<typename DType, typename ...Args>
-  MSHADOW_CINLINE static void LaunchMXNetOpEx(mshadow::Stream<xpu> *s,
-                                              const int N,
-                                              DType *dest,
-                                              Args... args) {
-    mxnet::op::mxnet_op::Kernel<OP, xpu>::template LaunchTuned<OP, DType>(s, N, dest, args...);
-  }
-};
-
 /*!
  * \brief Set to immediate scalar value kernel
  * \tparam val Scalar immediate
  */
 template<int val>
-struct set_to_int {
+struct set_to_int : public tunable {
   // mxnet_op version (when used directly with Kernel<>::Launch()) */
   template<typename DType>
   MSHADOW_XINLINE static void Map(int i, DType *out) {
@@ -540,23 +547,8 @@ struct set_to_int {
  */
 using set_zero = set_to_int<0>;
 using set_one  = set_to_int<1>;
-_MXNET_TUNABLE_MXNET_OP_FWD(set_zero);  // _ prefix denotes "already in mxnet_op namespace"
-_MXNET_TUNABLE_MXNET_OP_FWD(set_one);
 }  // namespace mxnet_op
 
-/*!
- * \brief Tuning specializations for the simple ops in <mshadow/base.h>
- *        Basically, this overrides mxnet::op::mxnet_op::Kernel<OP, cpu>::Launch() and
- *        redirects to mxnet::op::mxnet_op::KernelWrapper<OP, cpu>::Launch????OpEx(),
- *        which eventually leads back to mxnet::op::mxnet_op::Kernel<OP, cpu>::LaunchTuned()
- */
-MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::identity)
-MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::plus)
-MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::minus)
-MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::mul)
-MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::div)
-MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(mshadow::op::right)
-
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index e804c67c07..4686fb8c0d 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -59,7 +59,7 @@ Example::
 
 MXNET_OPERATOR_REGISTER_BINARY(_backward_softmax)
 .set_attr_parser(ParamParser<SoftmaxParam>)
-.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, mshadow::op::mul,
+.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, op::mshadow_op::mul,
                                                         mxnet_op::softmax_bwd>);
 
 MXNET_OPERATOR_REGISTER_UNARY(log_softmax)
diff --git a/src/operator/nn/softmax.cu b/src/operator/nn/softmax.cu
index 4b9c04cdbe..8274642c81 100644
--- a/src/operator/nn/softmax.cu
+++ b/src/operator/nn/softmax.cu
@@ -32,7 +32,7 @@ NNVM_REGISTER_OP(softmax)
 .set_attr<FCompute>("FCompute<gpu>", SoftmaxCompute<gpu, mxnet_op::softmax_fwd>);
 
 NNVM_REGISTER_OP(_backward_softmax)
-.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, mshadow::op::mul,
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, op::mshadow_op::mul,
                                                         mxnet_op::softmax_bwd>);
 
 NNVM_REGISTER_OP(log_softmax)
diff --git a/src/operator/operator_tune-inl.h b/src/operator/operator_tune-inl.h
index d0cf7e7139..4a2025b94a 100644
--- a/src/operator/operator_tune-inl.h
+++ b/src/operator/operator_tune-inl.h
@@ -635,7 +635,7 @@ class UnaryOpTune : public OperatorTune<DType> {
    */
   template<typename OP>
   static void TuneBlankOperator() {
-    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_ = GetBlankWorkload<OP>();
+    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_[0] = GetBlankWorkload<OP>();
     if (Super::output_tuning_data_) {
       std::cout << "IMPLEMENT_UNARY_WORKLOAD_FWD("
                 << Super::template type_name<OP>()
@@ -651,7 +651,7 @@ class UnaryOpTune : public OperatorTune<DType> {
    */
   template<typename OP>
   static void TuneUnaryOperator() {
-    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_ = GetUnaryWorkload<OP>();
+    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_[0] = GetUnaryWorkload<OP>();
     if (Super::output_tuning_data_) {
       std::cout << "IMPLEMENT_UNARY_WORKLOAD_FWD("
                 << Super::template type_name<OP>()
@@ -667,8 +667,8 @@ class UnaryOpTune : public OperatorTune<DType> {
    */
   template<typename OP>
   static void TuneUnaryBackwardOperator() {
-    mxnet::op::mxnet_op::tuned_op<mxnet_op::backward_grad<OP>, DType>::workload_ =
-      GetBinaryWorkload<mxnet::op::mxnet_op::backward_grad<OP>>();
+    mxnet::op::mxnet_op::tuned_op<mxnet_op::backward_grad_tuned<OP>, DType>::workload_[0] =
+      GetBinaryWorkload<mxnet::op::mxnet_op::backward_grad_tuned<OP>>();
     if (Super::output_tuning_data_) {
       std::cout << "IMPLEMENT_UNARY_WORKLOAD_BWD("
                 << Super::template type_name<OP>()
@@ -685,7 +685,7 @@ class UnaryOpTune : public OperatorTune<DType> {
    */
   template<typename OP>
   static void TuneBlankOperatorEx() {
-    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_ = GetBlankWorkloadEx<OP>();
+    mxnet::op::mxnet_op::tuned_op<OP, DType>::workload_[0] = GetBlankWorkloadEx<OP>();
     if (Super::output_tuning_data_) {
       std::cout << "IMPLEMENT_BLANK_WORKLOAD_FWD("
                 << Super::template type_name<OP>()
@@ -696,7 +696,7 @@ class UnaryOpTune : public OperatorTune<DType> {
   /*!
    * \brief Determine whether to use OMP based upon both timing and configuration using the
    *        given (templated) operator's workload
-   * \tparam OP Operator whose workload to use (tuned_op::workload_)
+   * \tparam OP Operator whose workload to use (tuned_op::workload_[0])
    * \param N Number of iterations desired
    * \param thread_count Number of OMP threads available to perform the iterations
    * \returns Whether it's faster to use OMP for these iterations
@@ -705,7 +705,7 @@ class UnaryOpTune : public OperatorTune<DType> {
   inline static bool UseOMP(size_t N, size_t thread_count) {
       return OperatorTune<DType>::UseOMP(N,
                                          thread_count,
-                                         static_cast<uint64_t>(N) * OP::workload_);
+                                         static_cast<uint64_t>(N) * OP::workload_[0]);
   }
 };
 
@@ -725,7 +725,7 @@ class BinaryOpTune : public UnaryOpTune<DType> {
    */
   template<typename OP>
   static void TuneBinaryOperator() {
-    mxnet_op::tuned_op<OP, DType>::workload_ = Super::template GetBinaryWorkload<OP>();
+    mxnet_op::tuned_op<OP, DType>::workload_[0] = Super::template GetBinaryWorkload<OP>();
     if (Super::Super::output_tuning_data_) {
       std::cout << "IMPLEMENT_BINARY_WORKLOAD_FWD("
                 << Super::template type_name<OP>()
@@ -739,8 +739,8 @@ class BinaryOpTune : public UnaryOpTune<DType> {
    */
   template<typename OP>
   static void TuneBinaryBackwardOperator() {
-    mxnet::op::mxnet_op::tuned_op<mxnet_op::backward_grad<OP>, DType>::workload_ =
-      Super::template GetTertiaryWorkload<mxnet::op::mxnet_op::backward_grad<OP>>();
+    mxnet::op::mxnet_op::tuned_op<mxnet_op::backward_grad_tuned<OP>, DType>::workload_[0] =
+      Super::template GetTertiaryWorkload<mxnet::op::mxnet_op::backward_grad_tuned<OP>>();
     if (Super::Super::output_tuning_data_) {
       std::cout << "IMPLEMENT_BINARY_WORKLOAD_BWD("
                 << Super::template type_name<OP>()
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index 525a66b6f8..8263c44eb4 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -1,4 +1,3 @@
-
 /*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
@@ -17,6 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#include <float.h>
 #include <atomic>
 #include "./mxnet_op.h"
 #include "./mshadow_op.h"
@@ -84,27 +84,25 @@ struct static_init_var {
   __macro$(__VA_ARGS__, int32_t); \
   __macro$(__VA_ARGS__, int64_t);
 
-
 #define IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$) \
   namespace mxnet_op { \
-  template<> size_t mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::workload_ = INT_MAX / 4; \
-  template<> std::vector<float> mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::workload_ex_ = {}; \
+  template<> std::vector<float> mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::workload_ = \
+    { static_cast<float>(INT_MAX >> 3) }; \
   }  /* namespace mxnet_op */
-
 /*!
  * \brief Implement tuning objects for a forward blank (no arguments) kernel operator
  */
 #define _IMPLEMENT_BLANK_WORKLOAD_FWD(__op$, __typ$) \
   IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$); \
   namespace mxnet_op { \
-  template<> bool mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
+  template<> bool ::mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
     size_t N, size_t omp_threads) { \
-    return mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
+    return ::mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
       N, omp_threads); \
   }}  /* namespace mxnet_op */ \
   template<> bool static_init_var<__op$, __typ$>::init_ = \
-    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
-      mxnet::op::UnaryOpTune<__typ$>::TuneBlankOperatorEx<__op$>)
+    ::mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      ::mxnet::op::UnaryOpTune<__typ$>::TuneBlankOperatorEx<__op$>)
 
 /*!
  * \brief Implement tuning objects for a forward unary kernel operator
@@ -112,30 +110,30 @@ struct static_init_var {
 #define _IMPLEMENT_UNARY_WORKLOAD_FWD(__op$, __typ$) \
   IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$); \
   namespace mxnet_op { \
-  template<> bool mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
+  template<> bool ::mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
     size_t N, size_t omp_threads) { \
-    return mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
+    return ::mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
       N, omp_threads); \
   }}  /* namespace mxnet_op */ \
   template<> bool static_init_var<__op$, __typ$>::init_ = \
-    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
-      mxnet::op::UnaryOpTune<__typ$>::TuneUnaryOperator<__op$>)
+    ::mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      ::mxnet::op::UnaryOpTune<__typ$>::TuneUnaryOperator<__op$>)
 
 /*!
  * \brief Implement tuning objects for a backward unary kernel operator
  */
 #define _IMPLEMENT_UNARY_WORKLOAD_BWD(__op$, __typ$) \
-  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(mxnet::op::mxnet_op::backward_grad<__op$>, __typ$); \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(::mxnet::op::mxnet_op::backward_grad_tuned<__op$>, __typ$); \
   namespace mxnet_op { \
   template<> \
-  bool mxnet::op::mxnet_op::tuned_op<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::UseOMP( \
-    size_t N, size_t omp_threads) { \
-    return mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op< \
-      mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>>(N, omp_threads); \
+  bool ::mxnet::op::mxnet_op::tuned_op<::mxnet::op::mxnet_op::backward_grad_tuned<__op$>, __typ$>::\
+    UseOMP(size_t N, size_t omp_threads) { \
+    return ::mxnet::op::UnaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op< \
+      ::mxnet::op::mxnet_op::backward_grad_tuned<__op$>, __typ$>>(N, omp_threads); \
   }}  /* namespace mxnet_op */ \
-  template<> bool static_init_var<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::init_ = \
-    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
-      mxnet::op::UnaryOpTune<__typ$>::TuneUnaryBackwardOperator<__op$>)
+  template<> bool static_init_var<::mxnet::op::mxnet_op::backward_grad_tuned<__op$>, __typ$>:: \
+    init_ = ::mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      ::mxnet::op::UnaryOpTune<__typ$>::TuneUnaryBackwardOperator<__op$>)
 
 /*!
  * \brief Implement tuning objects for a forward binary kernel operator
@@ -143,30 +141,32 @@ struct static_init_var {
 #define _IMPLEMENT_BINARY_WORKLOAD_FWD(__op$, __typ$) \
   IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$); \
   namespace mxnet_op { \
-  template<> bool mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
+  template<> bool ::mxnet::op::mxnet_op::tuned_op<__op$, __typ$>::UseOMP( \
     size_t N, size_t omp_threads) { \
-    return mxnet::op::BinaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
+    return ::mxnet::op::BinaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op<__op$, __typ$>>( \
       N, omp_threads); \
   }}  /* namespace mxnet_op */ \
   template<> bool static_init_var<__op$, __typ$>::init_ = \
-    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
-      mxnet::op::BinaryOpTune<__typ$>::TuneBinaryOperator<__op$>)
+    ::mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>( \
+      ::mxnet::op::BinaryOpTune<__typ$>::TuneBinaryOperator<__op$>)
 
 /*!
  * \brief Implement tuning objects for a backward binary kernel operator
  */
 #define _IMPLEMENT_BINARY_WORKLOAD_BWD(__op$, __typ$) \
-  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(mxnet::op::mxnet_op::backward_grad<__op$>, __typ$); \
+  IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(::mxnet::op::mxnet_op::backward_grad_tuned<__op$>, __typ$); \
   namespace mxnet_op { \
   template<> \
-    bool mxnet::op::mxnet_op::tuned_op<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::UseOMP( \
-    size_t N, size_t omp_threads) { \
-    return mxnet::op::BinaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op< \
-      mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>>(N, omp_threads); \
+    bool ::mxnet::op::mxnet_op::tuned_op< \
+      ::mxnet::op::mxnet_op::backward_grad_tuned<__op$>, __typ$>:: \
+      UseOMP(size_t N, size_t omp_threads) { \
+    return ::mxnet::op::BinaryOpTune<__typ$>::UseOMP<mxnet_op::tuned_op< \
+      ::mxnet::op::mxnet_op::backward_grad_tuned<__op$>, __typ$>>(N, omp_threads); \
   }}  /* namespace mxnet_op */ \
-  template<> bool static_init_var<mxnet::op::mxnet_op::backward_grad<__op$>, __typ$>::init_ = \
-    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>(  \
-      mxnet::op::BinaryOpTune<__typ$>::TuneBinaryBackwardOperator<__op$>)
+  template<> bool static_init_var<::mxnet::op::mxnet_op::backward_grad_tuned<__op$>, \
+    __typ$>::init_ = \
+    ::mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$>(  \
+      ::mxnet::op::BinaryOpTune<__typ$>::TuneBinaryBackwardOperator<__op$>)
 
 /*!
  * \brief Implement tuning objects for a custom forward kernel operator
@@ -174,7 +174,7 @@ struct static_init_var {
 #define _IMPLEMENT_CUSTOM_WORKLOAD_FWD(__op$, __typ$) \
   IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$<__typ$>, __typ$); \
   template<> bool static_init_var<__op$<__typ$>, __typ$>::init_ = \
-    mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$<__typ$>>(\
+    ::mxnet::op::OperatorTune<__typ$>::ScheduleTune<__op$<__typ$>>(\
       __op$<__typ$>::Tune)
 
 /*!
@@ -206,7 +206,6 @@ struct static_init_var {
  *       integer value
  */
 OperatorTuneBase::duration_t OperatorTuneBase::omp_overhead_ns_ = 5000;
-IMPLEMENT_UNARY_WORKLOAD_FWD(mshadow::op::identity);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::identity);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::identity_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::negation);  // NOLINT()
@@ -281,12 +280,19 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::degrees);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::degrees_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::radians);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::radians_grad);  // NOLINT()
-IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::plus);  // NOLINT()
-IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::minus);  // NOLINT()
-IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::mul);  // NOLINT()
-IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::div);  // NOLINT()
-IMPLEMENT_BINARY_WORKLOAD_FWD(mshadow::op::right);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::clip);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::clip);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::plus);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mul);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::plus);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::minus);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mul);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rminus);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rdiv);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_grad);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_grad);  // NOLINT()
diff --git a/src/operator/operator_tune.h b/src/operator/operator_tune.h
index 622f0af9ff..6e73ed3711 100644
--- a/src/operator/operator_tune.h
+++ b/src/operator/operator_tune.h
@@ -24,6 +24,36 @@
 #include <vector>
 #include <set>
 #include <atomic>
+#include <string>
+
+// #define MXNET_DEBUG_TUNING_LAUNCH
+
+#ifdef MXNET_DEBUG_TUNING_LAUNCH
+#include <cxxabi.h>
+template<typename T> inline std::string type_name() {
+  const char *name = typeid(T).name();
+  int status = -4;  // some arbitrary value to eliminate the compiler warning
+  std::unique_ptr<char, void (*)(void *)> res {
+    abi::__cxa_demangle(name, nullptr, nullptr, &status),
+    &std::free
+  };
+  if (!status) {
+    return res.get();
+  }
+  return std::move(name);
+}
+#define MXNET_DEBUG_PRINT_UNIQUE_OP(__label$, __op$) \
+  { \
+    static std::mutex cs; \
+    static std::unordered_set<std::string> ops; \
+    const std::string name = type_name<__op$>(); \
+    if (ops.emplace(name).second) { \
+      std::cout << (__label$) << ": " << name << std::endl << std::flush; \
+    } \
+  }
+#else
+#define MXNET_DEBUG_PRINT_UNIQUE_OP(__label$, __op$) /* */
+#endif
 
 namespace mxnet {
 namespace op {
@@ -191,20 +221,15 @@ namespace mxnet_op {
  */
 template<typename Operation, typename DType>
 struct tuned_op : public Operation {
-  /*! \brief nanoseconds to perform WORKLOAD_COUNT operations
-   *  \note It is conceivable that a vector of values could be used for more complex tuning,
-   *        but the need hasn't yet arisen
+  /*! \brief Runtime workload calculation values. Generally, nanoseconds to perform WORKLOAD_COUNT
+   *        operations (for unary and binary ops), although they can be anything if the UseOMP()
+   *        function is written elsewhere for that op (other than in operator_tune-inl.h)
    *  \remarks This variable generally needs to be implemented somewhere.  Currently this is mostly
    *           done via macros in operator_tune.cc.  If you get undefined reference errors when
    *           linking, then try to use one of the macros in that file to instantiate the required
    *           data/functions
    */
-  static size_t workload_;
-
-  /*!
-   * \brief Extra workload-calculating information (ie times for sub-portions of the calculation)
-   */
-  static std::vector<float> workload_ex_;
+  static std::vector<float> workload_;
 
   /*!
    * \brief Calls parent class (Operation)'s UseOMP
@@ -231,7 +256,6 @@ struct tuned_op : public Operation {
    */
   static bool UseOMP(size_t N, size_t thread_count);
 };
-}  // namespace mxnet_op
 
 /*!
  * \brief Calculate workload for a given lambda function
@@ -253,78 +277,9 @@ inline int64_t get_workload(Function function) {
   return *++durations.begin();  // return median value
 }
 
-/*!
- * \brief Declare a template specialization for the Kernel::Launch call for the given OP
- *        wrapped with mxnet_op::op_with_req, using the given OpReqType as the 'req'
- *        template parameter for 'op_with_req'.  This is useful for the standard mshadow_op
- *        operators which need to be wrapped with op_with_req in order to be used with the
- *        Kernel::Launch command.
- *
- * \note Expects to be used within the mxnet::op namespace
- *
- * For example:
- *
- * namespace mxnet_op {
- * template <>
- * template <typename... Args>
- * inline void Kernel<typename mxnet_op::op_with_req<mshadow::op::identity, kNullOp>, cpu>
- *   ::Launch(mshadow::Stream<cpu>* s, const int N, Args... args) {
- *   ::mxnet::op::mxnet_op::Kernel<typename mxnet_op::op_with_req<mshadow::op::identity, kNullOp>,
- *     cpu>::LaunchMShadowOpEx(s, N, args...);
- *   }
- * }
- *
- */
-#define MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, __req$) \
-  namespace mxnet_op { \
-  template<> template<typename ...Args> \
-  inline void Kernel<typename mxnet_op::op_with_req<__op$, __req$>, ::mshadow::cpu>:: \
-    Launch(mshadow::Stream<::mshadow::cpu> *s, const int N, Args... args) { \
-      /* Launch via LaunchMShadowOpEx() */ \
-      KernelWrapper<typename mxnet_op::op_with_req<__op$, __req$>, ::mshadow::cpu>:: \
-        LaunchMShadowOpEx(s, N, args...); \
-  } \
-  }  /* namespace mxnet_op */
-
-/*!
- * \brief Declare template specializations for the Kernel::Launch call for the given OP
- *        wrapped with mxnet_op::op_with_req, using the all supported OpReqType as the 'req'
- *        template parameter for 'op_with_req'.  This is useful for the standard mshadow_op
- *        operators which need to be wrapped with op_with_req in order to be used with the
- *        Kernel::Launch command.
- * \note Expects to be used within the mxnet::op namespace
- */
-#define MXNET_TUNABLE_MSHADOW_OP(__op$) \
-  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kNullOp); \
-  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kWriteTo); \
-  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kWriteInplace); \
-  MXNET_TUNABLE_MSHADOW_OP_WITH_REQ(__op$, kAddTo);
-
-#define MXNET_TUNABLE_MSHADOW_OP_BACKWARD(__op$) \
-  MXNET_TUNABLE_MSHADOW_OP(mxnet::op::mxnet_op::backward_grad<__op$>)
-
-#define MXNET_TUNABLE_MSHADOW_OP_FWD_AND_BWD(__op$) \
-  MXNET_TUNABLE_MSHADOW_OP(__op$) \
-  MXNET_TUNABLE_MSHADOW_OP_BACKWARD(__op$)
-
-/*!
- * \brief mxnet::op::mxnet_op format ops (work directly with Kernel<>::Launch()
- *        Used from within mxnet::op::mxnet_op namespace
- */
-#define _MXNET_TUNABLE_MXNET_OP_FWD(__op$) \
-  template<> template<typename ...Args> inline void Kernel<__op$, ::mshadow::cpu>::Launch( \
-    mshadow::Stream<::mshadow::cpu> *s, const int N, Args... args) { \
-      /* Launch via LaunchMXNetOpEx() */ \
-      KernelWrapper<__op$, ::mshadow::cpu>::LaunchMXNetOpEx(s, N, args...); \
-  }
-
-/*!
- * \brief mxnet::op::mxnet_op format ops (work directly with Kernel<>::Launch()
- *        Used from within mxnet::op
- */
-#define MXNET_TUNABLE_MXNET_OP_FWD(__op$) \
-  namespace mxnet_op { _MXNET_TUNABLE_MXNET_OP_FWD(__op$) }  /* namespace mxnet_op */
+struct tunable {};
 
+}  // namespace mxnet_op
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/regression_output.cc b/src/operator/regression_output.cc
index 0c70a86b26..2f8042e9e8 100644
--- a/src/operator/regression_output.cc
+++ b/src/operator/regression_output.cc
@@ -33,11 +33,11 @@ Operator *CreateRegressionOutputOp<cpu>(reg_enum::RegressionOutputType type,
                                         RegressionOutputParam param) {
   switch (type) {
     case reg_enum::kLinear:
-      return new RegressionOutputOp<cpu, mshadow::op::identity, mshadow::op::minus>(param);
+      return new RegressionOutputOp<cpu, op::mshadow_op::identity, op::mshadow_op::minus>(param);
     case reg_enum::kLogistic:
-      return new RegressionOutputOp<cpu, mshadow_op::sigmoid, mshadow::op::minus>(param);
+      return new RegressionOutputOp<cpu, mshadow_op::sigmoid, op::mshadow_op::minus>(param);
     case reg_enum::kMAE:
-      return new RegressionOutputOp<cpu, mshadow::op::identity, mshadow_op::minus_sign>(param);
+      return new RegressionOutputOp<cpu, op::mshadow_op::identity, mshadow_op::minus_sign>(param);
     default:
       LOG(FATAL) << "unknown activation type " << type;
   }
diff --git a/src/operator/regression_output.cu b/src/operator/regression_output.cu
index 255b020d20..cb951f1fd2 100644
--- a/src/operator/regression_output.cu
+++ b/src/operator/regression_output.cu
@@ -33,11 +33,11 @@ Operator *CreateRegressionOutputOp<gpu>(reg_enum::RegressionOutputType type,
                                         RegressionOutputParam param) {
   switch (type) {
     case reg_enum::kLinear:
-      return new RegressionOutputOp<gpu, mshadow::op::identity, mshadow::op::minus>(param);
+      return new RegressionOutputOp<gpu, op::mshadow_op::identity, op::mshadow_op::minus>(param);
     case reg_enum::kLogistic:
-      return new RegressionOutputOp<gpu, mshadow_op::sigmoid, mshadow::op::minus>(param);
+      return new RegressionOutputOp<gpu, mshadow_op::sigmoid, op::mshadow_op::minus>(param);
     case reg_enum::kMAE:
-      return new RegressionOutputOp<gpu, mshadow::op::identity, mshadow_op::minus_sign>(param);
+      return new RegressionOutputOp<gpu, op::mshadow_op::identity, mshadow_op::minus_sign>(param);
     default:
       LOG(FATAL) << "unknown activation type " << type;
   }
diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index 79f7c39c87..6e92c8ac91 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -428,7 +428,7 @@ void ReduceAxesComputeImpl(const nnvm::NodeAttrs& attrs,
           s, out_data, req[0], in_data);
       Tensor<xpu, 1, char> workspace =
           ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
-      broadcast::Reduce<reducer, NDim, DType, mshadow::op::identity>(
+      broadcast::Reduce<reducer, NDim, DType, op::mshadow_op::identity>(
           s, out_data, req[0], workspace, in_data);
       if (normalize) {
         auto out = out_data.FlatTo2D<xpu, DType>(s);
@@ -635,7 +635,7 @@ void SumCsrImpl(const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu>* s, const OpC
                 seg_len);
             if (normalize) {
               mxnet_op::Kernel<
-                  mxnet_op::op_with_req<mshadow::op::div, req_type>,
+                  mxnet_op::op_with_req<op::mshadow_op::div, req_type>,
                   xpu>::Launch(s, out_data_size, output->data().dptr<DType>(),
                                output->data().dptr<DType>(), DType(num_rows));
             }
@@ -656,7 +656,7 @@ void SumCsrImpl(const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu>* s, const OpC
                 in_data);
             if (normalize) {
               mxnet_op::Kernel<
-                  mxnet_op::op_with_req<mshadow::op::div, req_type>,
+                  mxnet_op::op_with_req<op::mshadow_op::div, req_type>,
                   xpu>::Launch(s, out_data_size, output->data().dptr<DType>(),
                                output->data().dptr<DType>(), DType(num_cols));
             }
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 2317c98285..af5f5ce3af 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -137,22 +137,24 @@ inline int BinaryBroadcastShapeCompact(const TShape& lshape, const TShape& rshap
 namespace mxnet_op {
 template<int ndim, typename DType, typename OP>
 struct binary_broadcast_kernel {
+  /*! \brief Map function for binary_broadcast_kernel */
   MSHADOW_XINLINE static void Map(int base, int length, OpReqType req,
-                                  const Shape<ndim>& lstride, const Shape<ndim>& rstride,
-                                  const Shape<ndim>& oshape, DType* lhs, DType* rhs,
-                                  DType* out, int lsize, int rsize) {
-      Shape <ndim> coord = unravel(base, oshape);
-    index_t lidx = dot(coord, lstride);
-    index_t ridx = dot(coord, rstride);
-      KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx]));
-      // starts from 1 to avoid extra inc at end of loop
-      for (int i = 1; i < length; ++i) {
-        inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
-        KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx]));
-      }
+                                  const Shape <ndim> &lstride, const Shape <ndim> &rstride,
+                                  const Shape <ndim> &oshape, DType *lhs, DType *rhs,
+                                  DType *out) {
+    Shape <ndim> coord = unravel(base, oshape);
+    auto lidx = static_cast<index_t>(dot(coord, lstride));
+    auto ridx = static_cast<index_t>(dot(coord, rstride));
+    KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx]));
+    // starts from 1 to avoid extra inc at end of loop
+    for (int i = 1; i < length; ++i) {
+      inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
+      // When tuning, don't actually run the op, since it's not going to be tuned against
+      // the actual op we'll eventually be using
+      KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx]));
     }
+  }
 };
-
 }  // namespace mxnet_op
 
 template<typename xpu, typename OP>
@@ -161,25 +163,25 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
                             const std::vector<TBlob>& inputs,
                             const std::vector<OpReqType>& req,
                             const std::vector<TBlob>& outputs) {
-  using namespace mxnet_op;
   TShape new_lshape, new_rshape, new_oshape;
   int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_,
                                          &new_lshape, &new_rshape, &new_oshape);
   if (!ndim) {
     ElemwiseBinaryOp::Compute<xpu, OP>(attrs, ctx, inputs, req, outputs);
   } else {
-    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-      BROADCAST_NDIM_SWITCH(ndim, NDim, {
-        Shape<NDim> oshape = new_oshape.get<NDim>();
-        Shape<NDim> lstride = calc_stride(new_lshape.get<NDim>());
-        Shape<NDim> rstride = calc_stride(new_rshape.get<NDim>());
-        Kernel<binary_broadcast_kernel<NDim, DType, OP>, xpu>::LaunchEx(
-            s, new_oshape.Size(), req[0], lstride, rstride, oshape,
-            inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), outputs[0].dptr<DType>(),
-            inputs[0].Size(), inputs[1].Size());
+    if (req[0] != kNullOp) {
+      mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+      MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+        BROADCAST_NDIM_SWITCH(ndim, NDim, {
+          mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
+          mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
+          mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
+          mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType, OP>, xpu>::
+          template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape,
+          inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), outputs[0].dptr<DType>());
+        });
       });
-    });
+    }
   }
 }
 
@@ -237,9 +239,9 @@ inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx,
   size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
   Tensor<xpu, 1, char> workspace =
     ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
-  Reduce<red::sum, ndim, DType, mshadow::op::mul, LOP>(s, lgrad, req[0], workspace,
+  Reduce<red::sum, ndim, DType, op::mshadow_op::mul, LOP>(s, lgrad, req[0], workspace,
     ograd, lhs, rhs);
-  Reduce<red::sum, ndim, DType, mshadow::op::mul, ROP>(s, rgrad, req[1], workspace,
+  Reduce<red::sum, ndim, DType, op::mshadow_op::mul, ROP>(s, rgrad, req[1], workspace,
     ograd, lhs, rhs);
 }
 
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
index 04281087f0..634e90557e 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
@@ -49,7 +49,7 @@ Example::
                            [ 2.,  2.,  2.]]
 
 )code" ADD_FILELINE)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow::op::plus>)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::plus>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"});
 
 NNVM_REGISTER_OP(_backward_broadcast_add)
@@ -88,7 +88,7 @@ Example::
                             [ 0.,  0.,  0.]]
 
 )code" ADD_FILELINE)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow::op::minus>)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});
 
 NNVM_REGISTER_OP(_backward_broadcast_sub)
@@ -121,7 +121,7 @@ Example::
                           [ 1.,  1.,  1.]]
 
 )code" ADD_FILELINE)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow::op::mul>)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
 
 
@@ -155,7 +155,7 @@ Example::
                           [ 2.,  2.,  2.]]
 
 )code" ADD_FILELINE)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow::op::div>)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::div>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"});
 
 NNVM_REGISTER_OP(_backward_broadcast_div)
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
index dd3c1b2e12..dc0ba021f5 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
@@ -29,28 +29,28 @@
 namespace mxnet {
 namespace op {
 NNVM_REGISTER_OP(broadcast_add)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow::op::plus>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::plus>);
 
 NNVM_REGISTER_OP(_backward_broadcast_add)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseNone<gpu, mshadow_op::identity,
                                                                 mshadow_op::identity>);
 
 NNVM_REGISTER_OP(broadcast_sub)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow::op::minus>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>);
 
 NNVM_REGISTER_OP(_backward_broadcast_sub)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseNone<gpu, mshadow_op::identity,
                                                                 mshadow_op::negation>);
 
 NNVM_REGISTER_OP(broadcast_mul)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow::op::mul>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::mul>);
 
 NNVM_REGISTER_OP(_backward_broadcast_mul)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::right,
                                                                 mshadow_op::left>);
 
 NNVM_REGISTER_OP(broadcast_div)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow::op::div>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::div>);
 
 NNVM_REGISTER_OP(_backward_broadcast_div)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::div_grad,
diff --git a/src/operator/tensor/elemwise_binary_op-inl.h b/src/operator/tensor/elemwise_binary_op-inl.h
index 8f8bcddea6..5cd3314531 100644
--- a/src/operator/tensor/elemwise_binary_op-inl.h
+++ b/src/operator/tensor/elemwise_binary_op-inl.h
@@ -164,7 +164,7 @@ void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<cpu> *s,
       Tensor<cpu, 1, DType> rvalue = !rhs_is_dense ? data_r[iter_r++] : data_r[idx_r];
       DCHECK_EQ(lvalue.shape_.Size(), rvalue.shape_.Size());
       MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-        SerialLaunchCPU<mxnet_op::op_with_req<OP, Req>>(
+        mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, cpu>::Launch(
           s, lvalue.shape_.Size(), out[iter_out].dptr_, lvalue.dptr_, rvalue.dptr_);
       });
       num_common_rows++;
@@ -175,7 +175,7 @@ void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<cpu> *s,
       }
       Tensor<cpu, 1, DType> lvalue = !lhs_is_dense ? data_l[iter_l++] : data_l[idx_l];
       MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-        SerialLaunchCPU<MissingRValueOp<OP, Req>>(
+        mxnet_op::Kernel<MissingRValueOp<OP, Req>, cpu>::Launch(
           s, lvalue.shape_.Size(), out[iter_out].dptr_, lvalue.dptr_);
       });
     } else {
@@ -189,7 +189,7 @@ void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<cpu> *s,
       }
       Tensor<cpu, 1, DType> rvalue = !rhs_is_dense ? data_r[iter_r++] : data_r[idx_r];
       MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-        SerialLaunchCPU<MissingLValueOp<OP, Req>>(
+        mxnet_op::Kernel<MissingLValueOp<OP, Req>, cpu>::Launch(
           s, rvalue.shape_.Size(), out[iter_out].dptr_, rvalue.dptr_);
       });
     }
@@ -205,7 +205,7 @@ void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<cpu> *s,
     }
     Tensor<cpu, 1, DType> lvalue = data_l[iter_l++];
     MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-      SerialLaunchCPU<MissingRValueOp<OP, Req>>(
+      mxnet_op::Kernel<MissingRValueOp<OP, Req>, cpu>::Launch(
         s, lvalue.shape_.Size(), out[iter_out++].dptr_, lvalue.dptr_);
     });
   }
@@ -218,7 +218,7 @@ void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<cpu> *s,
     }
     Tensor<cpu, 1, DType> rvalue = data_r[iter_r++];
     MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-      SerialLaunchCPU<MissingLValueOp<OP, Req>>(
+      mxnet_op::Kernel<MissingLValueOp<OP, Req>, cpu>::Launch(
         s, rvalue.shape_.Size(), out[iter_out++].dptr_, rvalue.dptr_);
     });
   }
diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h
index d54636c055..6fc4107571 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -48,6 +48,7 @@ class ElemwiseBinaryOp : public OpBase {
   /*! \brief For sparse, assume missing rvalue is 0 */
   template<typename OP, int Req>
   struct MissingRValueOp {
+    typedef OP Operation;
     template<typename DType>
     MSHADOW_XINLINE static void Map(int i, DType *out, const DType *lhs) {
       KERNEL_ASSIGN(out[i], Req, OP::Map(lhs[i], DType(0)));
@@ -57,6 +58,7 @@ class ElemwiseBinaryOp : public OpBase {
   /*! \brief For sparse, assume missing lvalue is 0 */
   template<typename OP, int Req>
   struct MissingLValueOp {
+    typedef OP Operation;
     template<typename DType>
     MSHADOW_XINLINE static void Map(int i, DType *out, const DType *rhs) {
       KERNEL_ASSIGN(out[i], Req, OP::Map(DType(0), rhs[i]));
@@ -148,14 +150,14 @@ class ElemwiseBinaryOp : public OpBase {
         (outputs[0].Size() + mxnet_op::DataType<DType>::kLanes - 1)
         / mxnet_op::DataType<DType>::kLanes);
       DType * lgrad_dptr = outputs[0].dptr<DType>();
-      mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad<LOP>, Req>, xpu>::Launch(
+      mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad_tuned<LOP>, Req>, xpu>::Launch(
         s, size, lgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
     MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
       const int size = static_cast<int>(
         (outputs[1].Size() + mxnet_op::DataType<DType>::kLanes - 1)
         / mxnet_op::DataType<DType>::kLanes);
       DType * rgrad_dptr = outputs[1].dptr<DType>();
-      mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad<ROP>, Req>, xpu>::Launch(
+      mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad_tuned<ROP>, Req>, xpu>::Launch(
         s, size, rgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
   }
 
@@ -185,7 +187,7 @@ class ElemwiseBinaryOp : public OpBase {
       });
       // lhs in-place
       MSHADOW_IDX_TYPE_SWITCH(inputs[0].aux_type(rowsparse::kIdx), IType, {
-        RspRspOp<DType, IType, mshadow::op::mul>(
+        RspRspOp<DType, IType, op::mshadow_op::mul>(
           s, attrs, ctx, outputs[0], inputs[0], req[0], outputs[0],
           false, false, true, false);
       });
@@ -199,7 +201,7 @@ class ElemwiseBinaryOp : public OpBase {
       });
       // rhs in-place
       MSHADOW_IDX_TYPE_SWITCH(inputs[0].aux_type(rowsparse::kIdx), IType, {
-        RspRspOp<DType, IType, mshadow::op::mul>(
+        RspRspOp<DType, IType, op::mshadow_op::mul>(
           s, attrs, ctx, inputs[0], outputs[1], req[1], outputs[1],
           false, false, true, false);
       });
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc
index 10e7fac5e9..d7e5e04ce8 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_op_basic.cc
@@ -28,7 +28,7 @@
 namespace mxnet {
 namespace op {
 
-MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(elemwise_add, mshadow::op::plus)
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(elemwise_add, op::mshadow_op::plus)
 MXNET_ADD_SPARSE_OP_ALIAS(elemwise_add)
 .add_alias("_add").add_alias("_plus").add_alias("_Plus")
 .describe(R"code(Adds arguments element-wise.
@@ -44,7 +44,7 @@ The storage type of ``elemwise_add`` output depends on storage types of inputs
 
 // specialized gradient add function to do add to optimization
 // this must differ from elemwise_add to prevent add to optimization in forward pass.
-MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_grad_add, mshadow::op::plus);
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_grad_add, op::mshadow_op::plus);
 
 NNVM_REGISTER_OP(_backward_add)
 .set_num_inputs(1)
@@ -63,7 +63,7 @@ NNVM_REGISTER_OP(_backward_add)
 .set_attr<FInferStorageType>("FInferStorageType",
                              ElemwiseStorageType<1, 2, true, true, true>);
 
-MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(elemwise_sub, mshadow::op::minus)
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(elemwise_sub, op::mshadow_op::minus)
 MXNET_ADD_SPARSE_OP_ALIAS(elemwise_sub)
 .add_alias("_sub").add_alias("_minus").add_alias("_Minus")
 .describe(R"code(Subtracts arguments element-wise.
@@ -110,9 +110,9 @@ The storage type of ``elemwise_mul`` output depends on storage types of inputs
 .set_attr<FInferStorageType>("FInferStorageType",
                              ElemwiseBinaryOp::AllowLRDenseInputWithSparseOutputStorageType<
                                false, false>)  // 0 * nan or nan * 0 -> nan, so rsp * dns -> dns
-.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, mshadow::op::mul>)
+.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, op::mshadow_op::mul>)
 .set_attr<FComputeEx>("FComputeEx<cpu>",
-                      ElemwiseBinaryOp::ComputeDnsLRValueEx<cpu, mshadow::op::mul, true, true>)
+                      ElemwiseBinaryOp::ComputeDnsLRValueEx<cpu, op::mshadow_op::mul, true, true>)
 .set_attr<FResourceRequest>("FResourceRequest",  /* For Sparse CSR */
                               [](const NodeAttrs& attrs) {
                                 return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
@@ -138,7 +138,7 @@ NNVM_REGISTER_OP(_backward_mul)
 .set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseBinaryOp::BackwardUseInEx<
   cpu, mshadow_op::right, mshadow_op::left>);
 
-MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(elemwise_div, mshadow::op::div)
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(elemwise_div, op::mshadow_op::div)
 MXNET_ADD_SPARSE_OP_ALIAS(elemwise_div)
 .describe(R"code(Divides arguments element-wise.
 
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu
index 9b55e2fd76..c8e208e924 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_op_basic.cu
@@ -27,10 +27,10 @@
 namespace mxnet {
 namespace op {
 NNVM_REGISTER_OP(elemwise_add)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, mshadow::op::plus>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::plus>);
 
 NNVM_REGISTER_OP(_grad_add)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, mshadow::op::plus>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::plus>);
 
 NNVM_REGISTER_OP(_backward_add)
 .set_attr<FCompute>("FCompute<gpu>",
@@ -38,7 +38,8 @@ NNVM_REGISTER_OP(_backward_add)
                     mshadow_op::identity>);
 
 NNVM_REGISTER_OP(elemwise_sub)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, mshadow::op::minus>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<
+  gpu, op::mshadow_op::minus>);
 
 NNVM_REGISTER_OP(_backward_sub)
 .set_attr<FCompute>("FCompute<gpu>",
@@ -46,7 +47,7 @@ NNVM_REGISTER_OP(_backward_sub)
                     mshadow_op::negation>);
 
 NNVM_REGISTER_OP(elemwise_mul)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, mshadow::op::mul>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::mul>);
 
 NNVM_REGISTER_OP(_backward_mul)
 .set_attr<FCompute>("FCompute<gpu>",
@@ -55,7 +56,7 @@ NNVM_REGISTER_OP(_backward_mul)
 
 NNVM_REGISTER_OP(elemwise_div)
 .set_attr<FCompute>("FCompute<gpu>",
-                    ElemwiseBinaryOp::ElemwiseBinaryOp::ComputeWithHalf2<gpu, mshadow::op::div>);
+                    ElemwiseBinaryOp::ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::div>);
 
 NNVM_REGISTER_OP(_backward_div)
 .set_attr<FCompute>("FCompute<gpu>",
diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h
index cdf14055cf..0419e9938a 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op.h
+++ b/src/operator/tensor/elemwise_binary_scalar_op.h
@@ -286,7 +286,7 @@ class BinaryScalarOp : public UnaryOp {
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
         mxnet::op::mxnet_op::Kernel<mxnet::op::mxnet_op::op_with_req<
-          mxnet::op::mxnet_op::backward_grad<OP>, Req>, xpu>::
+          mxnet::op::mxnet_op::backward_grad_tuned<OP>, Req>, xpu>::
           Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
                  inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
                  DType(alpha));
diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc
index 2d6662ef2b..9a278d8c97 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc
@@ -114,14 +114,14 @@ static bool BinaryScalarStorageType(const nnvm::NodeAttrs& attrs,
 }
 
 MXNET_OPERATOR_REGISTER_BINARY_WITH_SCALAR_SUPPORT_WITH_DENSE_RESULT(_plus_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow::op::plus>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryScalarOp::ComputeEx<cpu, mshadow::op::plus>)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::plus>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryScalarOp::ComputeEx<cpu, op::mshadow_op::plus>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"})
 .add_alias("_PlusScalar");
 
 MXNET_OPERATOR_REGISTER_BINARY_WITH_SCALAR_SUPPORT_WITH_DENSE_RESULT(_minus_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow::op::minus>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryScalarOp::ComputeEx<cpu, mshadow::op::minus>)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::minus>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryScalarOp::ComputeEx<cpu, op::mshadow_op::minus>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"})
 .add_alias("_MinusScalar");
 
@@ -141,16 +141,16 @@ it will result output.data = [nan, nan] instead of 10000 nans.
 
 )doc" ADD_FILELINE)
 .set_attr<FInferStorageType>("FInferStorageType", BinaryScalarStorageType)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow::op::mul>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryScalarOp::ComputeEx<cpu, mshadow::op::mul>)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::mul>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryScalarOp::ComputeEx<cpu, op::mshadow_op::mul>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_mul_scalar"})
 .add_alias("_MulScalar");
 
 MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_backward_mul_scalar)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FInferStorageType>("FInferStorageType", BinaryScalarStorageType)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow::op::mul>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryScalarOp::ComputeEx<cpu, mshadow::op::mul>);
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::mul>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryScalarOp::ComputeEx<cpu, op::mshadow_op::mul>);
 
 MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_div_scalar)
 .describe(R"doc(Divide an array with a scalar.
@@ -163,16 +163,16 @@ it will result output.data = [nan, nan] instead of 10000 nans.
 
 )doc" ADD_FILELINE)
 .set_attr<FInferStorageType>("FInferStorageType", BinaryScalarStorageType)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow::op::div>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryScalarOp::ComputeEx<cpu, mshadow::op::div>)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::div>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryScalarOp::ComputeEx<cpu, op::mshadow_op::div>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_div_scalar"})
 .add_alias("_DivScalar");
 
 MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_backward_div_scalar)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FInferStorageType>("FInferStorageType", BinaryScalarStorageType)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow::op::div>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryScalarOp::ComputeEx<cpu, mshadow::op::div>);
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::div>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryScalarOp::ComputeEx<cpu, op::mshadow_op::div>);
 
 
 MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rdiv_scalar)
diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cu b/src/operator/tensor/elemwise_binary_scalar_op_basic.cu
index 21be0a0e12..51b3866c8e 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cu
@@ -29,29 +29,29 @@
 namespace mxnet {
 namespace op {
 NNVM_REGISTER_OP(_plus_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow::op::plus>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::plus>);
 
 NNVM_REGISTER_OP(_minus_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow::op::minus>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::minus>);
 
 NNVM_REGISTER_OP(_rminus_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rminus>);
 
 NNVM_REGISTER_OP(_mul_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow::op::mul>)
-.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::ComputeEx<gpu, mshadow::op::mul>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::mul>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::ComputeEx<gpu, op::mshadow_op::mul>);
 
 NNVM_REGISTER_OP(_backward_mul_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow::op::mul>)
-.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::ComputeEx<gpu, mshadow::op::mul>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::mul>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::ComputeEx<gpu, op::mshadow_op::mul>);
 
 NNVM_REGISTER_OP(_div_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow::op::div>)
-.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::ComputeEx<gpu, mshadow::op::div>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::div>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::ComputeEx<gpu, op::mshadow_op::div>);
 
 NNVM_REGISTER_OP(_backward_div_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow::op::div>)
-.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::ComputeEx<gpu, mshadow::op::div>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::div>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::ComputeEx<gpu, op::mshadow_op::div>);
 
 NNVM_REGISTER_OP(_rdiv_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rdiv>);
diff --git a/src/operator/tensor/elemwise_scatter_op.cc b/src/operator/tensor/elemwise_scatter_op.cc
index ec5df9b4e1..2f0883d9de 100644
--- a/src/operator/tensor/elemwise_scatter_op.cc
+++ b/src/operator/tensor/elemwise_scatter_op.cc
@@ -79,8 +79,9 @@ static bool StorageTypeScatteredScalarOp(const NodeAttrs& attrs,
 
 /*! \brief _scatter_elemwise_div */
 MXNET_OPERATOR_REGISTER_BINARY(_scatter_elemwise_div)
-.set_attr<FCompute>("FCompute<cpu>", ElemwiseScatterBinaryOp::Compute<cpu, mshadow::op::div>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseScatterBinaryOp::ComputeEx<cpu, mshadow::op::div>)
+.set_attr<FCompute>("FCompute<cpu>", ElemwiseScatterBinaryOp::Compute<cpu, op::mshadow_op::div>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseScatterBinaryOp::ComputeEx<
+  cpu, op::mshadow_op::div>)
 .describe(R"code(Divides arguments element-wise.  If the left-hand-side input is 'row_sparse', then
 only the values which exist in the left-hand sparse array are computed.  The 'missing' values
 are ignored.
@@ -117,9 +118,9 @@ with default storage
 )code")
 .set_attr<FInferStorageType>("FInferStorageType", StorageTypeScatteredScalarOp)
 .set_attr<FCompute>("FCompute<cpu>",
-                    ElemwiseScatterBinaryScalarOp::Compute<cpu, mshadow::op::plus>)
+                    ElemwiseScatterBinaryScalarOp::Compute<cpu, op::mshadow_op::plus>)
 .set_attr<FComputeEx>("FComputeEx<cpu>",
-                      ElemwiseScatterBinaryScalarOp::ComputeEx<cpu, mshadow::op::plus>)
+                      ElemwiseScatterBinaryScalarOp::ComputeEx<cpu, op::mshadow_op::plus>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
 
 /*! \brief _scatter_minus_scalar */
@@ -138,9 +139,9 @@ with default storage
 )code")
 .set_attr<FInferStorageType>("FInferStorageType", StorageTypeScatteredScalarOp)
 .set_attr<FCompute>("FCompute<cpu>",
-                    ElemwiseScatterBinaryScalarOp::Compute<cpu, mshadow::op::minus>)
+                    ElemwiseScatterBinaryScalarOp::Compute<cpu, op::mshadow_op::minus>)
 .set_attr<FComputeEx>("FComputeEx<cpu>",
-                      ElemwiseScatterBinaryScalarOp::ComputeEx<cpu, mshadow::op::minus>)
+                      ElemwiseScatterBinaryScalarOp::ComputeEx<cpu, op::mshadow_op::minus>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
 
 }  // namespace op
diff --git a/src/operator/tensor/elemwise_scatter_op.cu b/src/operator/tensor/elemwise_scatter_op.cu
index 28c8df3de6..acb7edf21c 100644
--- a/src/operator/tensor/elemwise_scatter_op.cu
+++ b/src/operator/tensor/elemwise_scatter_op.cu
@@ -23,18 +23,19 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_scatter_elemwise_div)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseScatterBinaryOp::Compute<gpu, mshadow::op::div>)
-.set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseScatterBinaryOp::ComputeEx<gpu, mshadow::op::div>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseScatterBinaryOp::Compute<gpu, op::mshadow_op::div>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseScatterBinaryOp::ComputeEx<gpu,
+  op::mshadow_op::div>);
 
 NNVM_REGISTER_OP(_scatter_plus_scalar)
 .set_attr<FCompute>("FCompute<gpu>",
-                    ElemwiseScatterBinaryScalarOp::Compute<gpu, mshadow::op::plus>)
+                    ElemwiseScatterBinaryScalarOp::Compute<gpu, op::mshadow_op::plus>)
 .set_attr<FComputeEx>("FComputeEx<gpu>",
-                      ElemwiseScatterBinaryScalarOp::ComputeEx<gpu, mshadow::op::plus>);
+                      ElemwiseScatterBinaryScalarOp::ComputeEx<gpu, op::mshadow_op::plus>);
 
 NNVM_REGISTER_OP(_scatter_minus_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow::op::minus>)
-.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::ComputeEx<gpu, mshadow::op::minus>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::minus>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::ComputeEx<gpu, op::mshadow_op::minus>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h
index 82ecf4f5ad..3472d87e54 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -39,21 +39,6 @@ namespace op {
 
 class OpBase {
  protected:
-  /*!
-   * \brief Launch CPU-only kernel without OMP (temporary solution until OMP-tuned kernels arrive)
-   * \tparam OP Kernel operation type
-   * \tparam Args Argument types to be passed to kernel
-   * \param s CPU stream
-   * \param N Number of iterations
-   * \param args Arguments to be passed to kernel
-   */
-  template <typename OP, typename ...Args>
-  static inline void SerialLaunchCPU(mshadow::Stream<cpu> *s, const int N, Args... args) {
-    for (int i = 0; i < N; ++i) {
-      OP::Map(i, args...);
-    }
-  }
-
   /*! \brief simple kernel to set to a scalar value of arbitrary type */
   template<int req>
   using set_to_scalar = mxnet_op::op_with_req<mshadow_op::identity, req>;
@@ -172,7 +157,7 @@ class OpBase {
                                const OpReqType req,
                                DType *out) {
     MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-      SerialLaunchCPU<OpBase::set_to_scalar<Req>>(s, size, out, val);
+      mxnet_op::Kernel<OpBase::set_to_scalar<Req>, cpu>::Launch(s, size, out, val);
     });
   }
 };  // OpBase
@@ -360,7 +345,7 @@ class UnaryOp : public OpBase {
 
 /*! \brief Map legacy unary_bwd to backward_grad */
 template<typename GRAD_OP>
-using unary_bwd = ::mxnet::op::mxnet_op::backward_grad<GRAD_OP>;
+using unary_bwd = ::mxnet::op::mxnet_op::backward_grad_tuned<GRAD_OP>;
 
 struct CastParam : public dmlc::Parameter<CastParam> {
   // use int for enumeration
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 95e8184f8a..4d899704a1 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -263,16 +263,15 @@ void InitFillWithScalarCompute(const nnvm::NodeAttrs &attrs,
   CHECK_EQ(inputs.size(), 0);
   CHECK_EQ(outputs.size(), 1U);
   const auto& param = nnvm::get<InitOpWithScalarParam>(attrs.parsed);
-  Fill<true>(ctx.get_stream<xpu>(), outputs[0], req[0], param.value);
+  Fill<false>(ctx.get_stream<xpu>(), outputs[0], req[0], param.value);
 }
 
-struct PopulateFullIdxRspKernel {
+struct PopulateFullIdxRspKernel : public mxnet_op::tunable {
   template<typename IType>
   MSHADOW_XINLINE static void Map(int i, IType* out) {
     KERNEL_ASSIGN(out[i], kWriteTo, i);
   }
 };
-MXNET_TUNABLE_MXNET_OP_FWD(PopulateFullIdxRspKernel);
 
 // Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray,
 // instead of the usual compact representation.


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services