You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/11 20:05:59 UTC

[incubator-mxnet] branch master updated: Pull more optimize and simplification changes from tuner branch (#8599)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8ba5de8  Pull more optimize and simplification changes from tuner branch (#8599)
8ba5de8 is described below

commit 8ba5de841ef28354be7aca5347391c0f28441ba5
Author: Chris Olivier <cj...@gmail.com>
AuthorDate: Sat Nov 11 12:05:57 2017 -0800

    Pull more optimize and simplification changes from tuner branch (#8599)
    
    * Pull more optimize changes from tuner branch
    
    * remove newline
    
    * Move file
    
    * Added slice_channel_perf.cc
---
 docs/faq/new_op.md                                 |   2 +-
 include/mxnet/engine.h                             |   5 -
 src/engine/naive_engine.cc                         |   7 --
 src/engine/openmp.cc                               |   2 +-
 src/engine/threaded_engine.h                       |  26 ------
 src/operator/mxnet_op.h                            |  58 ++++++------
 src/operator/optimizer_op-inl.h                    |   2 +-
 src/operator/tensor/elemwise_binary_op.h           |  56 ++++-------
 src/operator/tensor/elemwise_binary_scalar_op.h    |  20 ++--
 src/operator/tensor/elemwise_unary_op.h            |  77 +--------------
 src/operator/tensor/elemwise_unary_op_basic.cc     |  15 ++-
 src/operator/tensor/elemwise_unary_op_basic.cu     |  11 ++-
 src/operator/tensor/indexing_op.h                  |   3 +-
 src/operator/tensor/init_op.h                      |   3 +-
 tests/cpp/include/test_core_op.h                   |  60 +++++++-----
 tests/cpp/include/test_legacy_op.h                 |  22 ++---
 tests/cpp/include/test_ndarray_utils.h             |   2 +-
 tests/cpp/include/test_op.h                        |  26 ++++--
 tests/cpp/include/test_op_runner.h                 |  33 ++++---
 tests/cpp/include/test_perf.h                      |  46 +++++----
 tests/cpp/include/test_util.h                      |  61 ++++++++++++
 .../operator/{ => runner}/core_op_runner_test.cc   |   0
 tests/cpp/operator/slice_channel_perf.cc           | 104 +++++++++++++++++++++
 23 files changed, 360 insertions(+), 281 deletions(-)

diff --git a/docs/faq/new_op.md b/docs/faq/new_op.md
index 55b7409..994a2a6 100644
--- a/docs/faq/new_op.md
+++ b/docs/faq/new_op.md
@@ -339,7 +339,7 @@ NNVM_REGISTER_OP(_backward_abs)
 [](const NodeAttrs& attrs){
   return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
 })
-.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, unary_bwd<mshadow_op::sign> >);
+.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, backward_grad<mshadow_op::sign> >);
 ```
 
 ### Legacy Operators
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index 4048d5a..4c2314e 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -267,11 +267,6 @@ class MXNET_API Engine {
     }
     read_vars->resize(rtop - read_vars->begin());
   }
-
-  /*! \brief Return the number of OMP threads that should be used per worker
-   * \return Number of OMP threads that should be used per worker
-   */
-  virtual int num_omp_threads_per_worker() const = 0;
 };  // class Engine
 #endif  // DMLC_USE_CXX11
 }  // namespace mxnet
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 7e3554a..4d63749 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -188,13 +188,6 @@ class NaiveEngine final : public Engine {
     shutdown_phase_.store(true);
   }
 
-  /*! \brief Return the number of OMP threads that should be used per worker
-   * \return Number of OMP threads that should be used per worker
-   */
-  int num_omp_threads_per_worker() const override {
-    return OpenMP::Get()->GetRecommendedOMPThreadCount();
-  }
-
  private:
   // callback to oncomplete
   static void OnComplete(Engine *engine, void *param) {
diff --git a/src/engine/openmp.cc b/src/engine/openmp.cc
index be7885b..ad0c574 100644
--- a/src/engine/openmp.cc
+++ b/src/engine/openmp.cc
@@ -53,7 +53,7 @@ OpenMP::OpenMP()
       omp_set_num_threads(omp_thread_max_);
     } else {
       omp_thread_max_ = omp_get_max_threads();
-  }
+    }
   }
   omp_set_nested(dmlc::GetEnv("OMP_NESTED", false));
   omp_set_dynamic(dmlc::GetEnv("OMP_DYNAMIC", false));
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 3cf6653..e000a22 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -297,25 +297,6 @@ class ThreadedEngine : public Engine {
     finished_cv_.notify_all();
   }
 
-  /*! \brief Return default OMP thread count. Currently, this is whatever OMP shows as number
-   * of procs
-   * \warning Do not call this in any performance-sensitive use-case since checking the environment
-   * is slow
-   */
-  static int DefaultOMPThreadsPerWorker() {
-#ifdef _OPENMP
-    // If OMP_NUM_THREADS is set, use omp_get_max_threads(), which will be the value
-    // interpreted by the implemetation from the OMP_NUM_THREADS environment variable.
-    // Otherwise, return the number of processors, not counting hyperthreading.
-    // Test for set OMP_NUM_THREADS by checking against some nonsensical value
-    const int max_threads = dmlc::GetEnv("OMP_NUM_THREADS", INT_MIN) == INT_MIN ?
-                            omp_get_num_procs() : omp_get_max_threads();
-    return max_threads;
-#else
-    return 1;
-#endif
-  }
-
  protected:
   /*!
    * \brief Push the opr block to execution queue to be executed.
@@ -383,13 +364,6 @@ class ThreadedEngine : public Engine {
     }
   }
 
-  /*! \brief Return the number of OMP threads that should be used per worker
-   * \return Number of OMP threads that should be used per worker
-   */
-  int num_omp_threads_per_worker() const override {
-    return OpenMP::Get()->GetRecommendedOMPThreadCount();
-  }
-
  private:
   /*!
    * \brief check if thee is duplication in const_vars and mutable_vars.
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 564ad81..b2d5011 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -30,6 +30,7 @@
 #include <mxnet/engine.h>
 #include <mxnet/op_attr_types.h>
 #include <algorithm>
+#include "../engine/openmp.h"
 
 #ifdef __CUDACC__
 #include "../common/cuda_utils.h"
@@ -341,50 +342,30 @@ struct op_with_req {
   }
 };
 
-/*!
- * \brief Set to immediate scalar value kernel
- * \tparam val Scalar immediate
- */
-template<int val>
-struct set_to_int {
-  // mxnet_op version (when used directly with Kernel<>::Launch()) */
-  template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out) {
-    out[i] = DType(val);
-  }
-  // mshadow_op version (when used with op_with_req<>)
-  MSHADOW_XINLINE static int Map() {
-    return val;
-  }
-};
-
-/*! \brief Special-case kernel shortcut for setting to zero */
-using set_zero = set_to_int<0>;
-
 template<typename OP, typename xpu>
 struct Kernel;
 
-
 template<typename OP>
 struct Kernel<OP, cpu> {
+  /*! \brief Launch CPU kernel */
   template<typename ...Args>
-  inline static void Launch(mshadow::Stream<cpu> *s, const int N, Args... args) {
+  inline static void Launch(mshadow::Stream<cpu> *, const int N, Args... args) {
 #ifdef _OPENMP
-    const int omp_cores = Engine::Get()->num_omp_threads_per_worker();
-    if (omp_cores <= 1) {
+    const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+    if (omp_threads < 2) {
       // Zero means not to use OMP, but don't interfere with external OMP behavior
       for (int i = 0; i < N; ++i) {
         OP::Map(i, args...);
       }
     } else {
-      #pragma omp parallel for num_threads(omp_cores)
+      #pragma omp parallel for num_threads(omp_threads)
       for (int i = 0; i < N; ++i) {
         OP::Map(i, args...);
       }
     }
 #else
     for (int i = 0; i < N; ++i) {
-        OP::Map(i, args...);
+      OP::Map(i, args...);
     }
 #endif
   }
@@ -408,7 +389,6 @@ struct Kernel<OP, cpu> {
   }
 };
 
-
 #ifdef __CUDACC__
 template<typename OP, typename ...Args>
 __global__ void mxnet_generic_kernel(int N, Args... args) {
@@ -426,6 +406,7 @@ __global__ void mxnet_generic_kernel_ex(int N, Args... args) {
 
 template<typename OP>
 struct Kernel<OP, gpu> {
+  /*! \brief Launch GPU kernel */
   template<typename ...Args>
   inline static void Launch(mshadow::Stream<gpu> *s, int N, Args... args) {
     using namespace mshadow::cuda;
@@ -446,7 +427,30 @@ struct Kernel<OP, gpu> {
 };
 #endif  // __CUDACC__
 
+/*!
+ * \brief Set to immediate scalar value kernel
+ * \tparam val Scalar immediate
+ */
+template<int val>
+struct set_to_int {
+  // mxnet_op version (when used directly with Kernel<>::Launch()) */
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType *out) {
+    out[i] = DType(val);
+  }
+  // mshadow_op version (when used with op_with_req<>)
+  MSHADOW_XINLINE static int Map() {
+    return val;
+  }
+};
+
+/*!
+ * \brief Special-case kernel shortcut for setting to zero and one
+ */
+using set_zero = set_to_int<0>;
+using set_one  = set_to_int<1>;
 }  // namespace mxnet_op
 }  // namespace op
 }  // namespace mxnet
+
 #endif  // MXNET_OPERATOR_MXNET_OP_H_
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 1c5e1c6..61b97ba 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -264,7 +264,7 @@ inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs,
       grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
       static_cast<DType>(param.lr), static_cast<DType>(param.wd),
       static_cast<DType>(param.rescale_grad), req[0]);
-  });
+    });
 }
 
 template<int n_in, int n_out, int total_in>
diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h
index b8b5bd1..9c8f180 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -33,8 +33,10 @@
 #include <algorithm>
 #include "../mxnet_op.h"
 #include "../mshadow_op.h"
+#include "../../engine/openmp.h"
 #include "elemwise_unary_op.h"
 #include "../../common/utils.h"
+#include "./init_op.h"
 
 namespace mxnet {
 namespace op {
@@ -42,23 +44,6 @@ namespace op {
 /*! Gather binary operator functions into ElemwiseBinaryOp class */
 class ElemwiseBinaryOp : public OpBase {
  public:
-  template<typename OP, int Req>
-  struct BackwardUseNoneOp {
-    template<typename DType>
-    MSHADOW_XINLINE static void Map(int i, DType *igrad, const DType *ograd) {
-      KERNEL_ASSIGN(igrad[i], Req, OP::Map(ograd[i]));
-    }
-  };
-
-  template<typename OP, int Req>
-  struct BackwardUseInOp {
-    template<typename DType>
-    MSHADOW_XINLINE static void Map(int i, DType *igrad,
-                                    const DType *ograd, const DType *lhs, const DType *rhs) {
-      KERNEL_ASSIGN(igrad[i], Req, ograd[i] * OP::Map(lhs[i], rhs[i]));
-    }
-  };
-
   /*! \brief For sparse, assume missing rvalue is 0 */
   template<typename OP, int Req>
   struct MissingRValueOp {
@@ -89,25 +74,22 @@ class ElemwiseBinaryOp : public OpBase {
    * \brief Fill contiguous dense output rows with value computed from 0 lhs and 0 rhs input
    *        CPU-Only version
    */
-  template<typename DType, typename OP>
-  static inline size_t FillDense(mshadow::Stream<cpu> *s,
+  template<typename DType, typename OP, typename xpu>
+  static inline size_t FillDense(mshadow::Stream<xpu> *s,
                                  const size_t idx_l,
                                  const size_t idx_r,
                                  const OpReqType req,
-                                 mshadow::Tensor<cpu, 2, DType> *out,
+                                 mshadow::Tensor<xpu, 2, DType> *out,
                                  const size_t iter_out) {
-    const int index_out_min = std::min(idx_l, idx_r);
+    const int index_out_min = static_cast<int>(std::min(idx_l, idx_r));
     if (static_cast<size_t>(index_out_min) > iter_out) {
-      const size_t size = (*out)[iter_out].shape_.Size();
       const DType zero_input_val = OP::Map(DType(0), DType(0));
-      #pragma omp parallel for
-      for (int i = iter_out; i < index_out_min; ++i) {
-        MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-          SerialLaunchCPU<OpBase::set_to_scalar<Req>>(s, size, (*out)[i].dptr_, zero_input_val);
-        });
+      #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+      for (int i = static_cast<int>(iter_out); i < index_out_min; ++i) {
+        Fill<false>(s, (*out)[i], req, zero_input_val);
       }
     }
-    return index_out_min;
+    return static_cast<size_t>(index_out_min);  // MSVC wants OMP loops to always use 'int'
   }
 
   static inline bool IsSameArray(const NDArray& a1, const NDArray& a2) {
@@ -135,7 +117,7 @@ class ElemwiseBinaryOp : public OpBase {
     } else if (req[0] != kNullOp) {
       DType *lgrad_dptr = outputs[0].dptr<DType>();
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-        Kernel<BackwardUseNoneOp<LOP, Req>, xpu>::Launch(s, size, lgrad_dptr, ograd_dptr);
+        Kernel<mxnet_op::op_with_req<LOP, Req>, xpu>::Launch(s, size, lgrad_dptr, ograd_dptr);
       });
     }
     if (std::is_same<ROP, mshadow_op::identity>::value && req[1] == kWriteInplace) {
@@ -143,7 +125,7 @@ class ElemwiseBinaryOp : public OpBase {
     } else if (req[1] != kNullOp) {
       DType *rgrad_dptr = outputs[1].dptr<DType>();
       MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
-        Kernel<BackwardUseNoneOp<ROP, Req>, xpu>::Launch(s, size, rgrad_dptr, ograd_dptr);
+        Kernel<mxnet_op::op_with_req<ROP, Req>, xpu>::Launch(s, size, rgrad_dptr, ograd_dptr);
       });
     }
   }
@@ -165,14 +147,14 @@ class ElemwiseBinaryOp : public OpBase {
         (outputs[0].Size() + mxnet_op::DataType<DType>::kLanes - 1)
         / mxnet_op::DataType<DType>::kLanes);
       DType * lgrad_dptr = outputs[0].dptr<DType>();
-      mxnet_op::Kernel<BackwardUseInOp<LOP, Req>, xpu>::Launch(
+      mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad<LOP>, Req>, xpu>::Launch(
         s, size, lgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
     MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
       const int size = static_cast<int>(
         (outputs[1].Size() + mxnet_op::DataType<DType>::kLanes - 1)
         / mxnet_op::DataType<DType>::kLanes);
       DType * rgrad_dptr = outputs[1].dptr<DType>();
-      mxnet_op::Kernel<BackwardUseInOp<ROP, Req>, xpu>::Launch(
+      mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad<ROP>, Req>, xpu>::Launch(
         s, size, rgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
   }
 
@@ -503,10 +485,7 @@ class ElemwiseBinaryOp : public OpBase {
         CHECK_EQ(outputs[0].storage_type(), in_stype);
         // rsp -> rsp, _. op requires 0-input returns 0-output
         DCHECK_LT(fabs(static_cast<float>(LOP::Map(0))), 1e-5f);
-        MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-          UnaryOp::KernelComputeEx<xpu, BackwardUseNoneOp<LOP, Req>>(attrs, ctx, inputs,
-                                                                     req, {outputs[0]});
-        });
+        UnaryOp::ComputeEx<xpu, LOP>(attrs, ctx, inputs, req, {outputs[0]});
       } else {
         LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
       }
@@ -517,10 +496,7 @@ class ElemwiseBinaryOp : public OpBase {
         CHECK_EQ(outputs[0].storage_type(), in_stype);
         // rsp -> _, rsp. op requires 0-input returns 0-output
         DCHECK_LT(fabs(static_cast<float>(ROP::Map(0))), 1e-5f);
-        MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
-          UnaryOp::KernelComputeEx<xpu, BackwardUseNoneOp<ROP, Req>>(attrs, ctx, inputs,
-                                                                     req, {outputs[1]});
-        });
+        UnaryOp::ComputeEx<xpu, ROP>(attrs, ctx, inputs, req, {outputs[1]});
       } else {
         LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
       }
diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h
index b866a29..27d8ed3 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op.h
+++ b/src/operator/tensor/elemwise_binary_scalar_op.h
@@ -66,7 +66,7 @@ class BinaryScalarOp : public UnaryOp {
         const int64_t dense_block_count = next_input_row - output_row;
         if (dense_block_count > 0) {
           MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-            mxnet_op::Kernel<OpBase::set_to_scalar<Req>, cpu>::Launch(
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, cpu>::Launch(
               stream,
               items_per_row * dense_block_count,
               output_data.dptr_ + items_per_row * output_row,
@@ -237,11 +237,8 @@ class BinaryScalarOp : public UnaryOp {
     const double alpha = nnvm::get<double>(attrs.parsed);
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-        mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s,
-                                                                      inputs[0].Size(),
-                                                                      outputs[0].dptr<DType>(),
-                                                                      inputs[0].dptr<DType>(),
-                                                                      DType(alpha));
+        mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
+          s, inputs[0].Size(), outputs[0].dptr<DType>(), inputs[0].dptr<DType>(), DType(alpha));
       });
     });
   }
@@ -286,10 +283,13 @@ class BinaryScalarOp : public UnaryOp {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const double alpha = nnvm::get<double>(attrs.parsed);
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-      Tensor<xpu, 1, DType> igrad = outputs[0].FlatTo1D<xpu, DType>(s);
-      Tensor<xpu, 1, DType> ograd = inputs[0].FlatTo1D<xpu, DType>(s);
-      Tensor<xpu, 1, DType> lhs = inputs[1].FlatTo1D<xpu, DType>(s);
-      ASSIGN_DISPATCH(igrad, req[0], ograd * F<OP>(lhs, scalar<DType>(DType(alpha))));
+      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+        mxnet::op::mxnet_op::Kernel<mxnet::op::mxnet_op::op_with_req<
+          mxnet::op::mxnet_op::backward_grad<OP>, Req>, xpu>::
+          Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
+                 inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
+                 DType(alpha));
+      });
     });
   }
 };
diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h
index d455b7e..6fbde05 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -274,25 +274,6 @@ class UnaryOp : public OpBase {
   }
 
   template<typename xpu, typename op>
-  static void KernelCompute(const nnvm::NodeAttrs& attrs,
-                            const OpContext& ctx,
-                            const std::vector<TBlob>& inputs,
-                            const std::vector<OpReqType>& req,
-                            const std::vector<TBlob>& outputs) {
-    using namespace mshadow;
-    using namespace mxnet_op;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    CHECK_EQ(inputs.size(), 1U);
-    CHECK_EQ(outputs.size(), 1U);
-    if (req[0] != kNullOp) {
-      MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-        Kernel<op, xpu>::Launch(s, outputs[0].Size(),
-                                outputs[0].dptr<DType>(), inputs[0].dptr<DType>());
-      });
-    }
-  }
-
-  template<typename xpu, typename op>
   static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs,
                                const OpContext &ctx,
                                const std::vector<TBlob> &inputs,
@@ -309,25 +290,6 @@ class UnaryOp : public OpBase {
     });
   }
 
-  template<typename xpu, typename OP>
-  static void KernelComputeEx(const nnvm::NodeAttrs& attrs,
-                              const OpContext& ctx,
-                              const std::vector<NDArray>& inputs,
-                              const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& outputs) {
-    CHECK_EQ(inputs.size(), 1U);
-    CHECK_EQ(outputs.size(), 1U);
-    const auto in_stype = inputs[0].storage_type();
-    const auto out_stype = outputs[0].storage_type();
-    if (in_stype == out_stype && (in_stype == kRowSparseStorage || in_stype == kCSRStorage)) {
-      if (inputs[0].storage_shape().Size()) {
-        MapToFCompute<xpu>(attrs, ctx, inputs, req, outputs, KernelCompute<xpu, OP>);
-      }
-    } else {
-      LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
-    }
-  }
-
   template<typename xpu>
   static void IdentityCompute(const nnvm::NodeAttrs& attrs,
                               const OpContext& ctx,
@@ -395,13 +357,9 @@ class UnaryOp : public OpBase {
   }
 };
 
+/*! \brief Map legacy unary_bwd to backward_grad */
 template<typename GRAD_OP>
-struct unary_bwd {
-  template<typename DType>
-  MSHADOW_XINLINE static DType Map(DType a, DType b) {
-    return a * GRAD_OP::Map(b);
-  }
-};
+using unary_bwd = ::mxnet::op::mxnet_op::backward_grad<GRAD_OP>;
 
 struct CastParam : public dmlc::Parameter<CastParam> {
   // use int for enumeration
@@ -445,37 +403,6 @@ void CastCompute(const nnvm::NodeAttrs& attrs,
   });
 }
 
-namespace kernel_launch_op {
-/*! \brief sigmoid unit */
-struct sigmoid {
-  template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out,
-                                  const DType *in) {
-    out[i] = mshadow_op::sigmoid::Map<DType>(in[i]);
-  }
-};
-struct sigmoid_grad {
-  template<typename DType>
-  MSHADOW_XINLINE static DType Map(DType out_grad, DType in) {
-    return out_grad * mshadow_op::sigmoid_grad::Map<DType>(in);
-  }
-};
-/*! \brief Rectified Linear Operation */
-struct relu {
-  template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out,
-                                  const DType *in) {
-    out[i] = mshadow_op::relu::Map<DType>(in[i]);
-  }
-};
-struct relu_grad {
-  template<typename DType>
-  MSHADOW_XINLINE static DType Map(DType out_grad, DType in) {
-    return out_grad * mshadow_op::relu_grad::Map<DType>(in);
-  }
-};
-}  // namespace kernel_launch_op
-
 /*! \brief Unary compute */
 #define MXNET_OPERATOR_REGISTER_UNARY(__name$)                      \
   NNVM_REGISTER_OP(__name$)                                         \
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index c356c58..916c385 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -83,13 +83,12 @@ The storage type of ``relu`` output depends upon the input storage type:
 
 )code" ADD_FILELINE)
 .set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1, false, true, false>)
-.set_attr<FCompute>("FCompute<cpu>", UnaryOp::KernelCompute<
-  cpu, kernel_launch_op::relu>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::KernelComputeEx<
-  cpu, kernel_launch_op::relu>)
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::relu>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::ComputeEx<cpu, mshadow_op::relu>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_relu"});
 
-MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_relu, kernel_launch_op::relu_grad);
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_relu,
+                                               unary_bwd<mshadow_op::relu_grad>);
 
 // sigmoid
 MXNET_OPERATOR_REGISTER_UNARY(sigmoid)
@@ -102,11 +101,11 @@ MXNET_ADD_SPARSE_OP_ALIAS(sigmoid)
 The storage type of ``sigmoid`` output is always dense
 
 )code" ADD_FILELINE)
-.set_attr<FCompute>("FCompute<cpu>", UnaryOp::KernelCompute<
-  cpu, kernel_launch_op::sigmoid>)
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::sigmoid>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"});
 
-MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid, kernel_launch_op::sigmoid_grad);
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid,
+                                               unary_bwd<mshadow_op::sigmoid_grad>);
 
 // copy
 MXNET_OPERATOR_REGISTER_UNARY(_copy)
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu
index 3f982a2..41eef90 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cu
+++ b/src/operator/tensor/elemwise_unary_op_basic.cu
@@ -26,18 +26,19 @@
 namespace mxnet {
 namespace op {
 NNVM_REGISTER_OP(relu)
-.set_attr<FCompute>("FCompute<gpu>", UnaryOp::KernelCompute<gpu, kernel_launch_op::relu>)
-.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::KernelComputeEx<gpu, kernel_launch_op::relu>);
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::relu>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::ComputeEx<gpu, mshadow_op::relu>);
 
 NNVM_REGISTER_OP(_backward_relu)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, kernel_launch_op::relu_grad>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
+  gpu, unary_bwd<mshadow_op::relu_grad>>);
 
 NNVM_REGISTER_OP(sigmoid)
-.set_attr<FCompute>("FCompute<gpu>", UnaryOp::KernelCompute<gpu, kernel_launch_op::sigmoid>);
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::sigmoid>);
 
 NNVM_REGISTER_OP(_backward_sigmoid)
 .set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
-  gpu, kernel_launch_op::sigmoid_grad>);
+  gpu, unary_bwd<mshadow_op::sigmoid_grad>>);
 
 // copy
 NNVM_REGISTER_OP(_copy)
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 6a27b72..7624f2d 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -44,6 +44,7 @@
 #include "./dot-inl.h"
 #include "./init_op.h"
 #include "./matrix_op-inl.h"
+#include "../../engine/openmp.h"
 
 namespace mxnet {
 namespace op {
@@ -657,7 +658,7 @@ inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx,
         DType* grad_data = output.data().dptr<DType>();
         Kernel<set_zero, cpu>::Launch(s, nnr * row_length, grad_data);
         // add the final gradients
-        int num_threads = Engine::Get()->num_omp_threads_per_worker();
+        const int num_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
         dim_t segment_len = (nnr + num_threads - 1) / num_threads;
         Kernel<AddTakeGradRspKernel, cpu>::Launch(s, num_threads, grad_data, prefix_sum,
                                                   ograd.dptr<DType>(), row_length,
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index bb6d3c1..c621f6e 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -32,6 +32,7 @@
 #include <vector>
 #include <string>
 #include <limits>
+#include "../mshadow_op.h"
 #include "../elemwise_op_common.h"
 #include "../mxnet_op.h"
 #include "../mshadow_op.h"
@@ -225,7 +226,7 @@ void Fill(mshadow::Stream<xpu> *s, const TBlob& b, const OpReqType req, ValueTyp
       // Optimize common use-case of filling with ones
       MSHADOW_TYPE_SWITCH(b.type_flag_, DType, {
         MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-          mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::set_to_int<1>, Req>, xpu>::Launch(
+          mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::set_one, Req>, xpu>::Launch(
             s, b.Size(), b.dptr<DType>());
         });
       });
diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h
index 21d0776..c454c95 100644
--- a/tests/cpp/include/test_core_op.h
+++ b/tests/cpp/include/test_core_op.h
@@ -33,7 +33,24 @@ namespace op {
 // Tried making this a struct w/constexpr, but getting undefined reference on gcc 5.4.1
 #define COREOP_FWD_OP_NAME_KEY          "fwd_op_name"
 #define COREOP_BWD_OP_NAME_KEY          "bwd_op_name"
-#define COREOP_BWD_OP_NAME_VALUE_NONE   "<none>"
+#define COREOP_BWD_OP_NAME_VALUE_NONE   "[none]"
+
+enum TimingDirection {
+  Forward,
+  Backward
+};
+
+inline const char *TimingDirectionAsString(const TimingDirection td) {
+  switch (td) {
+    case Forward:
+      return "Forward";
+    case Backward:
+      return "Backward";
+    default:
+      CHECK(false) << "Unknown timing direction: " << static_cast<int>(td);
+      return "<unknown>";
+  }
+}
 
 /*!
  * Low-noise operator executor
@@ -43,11 +60,6 @@ template<typename DType>
 class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
   , public test::op::OperatorExecutorTiming {
   /*! \brief Performance timing categories */
-  enum TimingId {
-    Forward,
-    Backward
-  };
-
   /*!
    * \brief Access data blob as if on the CPU via a callback
    * \tparam Type of callback Function to call with CPU-data NDArray
@@ -92,8 +104,8 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
     values.reserve(count);
     for (kwargs_t::const_iterator i_iter = args.begin(), e_iter = args.end();
          i_iter != e_iter; ++i_iter) {
-      keys.push_back(i_iter->first.c_str());
-      values.push_back(i_iter->second.c_str());
+      keys.emplace_back(i_iter->first.c_str());
+      values.emplace_back(i_iter->second.c_str());
     }
     return imperative::ParseAttrs(op, op->num_inputs, count, &keys[0], &values[0]);
   }
@@ -108,7 +120,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
                                                  std::vector<TBlob> *dest) {
     dest->reserve(dest->size() + src.size());
     for (size_t i = 0, n = src.size(); i < n; ++i) {
-      dest->push_back(src[i].data());
+      dest->emplace_back(src[i].data());
     }
     return *dest;
   }
@@ -194,9 +206,9 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
       for (const ResourceRequest& req : reqs) {
         if (req.type == ResourceRequest::kTempSpace) {
           Resource r = ResourceManager::Get()->Request(ctx->run_ctx.ctx, req);
-          requested.push_back(r);
+          requested.emplace_back(r);
         } else if (req.type == ResourceRequest::kRandom) {
-          requested.push_back(ResourceManager::Get()->Request(ctx->run_ctx.ctx, req));
+          requested.emplace_back(ResourceManager::Get()->Request(ctx->run_ctx.ctx, req));
         } else {
           LOG(FATAL) << "resource type not yet supported";
         }
@@ -216,7 +228,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
     new_args.reserve(args.size() + 1);
     for (const auto& a : args) {
       if (a.first != COREOP_FWD_OP_NAME_KEY && a.first != COREOP_BWD_OP_NAME_KEY) {
-        new_args.push_back(a);
+        new_args.emplace_back(a);
       }
     }
     new_args.push_back({ COREOP_FWD_OP_NAME_KEY, fwd_op_name});
@@ -241,7 +253,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
       } else if (a.first == COREOP_BWD_OP_NAME_KEY) {
         *bwd_op_name_ptr = a.second;
       } else {
-        new_args.push_back(a);
+        new_args.emplace_back(a);
       }
     }
     return new_args;
@@ -317,7 +329,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
       // operators such as dot
       std::vector<TShape> shapes;
       for (size_t i = 0, n = std::max(num_visible_outputs, num_inputs); i < n; ++i) {
-        shapes.push_back(i < input_shapes_.size() ? input_shapes_[i]
+        shapes.emplace_back(i < input_shapes_.size() ? input_shapes_[i]
                                                   : input_shapes_[input_shapes_.size() - 1]);
       }
       std::vector<NDArray *> inputs_p, outputs_p;
@@ -331,21 +343,21 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
       outputs_.reserve(num_visible_outputs);
       outputs_p.reserve(num_visible_outputs);
 
-      for (int i = 0; i < num_inputs; ++i) {
+      for (size_t i = 0; i < static_cast<size_t>(num_inputs); ++i) {
         CHECK_LT(i, static_cast<int>(shapes.size()));
-        inputs_.push_back(i < inputs.size() ? inputs[i] : CreateRandArray(shapes[i],
+        inputs_.emplace_back(i < inputs.size() ? inputs[i] : CreateRandArray(shapes[i],
                                                                           ctx_.run_ctx.ctx));
-        inputs_p.push_back(&*inputs_.rbegin());
+        inputs_p.emplace_back(&*inputs_.rbegin());
       }
 
-      for (int i = 0; i < num_visible_outputs; ++i) {
+      for (size_t i = 0; i < static_cast<size_t>(num_visible_outputs); ++i) {
         // If supplied and valid, pass from the supplied outputs vector
         // Otherwise use empty for forward pass, or zero-filled for backward pass
-        outputs_.push_back(i < outputs.size()
-                           ? outputs[i]
-                           : (backward_for_op ? CreateZeroArray(shapes[i], ctx_.run_ctx.ctx)
-                                              : NDArray()));
-        outputs_p.push_back(&*outputs_.rbegin());
+        outputs_.emplace_back(i < outputs.size()
+                              ? outputs[i]
+                              : (backward_for_op ? CreateZeroArray(shapes[i], ctx_.run_ctx.ctx)
+                                                 : NDArray()));
+        outputs_p.emplace_back(&*outputs_.rbegin());
       }
 
       if (!backward_for_op) {
@@ -396,7 +408,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
             << "Can't automatically determine backward op name. Please specify";
           for (std::pair<std::shared_ptr<CoreOpExecutor>, std::string> &bw_item : bwd) {
             bw_item.first->set_verbose(verbose_);
-            backward_.push_back(bw_item.first);
+            backward_.emplace_back(bw_item.first);
             bw_item.first->Init(ArgsWithOpName(args, bw_item.second), {}, {}, this);
           }
         }
diff --git a/tests/cpp/include/test_legacy_op.h b/tests/cpp/include/test_legacy_op.h
index 30bdf07..6d326fc 100644
--- a/tests/cpp/include/test_legacy_op.h
+++ b/tests/cpp/include/test_legacy_op.h
@@ -135,7 +135,7 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
         // Get the resource of temporal space
         std::vector<TShape> inputShapes;
         for (size_t x = 0, n = shape_input_vec_.size(); x < n; ++x) {
-          inputShapes.push_back(shape_input_vec_[x]);
+          inputShapes.emplace_back(shape_input_vec_[x]);
         }
         allocateResources(opProp.ForwardResource(inputShapes));
 
@@ -408,11 +408,11 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
 
     std::vector<std::vector<TBlob> *> all_blob_vects_;
     inline OpData() {
-      all_blob_vects_.push_back(&blob_input_vec_);
-      all_blob_vects_.push_back(&blob_output_vec_);
-      all_blob_vects_.push_back(&blob_aux_states_);
-      all_blob_vects_.push_back(&blob_in_grad_);
-      all_blob_vects_.push_back(&blob_out_grad_);  // Remaining err (loss) pushing back upstream
+      all_blob_vects_.emplace_back(&blob_input_vec_);
+      all_blob_vects_.emplace_back(&blob_output_vec_);
+      all_blob_vects_.emplace_back(&blob_aux_states_);
+      all_blob_vects_.emplace_back(&blob_in_grad_);
+      all_blob_vects_.emplace_back(&blob_out_grad_);  // Remaining err (loss) pushing back upstream
     }
     virtual ~OpData() {}
   };
@@ -495,14 +495,14 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
     for (const ResourceRequest& req : reqs) {
       if (req.type == ResourceRequest::kTempSpace) {
         if (cached_temp.count(ctx) != 0) {
-          opContext_.requested.push_back(cached_temp.at(ctx));
+          opContext_.requested.emplace_back(cached_temp.at(ctx));
         } else {
           Resource r = ResourceManager::Get()->Request(ctx, req);
-          opContext_.requested.push_back(r);
+          opContext_.requested.emplace_back(r);
           cached_temp[ctx] = r;
         }
       } else if (req.type == ResourceRequest::kRandom) {
-        opContext_.requested.push_back(ResourceManager::Get()->Request(ctx, req));
+        opContext_.requested.emplace_back(ResourceManager::Get()->Request(ctx, req));
       } else {
         LOG(FATAL) << "resource type not yet supported";
       }
@@ -517,8 +517,8 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
                              const int dtype) {
     test::StandaloneBlob *blob = new test::StandaloneBlob(shape, isGPU, dtype);
     CHECK_NE(blob, static_cast<TBlob *>(nullptr));
-    standalone_blobs->push_back(std::unique_ptr<test::StandaloneBlob>(blob));
-    (*dest).push_back(*blob);
+    standalone_blobs->emplace_back(std::unique_ptr<test::StandaloneBlob>(blob));
+    (*dest).emplace_back(*blob);
     return blob;
   }
 
diff --git a/tests/cpp/include/test_ndarray_utils.h b/tests/cpp/include/test_ndarray_utils.h
index bbc7c05..f5ab967 100644
--- a/tests/cpp/include/test_ndarray_utils.h
+++ b/tests/cpp/include/test_ndarray_utils.h
@@ -80,7 +80,7 @@ inline NDArray DnsND(const TShape shape, const Context ctx, std::vector<TEST_DTY
   // generate random values
   while (vs.size() < num_val) {
     auto v = RandFloat();
-    vs.push_back(v);
+    vs.emplace_back(v);
   }
   CHECK_EQ(vs.size(), nd.shape().Size());
   MSHADOW_TYPE_SWITCH(nd.dtype(), DType, {
diff --git a/tests/cpp/include/test_op.h b/tests/cpp/include/test_op.h
index 949f2cc..cbafe14 100644
--- a/tests/cpp/include/test_op.h
+++ b/tests/cpp/include/test_op.h
@@ -100,17 +100,27 @@ class OperatorDataInitializer {
    * \param blob Blob which to fill with random values
    */
   void FillRandom(const TBlob& blob) const {
-    std::uniform_real_distribution<DType> distribution(-1.0, 1.0);
-    test::patternFill<DType>(&blob, [this, &distribution]() -> DType {
-      return distribution(this->generator());
+    std::uniform_real_distribution<> dis_real(-5.0, 5.0);
+    std::uniform_int_distribution<> dis_int(-128, 127);
+    test::patternFill<DType>(&blob, [this, &dis_real, &dis_int]() -> DType {
+      if (!std::is_integral<DType>::value) {
+        DType val;
+        do {
+          val = static_cast<DType>(dis_real(this->generator()));
+        } while (fabs(val) < 1e-5);  // If too close to zero, try again
+        return val;
+      } else {
+        DType val;
+        do {
+          val = static_cast<DType>(dis_int(this->generator()));
+        } while (!val);  // If zero, try again
+        return val;
+      }
     });
   }
 
   void FillZero(const TBlob& blob) const {
-    std::uniform_real_distribution<DType> distribution(-1.0, 1.0);
-    test::patternFill<DType>(&blob, [this, &distribution]() -> DType {
-      return DType(0);
-    });
+    test::patternFill<DType>(&blob, []() -> DType { return DType(0); });
   }
 
  private:
@@ -271,7 +281,7 @@ inline std::vector<TShape> ShapesOf(const std::vector<NDArray>& arrays) {
   std::vector<TShape> res;
   res.reserve(arrays.size());
   for (const NDArray& ar : arrays) {
-    res.push_back(ar.shape());
+    res.emplace_back(ar.shape());
   }
   return std::move(res);
 }
diff --git a/tests/cpp/include/test_op_runner.h b/tests/cpp/include/test_op_runner.h
index 4c7cd1d..eb25999 100644
--- a/tests/cpp/include/test_op_runner.h
+++ b/tests/cpp/include/test_op_runner.h
@@ -122,15 +122,14 @@ class OperatorRunner {
    * \param dim Data dimensions
    * \param count Number of times to run in each direction
    */
-  void TimingTest(const std::string& label,
-                  const bool isGPU,
-                  const bool stochastic,
-                  const test::op::kwargs_t& kwargs,
-                  int dim = 0,
-                  size_t count = 1,
-                  const std::vector<TShape>& timing_shapes = {}) {
-    std::cout << std::endl << std::flush;
-
+  std::unordered_map<int, perf::TimingInstrument::Info>
+  TimingTest(const std::string& label,
+             const bool isGPU,
+             const bool stochastic,
+             const test::op::kwargs_t& kwargs,
+             int dim = 0,
+             size_t count = 1,
+             const std::vector<TShape>& timing_shapes = {}) {
 #ifdef NDEBUG
     size_t COUNT = 50;
 #else
@@ -160,7 +159,7 @@ class OperatorRunner {
 
       if (timing_shapes.empty()) {
         do {
-          batchSize = stochastic ? test::rangedRand(1U, TES_BATCH_SIZE * 2U) : TIMING_BATCH_SIZE;
+          batchSize = stochastic ? test::rangedRand(1U, TEST_BATCH_SIZE * 2U) : TIMING_BATCH_SIZE;
           channels = stochastic ? test::rangedRand(1U, TEST_CHANNELS * 2U) : TIMING_CHANNELS;
           depth = stochastic ? test::rangedRand(1U, TEST_DEPTH * 2U) : TIMING_DEPTH;
           height = stochastic ? test::rangedRand(1U, TEST_DH * 2U) : TIMING_DH;
@@ -218,12 +217,18 @@ class OperatorRunner {
       }
     }
 
-    timing.print(&std::cout, label);
-    std::cout << std::endl << std::flush;
+    if (verbose_) {
+      timing.print(&std::cout, label);
+      std::cout << std::endl << std::flush;
+    }
+
+    return timing.data();
   }
 
+  void set_verbose(bool verbose) { verbose_ = verbose; }
+
  protected:
-  static constexpr int TES_BATCH_SIZE = 5;
+  static constexpr int TEST_BATCH_SIZE = 5;
   static constexpr int TEST_CHANNELS = 3;
   static constexpr int TEST_DEPTH = 2;
   static constexpr int TEST_DH = 2;
@@ -234,6 +239,8 @@ class OperatorRunner {
   static constexpr int TIMING_DEPTH = 2;
   static constexpr int TIMING_DH = 64;
   static constexpr int TIMING_DW = 64;
+  /*! \brief verbose output */
+  bool verbose_ = true;
 };
 
 }  // namespace test
diff --git a/tests/cpp/include/test_perf.h b/tests/cpp/include/test_perf.h
index b6f2145..7971ed7 100644
--- a/tests/cpp/include/test_perf.h
+++ b/tests/cpp/include/test_perf.h
@@ -45,7 +45,7 @@ namespace perf {
 inline uint64_t getMicroTickCount() {
 #ifndef _WIN32
   struct timeval tv;
-  gettimeofday(&tv, NULL);
+  gettimeofday(&tv, nullptr);
   return uint64_t(tv.tv_sec) * 1000000 + tv.tv_usec;
 #else
   LARGE_INTEGER CurrentTime;
@@ -79,11 +79,6 @@ inline uint64_t getNannoTickCount() {
 #endif
 }
 
-/*! \brief millisecond tick count */
-inline uint64_t getTickCount() {
-  return getMicroTickCount() / 1000;
-}
-
 #define MICRO2MS(__micro$)  (((__micro$) + 500)/1000)
 #define MICRO2MSF(__micro$) (static_cast<float>(__micro$)/1000)
 #define MICRO2MSF(__micro$) (static_cast<float>(__micro$)/1000)
@@ -100,7 +95,7 @@ class TimedScope {
   const size_t    count_;
 
  public:
-  explicit inline TimedScope(const char *msg = NULL, size_t count = 1, const bool start = true)
+  explicit inline TimedScope(const char *msg = nullptr, size_t count = 1, const bool start = true)
     : startTime_(start ? getMicroTickCount() : 0)
       , stopTime_(0)
       , count_(count) {
@@ -164,7 +159,7 @@ class TimingInstrument {
   }
   void startTiming(int id, const char *s) {
     std::unique_lock<std::recursive_mutex>  lk(mutex_);
-    std::unordered_map<int, Info>::iterator i = data_.find(id);
+    auto i = data_.find(id);
     if (i == data_.end()) {
       i = data_.emplace(std::make_pair(id, Info(s))).first;
     }
@@ -174,7 +169,7 @@ class TimingInstrument {
   }
   void stopTiming(int id, const size_t subIterationCount = 1) {
     std::unique_lock<std::recursive_mutex>  lk(mutex_);
-    std::unordered_map<int, Info>::iterator i = data_.find(id);
+    auto i = data_.find(id);
     CHECK_NE(i == data_.end(), true) << "Can't stop timing on an object that we don't know about";
     if (i != data_.end()) {
       CHECK_NE(i->second.nestingCount_, 0U) << "While stopping timing, invalid nesting count of 0";
@@ -188,7 +183,7 @@ class TimingInstrument {
   }
   uint64_t getDuration(int id) {
     std::unique_lock<std::recursive_mutex>  lk(mutex_);
-    std::unordered_map<int, Info>::iterator i = data_.find(id);
+    auto i = data_.find(id);
     if (i != data_.end()) {
       const Info&        info = i->second;
       const uint64_t duration = info.nestingCount_.load()
@@ -202,7 +197,7 @@ class TimingInstrument {
   bool isTiming(int id) {
     std::unordered_map<int, Info>::const_iterator i = data_.find(id);
     if (i != data_.end()) {
-      return !!i->second.nestingCount_.load();
+      return i->second.nestingCount_.load() != 0;
     }
     return false;
   }
@@ -216,7 +211,7 @@ class TimingInstrument {
         i != e; ++i) {
       const Info&        info = i->second;
       const uint64_t duration = getDuration(i->first);
-      *os << /*std::endl <<*/ label_ << ": " << name_ << " Timing [" << info.name_ << "] "
+      *os << label_ << ": " << name_ << " Timing [" << info.name_ << "] "
           << (info.nestingCount_.load() ? "*" : "")
           << MICRO2MSF(duration) << " ms";
         if (info.cycleCount_.load()) {
@@ -233,7 +228,7 @@ class TimingInstrument {
 
   void reset() {
     std::unique_lock<std::recursive_mutex>  lk(mutex_);
-    for (std::unordered_map<int, Info>::iterator i = data_.begin(), e = data_.end();
+    for (auto i = data_.begin(), e = data_.end();
         i != e; ++i) {
       const int id = i->first;
       const bool wasTiming = isTiming(id);
@@ -250,9 +245,9 @@ class TimingInstrument {
   }
 
   TimingInstrument& operator += (const TimingInstrument& o) {
-    for (std::unordered_map<int, Info>::const_iterator i = o.data_.begin(), e = o.data_.end();
+    for (auto i = o.data_.begin(), e = o.data_.end();
         i != e; ++i) {
-      std::unordered_map<int, Info>::iterator j = data_.find(i->first);
+      auto j = data_.find(i->first);
       if (j != data_.end())  {
         const Info &oInfo = i->second;
         CHECK_EQ(oInfo.nestingCount_, 0U);
@@ -265,7 +260,6 @@ class TimingInstrument {
     return *this;
   }
 
- private:
   struct Info {
     explicit inline Info(const char *s)
       : name_(s ? s : "")
@@ -273,6 +267,7 @@ class TimingInstrument {
         , nestingCount_(0)
         , cycleCount_(0)
         , duration_(0) {}
+
     inline Info(const Info& o)
       : name_(o.name_)
         , baseTime_(o.baseTime_.load())
@@ -281,17 +276,36 @@ class TimingInstrument {
         , duration_(o.duration_.load()) {
       CHECK_EQ(o.nestingCount_, 0U);
     }
+
+    /*!
+     * \brief Return time for each operation in milliseconds
+     * \return Time for each operation in milliseconds
+     */
+    inline double TimeEach() const {
+      return static_cast<double>(duration_) / cycleCount_.load() / 1000.0f;
+    }
+
     std::string           name_;
     std::atomic<uint64_t> baseTime_;
     std::atomic<uint64_t> nestingCount_;
     std::atomic<uint64_t> cycleCount_;  // Note that nesting may skew averages
     std::atomic<uint64_t> duration_;
   };
+
+  typedef std::unordered_map<int, TimingInstrument::Info> timing_map_t;
+
+  const timing_map_t& data() const {
+    return data_;
+  }
+
+ private:
   std::string                   name_;
   mutable std::recursive_mutex  mutex_;
   std::unordered_map<int, Info> data_;
 };
 
+using timing_map_t = TimingInstrument::timing_map_t;
+
 /*! \brief Accumulated scoped timing, indexed by ID */
 class TimingItem {
  public:
diff --git a/tests/cpp/include/test_util.h b/tests/cpp/include/test_util.h
index 95ab141..33ca3c4 100644
--- a/tests/cpp/include/test_util.h
+++ b/tests/cpp/include/test_util.h
@@ -609,6 +609,67 @@ inline ScalarType rangedRand(const ScalarType min, const ScalarType max) {
   return static_cast<ScalarType>(x / bin_size + min);
 }
 
+/*!
+ * \brief Deterministically compare TShape objects as less-than,
+ *        for use in stl sorted key such as map and set
+ * \param s1 First shape
+ * \param s2 Second shape
+ * \return true if s1 is less than s2
+ */
+inline bool operator < (const nnvm::TShape &s1, const nnvm::TShape &s2) {
+  if (s1.Size() == s2.Size()) {
+    if (s1.ndim() == s2.ndim()) {
+      for (size_t i = 0, n = s1.ndim(); i < n; ++i) {
+        if (s1[i] == s2[i]) {
+          continue;
+        }
+        return s1[i] < s2[i];
+      }
+      return false;
+    }
+    return s1.ndim() < s2.ndim();
+  }
+  return s1.Size() < s2.Size();
+}
+
+/*!
+ * \brief Deterministically compare a vector of TShape objects as less-than,
+ *        for use in stl sorted key such as map and set
+ * \param v1 First vector of shapes
+ * \param v2 Second vector of shapes
+ * \return true if v1 is less than v2
+ */
+inline bool operator < (const std::vector<nnvm::TShape>& v1, const std::vector<nnvm::TShape>& v2) {
+  if (v1.size() == v2.size()) {
+    for (size_t i = 0, n = v1.size(); i < n; ++i) {
+      if (v1[i] == v2[i]) {
+        continue;
+      }
+      return v1[i] < v2[i];
+    }
+    return false;
+  }
+  return v1.size() < v2.size();
+}
+
+/*!
+ * \brief std::less compare structure for compating vectors of shapes for stl sorted containers
+ */
+struct less_shapevect {
+  bool operator()(const std::vector<nnvm::TShape>& v1, const std::vector<nnvm::TShape>& v2) const {
+    if (v1.size() == v2.size()) {
+      for (size_t i = 0, n = v1.size(); i < n; ++i) {
+        if (v1[i] == v2[i]) {
+          continue;
+        }
+        return v1[i] < v2[i];
+      }
+      return false;
+    }
+    return v1.size() < v2.size();
+  }
+};
+
 inline std::string pretty_num(uint64_t val) {
   std::string res, s = std::to_string(val);
   size_t ctr = 0;
diff --git a/tests/cpp/operator/core_op_runner_test.cc b/tests/cpp/operator/runner/core_op_runner_test.cc
similarity index 100%
rename from tests/cpp/operator/core_op_runner_test.cc
rename to tests/cpp/operator/runner/core_op_runner_test.cc
diff --git a/tests/cpp/operator/slice_channel_perf.cc b/tests/cpp/operator/slice_channel_perf.cc
new file mode 100644
index 0000000..dc42d2a
--- /dev/null
+++ b/tests/cpp/operator/slice_channel_perf.cc
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  \file activation_perf.cc
+ *  \brief Perf/profile run of ActivationOp
+ *  \author Chris Olivier
+ */
+
+#include <gtest/gtest.h>
+#include <mxnet/tensor_blob.h>
+#include "../include/test_op_runner.h"
+#include "../include/test_legacy_op.h"
+#include "../../src/operator/slice_channel-inl.h"
+
+using namespace mxnet;
+
+typedef std::vector<std::pair<std::string, std::string> > kwargs_t;
+const kwargs_t basic_activation_args = { };
+
+/*!
+ * \brief Generic bidirectional sanity test
+ */
+TEST(SLICE_CHANNEL_PERF, ExecuteBidirectional) {
+  TShape shape({1, 160, 200});
+  kwargs_t kwargs = basic_activation_args;
+  kwargs.push_back({"num_outputs", "160"});
+  test::op::LegacyOpRunner<mxnet::op::SliceChannelProp, float, float> runner;
+  runner.RunBidirectional(false, { shape }, kwargs, 1);
+}
+
+/*!
+ * \brief ActivationOp timing test for CPU
+ */
+TEST(SLICE_CHANNEL_PERF, TimingCPU) {
+  kwargs_t kwargs = basic_activation_args;
+  // Which math function is arbitrary since it will have roughly constant timing among approaches
+  kwargs.push_back({"num_outputs", "160"});
+  test::op::LegacyOpRunner<mxnet::op::SliceChannelProp, float, float> runner;
+  runner.RunBidirectional(false,
+                          { TShape({1, 160, 200}) },
+                          kwargs, 1);  // prime code and cache
+  std::vector <TShape> shapes;
+  if (test::performance_run) {
+    shapes = {
+      {1, 160, 200},
+      {10, 160, 200},
+      {100, 160, 200},
+      {10, 160, 500},
+      {100, 160, 500}
+    };
+  } else {
+    shapes = {
+      {1, 160, 200},
+      {1, 160, 200}
+    };
+  }
+  for (const TShape &shape : shapes) {
+    runner.TimingTest("SliceChannel Operator CPU", false, false, kwargs, 2, 10, { shape });
+  }
+}
+
+#if MXNET_USE_CUDA == 1
+/*!
+ * \brief ActivationOp timing test for GPU
+ */
+TEST(SLICE_CHANNEL_PERF, TimingGPU) {
+  kwargs_t kwargs = basic_activation_args;
+  // Which math function is arbitrary since it will have roughly constant timing among approaches
+  kwargs.push_back({"num_outputs", "160"});
+  test::OperatorRunner<mxnet::op::SliceChannelProp,
+    test::op::LegacyOperatorExecutor<float, float>> runner;
+  runner.RunBidirectional(true,
+                          { TShape({1, 160, 200}) },
+                          kwargs, 1);  // prime code and cache
+  std::vector <TShape> shapes = {
+      {1, 160, 200},
+      {1, 160, 200},
+      {1, 160, 200},
+      {1, 160, 200},
+      {1, 160, 200}
+    };
+  for (const TShape &shape : shapes) {
+    runner.TimingTest("SliceChannel Operator GPU", true, false, kwargs, 2, 10, { shape });
+  }
+}
+#endif  // MXNET_USE_CUDA == 1
+

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].