You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/03/22 17:52:48 UTC

[incubator-mxnet] branch master updated: [MXNET-101] Support float16 in LeakyReLU operator (#10169)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 10d7458  [MXNET-101] Support float16 in LeakyReLU operator (#10169)
10d7458 is described below

commit 10d745843477b8277e46e93173c6f7bfde8eda63
Author: Hao Jin <ha...@users.noreply.github.com>
AuthorDate: Thu Mar 22 10:52:44 2018 -0700

    [MXNET-101] Support float16 in LeakyReLU operator (#10169)
    
    * support for any datatype in leaky ReLU
    
    * test for LeakyReLU operators
    
    * make lint
    
    * clean up unnecessary prints
    
    * fix for amalgamation build failure
    
    * add InferType for Leaky ReLU and slight modification to the tests
---
 src/operator/leaky_relu-inl.h          | 136 ++++++++++++++++++++++++---------
 src/operator/leaky_relu.cc             |  13 +++-
 src/operator/leaky_relu.cu             |   8 +-
 src/operator/mshadow_op.h              |  15 +++-
 src/operator/operator_tune.cc          |   4 +
 tests/python/unittest/test_operator.py |  71 +++++++++++++++++
 6 files changed, 201 insertions(+), 46 deletions(-)

diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h
index 77eba43..c99280a 100644
--- a/src/operator/leaky_relu-inl.h
+++ b/src/operator/leaky_relu-inl.h
@@ -34,8 +34,11 @@
 #include <string>
 #include <vector>
 #include <utility>
+#include "../common/random_generator.h"
 #include "./operator_common.h"
 #include "./mshadow_op.h"
+#include "./random/sampler.h"
+#include "./random/sample_op.h"
 
 namespace mxnet {
 namespace op {
@@ -75,7 +78,7 @@ struct prelu_grad {
   }
 };
 
-template<typename xpu>
+template<typename xpu, typename DType>
 class LeakyReLUOp : public Operator {
  public:
   explicit LeakyReLUOp(LeakyReLUParam param) {
@@ -92,25 +95,25 @@ class LeakyReLUOp : public Operator {
     size_t expected = param_.act_type == leakyrelu::kPReLU ? 2 : 1;
     CHECK_EQ(in_data.size(), expected);
     Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 3> data;
-    Tensor<xpu, 3> out;
-    Tensor<xpu, 3> mask;
-    Tensor<xpu, 1> weight;
+    Tensor<xpu, 3, DType> data;
+    Tensor<xpu, 3, DType> out;
+    Tensor<xpu, 3, DType> mask;
+    Tensor<xpu, 1, DType> weight;
     int n = in_data[leakyrelu::kData].shape_[0];
     int k = in_data[leakyrelu::kData].shape_[1];
     Shape<3> dshape = Shape3(n, k, in_data[leakyrelu::kData].Size()/n/k);
-    data = in_data[leakyrelu::kData].get_with_shape<xpu, 3, real_t>(dshape, s);
-    out = out_data[leakyrelu::kOut].get_with_shape<xpu, 3, real_t>(dshape, s);
-    if (param_.act_type == leakyrelu::kRReLU) {
-      mask = out_data[leakyrelu::kMask].get_with_shape<xpu, 3, real_t>(dshape, s);
-    }
+    data = in_data[leakyrelu::kData].get_with_shape<xpu, 3, DType>(dshape, s);
+    out = out_data[leakyrelu::kOut].get_with_shape<xpu, 3, DType>(dshape, s);
     switch (param_.act_type) {
       case leakyrelu::kLeakyReLU: {
-        Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, param_.slope));
+        MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
+          mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::xelu, Req>, xpu>::Launch(
+            s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_, DType(param_.slope));
+        });
         break;
       }
       case leakyrelu::kPReLU: {
-        weight = in_data[leakyrelu::kGamma].get<xpu, 1, real_t>(s);
+        weight = in_data[leakyrelu::kGamma].get<xpu, 1, DType>(s);
         if (weight.shape_.Size() == 1) {
           Assign(out, req[leakyrelu::kOut],
                  F<mshadow_op::xelu>(data, mshadow::expr::broadcast_scalar(weight, out.shape_)));
@@ -122,18 +125,43 @@ class LeakyReLUOp : public Operator {
       }
       case leakyrelu::kRReLU: {
         if (ctx.is_train) {
-          Random<xpu>* prnd = ctx.requested[leakyrelu::kRandom].get_random<xpu, real_t>(s);
-          mask = prnd->uniform(mask.shape_);
-          mask = mask * (param_.upper_bound - param_.lower_bound) + param_.lower_bound;
-          Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, mask));
+          mask = out_data[leakyrelu::kMask].get_with_shape<xpu, 3, DType>(dshape, s);
+          mxnet::op::UniformSampler<xpu> sampler;
+          Tensor<xpu, 1, DType> low, high;
+          mxnet::op::GetSamplingTempData<xpu, DType>(DType(0.0f), DType(1.0f), ctx, &low, &high);
+          mxnet::common::random::RandGenerator<xpu, DType> *pgen =
+            ctx.requested[0].get_parallel_random<xpu, DType>();
+          Tensor<xpu, 1, DType> out = mask.FlatTo1D();
+          sampler.Sample(low, high, out, pgen, s);
+          MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kMask], Req, {
+            mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::mul, Req>, xpu>::Launch(
+              s, mask.size(0) * mask.size(1) * mask.size(2), mask.dptr_, mask.dptr_,
+              DType(param_.upper_bound - param_.lower_bound));
+          });
+          MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kMask], Req, {
+            mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
+              s, mask.size(0) * mask.size(1) * mask.size(2), mask.dptr_, mask.dptr_,
+              DType(param_.lower_bound));
+          });
+          MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
+            mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::xelu, Req>, xpu>::Launch(
+              s, mask.size(0) * mask.size(1) * mask.size(2), out.dptr_, data.dptr_, mask.dptr_);
+          });
         } else {
           const float slope = (param_.lower_bound + param_.upper_bound) / 2.0f;
-          Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, slope));
+          MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
+            mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::xelu, Req>, xpu>::Launch(
+              s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_, DType(slope));
+          });
         }
         break;
       }
       case leakyrelu::kELU: {
-        Assign(out, req[leakyrelu::kOut], F<mshadow_op::elu>(data, param_.slope));
+        MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
+          mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::elu, Req>, xpu>::Launch(
+            s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_,
+            DType(param_.slope));
+        });
         break;
       }
       default:
@@ -155,33 +183,38 @@ class LeakyReLUOp : public Operator {
     CHECK_EQ(req.size(), expected);
     CHECK_EQ(in_data.size(), expected);
     Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 3> output;
-    Tensor<xpu, 3> data;
-    Tensor<xpu, 3> gdata;
-    Tensor<xpu, 3> grad;
-    Tensor<xpu, 3> mask;
-    Tensor<xpu, 1> weight;
-    Tensor<xpu, 1> grad_weight;
+    Tensor<xpu, 3, DType> output;
+    Tensor<xpu, 3, DType> data;
+    Tensor<xpu, 3, DType> gdata;
+    Tensor<xpu, 3, DType> grad;
+    Tensor<xpu, 3, DType> mask;
+    Tensor<xpu, 1, DType> weight;
+    Tensor<xpu, 1, DType> grad_weight;
     int n = out_grad[leakyrelu::kOut].shape_[0];
     int k = out_grad[leakyrelu::kOut].shape_[1];
     Shape<3> dshape = Shape3(n, k, out_grad[leakyrelu::kOut].Size()/n/k);
-    grad = out_grad[leakyrelu::kOut].get_with_shape<xpu, 3, real_t>(dshape, s);
-    gdata = in_grad[leakyrelu::kData].get_with_shape<xpu, 3, real_t>(dshape, s);
-    output = out_data[leakyrelu::kOut].get_with_shape<xpu, 3, real_t>(dshape, s);
+    grad = out_grad[leakyrelu::kOut].get_with_shape<xpu, 3, DType>(dshape, s);
+    gdata = in_grad[leakyrelu::kData].get_with_shape<xpu, 3, DType>(dshape, s);
+    output = out_data[leakyrelu::kOut].get_with_shape<xpu, 3, DType>(dshape, s);
     if (param_.act_type == leakyrelu::kRReLU) {
-      mask = out_data[leakyrelu::kMask].get_with_shape<xpu, 3, real_t>(dshape, s);
+      mask = out_data[leakyrelu::kMask].get_with_shape<xpu, 3, DType>(dshape, s);
     }
     if (param_.act_type == leakyrelu::kPReLU) {
-      data = in_data[leakyrelu::kData].get_with_shape<xpu, 3, real_t>(dshape, s);
+      data = in_data[leakyrelu::kData].get_with_shape<xpu, 3, DType>(dshape, s);
     }
     switch (param_.act_type) {
       case leakyrelu::kLeakyReLU: {
-        Assign(gdata, req[leakyrelu::kData], F<mshadow_op::xelu_grad>(output, param_.slope) * grad);
+        MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, {
+          mxnet_op::Kernel<mxnet_op::op_with_req<
+            mxnet_op::backward_grad_tuned<mxnet::op::mshadow_op::xelu_grad>, Req>, xpu>::Launch(
+              s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_,
+              output.dptr_, DType(param_.slope));
+        });
         break;
       }
       case leakyrelu::kPReLU: {
-        weight = in_data[leakyrelu::kGamma].get<xpu, 1, real_t>(s);
-        grad_weight = in_grad[leakyrelu::kGamma].get<xpu, 1, real_t>(s);
+        weight = in_data[leakyrelu::kGamma].get<xpu, 1, DType>(s);
+        grad_weight = in_grad[leakyrelu::kGamma].get<xpu, 1, DType>(s);
         if (weight.shape_.Size() == 1) {
           Shape<4> gshape = Shape4(1, grad.shape_[0], grad.shape_[1], grad.shape_[2]);
           Assign(grad_weight, req[leakyrelu::kGamma],
@@ -204,7 +237,12 @@ class LeakyReLUOp : public Operator {
         break;
       }
       case leakyrelu::kELU: {
-        Assign(gdata, req[leakyrelu::kData], F<mshadow_op::elu_grad>(output, param_.slope) * grad);
+        MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, {
+          mxnet_op::Kernel<mxnet_op::op_with_req<
+            mxnet_op::backward_grad_tuned<mxnet::op::mshadow_op::elu_grad>, Req>, xpu>::Launch(
+              s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_,
+              output.dptr_, DType(param_.slope));
+        });
         break;
       }
       default:
@@ -217,7 +255,7 @@ class LeakyReLUOp : public Operator {
 };  // class LeakyReLUOp
 
 template<typename xpu>
-Operator* CreateOp(LeakyReLUParam type);
+Operator* CreateOp(LeakyReLUParam type, int dtype);
 
 #if DMLC_USE_CXX11
 class LeakyReLUProp : public OperatorProperty {
@@ -256,6 +294,26 @@ class LeakyReLUProp : public OperatorProperty {
     return true;
   }
 
+  bool InferType(std::vector<int> *in_type,
+                 std::vector<int> *out_type,
+                 std::vector<int> *aux_type) const override {
+    int dtype = -1;
+    for (const int& type : *in_type) {
+      type_assign(&dtype, type);
+    }
+    for (const int& type : *out_type) {
+      type_assign(&dtype, type);
+    }
+
+    for (size_t i = 0; i < in_type->size(); ++i) {
+      TYPE_ASSIGN_CHECK(*in_type, i, dtype);
+    }
+    for (size_t i = 0; i < out_type->size(); ++i) {
+      TYPE_ASSIGN_CHECK(*out_type, i, dtype);
+    }
+    return dtype != -1;
+  }
+
   OperatorProperty* Copy() const override {
     auto ptr = new LeakyReLUProp();
     ptr->param_ = param_;
@@ -338,7 +396,13 @@ class LeakyReLUProp : public OperatorProperty {
     }
   }
 
-  Operator* CreateOperator(Context ctx) const override;
+  Operator* CreateOperator(Context ctx) const override {
+    LOG(FATAL) << "Not Implemented.";
+    return NULL;
+  }
+
+  Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
+                           std::vector<int> *in_type) const override;
 
  private:
   LeakyReLUParam param_;
diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc
index 6e6fa53..99b6ba3 100644
--- a/src/operator/leaky_relu.cc
+++ b/src/operator/leaky_relu.cc
@@ -30,12 +30,17 @@
 namespace mxnet {
 namespace op {
 template<>
-Operator *CreateOp<cpu>(LeakyReLUParam param) {
-  return new LeakyReLUOp<cpu>(param);
+Operator *CreateOp<cpu>(LeakyReLUParam param, int dtype) {
+  Operator* op = NULL;
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    op = new LeakyReLUOp<cpu, DType>(param);
+  });
+  return op;
 }
 
-Operator *LeakyReLUProp::CreateOperator(Context ctx) const {
-  DO_BIND_DISPATCH(CreateOp, param_);
+Operator *LeakyReLUProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
+                                          std::vector<int> *in_type) const {
+  DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
 }
 
 DMLC_REGISTER_PARAMETER(LeakyReLUParam);
diff --git a/src/operator/leaky_relu.cu b/src/operator/leaky_relu.cu
index 9de237c..74b444d 100644
--- a/src/operator/leaky_relu.cu
+++ b/src/operator/leaky_relu.cu
@@ -29,8 +29,12 @@
 namespace mxnet {
 namespace op {
 template<>
-Operator *CreateOp<gpu>(LeakyReLUParam param) {
-  return new LeakyReLUOp<gpu>(param);
+Operator *CreateOp<gpu>(LeakyReLUParam param, int dtype) {
+  Operator* op = NULL;
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    op = new LeakyReLUOp<gpu, DType>(param);
+  });
+  return op;
 }
 
 }  // namespace op
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 1d4284e..5606c64 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -89,6 +89,13 @@ MXNET_UNARY_MATH_OP_NC(identity, a);
 
 MXNET_UNARY_MATH_OP(identity_grad, 1);
 
+struct identity_with_cast {
+  template<typename DTypeIn, typename DTypeOut>
+  MSHADOW_XINLINE static void Map(int i, DTypeOut *out, DTypeIn *in) {
+    out[i] = DTypeOut(in[i]);
+  }
+};
+
 MXNET_BINARY_MATH_OP_NC(left, a);
 
 MXNET_BINARY_MATH_OP_NC(right, b);
@@ -119,13 +126,13 @@ MXNET_UNARY_MATH_OP_NC(relu, a > DType(0) ? a : DType(0));
 
 MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0));
 
-MXNET_BINARY_MATH_OP(xelu, a > DType(0) ? math::id(a) :
-                     math::id(a) * math::id(b));
+MXNET_BINARY_MATH_OP_NC(xelu, a > DType(0) ? a :
+                        DType(static_cast<float>(a) * static_cast<float>(b)));
 
 MXNET_BINARY_MATH_OP_NC(xelu_grad, a > DType(0) ? DType(1) : b);
 
-MXNET_BINARY_MATH_OP(elu, a > DType(0) ? math::id(a) :
-                     math::id(b) * math::expm1(a));
+MXNET_BINARY_MATH_OP_NC(elu, a > DType(0) ? a :
+                        DType(math::id(b) * math::expm1(a)));
 
 MXNET_BINARY_MATH_OP_NC(elu_grad, a > DType(0) ? DType(1) : DType(b + a));
 
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index c13f1ac..c48d83a 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -314,9 +314,13 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::right);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::right);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::power);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rpower);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::xelu); // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::elu); // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::elu_grad); // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot);  // NOLINT()
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index c1df291..240c06a 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -489,6 +489,77 @@ def test_relu():
     check_symbolic_backward(y, [xa], [np.ones(shape)], [ga])
 
 
