You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/20 20:34:47 UTC

[GitHub] piiswrong closed pull request #10078: [MXNET-92] Support float16 in L2Normalization operator

piiswrong closed pull request #10078: [MXNET-92] Support float16 in L2Normalization operator
URL: https://github.com/apache/incubator-mxnet/pull/10078
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/l2_normalization-inl.h b/src/operator/l2_normalization-inl.h
index cb8e740d7ff..d53e0c5caf9 100644
--- a/src/operator/l2_normalization-inl.h
+++ b/src/operator/l2_normalization-inl.h
@@ -66,7 +66,7 @@ struct L2NormalizationParam : public dmlc::Parameter<L2NormalizationParam> {
  * \brief This is the implementation of l2 normalization operator.
  * \tparam xpu The device that the op will be executed on.
  */
-template<typename xpu>
+template<typename xpu, typename DType>
 class L2NormalizationOp : public Operator {
  public:
   explicit L2NormalizationOp(L2NormalizationParam p) {
@@ -89,41 +89,53 @@ class L2NormalizationOp : public Operator {
     if (param_.mode == l2_normalization::kInstance) {
       Shape<2> dshape = Shape2(orig_shape[0],
         orig_shape.ProdShape(1, orig_shape.ndim()));
-      Tensor<xpu, 2> data = in_data[l2_normalization::kData]
-        .get_with_shape<xpu, 2, real_t>(dshape, s);
-      Tensor<xpu, 2> out = out_data[l2_normalization::kOut]
-        .get_with_shape<xpu, 2, real_t>(dshape, s);
-      Tensor<xpu, 1> norm = out_data[l2_normalization::kNorm].get<xpu, 1, real_t>(s);
+      Tensor<xpu, 2, DType> data = in_data[l2_normalization::kData]
+        .get_with_shape<xpu, 2, DType>(dshape, s);
+      Tensor<xpu, 2, DType> out = out_data[l2_normalization::kOut]
+        .get_with_shape<xpu, 2, DType>(dshape, s);
+      Tensor<xpu, 1, DType> norm = out_data[l2_normalization::kNorm].get<xpu, 1, DType>(s);
       norm = sumall_except_dim<0>(F<mxnet::op::mshadow_op::square>(data));
-      norm = F<mxnet::op::mshadow_op::square_root>(norm + param_.eps);
+      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+        mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
+          s, norm.size(0), norm.dptr_, norm.dptr_, DType(param_.eps));
+      });
+      norm = F<mxnet::op::mshadow_op::square_root>(norm);
       out = data / broadcast<0>(norm, out.shape_);
     } else if (param_.mode == l2_normalization::kChannel) {
       CHECK_GE(orig_shape.ndim(), 3U);
       Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
         orig_shape.ProdShape(2, orig_shape.ndim()));
-      Tensor<xpu, 3> data = in_data[l2_normalization::kData]
-        .get_with_shape<xpu, 3, real_t>(dshape, s);
-      Tensor<xpu, 3> out = out_data[l2_normalization::kOut]
-        .get_with_shape<xpu, 3, real_t>(dshape, s);
+      Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData]
+        .get_with_shape<xpu, 3, DType>(dshape, s);
+      Tensor<xpu, 3, DType> out = out_data[l2_normalization::kOut]
+        .get_with_shape<xpu, 3, DType>(dshape, s);
       Shape<2> norm_shape = Shape2(dshape[0], dshape[2]);
-      Tensor<xpu, 2> norm = out_data[l2_normalization::kNorm]
-        .get_with_shape<xpu, 2, real_t>(norm_shape, s);
+      Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
+        .get_with_shape<xpu, 2, DType>(norm_shape, s);
       norm = reduce_with_axis<red::sum, false>(F<mxnet::op::mshadow_op::square>(data), 1);
-      norm = F<mxnet::op::mshadow_op::square_root>(norm + param_.eps);
+      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+        mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
+          s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps));
+      });
+      norm = F<mxnet::op::mshadow_op::square_root>(norm);
       out = data / broadcast_with_axis(norm, 0, orig_shape[1]);
     } else if (param_.mode == l2_normalization::kSpatial) {
       CHECK_GE(orig_shape.ndim(), 3U);
       Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
         orig_shape.ProdShape(2, orig_shape.ndim()));
