You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/26 23:32:13 UTC

[GitHub] eric-haibin-lin closed pull request #10081: [MXNET-82] Sparse op tutorial for developers

eric-haibin-lin closed pull request #10081: [MXNET-82] Sparse op tutorial for developers
URL: https://github.com/apache/incubator-mxnet/pull/10081
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/faq/index.md b/docs/faq/index.md
index 099cd509b14..098d37f5fc0 100644
--- a/docs/faq/index.md
+++ b/docs/faq/index.md
@@ -56,6 +56,8 @@ and full working examples, visit the [tutorials section](../tutorials/index.md).
 
 * [How do I create new operators in MXNet?](http://mxnet.io/faq/new_op.html)
 
+* [How do I implement sparse operators in MXNet backend?](https://cwiki.apache.org/confluence/display/MXNET/A+Guide+to+Implementing+Sparse+Operators+in+MXNet+Backend)
+
 * [How do I contribute an example or tutorial?](https://github.com/apache/incubator-mxnet/tree/master/example#contributing)
 
 * [How do I set MXNet's environmental variables?](http://mxnet.io/faq/env_var.html)
diff --git a/src/operator/contrib/quadratic_op-inl.h b/src/operator/contrib/quadratic_op-inl.h
index 8d73a4286f6..fe477811b06 100644
--- a/src/operator/contrib/quadratic_op-inl.h
+++ b/src/operator/contrib/quadratic_op-inl.h
@@ -32,6 +32,7 @@
 #include "../mxnet_op.h"
 #include "../operator_common.h"
 #include "../elemwise_op_common.h"
+#include "../tensor/init_op.h"
 
 namespace mxnet {
 namespace op {
@@ -73,6 +74,33 @@ inline bool QuadraticOpType(const nnvm::NodeAttrs& attrs,
   return out_attrs->at(0) != -1;
 }
 
+inline bool QuadraticOpStorageType(const nnvm::NodeAttrs& attrs,
+                                   const int dev_mask,
+                                   DispatchMode* dispatch_mode,
+                                   std::vector<int>* in_attrs,
+                                   std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const QuadraticParam& param = nnvm::get<QuadraticParam>(attrs.parsed);
+  const int in_stype = in_attrs->at(0);
+  int& out_stype = out_attrs->at(0);
+  bool dispatched = false;
+  if (!dispatched && in_stype == kDefaultStorage) {
+    // dns -> dns
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  if (!dispatched && in_stype == kCSRStorage && param.c == 0.0) {
+    // csr -> csr
+    dispatched = storage_type_assign(&out_stype, kCSRStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  return dispatched;
+}
+
 template<int req>
 struct quadratic_forward {
   template<typename DType>
@@ -114,6 +142,61 @@ void QuadraticOpForward(const nnvm::NodeAttrs& attrs,
   });
 }
 
+template<typename xpu>
+void QuadraticOpForwardCsrImpl(const QuadraticParam& param,
+                               const OpContext& ctx,
+                               const NDArray& input,
+                               const OpReqType req,
+                               const NDArray& output) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace csr;
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteTo) << "QuadraticOp with CSR only supports kWriteTo";
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  if (!input.storage_initialized()) {
+    FillZerosCsrImpl(s, output);
+    return;
+  }
+  const nnvm::dim_t nnz = input.storage_shape()[0];
+  const nnvm::dim_t num_rows = output.shape()[0];
+  output.CheckAndAlloc({Shape1(num_rows + 1), Shape1(nnz)});
+  CHECK_EQ(output.aux_type(kIdx), output.aux_type(kIndPtr))
+    << "The dtypes of indices and indptr don't match";
+  MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
+    MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), IType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+        Kernel<quadratic_forward<req_type>, xpu>::Launch(
+            s, nnz, output.data().dptr<DType>(), input.data().dptr<DType>(),
+            param.a, param.b, param.c);
+        Copy(output.aux_data(kIdx).FlatTo1D<xpu, IType>(),
+             input.aux_data(kIdx).FlatTo1D<xpu, IType>(), s);
+        Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, IType>(),
+             input.aux_data(kIndPtr).FlatTo1D<xpu, IType>(), s);
+      });
+    });
+  });
+}
+
+template<typename xpu>
+void QuadraticOpForwardEx(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<NDArray>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const QuadraticParam& param = nnvm::get<QuadraticParam>(attrs.parsed);
+  const auto in_stype = inputs[0].storage_type();
+  const auto out_stype = outputs[0].storage_type();
+  if (in_stype == kCSRStorage && out_stype == kCSRStorage && param.c == 0.0) {
+    QuadraticOpForwardCsrImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]);
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+}
+
 template<typename xpu>
 void QuadraticOpBackward(const nnvm::NodeAttrs& attrs,
                          const OpContext& ctx,
diff --git a/src/operator/contrib/quadratic_op.cc b/src/operator/contrib/quadratic_op.cc
index 5b2d84cfcba..d8b2d785c79 100644
--- a/src/operator/contrib/quadratic_op.cc
+++ b/src/operator/contrib/quadratic_op.cc
@@ -38,6 +38,11 @@ Example::
   x = [[1, 2], [3, 4]]
   y = quadratic(data=x, a=1, b=2, c=3)
   y = [[6, 11], [18, 27]]
+
+The storage type of ``quadratic`` output depends on storage types of inputs
+  - quadratic(csr, a, b, 0) = csr
+  - quadratic(default, a, b, c) = default
+
 )code" ADD_FILELINE)
 .set_attr_parser(ParamParser<QuadraticParam>)
 .set_num_inputs(1)
