You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/11 19:58:22 UTC

[GitHub] piiswrong closed pull request #8566: optimize broadcast

piiswrong closed pull request #8566: optimize broadcast
URL: https://github.com/apache/incubator-mxnet/pull/8566
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 5b8e109d7d..564ad81440 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -222,6 +222,37 @@ MSHADOW_XINLINE Shape<ndim> calc_stride(const Shape<ndim>& shape) {
   return stride;
 }
 
+/* Increment coordinates and modify index */
+template<int ndim>
+MSHADOW_XINLINE void inc(Shape<ndim>* coord, const Shape<ndim>& shape,
+                         index_t* idx, const Shape<ndim>& stride) {
+  ++(*coord)[ndim-1];
+  *idx += stride[ndim-1];
+  #pragma unroll
+  for (int i = ndim - 1; i > 0 && (*coord)[i] >= shape[i]; --i) {
+    (*coord)[i] -= shape[i];
+    ++(*coord)[i-1];
+    *idx = *idx + stride[i-1] - shape[i] * stride[i];
+  }
+}
+
+/* Increment coordinates and modify index */
+template<int ndim>
+MSHADOW_XINLINE void inc(Shape<ndim>* coord, const Shape<ndim>& shape,
+                         index_t* idx1, const Shape<ndim>& stride1,
+                         index_t* idx2, const Shape<ndim>& stride2) {
+  ++(*coord)[ndim-1];
+  *idx1 += stride1[ndim-1];
+  *idx2 += stride2[ndim-1];
+  #pragma unroll
+  for (int i = ndim - 1; i > 0 && (*coord)[i] >= shape[i]; --i) {
+    (*coord)[i] -= shape[i];
+    ++(*coord)[i-1];
+    *idx1 = *idx1 + stride1[i-1] - shape[i] * stride1[i];
+    *idx2 = *idx2 + stride2[i-1] - shape[i] * stride2[i];
+  }
+}
+
 /*!
  * \brief Simple copy data from one blob to another
  * \param to Destination blob
@@ -357,6 +388,24 @@ struct Kernel<OP, cpu> {
     }
 #endif
   }
+
+  template<typename ...Args>
+  inline static void LaunchEx(mshadow::Stream<cpu> *s, const int N, Args... args) {
+#ifdef _OPENMP
+    const int omp_cores = Engine::Get()->num_omp_threads_per_worker();
+    if (omp_cores <= 1) {
+      OP::Map(0, N, args...);
+    } else {
+      int length = (N + omp_cores - 1) / omp_cores;
+      #pragma omp parallel for num_threads(omp_cores)
+      for (int i = 0; i < N; i += length) {
+        OP::Map(i, i + length > N ? N - i : length, args...);
+      }
+    }
+#else
+    OP::Map(0, N, args...);
+#endif
+  }
 };
 
 
@@ -368,6 +417,13 @@ __global__ void mxnet_generic_kernel(int N, Args... args) {
   }
 }
 
+template<typename OP, typename ...Args>
+__global__ void mxnet_generic_kernel_ex(int N, Args... args) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
+    OP::Map(i, 1, args...);
+  }
+}
+
 template<typename OP>
 struct Kernel<OP, gpu> {
   template<typename ...Args>
@@ -378,6 +434,15 @@ struct Kernel<OP, gpu> {
       <<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
         N, args...);
   }
+
+  template<typename ...Args>
+  inline static void LaunchEx(mshadow::Stream<gpu> *s, int N, Args... args) {
+    using namespace mshadow::cuda;
+    int ngrid = std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum);
+    mxnet_generic_kernel_ex<OP, Args...>
+      <<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+        N, args...);
+  }
 };
 #endif  // __CUDACC__
 
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 7aae9cc87b..1aab714625 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -133,13 +133,34 @@ inline int BinaryBroadcastShapeCompact(const TShape& lshape, const TShape& rshap
   return j;
 }
 
+namespace mxnet_op {
+template<int ndim, typename DType, typename OP>
+struct binary_broadcast_kernel {
+  MSHADOW_XINLINE static void Map(int base, int length, OpReqType req,
+                                  const Shape<ndim>& lstride, const Shape<ndim>& rstride,
+                                  const Shape<ndim>& oshape, DType* lhs, DType* rhs,
+                                  DType* out, int lsize, int rsize) {
+    Shape<ndim> coord = unravel(base, oshape);
+    index_t lidx = dot(coord, lstride);
+    index_t ridx = dot(coord, rstride);
+    KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx]));
+    // starts from 1 to avoid extra inc at end of loop
+    for (int i = 1; i < length; ++i) {
+      inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
+      KERNEL_ASSIGN(out[base+i], req, OP::Map(lhs[lidx], rhs[ridx]));
+    }
+  }
+};
+
+}  // namespace mxnet_op
+
 template<typename xpu, typename OP>
 void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
                             const OpContext& ctx,
                             const std::vector<TBlob>& inputs,
                             const std::vector<OpReqType>& req,
                             const std::vector<TBlob>& outputs) {
-  using namespace broadcast;
+  using namespace mxnet_op;
   TShape new_lshape, new_rshape, new_oshape;
   int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_,
                                          &new_lshape, &new_rshape, &new_oshape);
@@ -149,8 +170,13 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       BROADCAST_NDIM_SWITCH(ndim, NDim, {
-        BinaryBroadcastComputeImpl<NDim, DType, OP>(s, req[0], inputs[0].reshape(new_lshape),
-          inputs[1].reshape(new_rshape), outputs[0].reshape(new_oshape));
+        Shape<NDim> oshape = new_oshape.get<NDim>();
+        Shape<NDim> lstride = calc_stride(new_lshape.get<NDim>());
+        Shape<NDim> rstride = calc_stride(new_rshape.get<NDim>());
+        Kernel<binary_broadcast_kernel<NDim, DType, OP>, xpu>::LaunchEx(
+            s, new_oshape.Size(), req[0], lstride, rstride, oshape,
+            inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), outputs[0].dptr<DType>(),
+            inputs[0].Size(), inputs[1].Size());
       });
     });
   }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services