You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2020/07/28 23:59:22 UTC

[incubator-mxnet] branch v1.x updated: Back port optimization to broadcast_axis to MXNet1.x (#18773)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 7bef9cb  Back port optimization to broadcast_axis to MXNet1.x (#18773)
7bef9cb is described below

commit 7bef9cb23b72c3b5b93c10d87e09db19f442d12e
Author: Rohit Kumar Srivastava <sr...@osu.edu>
AuthorDate: Tue Jul 28 16:58:07 2020 -0700

    Back port optimization to broadcast_axis to MXNet1.x (#18773)
    
    * Improving performance of broadcast_axis on GPU (#18168)
    
    * adding separate int32_t kernel for GPU in broadcast_axis/to/like operators
    
    * using structure instead of temp workspace to pass stride and shape
    
    * replacing hardcoded int32_t with generic index_t
    
    * combining CPU and GPU kernels to leverage cached stride calculation and fast access shape data in both
    
    Co-authored-by: Rohit Kumar Srivastava <sr...@buckeyemail.osu.edu>
    
    * Improve performance of broadcast_axis on CPU (#17882)
    
    * adding comments explaining code optimizations
    
    * fixing broadcast_axis kernel to int32
    
    * fixing slice_axis kernel to int32
    
    * combining CPU and GPU implementation method signatures and cleaned up
    code
    
    * adding new broadcast_axis to np_matmul
    
    Co-authored-by: Rohit Kumar Srivastava <sr...@buckeyemail.osu.edu>
    
    Co-authored-by: Rohit Kumar Srivastava <sr...@buckeyemail.osu.edu>
---
 src/operator/numpy/np_matmul_op-inl.h     |  40 +++++-
 src/operator/tensor/broadcast_reduce_op.h | 208 +++++++++++++++++++++++++++---
 2 files changed, 224 insertions(+), 24 deletions(-)

diff --git a/src/operator/numpy/np_matmul_op-inl.h b/src/operator/numpy/np_matmul_op-inl.h
index 89560f6..8f1b4f9 100644
--- a/src/operator/numpy/np_matmul_op-inl.h
+++ b/src/operator/numpy/np_matmul_op-inl.h
@@ -138,6 +138,8 @@ inline void MatmulImpl(const OpContext& ctx,
   mshadow::Tensor<xpu, 1, DType*> workspace;
   mshadow::Tensor<xpu, 3, DType> ans, mlhs, mrhs;
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  bool isCPU = std::is_same<xpu, cpu>::value;
+  // Is true if either a or b requires broadcast or not
   if (MatmulNeedBroadcast(a_shape, b_shape)) {
     // e.g. a.shape = (2, 3, 1, 4, 2)
     //      b.shape =       (5, 2, 4)
@@ -157,12 +159,38 @@ inline void MatmulImpl(const OpContext& ctx,
     DType* bc_b_ptr = bc_a_ptr + bc_size_a;
     MSHADOW_TYPE_SWITCH_WITH_BOOL(input_a.type_flag_, IType, {
       MSHADOW_TYPE_SWITCH_WITH_BOOL(input_b.type_flag_, OType, {
-        Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
-          s, bc_size_a, input_a.dptr<IType>(), bc_a_ptr,
-          k_a_shape, k_a_shape_bc, OpReqType::kWriteTo, ndim);
-        Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
-          s, bc_size_b, input_b.dptr<IType>(), bc_b_ptr,
-          k_b_shape, k_b_shape_bc, OpReqType::kWriteTo, ndim);
+        struct ShapeAndStride aux_data_a, aux_data_b;
+        PrepareAUXData(&aux_data_a, k_a_shape, k_a_shape_bc, ndim);
+        PrepareAUXData(&aux_data_b, k_b_shape, k_b_shape_bc, ndim);
+        if (isCPU) {
+          if (!aux_data_a.shape_changed) {
+            Kernel<direct_copy<mshadow_op::identity>, xpu>::Launch(
+              s, bc_size_a, input_a.dptr<IType>(), bc_a_ptr, OpReqType::kWriteTo);
+            Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
+              s, input_b.Size(), input_b.dptr<IType>(), bc_b_ptr,
+              aux_data_b, OpReqType::kWriteTo, ndim);
+          } else if (!aux_data_b.shape_changed) {
+            Kernel<direct_copy<mshadow_op::identity>, xpu>::Launch(
+              s, bc_size_b, input_b.dptr<IType>(), bc_b_ptr, OpReqType::kWriteTo);
+            Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
+              s, input_a.Size(), input_a.dptr<IType>(), bc_a_ptr,
+              aux_data_a, OpReqType::kWriteTo, ndim);
+          } else {
+            Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
+              s, input_a.Size(), input_a.dptr<IType>(), bc_a_ptr,
+              aux_data_a, OpReqType::kWriteTo, ndim);
+            Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
+              s, input_b.Size(), input_b.dptr<IType>(), bc_b_ptr,
+              aux_data_b, OpReqType::kWriteTo, ndim);
+          }
+        } else {
+          Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
+            s, bc_size_a, input_a.dptr<IType>(), bc_a_ptr,
+            aux_data_a, OpReqType::kWriteTo, ndim);
+          Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
+            s, bc_size_b, input_b.dptr<IType>(), bc_b_ptr,
+            aux_data_b, OpReqType::kWriteTo, ndim);
+        }
       });
     });
     ans = mshadow::Tensor<xpu, 3, DType>(output.dptr<DType>(),
diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index 5eb0c41..82b4f7d 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -25,6 +25,7 @@
 #ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_
 #define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_
 
+#include <assert.h>
 #include <mxnet/operator_util.h>
 #include <string>
 #include <vector>
@@ -1037,34 +1038,182 @@ void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs,
   ReduceAxesBackwardUseInOutImpl<xpu, OP, normalize>(ctx, small, inputs, req, outputs);
 }
 
