You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/11 20:05:59 UTC
[incubator-mxnet] branch master updated: Pull more optimize and
simplification changes from tuner branch (#8599)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8ba5de8 Pull more optimize and simplification changes from tuner branch (#8599)
8ba5de8 is described below
commit 8ba5de841ef28354be7aca5347391c0f28441ba5
Author: Chris Olivier <cj...@gmail.com>
AuthorDate: Sat Nov 11 12:05:57 2017 -0800
Pull more optimize and simplification changes from tuner branch (#8599)
* Pull more optimize changes from tuner branch
* remove newline
* Move file
* Added slice_channel_perf.cc
---
docs/faq/new_op.md | 2 +-
include/mxnet/engine.h | 5 -
src/engine/naive_engine.cc | 7 --
src/engine/openmp.cc | 2 +-
src/engine/threaded_engine.h | 26 ------
src/operator/mxnet_op.h | 58 ++++++------
src/operator/optimizer_op-inl.h | 2 +-
src/operator/tensor/elemwise_binary_op.h | 56 ++++-------
src/operator/tensor/elemwise_binary_scalar_op.h | 20 ++--
src/operator/tensor/elemwise_unary_op.h | 77 +--------------
src/operator/tensor/elemwise_unary_op_basic.cc | 15 ++-
src/operator/tensor/elemwise_unary_op_basic.cu | 11 ++-
src/operator/tensor/indexing_op.h | 3 +-
src/operator/tensor/init_op.h | 3 +-
tests/cpp/include/test_core_op.h | 60 +++++++-----
tests/cpp/include/test_legacy_op.h | 22 ++---
tests/cpp/include/test_ndarray_utils.h | 2 +-
tests/cpp/include/test_op.h | 26 ++++--
tests/cpp/include/test_op_runner.h | 33 ++++---
tests/cpp/include/test_perf.h | 46 +++++----
tests/cpp/include/test_util.h | 61 ++++++++++++
.../operator/{ => runner}/core_op_runner_test.cc | 0
tests/cpp/operator/slice_channel_perf.cc | 104 +++++++++++++++++++++
23 files changed, 360 insertions(+), 281 deletions(-)
diff --git a/docs/faq/new_op.md b/docs/faq/new_op.md
index 55b7409..994a2a6 100644
--- a/docs/faq/new_op.md
+++ b/docs/faq/new_op.md
@@ -339,7 +339,7 @@ NNVM_REGISTER_OP(_backward_abs)
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
})
-.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, unary_bwd<mshadow_op::sign> >);
+.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, backward_grad<mshadow_op::sign> >);
```
### Legacy Operators
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index 4048d5a..4c2314e 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -267,11 +267,6 @@ class MXNET_API Engine {
}
read_vars->resize(rtop - read_vars->begin());
}
-
- /*! \brief Return the number of OMP threads that should be used per worker
- * \return Number of OMP threads that should be used per worker
- */
- virtual int num_omp_threads_per_worker() const = 0;
}; // class Engine
#endif // DMLC_USE_CXX11
} // namespace mxnet
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 7e3554a..4d63749 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -188,13 +188,6 @@ class NaiveEngine final : public Engine {
shutdown_phase_.store(true);
}
- /*! \brief Return the number of OMP threads that should be used per worker
- * \return Number of OMP threads that should be used per worker
- */
- int num_omp_threads_per_worker() const override {
- return OpenMP::Get()->GetRecommendedOMPThreadCount();
- }
-
private:
// callback to oncomplete
static void OnComplete(Engine *engine, void *param) {
diff --git a/src/engine/openmp.cc b/src/engine/openmp.cc
index be7885b..ad0c574 100644
--- a/src/engine/openmp.cc
+++ b/src/engine/openmp.cc
@@ -53,7 +53,7 @@ OpenMP::OpenMP()
omp_set_num_threads(omp_thread_max_);
} else {
omp_thread_max_ = omp_get_max_threads();
- }
+ }
}
omp_set_nested(dmlc::GetEnv("OMP_NESTED", false));
omp_set_dynamic(dmlc::GetEnv("OMP_DYNAMIC", false));
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 3cf6653..e000a22 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -297,25 +297,6 @@ class ThreadedEngine : public Engine {
finished_cv_.notify_all();
}
- /*! \brief Return default OMP thread count. Currently, this is whatever OMP shows as number
- * of procs
- * \warning Do not call this in any performance-sensitive use-case since checking the environment
- * is slow
- */
- static int DefaultOMPThreadsPerWorker() {
-#ifdef _OPENMP
- // If OMP_NUM_THREADS is set, use omp_get_max_threads(), which will be the value
- // interpreted by the implemetation from the OMP_NUM_THREADS environment variable.
- // Otherwise, return the number of processors, not counting hyperthreading.
- // Test for set OMP_NUM_THREADS by checking against some nonsensical value
- const int max_threads = dmlc::GetEnv("OMP_NUM_THREADS", INT_MIN) == INT_MIN ?
- omp_get_num_procs() : omp_get_max_threads();
- return max_threads;
-#else
- return 1;
-#endif
- }
-
protected:
/*!
* \brief Push the opr block to execution queue to be executed.
@@ -383,13 +364,6 @@ class ThreadedEngine : public Engine {
}
}
- /*! \brief Return the number of OMP threads that should be used per worker
- * \return Number of OMP threads that should be used per worker
- */
- int num_omp_threads_per_worker() const override {
- return OpenMP::Get()->GetRecommendedOMPThreadCount();
- }
-
private:
/*!
* \brief check if thee is duplication in const_vars and mutable_vars.
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 564ad81..b2d5011 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -30,6 +30,7 @@
#include <mxnet/engine.h>
#include <mxnet/op_attr_types.h>
#include <algorithm>
+#include "../engine/openmp.h"
#ifdef __CUDACC__
#include "../common/cuda_utils.h"
@@ -341,50 +342,30 @@ struct op_with_req {
}
};
-/*!
- * \brief Set to immediate scalar value kernel
- * \tparam val Scalar immediate
- */
-template<int val>
-struct set_to_int {
- // mxnet_op version (when used directly with Kernel<>::Launch()) */
- template<typename DType>
- MSHADOW_XINLINE static void Map(int i, DType* out) {
- out[i] = DType(val);
- }
- // mshadow_op version (when used with op_with_req<>)
- MSHADOW_XINLINE static int Map() {
- return val;
- }
-};
-
-/*! \brief Special-case kernel shortcut for setting to zero */
-using set_zero = set_to_int<0>;
-
template<typename OP, typename xpu>
struct Kernel;
-
template<typename OP>
struct Kernel<OP, cpu> {
+ /*! \brief Launch CPU kernel */
template<typename ...Args>
- inline static void Launch(mshadow::Stream<cpu> *s, const int N, Args... args) {
+ inline static void Launch(mshadow::Stream<cpu> *, const int N, Args... args) {
#ifdef _OPENMP
- const int omp_cores = Engine::Get()->num_omp_threads_per_worker();
- if (omp_cores <= 1) {
+ const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+ if (omp_threads < 2) {
// Zero means not to use OMP, but don't interfere with external OMP behavior
for (int i = 0; i < N; ++i) {
OP::Map(i, args...);
}
} else {
- #pragma omp parallel for num_threads(omp_cores)
+ #pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
OP::Map(i, args...);
}
}
#else
for (int i = 0; i < N; ++i) {
- OP::Map(i, args...);
+ OP::Map(i, args...);
}
#endif
}
@@ -408,7 +389,6 @@ struct Kernel<OP, cpu> {
}
};
-
#ifdef __CUDACC__
template<typename OP, typename ...Args>
__global__ void mxnet_generic_kernel(int N, Args... args) {
@@ -426,6 +406,7 @@ __global__ void mxnet_generic_kernel_ex(int N, Args... args) {
template<typename OP>
struct Kernel<OP, gpu> {
+ /*! \brief Launch GPU kernel */
template<typename ...Args>
inline static void Launch(mshadow::Stream<gpu> *s, int N, Args... args) {
using namespace mshadow::cuda;
@@ -446,7 +427,30 @@ struct Kernel<OP, gpu> {
};
#endif // __CUDACC__
+/*!
+ * \brief Set to immediate scalar value kernel
+ * \tparam val Scalar immediate
+ */
+template<int val>
+struct set_to_int {
+ // mxnet_op version (when used directly with Kernel<>::Launch()) */
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i, DType *out) {
+ out[i] = DType(val);
+ }
+ // mshadow_op version (when used with op_with_req<>)
+ MSHADOW_XINLINE static int Map() {
+ return val;
+ }
+};
+
+/*!
+ * \brief Special-case kernel shortcut for setting to zero and one
+ */
+using set_zero = set_to_int<0>;
+using set_one = set_to_int<1>;
} // namespace mxnet_op
} // namespace op
} // namespace mxnet
+
#endif // MXNET_OPERATOR_MXNET_OP_H_
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 1c5e1c6..61b97ba 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -264,7 +264,7 @@ inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs,
grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad), req[0]);
- });
+ });
}
template<int n_in, int n_out, int total_in>
diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h
index b8b5bd1..9c8f180 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -33,8 +33,10 @@
#include <algorithm>
#include "../mxnet_op.h"
#include "../mshadow_op.h"
+#include "../../engine/openmp.h"
#include "elemwise_unary_op.h"
#include "../../common/utils.h"
+#include "./init_op.h"
namespace mxnet {
namespace op {
@@ -42,23 +44,6 @@ namespace op {
/*! Gather binary operator functions into ElemwiseBinaryOp class */
class ElemwiseBinaryOp : public OpBase {
public:
- template<typename OP, int Req>
- struct BackwardUseNoneOp {
- template<typename DType>
- MSHADOW_XINLINE static void Map(int i, DType *igrad, const DType *ograd) {
- KERNEL_ASSIGN(igrad[i], Req, OP::Map(ograd[i]));
- }
- };
-
- template<typename OP, int Req>
- struct BackwardUseInOp {
- template<typename DType>
- MSHADOW_XINLINE static void Map(int i, DType *igrad,
- const DType *ograd, const DType *lhs, const DType *rhs) {
- KERNEL_ASSIGN(igrad[i], Req, ograd[i] * OP::Map(lhs[i], rhs[i]));
- }
- };
-
/*! \brief For sparse, assume missing rvalue is 0 */
template<typename OP, int Req>
struct MissingRValueOp {
@@ -89,25 +74,22 @@ class ElemwiseBinaryOp : public OpBase {
* \brief Fill contiguous dense output rows with value computed from 0 lhs and 0 rhs input
* CPU-Only version
*/
- template<typename DType, typename OP>
- static inline size_t FillDense(mshadow::Stream<cpu> *s,
+ template<typename DType, typename OP, typename xpu>
+ static inline size_t FillDense(mshadow::Stream<xpu> *s,
const size_t idx_l,
const size_t idx_r,
const OpReqType req,
- mshadow::Tensor<cpu, 2, DType> *out,
+ mshadow::Tensor<xpu, 2, DType> *out,
const size_t iter_out) {
- const int index_out_min = std::min(idx_l, idx_r);
+ const int index_out_min = static_cast<int>(std::min(idx_l, idx_r));
if (static_cast<size_t>(index_out_min) > iter_out) {
- const size_t size = (*out)[iter_out].shape_.Size();
const DType zero_input_val = OP::Map(DType(0), DType(0));
- #pragma omp parallel for
- for (int i = iter_out; i < index_out_min; ++i) {
- MXNET_ASSIGN_REQ_SWITCH(req, Req, {
- SerialLaunchCPU<OpBase::set_to_scalar<Req>>(s, size, (*out)[i].dptr_, zero_input_val);
- });
+ #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+ for (int i = static_cast<int>(iter_out); i < index_out_min; ++i) {
+ Fill<false>(s, (*out)[i], req, zero_input_val);
}
}
- return index_out_min;
+ return static_cast<size_t>(index_out_min); // MSVC wants OMP loops to always use 'int'
}
static inline bool IsSameArray(const NDArray& a1, const NDArray& a2) {
@@ -135,7 +117,7 @@ class ElemwiseBinaryOp : public OpBase {
} else if (req[0] != kNullOp) {
DType *lgrad_dptr = outputs[0].dptr<DType>();
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
- Kernel<BackwardUseNoneOp<LOP, Req>, xpu>::Launch(s, size, lgrad_dptr, ograd_dptr);
+ Kernel<mxnet_op::op_with_req<LOP, Req>, xpu>::Launch(s, size, lgrad_dptr, ograd_dptr);
});
}
if (std::is_same<ROP, mshadow_op::identity>::value && req[1] == kWriteInplace) {
@@ -143,7 +125,7 @@ class ElemwiseBinaryOp : public OpBase {
} else if (req[1] != kNullOp) {
DType *rgrad_dptr = outputs[1].dptr<DType>();
MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
- Kernel<BackwardUseNoneOp<ROP, Req>, xpu>::Launch(s, size, rgrad_dptr, ograd_dptr);
+ Kernel<mxnet_op::op_with_req<ROP, Req>, xpu>::Launch(s, size, rgrad_dptr, ograd_dptr);
});
}
}
@@ -165,14 +147,14 @@ class ElemwiseBinaryOp : public OpBase {
(outputs[0].Size() + mxnet_op::DataType<DType>::kLanes - 1)
/ mxnet_op::DataType<DType>::kLanes);
DType * lgrad_dptr = outputs[0].dptr<DType>();
- mxnet_op::Kernel<BackwardUseInOp<LOP, Req>, xpu>::Launch(
+ mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad<LOP>, Req>, xpu>::Launch(
s, size, lgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
const int size = static_cast<int>(
(outputs[1].Size() + mxnet_op::DataType<DType>::kLanes - 1)
/ mxnet_op::DataType<DType>::kLanes);
DType * rgrad_dptr = outputs[1].dptr<DType>();
- mxnet_op::Kernel<BackwardUseInOp<ROP, Req>, xpu>::Launch(
+ mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad<ROP>, Req>, xpu>::Launch(
s, size, rgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
}
@@ -503,10 +485,7 @@ class ElemwiseBinaryOp : public OpBase {
CHECK_EQ(outputs[0].storage_type(), in_stype);
// rsp -> rsp, _. op requires 0-input returns 0-output
DCHECK_LT(fabs(static_cast<float>(LOP::Map(0))), 1e-5f);
- MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
- UnaryOp::KernelComputeEx<xpu, BackwardUseNoneOp<LOP, Req>>(attrs, ctx, inputs,
- req, {outputs[0]});
- });
+ UnaryOp::ComputeEx<xpu, LOP>(attrs, ctx, inputs, req, {outputs[0]});
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
@@ -517,10 +496,7 @@ class ElemwiseBinaryOp : public OpBase {
CHECK_EQ(outputs[0].storage_type(), in_stype);
// rsp -> _, rsp. op requires 0-input returns 0-output
DCHECK_LT(fabs(static_cast<float>(ROP::Map(0))), 1e-5f);
- MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
- UnaryOp::KernelComputeEx<xpu, BackwardUseNoneOp<ROP, Req>>(attrs, ctx, inputs,
- req, {outputs[1]});
- });
+ UnaryOp::ComputeEx<xpu, ROP>(attrs, ctx, inputs, req, {outputs[1]});
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h
index b866a29..27d8ed3 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op.h
+++ b/src/operator/tensor/elemwise_binary_scalar_op.h
@@ -66,7 +66,7 @@ class BinaryScalarOp : public UnaryOp {
const int64_t dense_block_count = next_input_row - output_row;
if (dense_block_count > 0) {
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
- mxnet_op::Kernel<OpBase::set_to_scalar<Req>, cpu>::Launch(
+ mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, cpu>::Launch(
stream,
items_per_row * dense_block_count,
output_data.dptr_ + items_per_row * output_row,
@@ -237,11 +237,8 @@ class BinaryScalarOp : public UnaryOp {
const double alpha = nnvm::get<double>(attrs.parsed);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
- mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s,
- inputs[0].Size(),
- outputs[0].dptr<DType>(),
- inputs[0].dptr<DType>(),
- DType(alpha));
+ mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
+ s, inputs[0].Size(), outputs[0].dptr<DType>(), inputs[0].dptr<DType>(), DType(alpha));
});
});
}
@@ -286,10 +283,13 @@ class BinaryScalarOp : public UnaryOp {
Stream<xpu> *s = ctx.get_stream<xpu>();
const double alpha = nnvm::get<double>(attrs.parsed);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
- Tensor<xpu, 1, DType> igrad = outputs[0].FlatTo1D<xpu, DType>(s);
- Tensor<xpu, 1, DType> ograd = inputs[0].FlatTo1D<xpu, DType>(s);
- Tensor<xpu, 1, DType> lhs = inputs[1].FlatTo1D<xpu, DType>(s);
- ASSIGN_DISPATCH(igrad, req[0], ograd * F<OP>(lhs, scalar<DType>(DType(alpha))));
+ MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+ mxnet::op::mxnet_op::Kernel<mxnet::op::mxnet_op::op_with_req<
+ mxnet::op::mxnet_op::backward_grad<OP>, Req>, xpu>::
+ Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
+ inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
+ DType(alpha));
+ });
});
}
};
diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h
index d455b7e..6fbde05 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -274,25 +274,6 @@ class UnaryOp : public OpBase {
}
template<typename xpu, typename op>
- static void KernelCompute(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs) {
- using namespace mshadow;
- using namespace mxnet_op;
- Stream<xpu> *s = ctx.get_stream<xpu>();
- CHECK_EQ(inputs.size(), 1U);
- CHECK_EQ(outputs.size(), 1U);
- if (req[0] != kNullOp) {
- MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
- Kernel<op, xpu>::Launch(s, outputs[0].Size(),
- outputs[0].dptr<DType>(), inputs[0].dptr<DType>());
- });
- }
- }
-
- template<typename xpu, typename op>
static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
@@ -309,25 +290,6 @@ class UnaryOp : public OpBase {
});
}
- template<typename xpu, typename OP>
- static void KernelComputeEx(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- CHECK_EQ(inputs.size(), 1U);
- CHECK_EQ(outputs.size(), 1U);
- const auto in_stype = inputs[0].storage_type();
- const auto out_stype = outputs[0].storage_type();
- if (in_stype == out_stype && (in_stype == kRowSparseStorage || in_stype == kCSRStorage)) {
- if (inputs[0].storage_shape().Size()) {
- MapToFCompute<xpu>(attrs, ctx, inputs, req, outputs, KernelCompute<xpu, OP>);
- }
- } else {
- LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
- }
- }
-
template<typename xpu>
static void IdentityCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -395,13 +357,9 @@ class UnaryOp : public OpBase {
}
};
+/*! \brief Map legacy unary_bwd to backward_grad */
template<typename GRAD_OP>
-struct unary_bwd {
- template<typename DType>
- MSHADOW_XINLINE static DType Map(DType a, DType b) {
- return a * GRAD_OP::Map(b);
- }
-};
+using unary_bwd = ::mxnet::op::mxnet_op::backward_grad<GRAD_OP>;
struct CastParam : public dmlc::Parameter<CastParam> {
// use int for enumeration
@@ -445,37 +403,6 @@ void CastCompute(const nnvm::NodeAttrs& attrs,
});
}
-namespace kernel_launch_op {
-/*! \brief sigmoid unit */
-struct sigmoid {
- template<typename DType>
- MSHADOW_XINLINE static void Map(int i, DType *out,
- const DType *in) {
- out[i] = mshadow_op::sigmoid::Map<DType>(in[i]);
- }
-};
-struct sigmoid_grad {
- template<typename DType>
- MSHADOW_XINLINE static DType Map(DType out_grad, DType in) {
- return out_grad * mshadow_op::sigmoid_grad::Map<DType>(in);
- }
-};
-/*! \brief Rectified Linear Operation */
-struct relu {
- template<typename DType>
- MSHADOW_XINLINE static void Map(int i, DType *out,
- const DType *in) {
- out[i] = mshadow_op::relu::Map<DType>(in[i]);
- }
-};
-struct relu_grad {
- template<typename DType>
- MSHADOW_XINLINE static DType Map(DType out_grad, DType in) {
- return out_grad * mshadow_op::relu_grad::Map<DType>(in);
- }
-};
-} // namespace kernel_launch_op
-
/*! \brief Unary compute */
#define MXNET_OPERATOR_REGISTER_UNARY(__name$) \
NNVM_REGISTER_OP(__name$) \
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index c356c58..916c385 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -83,13 +83,12 @@ The storage type of ``relu`` output depends upon the input storage type:
)code" ADD_FILELINE)
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1, false, true, false>)
-.set_attr<FCompute>("FCompute<cpu>", UnaryOp::KernelCompute<
- cpu, kernel_launch_op::relu>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::KernelComputeEx<
- cpu, kernel_launch_op::relu>)
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::relu>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::ComputeEx<cpu, mshadow_op::relu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_relu"});
-MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_relu, kernel_launch_op::relu_grad);
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_relu,
+ unary_bwd<mshadow_op::relu_grad>);
// sigmoid
MXNET_OPERATOR_REGISTER_UNARY(sigmoid)
@@ -102,11 +101,11 @@ MXNET_ADD_SPARSE_OP_ALIAS(sigmoid)
The storage type of ``sigmoid`` output is always dense
)code" ADD_FILELINE)
-.set_attr<FCompute>("FCompute<cpu>", UnaryOp::KernelCompute<
- cpu, kernel_launch_op::sigmoid>)
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::sigmoid>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"});
-MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid, kernel_launch_op::sigmoid_grad);
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid,
+ unary_bwd<mshadow_op::sigmoid_grad>);
// copy
MXNET_OPERATOR_REGISTER_UNARY(_copy)
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu
index 3f982a2..41eef90 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cu
+++ b/src/operator/tensor/elemwise_unary_op_basic.cu
@@ -26,18 +26,19 @@
namespace mxnet {
namespace op {
NNVM_REGISTER_OP(relu)
-.set_attr<FCompute>("FCompute<gpu>", UnaryOp::KernelCompute<gpu, kernel_launch_op::relu>)
-.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::KernelComputeEx<gpu, kernel_launch_op::relu>);
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::relu>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::ComputeEx<gpu, mshadow_op::relu>);
NNVM_REGISTER_OP(_backward_relu)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, kernel_launch_op::relu_grad>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
+ gpu, unary_bwd<mshadow_op::relu_grad>>);
NNVM_REGISTER_OP(sigmoid)
-.set_attr<FCompute>("FCompute<gpu>", UnaryOp::KernelCompute<gpu, kernel_launch_op::sigmoid>);
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::sigmoid>);
NNVM_REGISTER_OP(_backward_sigmoid)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
- gpu, kernel_launch_op::sigmoid_grad>);
+ gpu, unary_bwd<mshadow_op::sigmoid_grad>>);
// copy
NNVM_REGISTER_OP(_copy)
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 6a27b72..7624f2d 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -44,6 +44,7 @@
#include "./dot-inl.h"
#include "./init_op.h"
#include "./matrix_op-inl.h"
+#include "../../engine/openmp.h"
namespace mxnet {
namespace op {
@@ -657,7 +658,7 @@ inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx,
DType* grad_data = output.data().dptr<DType>();
Kernel<set_zero, cpu>::Launch(s, nnr * row_length, grad_data);
// add the final gradients
- int num_threads = Engine::Get()->num_omp_threads_per_worker();
+ const int num_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
dim_t segment_len = (nnr + num_threads - 1) / num_threads;
Kernel<AddTakeGradRspKernel, cpu>::Launch(s, num_threads, grad_data, prefix_sum,
ograd.dptr<DType>(), row_length,
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index bb6d3c1..c621f6e 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -32,6 +32,7 @@
#include <vector>
#include <string>
#include <limits>
+#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
#include "../mxnet_op.h"
#include "../mshadow_op.h"
@@ -225,7 +226,7 @@ void Fill(mshadow::Stream<xpu> *s, const TBlob& b, const OpReqType req, ValueTyp
// Optimize common use-case of filling with ones
MSHADOW_TYPE_SWITCH(b.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
- mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::set_to_int<1>, Req>, xpu>::Launch(
+ mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::set_one, Req>, xpu>::Launch(
s, b.Size(), b.dptr<DType>());
});
});
diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h
index 21d0776..c454c95 100644
--- a/tests/cpp/include/test_core_op.h
+++ b/tests/cpp/include/test_core_op.h
@@ -33,7 +33,24 @@ namespace op {
// Tried making this a struct w/constexpr, but getting undefined reference on gcc 5.4.1
#define COREOP_FWD_OP_NAME_KEY "fwd_op_name"
#define COREOP_BWD_OP_NAME_KEY "bwd_op_name"
-#define COREOP_BWD_OP_NAME_VALUE_NONE "<none>"
+#define COREOP_BWD_OP_NAME_VALUE_NONE "[none]"
+
+enum TimingDirection {
+ Forward,
+ Backward
+};
+
+inline const char *TimingDirectionAsString(const TimingDirection td) {
+ switch (td) {
+ case Forward:
+ return "Forward";
+ case Backward:
+ return "Backward";
+ default:
+ CHECK(false) << "Unknown timing direction: " << static_cast<int>(td);
+ return "<unknown>";
+ }
+}
/*!
* Low-noise operator executor
@@ -43,11 +60,6 @@ template<typename DType>
class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
, public test::op::OperatorExecutorTiming {
/*! \brief Performance timing categories */
- enum TimingId {
- Forward,
- Backward
- };
-
/*!
* \brief Access data blob as if on the CPU via a callback
* \tparam Type of callback Function to call with CPU-data NDArray
@@ -92,8 +104,8 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
values.reserve(count);
for (kwargs_t::const_iterator i_iter = args.begin(), e_iter = args.end();
i_iter != e_iter; ++i_iter) {
- keys.push_back(i_iter->first.c_str());
- values.push_back(i_iter->second.c_str());
+ keys.emplace_back(i_iter->first.c_str());
+ values.emplace_back(i_iter->second.c_str());
}
return imperative::ParseAttrs(op, op->num_inputs, count, &keys[0], &values[0]);
}
@@ -108,7 +120,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
std::vector<TBlob> *dest) {
dest->reserve(dest->size() + src.size());
for (size_t i = 0, n = src.size(); i < n; ++i) {
- dest->push_back(src[i].data());
+ dest->emplace_back(src[i].data());
}
return *dest;
}
@@ -194,9 +206,9 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
for (const ResourceRequest& req : reqs) {
if (req.type == ResourceRequest::kTempSpace) {
Resource r = ResourceManager::Get()->Request(ctx->run_ctx.ctx, req);
- requested.push_back(r);
+ requested.emplace_back(r);
} else if (req.type == ResourceRequest::kRandom) {
- requested.push_back(ResourceManager::Get()->Request(ctx->run_ctx.ctx, req));
+ requested.emplace_back(ResourceManager::Get()->Request(ctx->run_ctx.ctx, req));
} else {
LOG(FATAL) << "resource type not yet supported";
}
@@ -216,7 +228,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
new_args.reserve(args.size() + 1);
for (const auto& a : args) {
if (a.first != COREOP_FWD_OP_NAME_KEY && a.first != COREOP_BWD_OP_NAME_KEY) {
- new_args.push_back(a);
+ new_args.emplace_back(a);
}
}
new_args.push_back({ COREOP_FWD_OP_NAME_KEY, fwd_op_name});
@@ -241,7 +253,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
} else if (a.first == COREOP_BWD_OP_NAME_KEY) {
*bwd_op_name_ptr = a.second;
} else {
- new_args.push_back(a);
+ new_args.emplace_back(a);
}
}
return new_args;
@@ -317,7 +329,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
// operators such as dot
std::vector<TShape> shapes;
for (size_t i = 0, n = std::max(num_visible_outputs, num_inputs); i < n; ++i) {
- shapes.push_back(i < input_shapes_.size() ? input_shapes_[i]
+ shapes.emplace_back(i < input_shapes_.size() ? input_shapes_[i]
: input_shapes_[input_shapes_.size() - 1]);
}
std::vector<NDArray *> inputs_p, outputs_p;
@@ -331,21 +343,21 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
outputs_.reserve(num_visible_outputs);
outputs_p.reserve(num_visible_outputs);
- for (int i = 0; i < num_inputs; ++i) {
+ for (size_t i = 0; i < static_cast<size_t>(num_inputs); ++i) {
CHECK_LT(i, static_cast<int>(shapes.size()));
- inputs_.push_back(i < inputs.size() ? inputs[i] : CreateRandArray(shapes[i],
+ inputs_.emplace_back(i < inputs.size() ? inputs[i] : CreateRandArray(shapes[i],
ctx_.run_ctx.ctx));
- inputs_p.push_back(&*inputs_.rbegin());
+ inputs_p.emplace_back(&*inputs_.rbegin());
}
- for (int i = 0; i < num_visible_outputs; ++i) {
+ for (size_t i = 0; i < static_cast<size_t>(num_visible_outputs); ++i) {
// If supplied and valid, pass from the supplied outputs vector
// Otherwise use empty for forward pass, or zero-filled for backward pass
- outputs_.push_back(i < outputs.size()
- ? outputs[i]
- : (backward_for_op ? CreateZeroArray(shapes[i], ctx_.run_ctx.ctx)
- : NDArray()));
- outputs_p.push_back(&*outputs_.rbegin());
+ outputs_.emplace_back(i < outputs.size()
+ ? outputs[i]
+ : (backward_for_op ? CreateZeroArray(shapes[i], ctx_.run_ctx.ctx)
+ : NDArray()));
+ outputs_p.emplace_back(&*outputs_.rbegin());
}
if (!backward_for_op) {
@@ -396,7 +408,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
<< "Can't automatically determine backward op name. Please specify";
for (std::pair<std::shared_ptr<CoreOpExecutor>, std::string> &bw_item : bwd) {
bw_item.first->set_verbose(verbose_);
- backward_.push_back(bw_item.first);
+ backward_.emplace_back(bw_item.first);
bw_item.first->Init(ArgsWithOpName(args, bw_item.second), {}, {}, this);
}
}
diff --git a/tests/cpp/include/test_legacy_op.h b/tests/cpp/include/test_legacy_op.h
index 30bdf07..6d326fc 100644
--- a/tests/cpp/include/test_legacy_op.h
+++ b/tests/cpp/include/test_legacy_op.h
@@ -135,7 +135,7 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
// Get the resource of temporal space
std::vector<TShape> inputShapes;
for (size_t x = 0, n = shape_input_vec_.size(); x < n; ++x) {
- inputShapes.push_back(shape_input_vec_[x]);
+ inputShapes.emplace_back(shape_input_vec_[x]);
}
allocateResources(opProp.ForwardResource(inputShapes));
@@ -408,11 +408,11 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
std::vector<std::vector<TBlob> *> all_blob_vects_;
inline OpData() {
- all_blob_vects_.push_back(&blob_input_vec_);
- all_blob_vects_.push_back(&blob_output_vec_);
- all_blob_vects_.push_back(&blob_aux_states_);
- all_blob_vects_.push_back(&blob_in_grad_);
- all_blob_vects_.push_back(&blob_out_grad_); // Remaining err (loss) pushing back upstream
+ all_blob_vects_.emplace_back(&blob_input_vec_);
+ all_blob_vects_.emplace_back(&blob_output_vec_);
+ all_blob_vects_.emplace_back(&blob_aux_states_);
+ all_blob_vects_.emplace_back(&blob_in_grad_);
+ all_blob_vects_.emplace_back(&blob_out_grad_); // Remaining err (loss) pushing back upstream
}
virtual ~OpData() {}
};
@@ -495,14 +495,14 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
for (const ResourceRequest& req : reqs) {
if (req.type == ResourceRequest::kTempSpace) {
if (cached_temp.count(ctx) != 0) {
- opContext_.requested.push_back(cached_temp.at(ctx));
+ opContext_.requested.emplace_back(cached_temp.at(ctx));
} else {
Resource r = ResourceManager::Get()->Request(ctx, req);
- opContext_.requested.push_back(r);
+ opContext_.requested.emplace_back(r);
cached_temp[ctx] = r;
}
} else if (req.type == ResourceRequest::kRandom) {
- opContext_.requested.push_back(ResourceManager::Get()->Request(ctx, req));
+ opContext_.requested.emplace_back(ResourceManager::Get()->Request(ctx, req));
} else {
LOG(FATAL) << "resource type not yet supported";
}
@@ -517,8 +517,8 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
const int dtype) {
test::StandaloneBlob *blob = new test::StandaloneBlob(shape, isGPU, dtype);
CHECK_NE(blob, static_cast<TBlob *>(nullptr));
- standalone_blobs->push_back(std::unique_ptr<test::StandaloneBlob>(blob));
- (*dest).push_back(*blob);
+ standalone_blobs->emplace_back(std::unique_ptr<test::StandaloneBlob>(blob));
+ (*dest).emplace_back(*blob);
return blob;
}
diff --git a/tests/cpp/include/test_ndarray_utils.h b/tests/cpp/include/test_ndarray_utils.h
index bbc7c05..f5ab967 100644
--- a/tests/cpp/include/test_ndarray_utils.h
+++ b/tests/cpp/include/test_ndarray_utils.h
@@ -80,7 +80,7 @@ inline NDArray DnsND(const TShape shape, const Context ctx, std::vector<TEST_DTY
// generate random values
while (vs.size() < num_val) {
auto v = RandFloat();
- vs.push_back(v);
+ vs.emplace_back(v);
}
CHECK_EQ(vs.size(), nd.shape().Size());
MSHADOW_TYPE_SWITCH(nd.dtype(), DType, {
diff --git a/tests/cpp/include/test_op.h b/tests/cpp/include/test_op.h
index 949f2cc..cbafe14 100644
--- a/tests/cpp/include/test_op.h
+++ b/tests/cpp/include/test_op.h
@@ -100,17 +100,27 @@ class OperatorDataInitializer {
* \param blob Blob which to fill with random values
*/
void FillRandom(const TBlob& blob) const {
- std::uniform_real_distribution<DType> distribution(-1.0, 1.0);
- test::patternFill<DType>(&blob, [this, &distribution]() -> DType {
- return distribution(this->generator());
+ std::uniform_real_distribution<> dis_real(-5.0, 5.0);
+ std::uniform_int_distribution<> dis_int(-128, 127);
+ test::patternFill<DType>(&blob, [this, &dis_real, &dis_int]() -> DType {
+ if (!std::is_integral<DType>::value) {
+ DType val;
+ do {
+ val = static_cast<DType>(dis_real(this->generator()));
+ } while (fabs(val) < 1e-5); // If too close to zero, try again
+ return val;
+ } else {
+ DType val;
+ do {
+ val = static_cast<DType>(dis_int(this->generator()));
+ } while (!val); // If zero, try again
+ return val;
+ }
});
}
void FillZero(const TBlob& blob) const {
- std::uniform_real_distribution<DType> distribution(-1.0, 1.0);
- test::patternFill<DType>(&blob, [this, &distribution]() -> DType {
- return DType(0);
- });
+ test::patternFill<DType>(&blob, []() -> DType { return DType(0); });
}
private:
@@ -271,7 +281,7 @@ inline std::vector<TShape> ShapesOf(const std::vector<NDArray>& arrays) {
std::vector<TShape> res;
res.reserve(arrays.size());
for (const NDArray& ar : arrays) {
- res.push_back(ar.shape());
+ res.emplace_back(ar.shape());
}
return std::move(res);
}
diff --git a/tests/cpp/include/test_op_runner.h b/tests/cpp/include/test_op_runner.h
index 4c7cd1d..eb25999 100644
--- a/tests/cpp/include/test_op_runner.h
+++ b/tests/cpp/include/test_op_runner.h
@@ -122,15 +122,14 @@ class OperatorRunner {
* \param dim Data dimensions
* \param count Number of times to run in each direction
*/
- void TimingTest(const std::string& label,
- const bool isGPU,
- const bool stochastic,
- const test::op::kwargs_t& kwargs,
- int dim = 0,
- size_t count = 1,
- const std::vector<TShape>& timing_shapes = {}) {
- std::cout << std::endl << std::flush;
-
+ std::unordered_map<int, perf::TimingInstrument::Info>
+ TimingTest(const std::string& label,
+ const bool isGPU,
+ const bool stochastic,
+ const test::op::kwargs_t& kwargs,
+ int dim = 0,
+ size_t count = 1,
+ const std::vector<TShape>& timing_shapes = {}) {
#ifdef NDEBUG
size_t COUNT = 50;
#else
@@ -160,7 +159,7 @@ class OperatorRunner {
if (timing_shapes.empty()) {
do {
- batchSize = stochastic ? test::rangedRand(1U, TES_BATCH_SIZE * 2U) : TIMING_BATCH_SIZE;
+ batchSize = stochastic ? test::rangedRand(1U, TEST_BATCH_SIZE * 2U) : TIMING_BATCH_SIZE;
channels = stochastic ? test::rangedRand(1U, TEST_CHANNELS * 2U) : TIMING_CHANNELS;
depth = stochastic ? test::rangedRand(1U, TEST_DEPTH * 2U) : TIMING_DEPTH;
height = stochastic ? test::rangedRand(1U, TEST_DH * 2U) : TIMING_DH;
@@ -218,12 +217,18 @@ class OperatorRunner {
}
}
- timing.print(&std::cout, label);
- std::cout << std::endl << std::flush;
+ if (verbose_) {
+ timing.print(&std::cout, label);
+ std::cout << std::endl << std::flush;
+ }
+
+ return timing.data();
}
+ void set_verbose(bool verbose) { verbose_ = verbose; }
+
protected:
- static constexpr int TES_BATCH_SIZE = 5;
+ static constexpr int TEST_BATCH_SIZE = 5;
static constexpr int TEST_CHANNELS = 3;
static constexpr int TEST_DEPTH = 2;
static constexpr int TEST_DH = 2;
@@ -234,6 +239,8 @@ class OperatorRunner {
static constexpr int TIMING_DEPTH = 2;
static constexpr int TIMING_DH = 64;
static constexpr int TIMING_DW = 64;
+ /*! \brief verbose output */
+ bool verbose_ = true;
};
} // namespace test
diff --git a/tests/cpp/include/test_perf.h b/tests/cpp/include/test_perf.h
index b6f2145..7971ed7 100644
--- a/tests/cpp/include/test_perf.h
+++ b/tests/cpp/include/test_perf.h
@@ -45,7 +45,7 @@ namespace perf {
inline uint64_t getMicroTickCount() {
#ifndef _WIN32
struct timeval tv;
- gettimeofday(&tv, NULL);
+ gettimeofday(&tv, nullptr);
return uint64_t(tv.tv_sec) * 1000000 + tv.tv_usec;
#else
LARGE_INTEGER CurrentTime;
@@ -79,11 +79,6 @@ inline uint64_t getNannoTickCount() {
#endif
}
-/*! \brief millisecond tick count */
-inline uint64_t getTickCount() {
- return getMicroTickCount() / 1000;
-}
-
#define MICRO2MS(__micro$) (((__micro$) + 500)/1000)
#define MICRO2MSF(__micro$) (static_cast<float>(__micro$)/1000)
#define MICRO2MSF(__micro$) (static_cast<float>(__micro$)/1000)
@@ -100,7 +95,7 @@ class TimedScope {
const size_t count_;
public:
- explicit inline TimedScope(const char *msg = NULL, size_t count = 1, const bool start = true)
+ explicit inline TimedScope(const char *msg = nullptr, size_t count = 1, const bool start = true)
: startTime_(start ? getMicroTickCount() : 0)
, stopTime_(0)
, count_(count) {
@@ -164,7 +159,7 @@ class TimingInstrument {
}
void startTiming(int id, const char *s) {
std::unique_lock<std::recursive_mutex> lk(mutex_);
- std::unordered_map<int, Info>::iterator i = data_.find(id);
+ auto i = data_.find(id);
if (i == data_.end()) {
i = data_.emplace(std::make_pair(id, Info(s))).first;
}
@@ -174,7 +169,7 @@ class TimingInstrument {
}
void stopTiming(int id, const size_t subIterationCount = 1) {
std::unique_lock<std::recursive_mutex> lk(mutex_);
- std::unordered_map<int, Info>::iterator i = data_.find(id);
+ auto i = data_.find(id);
CHECK_NE(i == data_.end(), true) << "Can't stop timing on an object that we don't know about";
if (i != data_.end()) {
CHECK_NE(i->second.nestingCount_, 0U) << "While stopping timing, invalid nesting count of 0";
@@ -188,7 +183,7 @@ class TimingInstrument {
}
uint64_t getDuration(int id) {
std::unique_lock<std::recursive_mutex> lk(mutex_);
- std::unordered_map<int, Info>::iterator i = data_.find(id);
+ auto i = data_.find(id);
if (i != data_.end()) {
const Info& info = i->second;
const uint64_t duration = info.nestingCount_.load()
@@ -202,7 +197,7 @@ class TimingInstrument {
bool isTiming(int id) {
std::unordered_map<int, Info>::const_iterator i = data_.find(id);
if (i != data_.end()) {
- return !!i->second.nestingCount_.load();
+ return i->second.nestingCount_.load() != 0;
}
return false;
}
@@ -216,7 +211,7 @@ class TimingInstrument {
i != e; ++i) {
const Info& info = i->second;
const uint64_t duration = getDuration(i->first);
- *os << /*std::endl <<*/ label_ << ": " << name_ << " Timing [" << info.name_ << "] "
+ *os << label_ << ": " << name_ << " Timing [" << info.name_ << "] "
<< (info.nestingCount_.load() ? "*" : "")
<< MICRO2MSF(duration) << " ms";
if (info.cycleCount_.load()) {
@@ -233,7 +228,7 @@ class TimingInstrument {
void reset() {
std::unique_lock<std::recursive_mutex> lk(mutex_);
- for (std::unordered_map<int, Info>::iterator i = data_.begin(), e = data_.end();
+ for (auto i = data_.begin(), e = data_.end();
i != e; ++i) {
const int id = i->first;
const bool wasTiming = isTiming(id);
@@ -250,9 +245,9 @@ class TimingInstrument {
}
TimingInstrument& operator += (const TimingInstrument& o) {
- for (std::unordered_map<int, Info>::const_iterator i = o.data_.begin(), e = o.data_.end();
+ for (auto i = o.data_.begin(), e = o.data_.end();
i != e; ++i) {
- std::unordered_map<int, Info>::iterator j = data_.find(i->first);
+ auto j = data_.find(i->first);
if (j != data_.end()) {
const Info &oInfo = i->second;
CHECK_EQ(oInfo.nestingCount_, 0U);
@@ -265,7 +260,6 @@ class TimingInstrument {
return *this;
}
- private:
struct Info {
explicit inline Info(const char *s)
: name_(s ? s : "")
@@ -273,6 +267,7 @@ class TimingInstrument {
, nestingCount_(0)
, cycleCount_(0)
, duration_(0) {}
+
inline Info(const Info& o)
: name_(o.name_)
, baseTime_(o.baseTime_.load())
@@ -281,17 +276,36 @@ class TimingInstrument {
, duration_(o.duration_.load()) {
CHECK_EQ(o.nestingCount_, 0U);
}
+
+ /*!
+ * \brief Return time for each operation in milliseconds
+ * \return Time for each operation in milliseconds
+ */
+ inline double TimeEach() const {
+ return static_cast<double>(duration_) / cycleCount_.load() / 1000.0f;
+ }
+
std::string name_;
std::atomic<uint64_t> baseTime_;
std::atomic<uint64_t> nestingCount_;
std::atomic<uint64_t> cycleCount_; // Note that nesting may skew averages
std::atomic<uint64_t> duration_;
};
+
+ typedef std::unordered_map<int, TimingInstrument::Info> timing_map_t;
+
+ const timing_map_t& data() const {
+ return data_;
+ }
+
+ private:
std::string name_;
mutable std::recursive_mutex mutex_;
std::unordered_map<int, Info> data_;
};
+using timing_map_t = TimingInstrument::timing_map_t;
+
/*! \brief Accumulated scoped timing, indexed by ID */
class TimingItem {
public:
diff --git a/tests/cpp/include/test_util.h b/tests/cpp/include/test_util.h
index 95ab141..33ca3c4 100644
--- a/tests/cpp/include/test_util.h
+++ b/tests/cpp/include/test_util.h
@@ -609,6 +609,67 @@ inline ScalarType rangedRand(const ScalarType min, const ScalarType max) {
return static_cast<ScalarType>(x / bin_size + min);
}
+/*!
+ * \brief Deterministically compare TShape objects as less-than,
+ * for use in stl sorted key such as map and set
+ * \param s1 First shape
+ * \param s2 Second shape
+ * \return true if s1 is less than s2
+ */
+inline bool operator < (const nnvm::TShape &s1, const nnvm::TShape &s2) {
+ if (s1.Size() == s2.Size()) {
+ if (s1.ndim() == s2.ndim()) {
+ for (size_t i = 0, n = s1.ndim(); i < n; ++i) {
+ if (s1[i] == s2[i]) {
+ continue;
+ }
+ return s1[i] < s2[i];
+ }
+ return false;
+ }
+ return s1.ndim() < s2.ndim();
+ }
+ return s1.Size() < s2.Size();
+}
+
+/*!
+ * \brief Deterministically compare a vector of TShape objects as less-than,
+ * for use in stl sorted key such as map and set
+ * \param v1 First vector of shapes
+ * \param v2 Second vector of shapes
+ * \return true if v1 is less than v2
+ */
+inline bool operator < (const std::vector<nnvm::TShape>& v1, const std::vector<nnvm::TShape>& v2) {
+ if (v1.size() == v2.size()) {
+ for (size_t i = 0, n = v1.size(); i < n; ++i) {
+ if (v1[i] == v2[i]) {
+ continue;
+ }
+ return v1[i] < v2[i];
+ }
+ return false;
+ }
+ return v1.size() < v2.size();
+}
+
+/*!
+ * \brief std::less compare structure for compating vectors of shapes for stl sorted containers
+ */
+struct less_shapevect {
+ bool operator()(const std::vector<nnvm::TShape>& v1, const std::vector<nnvm::TShape>& v2) const {
+ if (v1.size() == v2.size()) {
+ for (size_t i = 0, n = v1.size(); i < n; ++i) {
+ if (v1[i] == v2[i]) {
+ continue;
+ }
+ return v1[i] < v2[i];
+ }
+ return false;
+ }
+ return v1.size() < v2.size();
+ }
+};
+
inline std::string pretty_num(uint64_t val) {
std::string res, s = std::to_string(val);
size_t ctr = 0;
diff --git a/tests/cpp/operator/core_op_runner_test.cc b/tests/cpp/operator/runner/core_op_runner_test.cc
similarity index 100%
rename from tests/cpp/operator/core_op_runner_test.cc
rename to tests/cpp/operator/runner/core_op_runner_test.cc
diff --git a/tests/cpp/operator/slice_channel_perf.cc b/tests/cpp/operator/slice_channel_perf.cc
new file mode 100644
index 0000000..dc42d2a
--- /dev/null
+++ b/tests/cpp/operator/slice_channel_perf.cc
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file activation_perf.cc
+ * \brief Perf/profile run of ActivationOp
+ * \author Chris Olivier
+ */
+
+#include <gtest/gtest.h>
+#include <mxnet/tensor_blob.h>
+#include "../include/test_op_runner.h"
+#include "../include/test_legacy_op.h"
+#include "../../src/operator/slice_channel-inl.h"
+
+using namespace mxnet;
+
+typedef std::vector<std::pair<std::string, std::string> > kwargs_t;
+const kwargs_t basic_activation_args = { };
+
+/*!
+ * \brief Generic bidirectional sanity test
+ */
+TEST(SLICE_CHANNEL_PERF, ExecuteBidirectional) {
+ TShape shape({1, 160, 200});
+ kwargs_t kwargs = basic_activation_args;
+ kwargs.push_back({"num_outputs", "160"});
+ test::op::LegacyOpRunner<mxnet::op::SliceChannelProp, float, float> runner;
+ runner.RunBidirectional(false, { shape }, kwargs, 1);
+}
+
+/*!
+ * \brief ActivationOp timing test for CPU
+ */
+TEST(SLICE_CHANNEL_PERF, TimingCPU) {
+ kwargs_t kwargs = basic_activation_args;
+ // Which math function is arbitrary since it will have roughly constant timing among approaches
+ kwargs.push_back({"num_outputs", "160"});
+ test::op::LegacyOpRunner<mxnet::op::SliceChannelProp, float, float> runner;
+ runner.RunBidirectional(false,
+ { TShape({1, 160, 200}) },
+ kwargs, 1); // prime code and cache
+ std::vector <TShape> shapes;
+ if (test::performance_run) {
+ shapes = {
+ {1, 160, 200},
+ {10, 160, 200},
+ {100, 160, 200},
+ {10, 160, 500},
+ {100, 160, 500}
+ };
+ } else {
+ shapes = {
+ {1, 160, 200},
+ {1, 160, 200}
+ };
+ }
+ for (const TShape &shape : shapes) {
+ runner.TimingTest("SliceChannel Operator CPU", false, false, kwargs, 2, 10, { shape });
+ }
+}
+
+#if MXNET_USE_CUDA == 1
+/*!
+ * \brief ActivationOp timing test for GPU
+ */
+TEST(SLICE_CHANNEL_PERF, TimingGPU) {
+ kwargs_t kwargs = basic_activation_args;
+ // Which math function is arbitrary since it will have roughly constant timing among approaches
+ kwargs.push_back({"num_outputs", "160"});
+ test::OperatorRunner<mxnet::op::SliceChannelProp,
+ test::op::LegacyOperatorExecutor<float, float>> runner;
+ runner.RunBidirectional(true,
+ { TShape({1, 160, 200}) },
+ kwargs, 1); // prime code and cache
+ std::vector <TShape> shapes = {
+ {1, 160, 200},
+ {1, 160, 200},
+ {1, 160, 200},
+ {1, 160, 200},
+ {1, 160, 200}
+ };
+ for (const TShape &shape : shapes) {
+ runner.TimingTest("SliceChannel Operator GPU", true, false, kwargs, 2, 10, { shape });
+ }
+}
+#endif // MXNET_USE_CUDA == 1
+
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].