You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/21 00:04:00 UTC

[GitHub] szha closed pull request #9479: add norm operator for sparse ndarray

szha closed pull request #9479: add norm operator for sparse ndarray 
URL: https://github.com/apache/incubator-mxnet/pull/9479
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/ndarray/sparse.md b/docs/api/python/ndarray/sparse.md
index 3e6021e3a3..5c8db0c9c1 100644
--- a/docs/api/python/ndarray/sparse.md
+++ b/docs/api/python/ndarray/sparse.md
@@ -157,6 +157,7 @@ We summarize the interface for each class in the following sections.
 
     CSRNDArray.sum
     CSRNDArray.mean
+    CSRNDArray.norm
 ```
 
 ### Powers
@@ -237,6 +238,15 @@ We summarize the interface for each class in the following sections.
     RowSparseNDArray.zeros_like
 ```
 
+### Array reduction
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    RowSparseNDArray.norm
+```
+
 ### Array rounding
 
 ```eval_rst
@@ -414,6 +424,7 @@ We summarize the interface for each class in the following sections.
 
     sum
     mean
+    norm
 ```
 
 ### Rounding
@@ -492,10 +503,10 @@ We summarize the interface for each class in the following sections.
 ```eval_rst
 
 .. autoclass:: mxnet.ndarray.sparse.CSRNDArray
-    :members: shape, context, dtype, stype, data, indices, indptr, copy, copyto, as_in_context, asscipy, asnumpy, asscalar, astype, tostype, slice, wait_to_read, zeros_like, __neg__, sum, mean, square, __getitem__, __setitem__, check_format
+    :members: shape, context, dtype, stype, data, indices, indptr, copy, copyto, as_in_context, asscipy, asnumpy, asscalar, astype, tostype, slice, wait_to_read, zeros_like, __neg__, sum, mean, norm, square, __getitem__, __setitem__, check_format
 
 .. autoclass:: mxnet.ndarray.sparse.RowSparseNDArray
-    :members: shape, context, dtype, stype, data, indices, copy, copyto, as_in_context, asnumpy, asscalar, astype, tostype, wait_to_read, zeros_like, round, rint, fix, floor, ceil, trunc, sin, tan, arcsin, arctan, degrees, radians, sinh, tanh, arcsinh, arctanh, expm1, log1p, sqrt, square, __negative__, __getitem__, __setitem__, check_format, retain, clip, sign
+    :members: shape, context, dtype, stype, data, indices, copy, copyto, as_in_context, asnumpy, asscalar, astype, tostype, wait_to_read, zeros_like, round, rint, fix, floor, ceil, trunc, sin, tan, arcsin, arctan, degrees, radians, sinh, tanh, arcsinh, arctanh, expm1, log1p, sqrt, square, __negative__, norm, __getitem__, __setitem__, check_format, retain, clip, sign
 
 .. automodule:: mxnet.ndarray.sparse
     :members:
diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index 2ae409f2ff..76c78be35a 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -816,6 +816,53 @@ struct ReduceGrad {
   }
 };
 
