You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/06/15 17:45:55 UTC

[incubator-mxnet] branch master updated: leaky relu speed (#11012)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new fb084cc  leaky relu speed (#11012)
fb084cc is described below

commit fb084cc7bb97176e24db7c426b39cf504d440adb
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Fri Jun 15 13:45:40 2018 -0400

    leaky relu speed (#11012)
    
    * leaky relu forward speed
    
    * leaky relu backward speed
    
    * fix infer shape
    
    * fix shape
---
 src/operator/leaky_relu-inl.h          | 132 +++++++++++++++++++++++----------
 src/operator/mshadow_op.h              |   2 +
 src/operator/operator_tune.cc          |   1 +
 tests/python/unittest/test_operator.py |  11 ++-
 4 files changed, 104 insertions(+), 42 deletions(-)

diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h
index c99280a..8b93e83 100644
--- a/src/operator/leaky_relu-inl.h
+++ b/src/operator/leaky_relu-inl.h
@@ -39,6 +39,7 @@
 #include "./mshadow_op.h"
 #include "./random/sampler.h"
 #include "./random/sample_op.h"
+#include "./tensor/elemwise_binary_broadcast_op.h"
 
 namespace mxnet {
 namespace op {
@@ -72,12 +73,6 @@ struct LeakyReLUParam : public dmlc::Parameter<LeakyReLUParam> {
   }
 };
 
-struct prelu_grad {
-  MSHADOW_XINLINE static real_t Map(real_t a) {
-    return a > 0.0f ? 0.0f : a;
-  }
-};
-
 template<typename xpu, typename DType>
 class LeakyReLUOp : public Operator {
  public:
@@ -98,28 +93,51 @@ class LeakyReLUOp : public Operator {
     Tensor<xpu, 3, DType> data;
     Tensor<xpu, 3, DType> out;
     Tensor<xpu, 3, DType> mask;
-    Tensor<xpu, 1, DType> weight;
     int n = in_data[leakyrelu::kData].shape_[0];
     int k = in_data[leakyrelu::kData].shape_[1];
     Shape<3> dshape = Shape3(n, k, in_data[leakyrelu::kData].Size()/n/k);
     data = in_data[leakyrelu::kData].get_with_shape<xpu, 3, DType>(dshape, s);
     out = out_data[leakyrelu::kOut].get_with_shape<xpu, 3, DType>(dshape, s);
+    if (req[leakyrelu::kOut] == kNullOp) {
+      return;
+    }
     switch (param_.act_type) {
       case leakyrelu::kLeakyReLU: {
         MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
-          mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::xelu, Req>, xpu>::Launch(
+          mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::xelu, Req>, xpu>::Launch(
             s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_, DType(param_.slope));
         });
         break;
       }
       case leakyrelu::kPReLU: {
-        weight = in_data[leakyrelu::kGamma].get<xpu, 1, DType>(s);
-        if (weight.shape_.Size() == 1) {
-          Assign(out, req[leakyrelu::kOut],
-                 F<mshadow_op::xelu>(data, mshadow::expr::broadcast_scalar(weight, out.shape_)));
+        TShape gshape = expand_shape(in_data[leakyrelu::kGamma].shape_,
+                                     in_data[leakyrelu::kData].shape_);
+        TShape new_lshape, new_rshape, new_oshape;
+        const int ndim = op::BinaryBroadcastShapeCompact(in_data[leakyrelu::kData].shape_,
+                                                         gshape,
+                                                         out_data[leakyrelu::kOut].shape_,
+                                                         &new_lshape, &new_rshape, &new_oshape);
+        if (!ndim) {
+          MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
+            const size_t size = (minthree(out_data[leakyrelu::kOut].Size(),
+                                          in_data[leakyrelu::kData].Size(),
+                                          in_data[leakyrelu::kGamma].Size())
+            + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::xelu, Req>, xpu>::Launch(
+                s, size, out_data[leakyrelu::kOut].dptr<DType>(),
+            in_data[leakyrelu::kData].dptr<DType>(), in_data[leakyrelu::kGamma].dptr<DType>());
+          });
         } else {
-          Assign(out, req[leakyrelu::kOut],
-                 F<mshadow_op::xelu>(data, mshadow::expr::broadcast<1>(weight, out.shape_)));
+          BROADCAST_NDIM_SWITCH(ndim, NDim, {
+            mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
+            mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
+            mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
+            mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType,
+                                                               mshadow_op::xelu>, xpu>::
+            template LaunchEx(s, new_oshape.Size(), req[leakyrelu::kOut], lstride, rstride, oshape,
+            in_data[leakyrelu::kData].dptr<DType>(), in_data[leakyrelu::kGamma].dptr<DType>(),
+            out_data[leakyrelu::kOut].dptr<DType>());
+          });
         }
         break;
       }
@@ -134,23 +152,23 @@ class LeakyReLUOp : public Operator {
           Tensor<xpu, 1, DType> out = mask.FlatTo1D();
           sampler.Sample(low, high, out, pgen, s);
           MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kMask], Req, {
-            mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::mul, Req>, xpu>::Launch(
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
               s, mask.size(0) * mask.size(1) * mask.size(2), mask.dptr_, mask.dptr_,
               DType(param_.upper_bound - param_.lower_bound));
           });
           MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kMask], Req, {
-            mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::plus, Req>, xpu>::Launch(
               s, mask.size(0) * mask.size(1) * mask.size(2), mask.dptr_, mask.dptr_,
               DType(param_.lower_bound));
           });
           MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
-            mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::xelu, Req>, xpu>::Launch(
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::xelu, Req>, xpu>::Launch(
               s, mask.size(0) * mask.size(1) * mask.size(2), out.dptr_, data.dptr_, mask.dptr_);
           });
         } else {
           const float slope = (param_.lower_bound + param_.upper_bound) / 2.0f;
           MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
-            mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::xelu, Req>, xpu>::Launch(
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::xelu, Req>, xpu>::Launch(
               s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_, DType(slope));
           });
         }