-      Tensor<xpu, 3> data = in_data[l2_normalization::kData]
-        .get_with_shape<xpu, 3, real_t>(dshape, s);
-      Tensor<xpu, 3> out = out_data[l2_normalization::kOut]
-        .get_with_shape<xpu, 3, real_t>(dshape, s);
+      Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData]
+        .get_with_shape<xpu, 3, DType>(dshape, s);
+      Tensor<xpu, 3, DType> out = out_data[l2_normalization::kOut]
+        .get_with_shape<xpu, 3, DType>(dshape, s);
       Shape<2> norm_shape = Shape2(dshape[0], dshape[1]);
-      Tensor<xpu, 2> norm = out_data[l2_normalization::kNorm]
-        .get_with_shape<xpu, 2, real_t>(norm_shape, s);
+      Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
+        .get_with_shape<xpu, 2, DType>(norm_shape, s);
       norm = reduce_with_axis<red::sum, false>(F<mxnet::op::mshadow_op::square>(data), 2);
-      norm = F<mxnet::op::mshadow_op::square_root>(norm + param_.eps);
+      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+        mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
+          s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps));
+      });
+      norm = F<mxnet::op::mshadow_op::square_root>(norm);
       out = data / broadcast_with_axis(norm, 1, dshape[2]);
     } else {
       LOG(FATAL) << "Unexpected mode in l2 normalization";
@@ -148,15 +160,15 @@ class L2NormalizationOp : public Operator {
     if (param_.mode == l2_normalization::kInstance) {
       Shape<2> dshape = Shape2(orig_shape[0],
         orig_shape.ProdShape(1, orig_shape.ndim()));
-      Tensor<xpu, 2> data = out_data[l2_normalization::kOut]
-        .get_with_shape<xpu, 2, real_t>(dshape, s);
-      Tensor<xpu, 2> grad_in = in_grad[l2_normalization::kData]
-        .get_with_shape<xpu, 2, real_t>(dshape, s);
-      Tensor<xpu, 2> grad_out = out_grad[l2_normalization::kOut]
-        .get_with_shape<xpu, 2, real_t>(dshape, s);
-      Tensor<xpu, 1> norm = out_data[l2_normalization::kNorm].get<xpu, 1, real_t>(s);
-      Tensor<xpu, 1> temp = ctx.requested[l2_normalization::kTempSpace]
-        .get_space<xpu>(mshadow::Shape1(data.shape_[0]), s);
+      Tensor<xpu, 2, DType> data = out_data[l2_normalization::kOut]
+        .get_with_shape<xpu, 2, DType>(dshape, s);
+      Tensor<xpu, 2, DType> grad_in = in_grad[l2_normalization::kData]
+        .get_with_shape<xpu, 2, DType>(dshape, s);
+      Tensor<xpu, 2, DType> grad_out = out_grad[l2_normalization::kOut]
+        .get_with_shape<xpu, 2, DType>(dshape, s);
+      Tensor<xpu, 1, DType> norm = out_data[l2_normalization::kNorm].get<xpu, 1, DType>(s);
+      Tensor<xpu, 1, DType> temp = ctx.requested[l2_normalization::kTempSpace]
+        .get_space_typed<xpu, 1, DType>(mshadow::Shape1(data.shape_[0]), s);
       temp = sumall_except_dim<0>(grad_out * data);
       Assign(grad_in, req[l2_normalization::kData],
         (grad_out - data * broadcast<0>(temp, data.shape_)) /
@@ -165,17 +177,17 @@ class L2NormalizationOp : public Operator {
       CHECK_GE(orig_shape.ndim(), 3U);
       Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
         orig_shape.ProdShape(2, orig_shape.ndim()));
-      Tensor<xpu, 3> data = out_data[l2_normalization::kOut]
-        .get_with_shape<xpu, 3, real_t>(dshape, s);
-      Tensor<xpu, 3> grad_in = in_grad[l2_normalization::kData]
-        .get_with_shape<xpu, 3, real_t>(dshape, s);
-      Tensor<xpu, 3> grad_out = out_grad[l2_normalization::kOut]
-        .get_with_shape<xpu, 3, real_t>(dshape, s);
+      Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut]
+        .get_with_shape<xpu, 3, DType>(dshape, s);
+      Tensor<xpu, 3, DType> grad_in = in_grad[l2_normalization::kData]
+        .get_with_shape<xpu, 3, DType>(dshape, s);
+      Tensor<xpu, 3, DType> grad_out = out_grad[l2_normalization::kOut]
+        .get_with_shape<xpu, 3, DType>(dshape, s);
       Shape<2> norm_shape = Shape2(dshape[0], dshape[2]);
-      Tensor<xpu, 2> norm = out_data[l2_normalization::kNorm]
-        .get_with_shape<xpu, 2, real_t>(norm_shape, s);
-      Tensor<xpu, 2> temp = ctx.requested[l2_normalization::kTempSpace]
-        .get_space<xpu>(mshadow::Shape2(data.shape_[0], data.shape_[2]), s);
+      Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
+        .get_with_shape<xpu, 2, DType>(norm_shape, s);
+      Tensor<xpu, 2, DType> temp = ctx.requested[l2_normalization::kTempSpace]
+        .get_space_typed<xpu, 2, DType>(mshadow::Shape2(data.shape_[0], data.shape_[2]), s);
       temp = reduce_with_axis<red::sum, false>(grad_out * data, 1);
       Assign(grad_in, req[l2_normalization::kData],
         (grad_out - data * broadcast_with_axis(temp, 0, orig_shape[1])) /
@@ -184,17 +196,17 @@ class L2NormalizationOp : public Operator {
       CHECK_GE(orig_shape.ndim(), 3U);
       Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
         orig_shape.ProdShape(2, orig_shape.ndim()));
-      Tensor<xpu, 3> data = out_data[l2_normalization::kOut]
-        .get_with_shape<xpu, 3, real_t>(dshape, s);
-      Tensor<xpu, 3> grad_in = in_grad[l2_normalization::kData]
-        .get_with_shape<xpu, 3, real_t>(dshape, s);
-      Tensor<xpu, 3> grad_out = out_grad[l2_normalization::kOut]
-        .get_with_shape<xpu, 3, real_t>(dshape, s);
+      Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut]
+        .get_with_shape<xpu, 3, DType>(dshape, s);
+      Tensor<xpu, 3, DType> grad_in = in_grad[l2_normalization::kData]
+        .get_with_shape<xpu, 3, DType>(dshape, s);
+      Tensor<xpu, 3, DType> grad_out = out_grad[l2_normalization::kOut]
+        .get_with_shape<xpu, 3, DType>(dshape, s);
       Shape<2> norm_shape = Shape2(dshape[0], dshape[1]);
-      Tensor<xpu, 2> norm = out_data[l2_normalization::kNorm]
-        .get_with_shape<xpu, 2, real_t>(norm_shape, s);
-      Tensor<xpu, 2> temp = ctx.requested[l2_normalization::kTempSpace]
-        .get_space<xpu>(mshadow::Shape2(data.shape_[0], data.shape_[1]), s);
+      Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
+        .get_with_shape<xpu, 2, DType>(norm_shape, s);
+      Tensor<xpu, 2, DType> temp = ctx.requested[l2_normalization::kTempSpace]
+        .get_space_typed<xpu, 2, DType>(mshadow::Shape2(data.shape_[0], data.shape_[1]), s);
       temp = reduce_with_axis<red::sum, false>(grad_out * data, 2);
       Assign(grad_in, req[l2_normalization::kData],
         (grad_out - data * broadcast_with_axis(temp, 1, dshape[2])) /
@@ -210,7 +222,7 @@ class L2NormalizationOp : public Operator {
 
 // Decalre Factory function, used for dispatch specialization
 template<typename xpu>
-Operator* CreateOp(L2NormalizationParam param);
+Operator* CreateOp(L2NormalizationParam param, int dtype);
 
 #if DMLC_USE_CXX11
 class L2NormalizationProp : public OperatorProperty {
@@ -235,6 +247,19 @@ class L2NormalizationProp : public OperatorProperty {
     return param_.__DICT__();
   }
 
+  bool InferType(std::vector<int> *in_type,
+                 std::vector<int> *out_type,
+                 std::vector<int> *aux_type) const override {
+    int dtype = (*in_type)[0];
+    type_assign(&dtype, (*out_type)[0]);
+    type_assign(&dtype, (*out_type)[1]);
+
+    TYPE_ASSIGN_CHECK(*in_type, 0, dtype);
+    TYPE_ASSIGN_CHECK(*out_type, 0, dtype);
+    TYPE_ASSIGN_CHECK(*out_type, 1, dtype);
+    return dtype != -1;
+  }
+
   bool InferShape(std::vector<TShape> *in_shape,
                   std::vector<TShape> *out_shape,
                   std::vector<TShape> *aux_shape) const override {
@@ -294,7 +319,13 @@ class L2NormalizationProp : public OperatorProperty {
     return {ResourceRequest::kTempSpace};
   }
 
-  Operator* CreateOperator(Context ctx) const override;
+  Operator* CreateOperator(Context ctx) const override {
+    LOG(FATAL) << "Not Implemented.";
+    return NULL;
+  }
+
+  Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
+                             std::vector<int> *in_type) const override;
 
  private:
   L2NormalizationParam param_;
diff --git a/src/operator/l2_normalization.cc b/src/operator/l2_normalization.cc
index 76e64c8d350..c313b442442 100644
--- a/src/operator/l2_normalization.cc
+++ b/src/operator/l2_normalization.cc
@@ -26,13 +26,18 @@
 namespace mxnet {
 namespace op {
 template<>
-Operator* CreateOp<cpu>(L2NormalizationParam param) {
-  return new L2NormalizationOp<cpu>(param);
+Operator* CreateOp<cpu>(L2NormalizationParam param, int dtype) {
+  Operator* op = NULL;
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    op = new L2NormalizationOp<cpu, DType>(param);
+  });
+  return op;
 }
 
 // DO_BIND_DISPATCH comes from static_operator_common.h
-Operator* L2NormalizationProp::CreateOperator(Context ctx) const {
-  DO_BIND_DISPATCH(CreateOp, param_);
+Operator* L2NormalizationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
+                                                std::vector<int> *in_type) const {
+  DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
 }
 
 DMLC_REGISTER_PARAMETER(L2NormalizationParam);
diff --git a/src/operator/l2_normalization.cu b/src/operator/l2_normalization.cu
index 1c1c0e5ed09..2034f984174 100644
--- a/src/operator/l2_normalization.cu
+++ b/src/operator/l2_normalization.cu
@@ -26,8 +26,12 @@
 namespace mxnet {
 namespace op {
 template<>
-Operator* CreateOp<gpu>(L2NormalizationParam param) {
-  return new L2NormalizationOp<gpu>(param);
+Operator* CreateOp<gpu>(L2NormalizationParam param, int dtype) {
+  Operator* op = NULL;
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    op = new L2NormalizationOp<gpu, DType>(param);
+  });
+  return op;
 }
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 405283bdf31..ae42dbf3e60 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -2391,11 +2391,11 @@ def test_instance_normalization():
     check_instance_norm_with_shape((3,3,2,3,2,1,1), default_context())
 
 
-def check_l2_normalization(in_shape, mode, norm_eps=1e-10):
+def check_l2_normalization(in_shape, mode, dtype, norm_eps=1e-10):
     ctx = default_context()
     data = mx.symbol.Variable('data')
     out = mx.symbol.L2Normalization(data=data, mode=mode, eps=norm_eps)
-    in_data = np.random.uniform(-1, 1, in_shape)
+    in_data = np.random.uniform(-1, 1, in_shape).astype(dtype)
     # calculate numpy results
     if mode == 'channel':
         assert in_data.ndim > 2
@@ -2419,7 +2419,7 @@ def check_l2_normalization(in_shape, mode, norm_eps=1e-10):
     exe = out.simple_bind(ctx=ctx, data=in_data.shape)
     output = exe.forward(is_train=True, data=in_data)
     # compare numpy + mxnet
-    assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-5)
+    assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-2 if dtype is 'float16' else 1e-5, atol=1e-5)
     # check gradient
     check_numeric_gradient(out, [in_data], numeric_eps=1e-3, rtol=1e-2, atol=1e-3)
 
@@ -2427,13 +2427,14 @@ def check_l2_normalization(in_shape, mode, norm_eps=1e-10):
 # TODO(szha): Seeding this masks failures. We need to do a deep dive for failures without this seed.
 @with_seed(1234)
 def test_l2_normalization():
-    for mode in ['channel', 'spatial', 'instance']:
-        for nbatch in [1, 4]:
-            for nchannel in [3, 5]:
-                for height in [4, 6]:
-                    check_l2_normalization((nbatch, nchannel, height), mode)
-                    for width in [5, 7]:
-                        check_l2_normalization((nbatch, nchannel, height, width), mode)
+    for dtype in ['float16', 'float32', 'float64']:
+        for mode in ['channel', 'spatial', 'instance']:
+            for nbatch in [1, 4]:
+                for nchannel in [3, 5]:
+                    for height in [4, 6]:
+                        check_l2_normalization((nbatch, nchannel, height), mode, dtype)
+                        for width in [5, 7]:
+                            check_l2_normalization((nbatch, nchannel, height, width), mode, dtype)
 
 
 def check_layer_normalization(in_shape, axis, eps, dtype=np.float32, forward_check_eps=1E-3):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services