+inline bool L2NormStorageType(const nnvm::NodeAttrs& attrs,
+                              const int dev_mask,
+                              DispatchMode* dispatch_mode,
+                              std::vector<int>* in_attrs,
+                              std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const int in_stype = in_attrs->at(0);
+  int& out_stype = out_attrs->at(0);
+  bool dispatched = false;
+  if (!dispatched && in_stype == kDefaultStorage) {
+    // dns -> dns
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
+                                     DispatchMode::kFCompute);
+  }
+  if (!dispatched && (in_stype == kCSRStorage || in_stype == kRowSparseStorage)) {
+    // csr/rsp -> dns
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
+                                     DispatchMode::kFComputeEx);
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  return dispatched;
+}
+
+template<typename xpu>
+void L2NormComputeImpl(mshadow::Stream<xpu> *s,
+                       const TBlob& input,
+                       const OpReqType req,
+                       const TBlob& output) {
+  if (req == kNullOp) return;
+  MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, {
+    MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+      mshadow::Tensor<xpu, 1, DType> out = output.get<xpu, 1, DType>(s);
+      mshadow::Tensor<xpu, 1, DType> in = input.get_with_shape<xpu, 1, DType>(
+        mshadow::Shape1(input.shape_.Size()), s);
+      mshadow::VectorDot(out, in, in);
+      DType* out_data = output.dptr<DType>();
+      using namespace mxnet_op;
+      Kernel<op_with_req<mshadow_op::square_root, Req>, xpu>::Launch(
+        s, output.Size(), out_data, out_data);
+    });
+  });
+}
+
+
 template<typename xpu>
 void L2NormCompute(const nnvm::NodeAttrs& attrs,
                    const OpContext& ctx,
@@ -823,13 +870,41 @@ void L2NormCompute(const nnvm::NodeAttrs& attrs,
                    const std::vector<OpReqType>& req,
                    const std::vector<TBlob>& outputs) {
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    mshadow::Tensor<xpu, 1, DType> out = outputs[0].get<xpu, 1, DType>(s);
-    mshadow::Tensor<xpu, 1, DType> in = inputs[0].get_with_shape<xpu, 1, DType>(
-      mshadow::Shape1(inputs[0].shape_.Size()), s);
-    mshadow::VectorDot(out, in, in);
-    ASSIGN_DISPATCH(out, req[0], mshadow::expr::F<mxnet::op::mshadow_op::square_root>(out));
-  });
+  L2NormComputeImpl(s, inputs[0], req[0], outputs[0]);
+}
+
+template<typename xpu>
+void L2NormComputeSparseImpl(mshadow::Stream<xpu> *s,
+                             const NDArray& input,
+                             const OpReqType req,
+                             const TBlob& output) {
+  if (req == kNullOp) return;
+  // input is zeros
+  if (!input.storage_initialized()) {
+    // Add zeros. No op.
+    if (req == kAddTo) return;
+    Fill<false>(s, output, req, 0);
+  } else {
+    L2NormComputeImpl(s, input.data(), req, output);
+  }
+}
+
+template<typename xpu>
+void L2NormComputeEx(const nnvm::NodeAttrs& attrs,
+                     const OpContext& ctx,
+                     const std::vector<NDArray>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
+  const NDArrayStorageType in_stype = inputs[0].storage_type();
+  if (in_stype == kCSRStorage || in_stype == kRowSparseStorage) {
+    L2NormComputeSparseImpl(s, inputs[0], req[0], outputs[0].data());
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
 }
 
 /*! \brief index element from array along axes */
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc
index 70cf778a58..40624e54ab 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cc
+++ b/src/operator/tensor/broadcast_reduce_op_value.cc
@@ -245,6 +245,7 @@ NNVM_REGISTER_OP(_broadcast_backward)
   });
 
 NNVM_REGISTER_OP(norm)
+MXNET_ADD_SPARSE_OP_ALIAS(norm)
 .describe(R"code(Flattens the input array and then computes the l2 norm.
 
 Examples::
@@ -254,6 +255,14 @@ Examples::
 
   norm(x) = [5.47722578]
 
+  rsp = x.cast_storage('row_sparse')
+
+  norm(rsp) = [5.47722578]
+
+  csr = x.cast_storage('csr')
+
+  norm(csr) = [5.47722578]
+
 )code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
@@ -268,7 +277,9 @@ Examples::
     return true;
   })
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FInferStorageType>("FInferStorageType", L2NormStorageType)
 .set_attr<FCompute>("FCompute<cpu>", L2NormCompute<cpu>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", L2NormComputeEx<cpu>)
 .add_argument("data", "NDArray-or-Symbol", "Source input");
 
 }  // namespace op
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cu b/src/operator/tensor/broadcast_reduce_op_value.cu
index 73c32f09cc..5fd7cbfc89 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cu
+++ b/src/operator/tensor/broadcast_reduce_op_value.cu
@@ -78,7 +78,8 @@ NNVM_REGISTER_OP(_broadcast_backward)
 .set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::sum>);
 
 NNVM_REGISTER_OP(norm)
-.set_attr<FCompute>("FCompute<gpu>", L2NormCompute<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", L2NormCompute<gpu>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", L2NormComputeEx<gpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 185ce7f4de..ab389b6d03 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -829,6 +829,19 @@ def test_sparse_nd_check_format():
     a = mx.nd.sparse.row_sparse_array((data_list, indices_list), shape=shape)
     assertRaises(mx.base.MXNetError, a.check_format)
 
+def test_sparse_nd_norm():
+    def check_sparse_nd_norm(stype, shape, density):
+        data, _ = rand_sparse_ndarray(shape, stype, density)
+        norm = data.norm()
+        expected_norm = np.linalg.norm(data.asnumpy())
+        assert_almost_equal(norm.asnumpy(), expected_norm)
+
+    shape = (5, 5)
+    stypes = ['row_sparse', 'csr']
+    densities = [0, 0.5]
+    for stype in stypes:
+        for density in densities:
+            check_sparse_nd_norm(stype, shape, density)
 
 if __name__ == '__main__':
     import nose


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services