You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/15 17:45:42 UTC

[GitHub] piiswrong closed pull request #11012: leaky relu speed

piiswrong closed pull request #11012: leaky relu speed
URL: https://github.com/apache/incubator-mxnet/pull/11012
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h
index c99280ac7ea..8b93e8358d5 100644
--- a/src/operator/leaky_relu-inl.h
+++ b/src/operator/leaky_relu-inl.h
@@ -39,6 +39,7 @@
 #include "./mshadow_op.h"
 #include "./random/sampler.h"
 #include "./random/sample_op.h"
+#include "./tensor/elemwise_binary_broadcast_op.h"
 
 namespace mxnet {
 namespace op {
@@ -72,12 +73,6 @@ struct LeakyReLUParam : public dmlc::Parameter<LeakyReLUParam> {
   }
 };
 
-struct prelu_grad {
-  MSHADOW_XINLINE static real_t Map(real_t a) {
-    return a > 0.0f ? 0.0f : a;
-  }
-};
-
 template<typename xpu, typename DType>
 class LeakyReLUOp : public Operator {
  public:
@@ -98,28 +93,51 @@ class LeakyReLUOp : public Operator {
     Tensor<xpu, 3, DType> data;
     Tensor<xpu, 3, DType> out;
     Tensor<xpu, 3, DType> mask;
-    Tensor<xpu, 1, DType> weight;
     int n = in_data[leakyrelu::kData].shape_[0];
     int k = in_data[leakyrelu::kData].shape_[1];
     Shape<3> dshape = Shape3(n, k, in_data[leakyrelu::kData].Size()/n/k);
     data = in_data[leakyrelu::kData].get_with_shape<xpu, 3, DType>(dshape, s);
     out = out_data[leakyrelu::kOut].get_with_shape<xpu, 3, DType>(dshape, s);
+    if (req[leakyrelu::kOut] == kNullOp) {
+      return;
+    }
     switch (param_.act_type) {
       case leakyrelu::kLeakyReLU: {
         MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
-          mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::xelu, Req>, xpu>::Launch(
+          mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::xelu, Req>, xpu>::Launch(
             s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_, DType(param_.slope));
         });
         break;
       }
       case leakyrelu::kPReLU: {
-        weight = in_data[leakyrelu::kGamma].get<xpu, 1, DType>(s);
-        if (weight.shape_.Size() == 1) {
-          Assign(out, req[leakyrelu::kOut],
-                 F<mshadow_op::xelu>(data, mshadow::expr::broadcast_scalar(weight, out.shape_)));
+        TShape gshape = expand_shape(in_data[leakyrelu::kGamma].shape_,
+                                     in_data[leakyrelu::kData].shape_);
+        TShape new_lshape, new_rshape, new_oshape;
+        const int ndim = op::BinaryBroadcastShapeCompact(in_data[leakyrelu::kData].shape_,
+                                                         gshape,
+                                                         out_data[leakyrelu::kOut].shape_,
+                                                         &new_lshape, &new_rshape, &new_oshape);
+        if (!ndim) {
+          MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
+            const size_t size = (minthree(out_data[leakyrelu::kOut].Size(),
+                                          in_data[leakyrelu::kData].Size(),
+                                          in_data[leakyrelu::kGamma].Size())
+            + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::xelu, Req>, xpu>::Launch(
+                s, size, out_data[leakyrelu::kOut].dptr<DType>(),
+            in_data[leakyrelu::kData].dptr<DType>(), in_data[leakyrelu::kGamma].dptr<DType>());
+          });
         } else {
-          Assign(out, req[leakyrelu::kOut],
-                 F<mshadow_op::xelu>(data, mshadow::expr::broadcast<1>(weight, out.shape_)));
+          BROADCAST_NDIM_SWITCH(ndim, NDim, {
+            mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
+            mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
+            mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
+            mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType,
+                                                               mshadow_op::xelu>, xpu>::
+            template LaunchEx(s, new_oshape.Size(), req[leakyrelu::kOut], lstride, rstride, oshape,
+            in_data[leakyrelu::kData].dptr<DType>(), in_data[leakyrelu::kGamma].dptr<DType>(),
+            out_data[leakyrelu::kOut].dptr<DType>());
+          });
         }
         break;
       }
@@ -134,23 +152,23 @@ class LeakyReLUOp : public Operator {
           Tensor<xpu, 1, DType> out = mask.FlatTo1D();
           sampler.Sample(low, high, out, pgen, s);
           MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kMask], Req, {
-            mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::mul, Req>, xpu>::Launch(
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
               s, mask.size(0) * mask.size(1) * mask.size(2), mask.dptr_, mask.dptr_,
               DType(param_.upper_bound - param_.lower_bound));
           });
           MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kMask], Req, {
-            mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::plus, Req>, xpu>::Launch(
               s, mask.size(0) * mask.size(1) * mask.size(2), mask.dptr_, mask.dptr_,
               DType(param_.lower_bound));
           });
           MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
