You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/05/15 16:55:59 UTC

[incubator-mxnet] branch master updated: Fix _backward_norm op registration (#10913)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a572aa4  Fix _backward_norm op registration (#10913)
a572aa4 is described below

commit a572aa43400f882f254b5fef0d9afda443bb40c3
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Tue May 15 09:55:52 2018 -0700

    Fix _backward_norm op registration (#10913)
    
    * Update test_operator_gpu.py
    
    * Update broadcast_reduce_op_value.cc
    
    * Update broadcast_reduce_op_value.cu
    
    * Update broadcast_reduce_op.h
    
    * Update broadcast_reduce_op.h
    
    * Update optimizer_op.cc
    
    * fix
    
    * update test
    
    * remove keep dims check
    
    * fix build
    
    * fix reshape
---
 src/operator/optimizer_op.cc                     |  2 +-
 src/operator/tensor/broadcast_reduce_op.h        | 77 +++++++++---------------
 src/operator/tensor/broadcast_reduce_op_value.cc | 42 +++++++++++--
 src/operator/tensor/broadcast_reduce_op_value.cu | 17 +++---
 tests/python/gpu/test_operator_gpu.py            |  6 +-
 tests/python/unittest/test_sparse_ndarray.py     | 16 +++--
 6 files changed, 83 insertions(+), 77 deletions(-)

diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 935c92a..cf126ed 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -305,7 +305,7 @@ inline bool SGDStorageType(const nnvm::NodeAttrs& attrs,
     dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
                                      dispatch_mode, DispatchMode::kFComputeEx);
     // warn users if lazy_update is turned on
-    if (dispatched && param.lazy_update) LogLazyUpdate();
+    if (dispatched && param.wd != 0 && param.lazy_update) LogLazyUpdate();
   }
   if (!dispatched) {
     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index 1d4ef0a..e50071b 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -385,11 +385,11 @@ inline bool ReduceAxesOpForwardStorage(const nnvm::NodeAttrs& attrs,
   const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
   const int in_stype = in_attrs->at(0);
   int& out_stype = out_attrs->at(0);
-  bool dispatched = false;
-  // sum only supported for CPU for now. TODO: Remove when support for GPU added
+  // sum and reduce only supports for CPU for now.
   const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
   const auto dispatch_ex =
       invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx;
+  bool dispatched = false;
   if (!dispatched && in_stype == kDefaultStorage) {
     // When input is dense output storage is set as dense and dispatched to
     // dense operator
@@ -707,18 +707,16 @@ void ReduceCsr(const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu>* s, const OpCo
 }
 
 template <typename xpu, typename reducer, bool normalize = false>
-void SumOpForwardEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
-                    const std::vector<NDArray>& inputs,
-                    const std::vector<OpReqType>& req,
-                    const std::vector<NDArray>& outputs) {
+void ReduceAxesOpForwardEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
+                           const std::vector<NDArray>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<NDArray>& outputs) {
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
   mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
   const NDArrayStorageType istype = inputs[0].storage_type();
   if (istype == kCSRStorage) {
-    CHECK_EQ(inputs[0].shape().ndim(), 2U)
-        << "sum(csr)/mean(csr) op only supports 2D ndarray as input";
     NDArray output = outputs[0];
     ReduceCsr<xpu, mshadow::red::sum, normalize>(attrs, s, ctx, inputs[0],
                                                  req[0], &output);
@@ -880,17 +878,30 @@ inline bool L2NormStorageType(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(out_attrs->size(), 1U);
   const int in_stype = in_attrs->at(0);
   int& out_stype = out_attrs->at(0);
+  const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
   bool dispatched = false;
+  // l2 norm on a particular axis only supports cpu
+  const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
+  const auto dispatch_ex =
+      invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx;
   if (!dispatched && in_stype == kDefaultStorage) {
     // dns -> dns
     dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
                                      DispatchMode::kFCompute);
   }
-  if (!dispatched && (in_stype == kCSRStorage || in_stype == kRowSparseStorage)) {
-    // csr/rsp -> dns
+  const TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
+  if (!dispatched && (in_stype == kRowSparseStorage || in_stype == kCSRStorage) &&
+      axis.ndim() == 0 && param.ord == 2) {
+    // l2 norm: rsp/csr, axis = () -> dns
     dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
                                      DispatchMode::kFComputeEx);
   }
+  if (!dispatched && in_stype == kCSRStorage && axis.ndim() == 1 && !param.keepdims &&
+      (axis[0] == 0 || axis[0] == 1) && param.ord == 2) {
+    // l2 norm: csr, axis = 0/1 -> dns
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
+                                     dispatch_ex);
+  }
   if (!dispatched) {
     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
   }
@@ -943,10 +954,11 @@ void L2NormComputeImpl(mshadow::Stream<xpu> *s,
                        const TBlob& input,
                        const OpReqType req,
                        const TBlob& output) {
-  if (req == kNullOp) return;
   MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, {
+    // assign_req switch exits immediately for null req
     MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-      mshadow::Tensor<xpu, 1, DType> out = output.get<xpu, 1, DType>(s);
+      mshadow::Tensor<xpu, 1, DType> out = output.get_with_shape<xpu, 1, DType>(
+        mshadow::Shape1(output.shape_.Size()), s);
       mshadow::Tensor<xpu, 1, DType> in = input.get_with_shape<xpu, 1, DType>(
         mshadow::Shape1(input.shape_.Size()), s);
       mshadow::VectorDot(out, in, in);
@@ -1002,12 +1014,12 @@ void L2NormGradCompute(const nnvm::NodeAttrs& attrs,
   using namespace mshadow::expr;
   if (req[0] == kNullOp) return;
 
-  const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
+  const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
   TShape small;
   if (param.keepdims) {
     small = inputs[0].shape_;
   } else {
-    small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, param.exclude);
+    small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, false);
   }
   ReduceAxesBackwardUseInOutImpl<xpu, mshadow_op::div, false>(ctx, small, inputs,
                                                               req, outputs);
@@ -1034,42 +1046,7 @@ void L2NormComputeEx(const nnvm::NodeAttrs& attrs,
                      const OpContext& ctx,
                      const std::vector<NDArray>& inputs,
                      const std::vector<OpReqType>& req,
-                     const std::vector<NDArray>& outputs) {
-  CHECK_EQ(inputs.size(), 1U);
-  CHECK_EQ(outputs.size(), 1U);
-  CHECK_EQ(req.size(), 1U);
-  const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
-  CHECK_EQ(param.ord, 2) << "norm only support ord=2";
-  mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
-  const NDArrayStorageType istype = inputs[0].storage_type();
-  const TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
-  if ((istype == kRowSparseStorage || istype == kCSRStorage) && axis.ndim() == 0) {
-    // We only support norm on the entire array for now.
-    L2NormComputeSparseImpl<xpu>(s, inputs[0], req[0], outputs[0].data());
-  } else if (istype == kCSRStorage) {
-    CHECK_EQ(inputs[0].shape().ndim(), 2U)
-        << "norm(csr) op only supports 2D ndarray as input";
-    CHECK_EQ(axis.ndim(), 1U) << "sum(csr)/mean(csr) only supports axis 0 or 1";
-    CHECK(axis[0] == 0 || axis[0] == 1)
-        << "sum(csr)/mean(csr) only support axis 0 or 1";
-    CHECK(!param.keepdims) << "keepdims not supported for sparse";
-    NDArray output = outputs[0];
-    ReduceCsrImpl<xpu, sq_sum, false>(s, ctx, inputs[0], req[0], &output, axis);
-    CHECK_EQ(outputs[0].storage_type(), kDefaultStorage);
-    SqRootForL2<xpu>(ctx, req[0], outputs[0].data());
-  } else {
-    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
-  }
-}
-
-template<typename xpu>
-void L2NormGradComputeEx(const nnvm::NodeAttrs& attrs,
-                         const OpContext& ctx,
-                         const std::vector<NDArray>& inputs,
-                         const std::vector<OpReqType>& req,
-                         const std::vector<NDArray>& outputs) {
-  LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
-}
+                     const std::vector<NDArray>& outputs);
 
 /*! \brief index element from array along axes */
 template<int ndim>
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc
index 5f433cc..7bcc3e9 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cc
+++ b/src/operator/tensor/broadcast_reduce_op_value.cc
@@ -46,6 +46,35 @@ Defined in )code";
   return doc;
 }
 
+template<>
+void L2NormComputeEx<cpu>(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<NDArray>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  const NDArrayStorageType istype = inputs[0].storage_type();
+  const TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
+  if ((istype == kRowSparseStorage || istype == kCSRStorage) && axis.ndim() == 0 &&
+       param.ord == 2) {
+    // l2 norm on the entire array
+    L2NormComputeSparseImpl<cpu>(s, inputs[0], req[0], outputs[0].data());
+  } else if (istype == kCSRStorage && axis.ndim() == 1 && (axis[0] == 0 || axis[0] == 1) &&
+             !param.keepdims && param.ord == 2) {
+    // l2 norm on a particular axis
+    NDArray output = outputs[0];
+    ReduceCsrImpl<cpu, sq_sum, false>(s, ctx, inputs[0], req[0], &output, axis);
+    CHECK_EQ(outputs[0].storage_type(), kDefaultStorage);
+    SqRootForL2<cpu>(ctx, req[0], outputs[0].data());
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+}
+
 MXNET_OPERATOR_REGISTER_REDUCE(sum)
 MXNET_ADD_SPARSE_OP_ALIAS(sum)
 .add_alias("sum_axis")
@@ -85,7 +114,7 @@ Example::
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", ReduceAxesCompute<cpu, mshadow::red::sum>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", SumOpForwardEx<cpu, mshadow::red::sum>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", ReduceAxesOpForwardEx<cpu, mshadow::red::sum>)
 .set_attr<FInferStorageType>("FInferStorageType", ReduceAxesOpForwardStorage)
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
@@ -101,7 +130,7 @@ MXNET_OPERATOR_REGISTER_REDUCE(mean)
 MXNET_ADD_SPARSE_OP_ALIAS(mean)
 .describe(get_reduce_axes_description("mean", __LINE__))
 .set_attr<FCompute>("FCompute<cpu>", ReduceAxesCompute<cpu, mshadow::red::sum, true>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", SumOpForwardEx<cpu, mshadow::red::sum, true>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", ReduceAxesOpForwardEx<cpu, mshadow::red::sum, true>)
 .set_attr<FInferStorageType>("FInferStorageType", ReduceAxesOpForwardStorage)
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
@@ -285,14 +314,15 @@ Examples::
 .add_argument("data", "NDArray-or-Symbol", "The input")
 .add_arguments(NormParam::__FIELDS__());
 
-MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_norm)
-.set_num_inputs(1)
+NNVM_REGISTER_OP(_backward_norm)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NormParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
   })