@@ -158,7 +176,7 @@ class LeakyReLUOp : public Operator {
       }
       case leakyrelu::kELU: {
         MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
-          mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::elu, Req>, xpu>::Launch(
+          mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::elu, Req>, xpu>::Launch(
             s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_,
             DType(param_.slope));
         });
@@ -188,8 +206,6 @@ class LeakyReLUOp : public Operator {
     Tensor<xpu, 3, DType> gdata;
     Tensor<xpu, 3, DType> grad;
     Tensor<xpu, 3, DType> mask;
-    Tensor<xpu, 1, DType> weight;
-    Tensor<xpu, 1, DType> grad_weight;
     int n = out_grad[leakyrelu::kOut].shape_[0];
     int k = out_grad[leakyrelu::kOut].shape_[1];
     Shape<3> dshape = Shape3(n, k, out_grad[leakyrelu::kOut].Size()/n/k);
@@ -206,29 +222,38 @@ class LeakyReLUOp : public Operator {
       case leakyrelu::kLeakyReLU: {
         MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, {
           mxnet_op::Kernel<mxnet_op::op_with_req<
-            mxnet_op::backward_grad_tuned<mxnet::op::mshadow_op::xelu_grad>, Req>, xpu>::Launch(
+            mxnet_op::backward_grad_tuned<mshadow_op::xelu_grad>, Req>, xpu>::Launch(
               s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_,
               output.dptr_, DType(param_.slope));
         });
         break;
       }
       case leakyrelu::kPReLU: {
-        weight = in_data[leakyrelu::kGamma].get<xpu, 1, DType>(s);
-        grad_weight = in_grad[leakyrelu::kGamma].get<xpu, 1, DType>(s);
-        if (weight.shape_.Size() == 1) {
-          Shape<4> gshape = Shape4(1, grad.shape_[0], grad.shape_[1], grad.shape_[2]);
-          Assign(grad_weight, req[leakyrelu::kGamma],
-                 sumall_except_dim<0>(reshape(F<prelu_grad>(data) * grad, gshape)));
-          Assign(gdata, req[leakyrelu::kData],
-                 F<mshadow_op::xelu_grad>(data,
-                                          mshadow::expr::broadcast_scalar(weight, data.shape_))
-                 * grad);
+        TShape gshape = expand_shape(in_grad[leakyrelu::kGamma].shape_,
+                                     in_grad[leakyrelu::kData].shape_);
+        TShape new_lshape, new_rshape, new_oshape;
+        const bool need_bc = BinaryBroadcastShapeCompact(in_grad[leakyrelu::kData].shape_,
+                                                         gshape,
+                                                         out_grad[leakyrelu::kOut].shape_,
+                                                         &new_lshape,
+                                                         &new_rshape,
+                                                         &new_oshape) != 0;
+        if (!need_bc) {
+          ElemwiseBinaryOp::BackwardUseIn<xpu,
+                                          mshadow_op::xelu_grad,
+                                          mshadow_op::prelu_grad>(
+            nnvm::NodeAttrs(), ctx, {out_grad[leakyrelu::kOut],
+                                     in_data[leakyrelu::kData],
+                                     in_data[leakyrelu::kGamma]}, req, in_grad);
         } else {
-          Assign(grad_weight, req[leakyrelu::kGamma],
-                 sumall_except_dim<1>(F<prelu_grad>(data) * grad));
-          Assign(gdata, req[leakyrelu::kData],
-                 F<mshadow_op::xelu_grad>(data, mshadow::expr::broadcast<1>(weight, data.shape_))
-                 * grad);
+          BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
+            BinaryBroadcastBackwardUseInImpl<xpu, NDim, DType,
+              mshadow_op::xelu_grad, mshadow_op::prelu_grad>(
+                ctx, {out_grad[leakyrelu::kOut],
+                      in_data[leakyrelu::kData],
+                      in_data[leakyrelu::kGamma]}, req, in_grad,
+                new_lshape, new_rshape, new_oshape);
+          });
         }
         break;
       }
@@ -239,7 +264,7 @@ class LeakyReLUOp : public Operator {
       case leakyrelu::kELU: {
         MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, {
           mxnet_op::Kernel<mxnet_op::op_with_req<
-            mxnet_op::backward_grad_tuned<mxnet::op::mshadow_op::elu_grad>, Req>, xpu>::Launch(
+            mxnet_op::backward_grad_tuned<mshadow_op::elu_grad>, Req>, xpu>::Launch(
               s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_,
               output.dptr_, DType(param_.slope));
         });
@@ -251,6 +276,24 @@ class LeakyReLUOp : public Operator {
   }
 
  private:
+  /*! \brief Minimum of three */
+  static MSHADOW_XINLINE size_t minthree(const size_t a, const size_t b, const size_t c) {
+    return a < b ? (a < c ? a : c) : (b < c ? b : c);
+  }
+  static inline TShape expand_shape(const TShape& src, const TShape& dst) {
+    TShape result(dst.ndim());
+    int s = src.ndim() - 1;
+    for (int i = dst.ndim() - 1; i >= 0; i--) {
+      if (s >= 0 && (dst[i] == src[s] || src[s] == 1)) {
+        result[i] = src[s];
+        s--;
+      } else {
+        result[i] = 1;
+      }
+    }
+    CHECK(s == -1) << "Cannot broadcast gamma to data. gamma: " << src << ", data: " << dst;
+    return result;
+  }
   LeakyReLUParam param_;
 };  // class LeakyReLUOp
 
@@ -281,10 +324,12 @@ class LeakyReLUProp : public OperatorProperty {
     if (dshape.ndim() == 0) return false;
     if (param_.act_type == leakyrelu::kPReLU) {
       const TShape &gshape = in_shape->at(leakyrelu::kGamma);
-      if (gshape.ndim() == 1 && gshape.Size() == 1)
-        in_shape->at(leakyrelu::kGamma) = TShape(Shape1(1));
-      else
+      if (gshape.ndim() == 0) {
         in_shape->at(leakyrelu::kGamma) = TShape(Shape1(dshape[1]));
+      }
+      if (dshape == gshape) {
+        SHAPE_ASSIGN_CHECK(*out_shape, 0, dshape);
+      }
     }
     out_shape->clear();
     out_shape->push_back(dshape);
@@ -396,6 +441,11 @@ class LeakyReLUProp : public OperatorProperty {
     }
   }
 
+  std::vector<ResourceRequest> BackwardResource(
+      const std::vector<TShape> &in_shape) const override {
+    return {ResourceRequest::kTempSpace};
+  }
+
   Operator* CreateOperator(Context ctx) const override {
     LOG(FATAL) << "Not Implemented.";
     return NULL;
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 19fa4f8..5953568 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -126,6 +126,8 @@ MXNET_UNARY_MATH_OP_NC(relu, a > DType(0) ? a : DType(0));
 
 MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0));
 
+MXNET_BINARY_MATH_OP_NC(prelu_grad, a > DType(0) ? DType(0) : a);
+
 MXNET_BINARY_MATH_OP_NC(xelu, a > DType(0) ? a :
                         DType(static_cast<float>(a) * static_cast<float>(b)));
 
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index de3c742..0953cba 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -322,6 +322,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::prelu_grad); // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::elu_grad); // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum);  // NOLINT()
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 0c68ae2..f287c19 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -677,7 +677,9 @@ def test_prelu():
         copy_x = x.copy()
         copy_x[pos_indices] = 0.0
         grad_x[pos_indices] = 1.0
