You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/15 16:55:54 UTC

[GitHub] eric-haibin-lin closed pull request #10913: Fix _backward_norm op registration

eric-haibin-lin closed pull request #10913: Fix _backward_norm op registration
URL: https://github.com/apache/incubator-mxnet/pull/10913
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 935c92a034c..cf126ed58ea 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -305,7 +305,7 @@ inline bool SGDStorageType(const nnvm::NodeAttrs& attrs,
     dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
                                      dispatch_mode, DispatchMode::kFComputeEx);
     // warn users if lazy_update is turned on
-    if (dispatched && param.lazy_update) LogLazyUpdate();
+    if (dispatched && param.wd != 0 && param.lazy_update) LogLazyUpdate();
   }
   if (!dispatched) {
     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index 1d4ef0adfb6..e50071bdab7 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -385,11 +385,11 @@ inline bool ReduceAxesOpForwardStorage(const nnvm::NodeAttrs& attrs,
   const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
   const int in_stype = in_attrs->at(0);
   int& out_stype = out_attrs->at(0);
-  bool dispatched = false;
-  // sum only supported for CPU for now. TODO: Remove when support for GPU added
+  // sum and reduce only supports for CPU for now.
   const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
   const auto dispatch_ex =
       invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx;
+  bool dispatched = false;
   if (!dispatched && in_stype == kDefaultStorage) {
     // When input is dense output storage is set as dense and dispatched to
     // dense operator
@@ -707,18 +707,16 @@ void ReduceCsr(const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu>* s, const OpCo
 }
 
 template <typename xpu, typename reducer, bool normalize = false>
-void SumOpForwardEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
-                    const std::vector<NDArray>& inputs,
-                    const std::vector<OpReqType>& req,
-                    const std::vector<NDArray>& outputs) {
+void ReduceAxesOpForwardEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
+                           const std::vector<NDArray>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<NDArray>& outputs) {
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
   mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
   const NDArrayStorageType istype = inputs[0].storage_type();
   if (istype == kCSRStorage) {
-    CHECK_EQ(inputs[0].shape().ndim(), 2U)
-        << "sum(csr)/mean(csr) op only supports 2D ndarray as input";
     NDArray output = outputs[0];
     ReduceCsr<xpu, mshadow::red::sum, normalize>(attrs, s, ctx, inputs[0],
                                                  req[0], &output);
@@ -880,17 +878,30 @@ inline bool L2NormStorageType(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(out_attrs->size(), 1U);
   const int in_stype = in_attrs->at(0);
   int& out_stype = out_attrs->at(0);
+  const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
   bool dispatched = false;
+  // l2 norm on a particular axis only supports cpu
+  const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
+  const auto dispatch_ex =
+      invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx;
   if (!dispatched && in_stype == kDefaultStorage) {
     // dns -> dns
     dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
                                      DispatchMode::kFCompute);
   }
-  if (!dispatched && (in_stype == kCSRStorage || in_stype == kRowSparseStorage)) {
-    // csr/rsp -> dns
+  const TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
+  if (!dispatched && (in_stype == kRowSparseStorage || in_stype == kCSRStorage) &&
+      axis.ndim() == 0 && param.ord == 2) {
+    // l2 norm: rsp/csr, axis = () -> dns
     dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
                                      DispatchMode::kFComputeEx);
   }
+  if (!dispatched && in_stype == kCSRStorage && axis.ndim() == 1 && !param.keepdims &&
+      (axis[0] == 0 || axis[0] == 1) && param.ord == 2) {
+    // l2 norm: csr, axis = 0/1 -> dns
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
+                                     dispatch_ex);
+  }
   if (!dispatched) {
     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
   }
@@ -943,10 +954,11 @@ void L2NormComputeImpl(mshadow::Stream<xpu> *s,
                        const TBlob& input,
                        const OpReqType req,
                        const TBlob& output) {
-  if (req == kNullOp) return;
   MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, {
+    // assign_req switch exits immediately for null req
     MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-      mshadow::Tensor<xpu, 1, DType> out = output.get<xpu, 1, DType>(s);
+      mshadow::Tensor<xpu, 1, DType> out = output.get_with_shape<xpu, 1, DType>(
+        mshadow::Shape1(output.shape_.Size()), s);
       mshadow::Tensor<xpu, 1, DType> in = input.get_with_shape<xpu, 1, DType>(
         mshadow::Shape1(input.shape_.Size()), s);
       mshadow::VectorDot(out, in, in);
@@ -1002,12 +1014,12 @@ void L2NormGradCompute(const nnvm::NodeAttrs& attrs,
   using namespace mshadow::expr;
   if (req[0] == kNullOp) return;
 
-  const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
+  const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
   TShape small;
   if (param.keepdims) {
     small = inputs[0].shape_;
   } else {
-    small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, param.exclude);
+    small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, false);
   }
   ReduceAxesBackwardUseInOutImpl<xpu, mshadow_op::div, false>(ctx, small, inputs,
                                                               req, outputs);
@@ -1034,42 +1046,7 @@ void L2NormComputeEx(const nnvm::NodeAttrs& attrs,
                      const OpContext& ctx,
                      const std::vector<NDArray>& inputs,
                      const std::vector<OpReqType>& req,
-                     const std::vector<NDArray>& outputs) {
-  CHECK_EQ(inputs.size(), 1U);
-  CHECK_EQ(outputs.size(), 1U);
-  CHECK_EQ(req.size(), 1U);
-  const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
-  CHECK_EQ(param.ord, 2) << "norm only support ord=2";
-  mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
-  const NDArrayStorageType istype = inputs[0].storage_type();
-  const TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
-  if ((istype == kRowSparseStorage || istype == kCSRStorage) && axis.ndim() == 0) {
-    // We only support norm on the entire array for now.
-    L2NormComputeSparseImpl<xpu>(s, inputs[0], req[0], outputs[0].data());
-  } else if (istype == kCSRStorage) {
-    CHECK_EQ(inputs[0].shape().ndim(), 2U)
-        << "norm(csr) op only supports 2D ndarray as input";
-    CHECK_EQ(axis.ndim(), 1U) << "sum(csr)/mean(csr) only supports axis 0 or 1";
-    CHECK(axis[0] == 0 || axis[0] == 1)
-        << "sum(csr)/mean(csr) only support axis 0 or 1";
-    CHECK(!param.keepdims) << "keepdims not supported for sparse";
-    NDArray output = outputs[0];
-    ReduceCsrImpl<xpu, sq_sum, false>(s, ctx, inputs[0], req[0], &output, axis);
-    CHECK_EQ(outputs[0].storage_type(), kDefaultStorage);
-    SqRootForL2<xpu>(ctx, req[0], outputs[0].data());
-  } else {
-    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
-  }
-}
-
-template<typename xpu>
-void L2NormGradComputeEx(const nnvm::NodeAttrs& attrs,
-                         const OpContext& ctx,
-                         const std::vector<NDArray>& inputs,
-                         const std::vector<OpReqType>& req,
-                         const std::vector<NDArray>& outputs) {
-  LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
-}
+                     const std::vector<NDArray>& outputs);
 
 /*! \brief index element from array along axes */
 template<int ndim>
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc
index 5f433cc8001..7bcc3e97e8e 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cc
+++ b/src/operator/tensor/broadcast_reduce_op_value.cc
@@ -46,6 +46,35 @@ Defined in )code";
   return doc;
 }
 
+template<>
+void L2NormComputeEx<cpu>(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<NDArray>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  const NDArrayStorageType istype = inputs[0].storage_type();
+  const TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
+  if ((istype == kRowSparseStorage || istype == kCSRStorage) && axis.ndim() == 0 &&
+       param.ord == 2) {
+    // l2 norm on the entire array
+    L2NormComputeSparseImpl<cpu>(s, inputs[0], req[0], outputs[0].data());
+  } else if (istype == kCSRStorage && axis.ndim() == 1 && (axis[0] == 0 || axis[0] == 1) &&
+             !param.keepdims && param.ord == 2) {
+    // l2 norm on a particular axis
+    NDArray output = outputs[0];
+    ReduceCsrImpl<cpu, sq_sum, false>(s, ctx, inputs[0], req[0], &output, axis);
+    CHECK_EQ(outputs[0].storage_type(), kDefaultStorage);
+    SqRootForL2<cpu>(ctx, req[0], outputs[0].data());
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+}
+
 MXNET_OPERATOR_REGISTER_REDUCE(sum)
 MXNET_ADD_SPARSE_OP_ALIAS(sum)
 .add_alias("sum_axis")
@@ -85,7 +114,7 @@ Example::
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", ReduceAxesCompute<cpu, mshadow::red::sum>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", SumOpForwardEx<cpu, mshadow::red::sum>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", ReduceAxesOpForwardEx<cpu, mshadow::red::sum>)
 .set_attr<FInferStorageType>("FInferStorageType", ReduceAxesOpForwardStorage)
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
@@ -101,7 +130,7 @@ MXNET_OPERATOR_REGISTER_REDUCE(mean)
 MXNET_ADD_SPARSE_OP_ALIAS(mean)
 .describe(get_reduce_axes_description("mean", __LINE__))
 .set_attr<FCompute>("FCompute<cpu>", ReduceAxesCompute<cpu, mshadow::red::sum, true>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", SumOpForwardEx<cpu, mshadow::red::sum, true>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", ReduceAxesOpForwardEx<cpu, mshadow::red::sum, true>)
 .set_attr<FInferStorageType>("FInferStorageType", ReduceAxesOpForwardStorage)
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
@@ -285,14 +314,15 @@ Examples::
 .add_argument("data", "NDArray-or-Symbol", "The input")
 .add_arguments(NormParam::__FIELDS__());
 
-MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_norm)
-.set_num_inputs(1)
+NNVM_REGISTER_OP(_backward_norm)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NormParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
   })