-            mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::xelu, Req>, xpu>::Launch(
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::xelu, Req>, xpu>::Launch(
               s, mask.size(0) * mask.size(1) * mask.size(2), out.dptr_, data.dptr_, mask.dptr_);
           });
         } else {
           const float slope = (param_.lower_bound + param_.upper_bound) / 2.0f;
           MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
-            mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::xelu, Req>, xpu>::Launch(
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::xelu, Req>, xpu>::Launch(
               s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_, DType(slope));
           });
         }
@@ -158,7 +176,7 @@ class LeakyReLUOp : public Operator {
       }
       case leakyrelu::kELU: {
         MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
-          mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::elu, Req>, xpu>::Launch(
+          mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::elu, Req>, xpu>::Launch(
             s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_,
             DType(param_.slope));
         });
@@ -188,8 +206,6 @@ class LeakyReLUOp : public Operator {
     Tensor<xpu, 3, DType> gdata;
     Tensor<xpu, 3, DType> grad;
     Tensor<xpu, 3, DType> mask;
-    Tensor<xpu, 1, DType> weight;
-    Tensor<xpu, 1, DType> grad_weight;
     int n = out_grad[leakyrelu::kOut].shape_[0];
     int k = out_grad[leakyrelu::kOut].shape_[1];
     Shape<3> dshape = Shape3(n, k, out_grad[leakyrelu::kOut].Size()/n/k);
@@ -206,29 +222,38 @@ class LeakyReLUOp : public Operator {
       case leakyrelu::kLeakyReLU: {
         MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, {
           mxnet_op::Kernel<mxnet_op::op_with_req<
-            mxnet_op::backward_grad_tuned<mxnet::op::mshadow_op::xelu_grad>, Req>, xpu>::Launch(
+            mxnet_op::backward_grad_tuned<mshadow_op::xelu_grad>, Req>, xpu>::Launch(
               s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_,
               output.dptr_, DType(param_.slope));
         });
         break;
       }
       case leakyrelu::kPReLU: {
-        weight = in_data[leakyrelu::kGamma].get<xpu, 1, DType>(s);
-        grad_weight = in_grad[leakyrelu::kGamma].get<xpu, 1, DType>(s);
-        if (weight.shape_.Size() == 1) {
-          Shape<4> gshape = Shape4(1, grad.shape_[0], grad.shape_[1], grad.shape_[2]);
-          Assign(grad_weight, req[leakyrelu::kGamma],
-                 sumall_except_dim<0>(reshape(F<prelu_grad>(data) * grad, gshape)));
-          Assign(gdata, req[leakyrelu::kData],
-                 F<mshadow_op::xelu_grad>(data,
-                                          mshadow::expr::broadcast_scalar(weight, data.shape_))
-                 * grad);
+        TShape gshape = expand_shape(in_grad[leakyrelu::kGamma].shape_,
+                                     in_grad[leakyrelu::kData].shape_);
+        TShape new_lshape, new_rshape, new_oshape;
+        const bool need_bc = BinaryBroadcastShapeCompact(in_grad[leakyrelu::kData].shape_,
+                                                         gshape,
+                                                         out_grad[leakyrelu::kOut].shape_,
+                                                         &new_lshape,
+                                                         &new_rshape,
+                                                         &new_oshape) != 0;
+        if (!need_bc) {
+          ElemwiseBinaryOp::BackwardUseIn<xpu,
+                                          mshadow_op::xelu_grad,
+                                          mshadow_op::prelu_grad>(
+            nnvm::NodeAttrs(), ctx, {out_grad[leakyrelu::kOut],
+                                     in_data[leakyrelu::kData],
+                                     in_data[leakyrelu::kGamma]}, req, in_grad);
         } else {
-          Assign(grad_weight, req[leakyrelu::kGamma],
-                 sumall_except_dim<1>(F<prelu_grad>(data) * grad));
-          Assign(gdata, req[leakyrelu::kData],
-                 F<mshadow_op::xelu_grad>(data, mshadow::expr::broadcast<1>(weight, data.shape_))
-                 * grad);
+          BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
+            BinaryBroadcastBackwardUseInImpl<xpu, NDim, DType,
+              mshadow_op::xelu_grad, mshadow_op::prelu_grad>(
+                ctx, {out_grad[leakyrelu::kOut],
+                      in_data[leakyrelu::kData],
+                      in_data[leakyrelu::kGamma]}, req, in_grad,
+                new_lshape, new_rshape, new_oshape);
+          });
         }
         break;
       }
@@ -239,7 +264,7 @@ class LeakyReLUOp : public Operator {
       case leakyrelu::kELU: {
         MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, {
           mxnet_op::Kernel<mxnet_op::op_with_req<
-            mxnet_op::backward_grad_tuned<mxnet::op::mshadow_op::elu_grad>, Req>, xpu>::Launch(
+            mxnet_op::backward_grad_tuned<mshadow_op::elu_grad>, Req>, xpu>::Launch(
               s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_,
               output.dptr_, DType(param_.slope));
         });
@@ -251,6 +276,24 @@ class LeakyReLUOp : public Operator {
   }
 
  private:
+  /*! \brief Minimum of three */
+  static MSHADOW_XINLINE size_t minthree(const size_t a, const size_t b, const size_t c) {
+    return a < b ? (a < c ? a : c) : (b < c ? b : c);
+  }
+  static inline TShape expand_shape(const TShape& src, const TShape& dst) {
+    TShape result(dst.ndim());
+    int s = src.ndim() - 1;
+    for (int i = dst.ndim() - 1; i >= 0; i--) {
+      if (s >= 0 && (dst[i] == src[s] || src[s] == 1)) {
+        result[i] = src[s];
+        s--;
+      } else {
+        result[i] = 1;
+      }
+    }
+    CHECK(s == -1) << "Cannot broadcast gamma to data. gamma: " << src << ", data: " << dst;
+    return result;
+  }
   LeakyReLUParam param_;
 };  // class LeakyReLUOp
 
@@ -281,10 +324,12 @@ class LeakyReLUProp : public OperatorProperty {
     if (dshape.ndim() == 0) return false;
     if (param_.act_type == leakyrelu::kPReLU) {
       const TShape &gshape = in_shape->at(leakyrelu::kGamma);
-      if (gshape.ndim() == 1 && gshape.Size() == 1)
-        in_shape->at(leakyrelu::kGamma) = TShape(Shape1(1));
-      else
+      if (gshape.ndim() == 0) {
         in_shape->at(leakyrelu::kGamma) = TShape(Shape1(dshape[1]));
+      }
+      if (dshape == gshape) {
+        SHAPE_ASSIGN_CHECK(*out_shape, 0, dshape);
+      }
     }
     out_shape->clear();
     out_shape->push_back(dshape);
@@ -396,6 +441,11 @@ class LeakyReLUProp : public OperatorProperty {
     }
   }
 
+  std::vector<ResourceRequest> BackwardResource(
+      const std::vector<TShape> &in_shape) const override {
+    return {ResourceRequest::kTempSpace};
+  }
+
   Operator* CreateOperator(Context ctx) const override {
     LOG(FATAL) << "Not Implemented.";
     return NULL;
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 19fa4f8ead8..5953568c7fa 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -126,6 +126,8 @@ MXNET_UNARY_MATH_OP_NC(relu, a > DType(0) ? a : DType(0));
 
 MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0));
 
+MXNET_BINARY_MATH_OP_NC(prelu_grad, a > DType(0) ? DType(0) : a);
+
 MXNET_BINARY_MATH_OP_NC(xelu, a > DType(0) ? a :
                         DType(static_cast<float>(a) * static_cast<float>(b)));
 
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index de3c7422c5f..0953cbaf519 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -322,6 +322,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::prelu_grad); // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::elu_grad); // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum);  // NOLINT()
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index e7976e01f9d..497170a2224 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -630,7 +630,9 @@ def fprelu_grad(x, y, gamma):
         copy_x = x.copy()
         copy_x[pos_indices] = 0.0
         grad_x[pos_indices] = 1.0