+namespace {  // unnamed namespace to keep scope of the struct within the file
+struct ShapeAndStride {
+  index_t in_stride[MXNET_SPECIAL_MAX_NDIM];
+  index_t out_stride[MXNET_SPECIAL_MAX_NDIM];
+  index_t input_shape[MXNET_SPECIAL_MAX_NDIM];
+  index_t output_shape[MXNET_SPECIAL_MAX_NDIM];
+  // axes: stores which axes in input is to broadcasted
+  index_t axes[MXNET_SPECIAL_MAX_NDIM];
+  int num_broadcast_axes = -1;
+  bool shape_changed = false;
+};
+}  // unnamed namespace
+
+/*!
+ * \brief Calculates Stride of input and output tensor dimesnions
+          And saves mshadow::Shape data in an integer array for
+          faster access.
+ * \param *aux_data to hold stride and shape data.
+ * \param in_shape input shape
+ * \param out_shape output shape
+ * \param ndim no of dimensions in output
+ */
+inline void PrepareAUXData(ShapeAndStride *aux_data,
+                    mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape,
+                    mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape,
+                    int ndim) {
+  int iter = ndim - 1, i = 0;
+  aux_data->out_stride[iter] = 1;
+  aux_data->in_stride[iter] = 1;
+  aux_data->input_shape[iter] = in_shape[iter];
+  aux_data->output_shape[iter] = out_shape[iter];
+  if (in_shape[iter] != out_shape[iter]) {
+    aux_data->axes[i++] = iter;
+    aux_data->shape_changed = true;
+  }
+  iter--;
+  for (; iter >= 0; --iter) {
+    aux_data->out_stride[iter] = aux_data->out_stride[iter + 1] * out_shape[iter + 1];
+    aux_data->in_stride[iter] = aux_data->in_stride[iter + 1] * in_shape[iter + 1];
+    aux_data->input_shape[iter] = in_shape[iter];
+    aux_data->output_shape[iter] = out_shape[iter];
+    if (in_shape[iter] != out_shape[iter]) {
+      aux_data->axes[i++] = iter;
+      aux_data->shape_changed = true;
+    }
+  }
+  aux_data->num_broadcast_axes = i;
+  assert(aux_data->num_broadcast_axes > -1 && aux_data->num_broadcast_axes < 4);
+}
+
 template<typename OP>
