You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/11 19:58:21 UTC
[incubator-mxnet] branch master updated: optimize broadcast (#8566)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 23a9294 optimize broadcast (#8566)
23a9294 is described below
commit 23a929451ef8f05dfbd49bcd7fe0381115546e8b
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Sat Nov 11 11:58:18 2017 -0800
optimize broadcast (#8566)
* optimize broadcast
* Update elemwise_binary_broadcast_op.h
---
src/operator/mxnet_op.h | 65 ++++++++++++++++++++++
src/operator/tensor/elemwise_binary_broadcast_op.h | 32 ++++++++++-
2 files changed, 94 insertions(+), 3 deletions(-)
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 5b8e109..564ad81 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -222,6 +222,37 @@ MSHADOW_XINLINE Shape<ndim> calc_stride(const Shape<ndim>& shape) {
return stride;
}
+/* Increment coordinates and modify index */
+template<int ndim>
+MSHADOW_XINLINE void inc(Shape<ndim>* coord, const Shape<ndim>& shape,
+ index_t* idx, const Shape<ndim>& stride) {
+ ++(*coord)[ndim-1];
+ *idx += stride[ndim-1];
+ #pragma unroll
+ for (int i = ndim - 1; i > 0 && (*coord)[i] >= shape[i]; --i) {
+ (*coord)[i] -= shape[i];
+ ++(*coord)[i-1];
+ *idx = *idx + stride[i-1] - shape[i] * stride[i];
+ }
+}
+
+/* Increment coordinates and modify index */
+template<int ndim>
+MSHADOW_XINLINE void inc(Shape<ndim>* coord, const Shape<ndim>& shape,
+ index_t* idx1, const Shape<ndim>& stride1,
+ index_t* idx2, const Shape<ndim>& stride2) {
+ ++(*coord)[ndim-1];
+ *idx1 += stride1[ndim-1];
+ *idx2 += stride2[ndim-1];
+ #pragma unroll
+ for (int i = ndim - 1; i > 0 && (*coord)[i] >= shape[i]; --i) {
+ (*coord)[i] -= shape[i];
+ ++(*coord)[i-1];
+ *idx1 = *idx1 + stride1[i-1] - shape[i] * stride1[i];
+ *idx2 = *idx2 + stride2[i-1] - shape[i] * stride2[i];
+ }
+}
+
/*!
* \brief Simple copy data from one blob to another
* \param to Destination blob
@@ -357,6 +388,24 @@ struct Kernel<OP, cpu> {
}
#endif
}
+
+ template<typename ...Args>
+ inline static void LaunchEx(mshadow::Stream<cpu> *s, const int N, Args... args) {
+#ifdef _OPENMP
+ const int omp_cores = Engine::Get()->num_omp_threads_per_worker();
+ if (omp_cores <= 1) {
+ OP::Map(0, N, args...);
+ } else {
+ int length = (N + omp_cores - 1) / omp_cores;
+ #pragma omp parallel for num_threads(omp_cores)
+ for (int i = 0; i < N; i += length) {
+ OP::Map(i, i + length > N ? N - i : length, args...);
+ }
+ }
+#else
+ OP::Map(0, N, args...);
+#endif
+ }
};
@@ -368,6 +417,13 @@ __global__ void mxnet_generic_kernel(int N, Args... args) {
}
}
+template<typename OP, typename ...Args>
+__global__ void mxnet_generic_kernel_ex(int N, Args... args) {
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
+ OP::Map(i, 1, args...);
+ }
+}
+
template<typename OP>
struct Kernel<OP, gpu> {
template<typename ...Args>
@@ -378,6 +434,15 @@ struct Kernel<OP, gpu> {
<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
N, args...);
}
+
+ template<typename ...Args>
+ inline static void LaunchEx(mshadow::Stream<gpu> *s, int N, Args... args) {
+ using namespace mshadow::cuda;
+ int ngrid = std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum);
+ mxnet_generic_kernel_ex<OP, Args...>
+ <<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+ N, args...);
+ }
};
#endif // __CUDACC__
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 7aae9cc..1aab714 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -133,13 +133,34 @@ inline int BinaryBroadcastShapeCompact(const TShape& lshape, const TShape& rshap
return j;
}
+namespace mxnet_op {
+template<int ndim, typename DType, typename OP>
+struct binary_broadcast_kernel {
+ MSHADOW_XINLINE static void Map(int base, int length, OpReqType req,
+ const Shape<ndim>& lstride, const Shape<ndim>& rstride,
+ const Shape<ndim>& oshape, DType* lhs, DType* rhs,
+ DType* out, int lsize, int rsize) {
+ Shape<ndim> coord = unravel(base, oshape);
+ index_t lidx = dot(coord, lstride);
+ index_t ridx = dot(coord, rstride);
+ KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx]));
+ // starts from 1 to avoid extra inc at end of loop
+ for (int i = 1; i < length; ++i) {
+ inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
+ KERNEL_ASSIGN(out[base+i], req, OP::Map(lhs[lidx], rhs[ridx]));
+ }
+ }
+};
+
+} // namespace mxnet_op
+
template<typename xpu, typename OP>
void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
- using namespace broadcast;
+ using namespace mxnet_op;
TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_,
&new_lshape, &new_rshape, &new_oshape);
@@ -149,8 +170,13 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
- BinaryBroadcastComputeImpl<NDim, DType, OP>(s, req[0], inputs[0].reshape(new_lshape),
- inputs[1].reshape(new_rshape), outputs[0].reshape(new_oshape));
+ Shape<NDim> oshape = new_oshape.get<NDim>();
+ Shape<NDim> lstride = calc_stride(new_lshape.get<NDim>());
+ Shape<NDim> rstride = calc_stride(new_rshape.get<NDim>());
+ Kernel<binary_broadcast_kernel<NDim, DType, OP>, xpu>::LaunchEx(
+ s, new_oshape.Size(), req[0], lstride, rstride, oshape,
+ inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), outputs[0].dptr<DType>(),
+ inputs[0].Size(), inputs[1].Size());
});
});
}
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].