-        if gamma.shape[0] == 1:
+        if len(gamma.shape) > 1:
+            grad_gam = copy_x
+        elif gamma.shape[0] == 1:
             grad_gam = np.sum(np.sum(copy_x))
         elif gamma.shape[0] > 1:
             grad_gam = np.sum(copy_x, axis=0)
@@ -640,6 +642,7 @@ def fprelu_grad(x, y, gamma):
     gamma = mx.symbol.Variable("gamma")
     for dtype in [np.float16, np.float32, np.float64]:
         for gam in [np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype)]:
+            gam_full = np.array([gam, gam, gam])
             xa = np.random.uniform(low=-1.0,high=1.0,size=shape).astype(dtype)
             rtol = 1e-2
             atol = 1e-3
@@ -647,12 +650,18 @@ def fprelu_grad(x, y, gamma):
             xa[abs(xa) < eps] = 1.0
             y = mx.symbol.LeakyReLU(data=x, gamma=gamma, act_type='prelu')
             ya = fprelu(xa, gam)
+            ya_full = fprelu(xa, gam_full)
             g_xa, g_gam = fprelu_grad(xa, ya, gamma=gam)
+            g_xa_full, g_gam_full = fprelu_grad(xa, ya_full, gamma=gam_full)
             # Skip numeric check for float16 type to get rid of flaky behavior
             if dtype is not np.float16:
                 check_numeric_gradient(y, [xa, gam], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
+                check_numeric_gradient(y, [xa, gam_full], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
             check_symbolic_forward(y, [xa, gam], [ya], rtol=rtol, atol=atol, dtype=dtype)
             check_symbolic_backward(y, [xa, gam], [np.ones(shape), np.ones(gam.shape)], [g_xa, g_gam], rtol=rtol, atol=atol, dtype=dtype)
+            check_symbolic_forward(y, [xa, gam_full], [ya_full], rtol=rtol, atol=atol, dtype=dtype)
+            check_symbolic_backward(y, [xa, gam_full], [np.ones(shape), np.ones(gam_full.shape)],
+                                    [g_xa_full, g_gam_full], rtol=rtol, atol=atol, dtype=dtype)
 
 @with_seed()
 def test_sigmoid():
@@ -5745,7 +5754,7 @@ def py_bilinear_resize(x, outputHeight, outputWidth):
         batch, channel, inputHeight, inputWidth = x.shape
         if outputHeight == inputHeight and outputWidth == inputWidth:
             return x
-        y = np.empty([batch, channel, outputHeight, outputWidth]) 
+        y = np.empty([batch, channel, outputHeight, outputWidth])
         rheight = 1.0 * (inputHeight - 1) / (outputHeight - 1) if outputHeight > 1 else 0.0
         rwidth = 1.0 * (inputWidth - 1) / (outputWidth - 1) if outputWidth > 1 else 0.0
         for h2 in range(outputHeight):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services