-struct broadcast_kernel {
+struct broadcast_kernel_gpu {
   template<typename IType, typename OType>
   MSHADOW_XINLINE static void Map(index_t i,
                                   IType *input,
                                   OType *output,
-                                  mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape,
-                                  mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape,
+                                  const ShapeAndStride& aux_data,
                                   const OpReqType req,
-                                  const uint32_t ndim) {
-    size_t in_stride = 1;
-    size_t out_stride = 1;
+                                  const int ndim) {
     index_t idx = i;
     index_t in_idx = i;
+#pragma unroll 4
     for (int iter = ndim - 1; iter >= 0; --iter) {
-      size_t dim_idx = idx % out_shape[iter];
-      in_idx -= dim_idx * out_stride;
-      if (in_shape[iter] != 1) {
-        in_idx += dim_idx * in_stride;
+      index_t out_dim_shape = aux_data.output_shape[iter];
+      index_t out_dim_stride = aux_data.out_stride[iter];
+      // x % y = x - (x / y) * y
+      // speeds up modulo(%) operation in GPU
+      index_t dim_idx = idx - (idx / out_dim_shape) * out_dim_shape;
+      if (aux_data.input_shape[iter] != 1) {
+        in_idx += dim_idx * (aux_data.in_stride[iter] - out_dim_stride);
+      } else {
+        in_idx -= dim_idx * out_dim_stride;
       }
-      idx /= out_shape[iter];
-      in_stride *= in_shape[iter];
-      out_stride *= out_shape[iter];
+      idx /= out_dim_shape;
     }
     KERNEL_ASSIGN(output[i], req, OP::Map(input[in_idx]));
   }
 };
 
+/**
+ * Changed the thread workload mapping from 1
+ * thread/output element to 1 thread/input to be broadcasted
+ * This approach leverages vectorization when fastest varying
+ * index(stride=1) of the tensor is to be broadcasted.
+ * In other cases it simply performs better by better load balancing.
+ */
+template<typename OP>
+struct broadcast_kernel_cpu {
+  template<typename IType, typename OType>
+  MSHADOW_XINLINE static void Map(index_t i,
+                                  IType *input,
+                                  OType *output,
+                                  const ShapeAndStride& aux_data,
+                                  const OpReqType req,
+                                  const int ndim) {
+    index_t idx = i;
+    index_t init_off = 0;
+    for (int iter = ndim - 1; idx > 0 && iter >= 0; --iter) {
+      size_t dim_idx = idx % aux_data.input_shape[iter];
+      init_off += dim_idx * aux_data.out_stride[iter];
+      idx /= aux_data.input_shape[iter];
+    }
+    index_t stride_0, stride_1, stride_2;
+    // Each case is based on the number of axis to be broadcasted
+    // (1, 2 or 3) after merging axes.
+    switch (aux_data.num_broadcast_axes) {
+      // when input shape is one of the following forms
+      // (x_1,1) or (x_1,1,x_2) or (1,x_1)
+      // x_1, x_2 are size of the dimensions that are not to be broadcasted
+      // in case of (x_1,1) the system leverages vectorization but in other 2
+      // the performance is improved due avoidance of duplicate stride calculations
+      // for each output location input[i] needs to be written to.
+      case 1 :
+        stride_0 = aux_data.out_stride[aux_data.axes[0]];
+        for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) {
+          KERNEL_ASSIGN(output[init_off + l * stride_0],
+              req, OP::Map(input[i]));
+        }
+        break;
+      // when input shape is one of the follwing forms
+      // (x_1,1,x_2,1) or (1,x_1,1,x_2) or (x_1,1,x_2,1,x_3)
+      // x_1, x_2, x_3 are size of the dimensions that are not to be broadcasted
+      // in the inner most loop can be vectorized by compiler in outer loops
+      // the performance is improved due avoidance of duplicate stride calculations
+      // for each output location input[i] needs to be written to.
+      case 2:
+        stride_1 = aux_data.out_stride[aux_data.axes[1]];
+        stride_0 = aux_data.out_stride[aux_data.axes[0]];
+        for (index_t k = 0; k < aux_data.output_shape[aux_data.axes[1]]; k++) {
+          for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) {
+            KERNEL_ASSIGN(output[init_off + k * stride_1 + l * stride_0],
+                req, OP::Map(input[i]));
+          }
+        }
+        break;
+      // when input shape is of the form (1,x_1,1,x_2,1)
+      // x_1, x_2 are size of the dimensions that are not to be broadcasted
+      // here the last axis which is [4] is the one where compiler can vectorize
+      // the code the outer 2 loops improve preformance by avoiding
+      // duplicate stride calculations
+      // for each output location input[i] needs to be written to.
+      case 3:
+        stride_2 = aux_data.out_stride[aux_data.axes[2]];
+        stride_1 = aux_data.out_stride[aux_data.axes[1]];
+        stride_0 = aux_data.out_stride[aux_data.axes[0]];
+        for (index_t j = 0; j < aux_data.output_shape[aux_data.axes[2]]; j++) {
+          for (index_t k = 0; k < aux_data.output_shape[aux_data.axes[1]]; k++) {
+            for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) {
+              KERNEL_ASSIGN(output[init_off + j * stride_2 + k * stride_1 + l * stride_0],
+                  req, OP::Map(input[i]));
+            }
+          }
+        }
+        break;
+    }
+  }
+};
+
+template<typename OP>
+struct direct_copy {
+  template<typename IType, typename OType>
+  MSHADOW_XINLINE static void Map(index_t i,
+                                  IType *input,
+                                  OType *output,
+                                  const OpReqType req) {
+    KERNEL_ASSIGN(output[i], req, OP::Map(input[i]));
+  }
+};
+
+/**
+ * When CPU context is used the no. of kernel launches are equal to
+ * the no. of input elements, this helps leverage vectorization when possible
+ * When GPU context is used no. of kernel launches are equal to
+ * the no. of output elements, this ensures coalesced memory writes to output
+ * and improves coalesced memory reads.
+ */
 template<typename xpu>
 inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
                                  const OpContext& ctx,