-        if gamma.shape[0] == 1:
+        if len(gamma.shape) > 1:
+            grad_gam = copy_x
+        elif gamma.shape[0] == 1:
             grad_gam = np.sum(np.sum(copy_x))
         elif gamma.shape[0] > 1:
             grad_gam = np.sum(copy_x, axis=0)
@@ -687,6 +689,7 @@ def test_prelu():
     gamma = mx.symbol.Variable("gamma")
     for dtype in [np.float16, np.float32, np.float64]:
         for gam in [np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype)]:
+            gam_full = np.array([gam, gam, gam])
             xa = np.random.uniform(low=-1.0,high=1.0,size=shape).astype(dtype)
             rtol = 1e-2
             atol = 1e-3
@@ -694,12 +697,18 @@ def test_prelu():
             xa[abs(xa) < eps] = 1.0
             y = mx.symbol.LeakyReLU(data=x, gamma=gamma, act_type='prelu')
             ya = fprelu(xa, gam)
+            ya_full = fprelu(xa, gam_full)
             g_xa, g_gam = fprelu_grad(xa, ya, gamma=gam)
+            g_xa_full, g_gam_full = fprelu_grad(xa, ya_full, gamma=gam_full)
             # Skip numeric check for float16 type to get rid of flaky behavior
             if dtype is not np.float16:
                 check_numeric_gradient(y, [xa, gam], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
+                check_numeric_gradient(y, [xa, gam_full], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
             check_symbolic_forward(y, [xa, gam], [ya], rtol=rtol, atol=atol, dtype=dtype)
             check_symbolic_backward(y, [xa, gam], [np.ones(shape), np.ones(gam.shape)], [g_xa, g_gam], rtol=rtol, atol=atol, dtype=dtype)
+            check_symbolic_forward(y, [xa, gam_full], [ya_full], rtol=rtol, atol=atol, dtype=dtype)
+            check_symbolic_backward(y, [xa, gam_full], [np.ones(shape), np.ones(gam_full.shape)],
+                                    [g_xa_full, g_gam_full], rtol=rtol, atol=atol, dtype=dtype)
 
 @with_seed()
 def test_sigmoid():

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.