+@with_seed(1234)
+def test_leaky_relu():
+    def fleaky_relu(x, act_type, slope=0.25):
+        neg_indices = x < 0
+        out = x.copy()
+        if act_type == 'elu':
+            out[neg_indices] = slope * (np.exp(out[neg_indices]) - 1.)
+        elif act_type == 'leaky':
+            out[neg_indices] = slope * out[neg_indices]
+        return out
+    def fleaky_relu_grad(grad, x, y, act_type, slope=0.25):
+        neg_indices = x < 0
+        out = np.ones(x.shape)
+        if act_type == 'elu':
+            out[neg_indices] = y[neg_indices] + slope
+        elif act_type == 'leaky':
+            out[neg_indices] = slope
+        return out * grad
+    shape = (3, 4)
+    x = mx.symbol.Variable("x")
+    slp = 0.0625
+    for dtype in [np.float16, np.float32, np.float64]:
+        xa = np.random.uniform(low=-1.0,high=-0.2,size=shape).astype(dtype)
+        eps = 1e-4
+        xa[abs(xa) < eps] = 1.0
+        # eps = 1e-2 if dtype is np.float16 else 1e-4
+        for act_type in ['leaky']:
+            y = mx.symbol.LeakyReLU(data=x, slope=slp, act_type=act_type)
+            ya = fleaky_relu(xa, slope=slp, act_type=act_type)
+            ga = fleaky_relu_grad(np.ones(shape), xa, ya, slope=slp, act_type=act_type)
+            check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=1e-4, atol=1e-4)
+            check_symbolic_forward(y, [xa], [ya], rtol=eps, atol=1e-5, dtype=dtype)
+            check_symbolic_backward(y, [xa], [np.ones(shape)], [ga], rtol=eps, atol=1e-5, dtype=dtype)
+
+
+@with_seed(1234)
+def test_prelu():
+    def fprelu(x, gamma):
+        pos_indices = x > 0
+        out = x.copy()
+        out = np.multiply(out, gamma)
+        out[pos_indices] = x[pos_indices]
+        return out
+    def fprelu_grad(x, y, gamma):
+        pos_indices = x > 0
+        grad_x = np.multiply(np.ones(x.shape), gamma)
+        grad_gam = np.zeros(gamma.shape)
+        copy_x = x.copy()
+        copy_x[pos_indices] = 0.0
+        grad_x[pos_indices] = 1.0
+        if gamma.shape[0] == 1:
+            grad_gam = np.sum(np.sum(copy_x))
+        elif gamma.shape[0] > 1:
+            grad_gam = np.sum(copy_x, axis=0)
+        return (grad_x, grad_gam)
+    shape = (3,4)
+    x = mx.symbol.Variable("x")
+    gamma = mx.symbol.Variable("gamma")
+    for dtype in [np.float16, np.float32, np.float64]:
+        for gam in [np.array([0.1], dtype=dtype), np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype)]:
+            xa = np.random.uniform(low=-1.0,high=1.0,size=shape).astype(dtype)
+            eps = 1e-4
+            xa[abs(xa) < eps] = 1.0
+            y = mx.symbol.LeakyReLU(data=x, gamma=gamma, act_type='prelu')
+            ya = fprelu(xa, gam)
+            g_xa, g_gam = fprelu_grad(xa, ya, gamma=gam)
+            check_numeric_gradient(y, [xa, gam], numeric_eps=eps, rtol=1e-3, atol=1e-4)
+            check_symbolic_forward(y, [xa, gam], [ya], rtol=1e-3, atol=1e-20)
+            check_symbolic_backward(y, [xa, gam], [np.ones(shape)], [g_xa], rtol=1e-3, atol=1e-20)
+
+
 @with_seed()
 def test_sigmoid():
     def fsigmoid(a):

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.