You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/08 00:08:31 UTC

[GitHub] piiswrong closed pull request #8264: Operators for mean(csr, axis=0) and mean(csr, axis=1)

piiswrong closed pull request #8264: Operators for mean(csr, axis=0) and mean(csr, axis=1)
URL: https://github.com/apache/incubator-mxnet/pull/8264
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index 75f96f9447..8e8b0a1fbb 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -566,14 +566,15 @@ struct SumCsrKernel<req, 1> {
   }
 };
 
-template <typename xpu>
+/*! \brief If normalize is true, the mean should be computed instead of sum */
+template <typename xpu, bool normalize = false>
 void SumCsrImpl(const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu>* s, const OpContext& ctx,
                 const NDArray& input, const OpReqType req, NDArray* output) {
   if (req == kNullOp) return;
   const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
-  CHECK_EQ(param.axis.ndim(), 1U) << "sum(csr) only supports axis 0 or 1";
+  CHECK_EQ(param.axis.ndim(), 1U) << "sum(csr)/mean(csr) only supports axis 0 or 1";
   CHECK(param.axis[0] == 0 || param.axis[0] == 1)
-      << "sum(csr) only support axis 0 or 1";
+      << "sum(csr)/mean(csr) only support axis 0 or 1";
   CHECK(!param.keepdims) << "keepdims not supported for sparse";
   CHECK(!param.exclude) << "exclude not supported for sparse";
   int64_t out_data_size = 0;
@@ -586,6 +587,7 @@ void SumCsrImpl(const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu>* s, const OpC
   CHECK_EQ(output->storage_type(), kDefaultStorage);
 
   using namespace mshadow;
+  using namespace mshadow::expr;
   using namespace mxnet_op;
   using namespace csr;
   using nnvm::dim_t;
@@ -630,19 +632,34 @@ void SumCsrImpl(const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu>* s, const OpC
                 s, num_threads, output->data().dptr<DType>(), in_indptr, in_idx,
                 in_data, sum.dptr_, residual.dptr_, num_rows, num_cols,
                 seg_len);
+            if (normalize) {
+              mxnet_op::Kernel<
+                  mxnet_op::op_with_req<mshadow::op::div, req_type>,
+                  xpu>::Launch(s, out_data_size, output->data().dptr<DType>(),
+                               output->data().dptr<DType>(), DType(num_rows));
+            }
           });
         });
       });
     });
   } else if (1 == param.axis[0]) {
     MSHADOW_IDX_TYPE_SWITCH(input.aux_type(kIndPtr), RType, {
-      MSHADOW_TYPE_SWITCH(input.dtype(), DType, {
-        MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
-          const RType* in_indptr = input.aux_data(kIndPtr).dptr<RType>();
-          const DType* in_data = input.data().dptr<DType>();
-          Kernel<SumCsrKernel<req_type, 1>, xpu>::Launch(
-              s, out_data_size, output->data().dptr<DType>(), in_indptr,
-              in_data);
+      MSHADOW_IDX_TYPE_SWITCH(input.aux_type(kIdx), IType, {
+        MSHADOW_TYPE_SWITCH(input.dtype(), DType, {
+          MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+            const RType* in_indptr = input.aux_data(kIndPtr).dptr<RType>();
+            const DType* in_data = input.data().dptr<DType>();
+            const IType num_cols = input.shape()[1];
+            Kernel<SumCsrKernel<req_type, 1>, xpu>::Launch(
+                s, out_data_size, output->data().dptr<DType>(), in_indptr,
+                in_data);
+            if (normalize) {
+              mxnet_op::Kernel<
+                  mxnet_op::op_with_req<mshadow::op::div, req_type>,
+                  xpu>::Launch(s, out_data_size, output->data().dptr<DType>(),
+                               output->data().dptr<DType>(), DType(num_cols));
+            }
+          });
         });
       });
     });
@@ -661,9 +678,9 @@ void SumOpForwardEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
   const NDArrayStorageType istype = inputs[0].storage_type();
   if (istype == kCSRStorage) {
     CHECK_EQ(inputs[0].shape().ndim(), 2U)
-        << "sum(csr) op only supports 2D ndarray as input";
+        << "sum(csr)/mean(csr) op only supports 2D ndarray as input";
     NDArray output = outputs[0];
-    SumCsrImpl(attrs, s, ctx, inputs[0], req[0], &output);
+    SumCsrImpl<xpu, normalize>(attrs, s, ctx, inputs[0], req[0], &output);
   } else {
     LOG(FATAL) << "Not implemented: "
                << operator_string(attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc
index c3644042aa..0d376c31e7 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cc
+++ b/src/operator/tensor/broadcast_reduce_op_value.cc
@@ -96,8 +96,11 @@ MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_sum)
 .set_attr<FCompute>("FCompute<cpu>", ReduceAxesBackwardUseNone<cpu>);
 
 MXNET_OPERATOR_REGISTER_REDUCE(mean)
+MXNET_ADD_SPARSE_OP_ALIAS(mean)
 .describe(get_reduce_axes_description("mean", __LINE__))
 .set_attr<FCompute>("FCompute<cpu>", ReduceAxesCompute<cpu, mshadow::red::sum, true>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", SumOpForwardEx<cpu, mshadow::red::sum, true>)
+.set_attr<FInferStorageType>("FInferStorageType", SumOpForwardInferStorageType)
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index e6af4022f4..0269bae7e8 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1304,8 +1304,8 @@ def check_sparse_nd_zeros_like(stype, shape):
     check_sparse_nd_zeros_like('row_sparse', shape)
     check_sparse_nd_zeros_like('csr', shape)
 
-def test_sparse_sum_axis():
-    def test_variations():
+def test_sparse_axis_operations():
+    def test_variations(func_name):
         dim0 = 30
         dim1 = 100
         axes = [0, 1]
@@ -1315,21 +1315,23 @@ def test_variations():
             csr_array = rand_ndarray(shape=shape, stype='csr', density=density)
             dns = csr_array.tostype('default')
             for axis in axes:
-                ret = mx.nd.sum(csr_array, axis=axis)
+                ret = func_name(csr_array, axis=axis)
                 assert ret.stype == 'default'
-                ret_expected = mx.nd.sum(dns, axis=axis)
+                ret_expected = func_name(dns, axis=axis)
                 assert_almost_equal(ret.asnumpy(), ret_expected.asnumpy())
 
-    def test_fallback(axis=0, keepdims=True, exclude=True):
+    def test_fallback(func_name, axis=0, keepdims=True, exclude=True):
         dim0 = 30
         dim1 = 100
         shape = rand_shape_2d(dim0, dim1)
         csr_array = rand_ndarray(shape=shape, stype='csr', density=0.01)
-        ret = mx.nd.sum(csr_array, axis=axis, keepdims=keepdims,
-                        exclude=exclude)
+        ret= func_name(csr_array, axis=axis, keepdims=keepdims,
+                       exclude=exclude)
 
-    test_variations()
-    test_fallback(axis=0, keepdims=True, exclude=True)
+    test_variations(mx.nd.sum)
+    test_fallback(mx.nd.sum, axis=0, keepdims=True, exclude=True)
+    test_variations(mx.nd.mean)
+    test_fallback(mx.nd.mean, axis=0, keepdims=True, exclude=True)
 
 def test_sparse_square_sum():
     if default_context().device_type == 'cpu':


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services