-.set_attr<FCompute>("FCompute<cpu>", L2NormGradCompute<cpu>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", L2NormGradComputeEx<cpu>);
+.set_attr<FCompute>("FCompute<cpu>", L2NormGradCompute<cpu>);
 
 
 }  // namespace op
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cu b/src/operator/tensor/broadcast_reduce_op_value.cu
index 5065b9f..f7fba68 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cu
+++ b/src/operator/tensor/broadcast_reduce_op_value.cu
@@ -36,14 +36,14 @@ void L2NormComputeEx<gpu>(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
+  const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
   mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
-  const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
-  const NDArrayStorageType in_stype = inputs[0].storage_type();
-  nnvm::TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
-  // CSR and RowSparse only works on the entire array.
-  if ((in_stype == kCSRStorage || in_stype == kRowSparseStorage)
-      && axis.ndim() == 0) {
-    L2NormComputeSparseImpl(s, inputs[0], req[0], outputs[0].data());
+  const NDArrayStorageType istype = inputs[0].storage_type();
+  const TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
+  if ((istype == kRowSparseStorage || istype == kCSRStorage) && axis.ndim() == 0 &&
+       param.ord == 2) {
+    // l2 norm on the entire array
+    L2NormComputeSparseImpl<gpu>(s, inputs[0], req[0], outputs[0].data());
   } else {
     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
   }
@@ -105,8 +105,7 @@ NNVM_REGISTER_OP(norm)
 .set_attr<FComputeEx>("FComputeEx<gpu>", L2NormComputeEx<gpu>);
 
 NNVM_REGISTER_OP(_backward_norm)
-.set_attr<FCompute>("FCompute<gpu>", L2NormGradCompute<gpu>)
-.set_attr<FComputeEx>("FComputeEx<gpu>", L2NormGradComputeEx<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", L2NormGradCompute<gpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index d356e78..090773c 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -41,11 +41,7 @@ from test_loss import *
 from test_exc_handling import *
 #from test_rnn import *
 from test_gluon_rnn import *
-from test_sparse_ndarray import test_create_csr, test_create_row_sparse, test_sparse_nd_slice
-from test_sparse_ndarray import test_create_sparse_nd_empty, test_create_sparse_nd_from_sparse
-from test_sparse_ndarray import test_create_sparse_nd_from_dense, test_create_sparse_nd_infer_shape
-from test_sparse_ndarray import test_sparse_nd_check_format, test_sparse_nd_copy
-from test_sparse_ndarray import test_sparse_nd_setitem, test_sparse_nd_binary_scalar_op
+from test_sparse_ndarray import *
 from test_sparse_operator import *
 from test_ndarray import *
 
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index a710038..1ed5080 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -918,18 +918,22 @@ def test_sparse_nd_check_format():
 
 @with_seed()
 def test_sparse_nd_norm():
-    def check_sparse_nd_norm(stype, shape, density):
+    def check_sparse_nd_norm(stype, shape, density, **kwargs):
         data, _ = rand_sparse_ndarray(shape, stype, density)
-        norm = data.norm()
-        expected_norm = np.linalg.norm(data.asnumpy())
-        assert_almost_equal(norm.asnumpy(), expected_norm)
+        norm = data.norm(**kwargs)
+        expected_norm = data.tostype('default').norm(**kwargs)
+        assert_almost_equal(norm.asnumpy(), expected_norm.asnumpy())
 
     shape = (5, 5)
     stypes = ['row_sparse', 'csr']
-    densities = [0, 0.5]
+    densities = [0, 0.5, 1]
     for stype in stypes:
         for density in densities:
-            check_sparse_nd_norm(stype, shape, density)
+           check_sparse_nd_norm(stype, shape, density, axis=None, keepdims=False, ord=2)
+
+    # test fallback
+    check_sparse_nd_norm(stype, shape, density, axis=0, keepdims=False, ord=2)
+    check_sparse_nd_norm(stype, shape, density, axis=None, keepdims=True, ord=2)
 
 @with_seed()
 def test_sparse_fc():

-- 
To stop receiving notification emails like this one, please contact
haibin@apache.org.