@@ -1076,8 +1225,14 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
   using namespace mshadow::expr;
   using namespace mxnet_op;
   mxnet::TShape src_shape, dst_shape;
+  // combines 2 or more consecutive broadcast/non-broadcast axes together
+  // e.g. (3,4,1,1,5,1,6,7) (2,3,5) (5,10,9) -> (3*4,1*1,5,1,6*7) (1,3) (5*10, 9)
+  //      -> (12,1,5,1,42) (1,3) (50, 9)
+  //      and this is the new input for broadcast_kernel whose total
+  //      num of dimensions cannot be greater than 5(throws an error otherwise).
   BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape);
   Stream<xpu> *s = ctx.get_stream<xpu>();
+  bool isCPU = std::is_same<xpu, cpu>::value;
   MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, {
     MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
       mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape;
@@ -1091,21 +1246,38 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
           out_shape[i] = 1;
         }
       }
-      if (dst_shape.ndim() == 2) {
+      struct ShapeAndStride aux_data;
+      PrepareAUXData(&aux_data, in_shape, out_shape, dst_shape.ndim());
+      if (!aux_data.shape_changed) {
+        // If no broadcast is required (i.e. input_shape == output_shape)
+        // then simply copy input to outout.
+        Kernel<direct_copy<mshadow_op::identity>, xpu>::Launch(
+          s, outputs[0].Size(), inputs[0].dptr<IType>(), outputs[0].dptr<OType>(), req[0]);
+      } else if (dst_shape.ndim() == 2) {
         Tensor<xpu, 2, OType> out =
           outputs[0].get_with_shape<xpu, 2, OType>(dst_shape.get<2>(), s);
         Tensor<xpu, 2, IType> data =
           inputs[0].get_with_shape<xpu, 2, IType>(src_shape.get<2>(), s);
-        Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
-          s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], 2);
+        if (isCPU) {
+          Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
+            s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2);
+        } else {
+          Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
+            s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2);
+        }
       } else {
         const int ndim = MXNET_SPECIAL_MAX_NDIM;
         Tensor<xpu, ndim, OType> out =
           outputs[0].get_with_shape<xpu, ndim, OType>(dst_shape.get<ndim>(), s);
         Tensor<xpu, ndim, IType> data =
           inputs[0].get_with_shape<xpu, ndim, IType>(src_shape.get<ndim>(), s);
-        Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
-          s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], ndim);
+        if (isCPU) {
+          Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
+            s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim);
+        } else {
+          Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
+            s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim);
+        }
       }
     });
   });