@@ -48,6 +53,7 @@ Example::
   })
 .set_attr<nnvm::FInferShape>("FInferShape", QuadraticOpShape)
 .set_attr<nnvm::FInferType>("FInferType", QuadraticOpType)
+.set_attr<FInferStorageType>("FInferStorageType", QuadraticOpStorageType)
 .set_attr<FCompute>("FCompute<cpu>", QuadraticOpForward<cpu>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_contrib_backward_quadratic"})
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
@@ -62,7 +68,8 @@ NNVM_REGISTER_OP(_contrib_backward_quadratic)
 .set_num_inputs(2)
 .set_num_outputs(1)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FCompute>("FCompute<cpu>", QuadraticOpBackward<cpu>);
+.set_attr<FCompute>("FCompute<cpu>", QuadraticOpBackward<cpu>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", QuadraticOpForwardEx<cpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/contrib/quadratic_op.cu b/src/operator/contrib/quadratic_op.cu
index ede773a7ea3..72d15ab3749 100644
--- a/src/operator/contrib/quadratic_op.cu
+++ b/src/operator/contrib/quadratic_op.cu
@@ -27,6 +27,7 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_contrib_quadratic)
+.set_attr<FComputeEx>("FComputeEx<gpu>", QuadraticOpForwardEx<gpu>)
 .set_attr<FCompute>("FCompute<gpu>", QuadraticOpForward<gpu>);
 
 NNVM_REGISTER_OP(_contrib_backward_quadratic)
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index a8bf5a5ed3c..eea801eb369 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1918,6 +1918,26 @@ def test_where_numeric_gradient(shape):
     test_where_helper((5, 9))
     test_where_numeric_gradient((5, 9))
 
+@with_seed()
+def test_sparse_quadratic_function():
+    def f(x, a, b, c):
+        return a * x**2 + b * x + c
+
+    def check_sparse_quadratic_function(a, b, c, expected_stype):
+      # check forward and compare the result with dense op
+      ndim = 2
+      shape = rand_shape_nd(ndim, 5)
+      data = rand_ndarray(shape=shape, stype='csr')
+      data_np = data.asnumpy()
+      expected = f(data_np, a, b, c)
+      output = mx.nd.contrib.quadratic(data, a=a, b=b, c=c)
+      assert(output.stype == expected_stype)
+      assert_almost_equal(output.asnumpy(), expected)
+
+    a = np.random.random_sample()
+    b = np.random.random_sample()
+    check_sparse_quadratic_function(a, b, 0.0, 'csr')
+    check_sparse_quadratic_function(a, b, 1.0, 'default')
 
 if __name__ == '__main__':
     import nose


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services