-.set_attr<FCompute>("FCompute<cpu>", L2NormGradCompute<cpu>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", L2NormGradComputeEx<cpu>);
+.set_attr<FCompute>("FCompute<cpu>", L2NormGradCompute<cpu>);
 
 
 }  // namespace op
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cu b/src/operator/tensor/broadcast_reduce_op_value.cu
index 5065b9fcc7b..f7fba682878 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cu
+++ b/src/operator/tensor/broadcast_reduce_op_value.cu
@@ -36,14 +36,14 @@ void L2NormComputeEx<gpu>(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
+  const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
   mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
-  const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
-  const NDArrayStorageType in_stype = inputs[0].storage_type();
-  nnvm::TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
-  // CSR and RowSparse only works on the entire array.
-  if ((in_stype == kCSRStorage || in_stype == kRowSparseStorage)
-      && axis.ndim() == 0) {
-    L2NormComputeSparseImpl(s, inputs[0], req[0], outputs[0].data());
+  const NDArrayStorageType istype = inputs[0].storage_type();
+  const TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
+  if ((istype == kRowSparseStorage || istype == kCSRStorage) && axis.ndim() == 0 &&
+       param.ord == 2) {
+    // l2 norm on the entire array
+    L2NormComputeSparseImpl<gpu>(s, inputs[0], req[0], outputs[0].data());
   } else {
     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
   }
@@ -105,8 +105,7 @@ NNVM_REGISTER_OP(norm)
 .set_attr<FComputeEx>("FComputeEx<gpu>", L2NormComputeEx<gpu>);
 
 NNVM_REGISTER_OP(_backward_norm)
-.set_attr<FCompute>("FCompute<gpu>", L2NormGradCompute<gpu>)
-.set_attr<FComputeEx>("FComputeEx<gpu>", L2NormGradComputeEx<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", L2NormGradCompute<gpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index d356e789289..090773c7787 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -41,11 +41,7 @@
 from test_exc_handling import *
 #from test_rnn import *
 from test_gluon_rnn import *
-from test_sparse_ndarray import test_create_csr, test_create_row_sparse, test_sparse_nd_slice
-from test_sparse_ndarray import test_create_sparse_nd_empty, test_create_sparse_nd_from_sparse
-from test_sparse_ndarray import test_create_sparse_nd_from_dense, test_create_sparse_nd_infer_shape
-from test_sparse_ndarray import test_sparse_nd_check_format, test_sparse_nd_copy
-from test_sparse_ndarray import test_sparse_nd_setitem, test_sparse_nd_binary_scalar_op
+from test_sparse_ndarray import *
 from test_sparse_operator import *
 from test_ndarray import *
 
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index a710038d5d8..1ed5080e1c1 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -918,18 +918,22 @@ def test_sparse_nd_check_format():
 
 @with_seed()
 def test_sparse_nd_norm():
-    def check_sparse_nd_norm(stype, shape, density):
+    def check_sparse_nd_norm(stype, shape, density, **kwargs):
         data, _ = rand_sparse_ndarray(shape, stype, density)
-        norm = data.norm()
-        expected_norm = np.linalg.norm(data.asnumpy())
-        assert_almost_equal(norm.asnumpy(), expected_norm)
+        norm = data.norm(**kwargs)
+        expected_norm = data.tostype('default').norm(**kwargs)
+        assert_almost_equal(norm.asnumpy(), expected_norm.asnumpy())
 
     shape = (5, 5)
     stypes = ['row_sparse', 'csr']
-    densities = [0, 0.5]
+    densities = [0, 0.5, 1]
     for stype in stypes:
         for density in densities:
-            check_sparse_nd_norm(stype, shape, density)
+           check_sparse_nd_norm(stype, shape, density, axis=None, keepdims=False, ord=2)
+
+    # test fallback
+    check_sparse_nd_norm(stype, shape, density, axis=0, keepdims=False, ord=2)
+    check_sparse_nd_norm(stype, shape, density, axis=None, keepdims=True, ord=2)
 
 @with_seed()
 def test_sparse_fc():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services