You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/15 01:43:13 UTC

[GitHub] eric-haibin-lin closed pull request #9014: square sum gpu impl

eric-haibin-lin closed pull request #9014: square sum gpu impl
URL: https://github.com/apache/incubator-mxnet/pull/9014
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/square_sum-inl.h b/src/operator/tensor/square_sum-inl.h
index a052ad96cf..fcc0215c12 100644
--- a/src/operator/tensor/square_sum-inl.h
+++ b/src/operator/tensor/square_sum-inl.h
@@ -53,18 +53,15 @@ inline bool SquareSumForwardInferStorageType(const nnvm::NodeAttrs& attrs,
   const auto& in_stype = in_attrs->at(0);
   auto& out_stype = out_attrs->at(0);
   bool dispatched = false;
-  // current impl is only available on cpu
-  if (dev_mask == mshadow::cpu::kDevMask) {
-    if (!dispatched && in_stype == kRowSparseStorage && param.axis[0] == 1 && param.keepdims) {
-      // sum per row and keep dims
-      dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
+  if (!dispatched && in_stype == kRowSparseStorage && param.axis[0] == 1 && param.keepdims) {
+    // sum per row and keep dims
+    dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+  if (!dispatched && in_stype == kRowSparseStorage &&
+      (param.axis[0] == 0 || (param.axis[0] == 1 && !param.keepdims))) {
+      dispatched = storage_type_assign(&out_stype, kDefaultStorage,
                                        dispatch_mode, DispatchMode::kFComputeEx);
-    }
-    if (!dispatched && in_stype == kRowSparseStorage &&
-        (param.axis[0] == 0 || (param.axis[0] == 1 && !param.keepdims))) {
-        dispatched = storage_type_assign(&out_stype, kDefaultStorage,
-                                         dispatch_mode, DispatchMode::kFComputeEx);
-    }
   }
   if (!dispatched) {
     // nothing to fallback on
@@ -86,13 +83,10 @@ inline bool SquareSumBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
   const auto& in_stype = in_attrs->at(1);
   auto& grad_stype = out_attrs->at(0);
   bool dispatched = false;
-  // only implemented on cpu
-  if (dev_mask == mshadow::cpu::kDevMask) {
-    if (!dispatched && (ograd_stype == kDefaultStorage || ograd_stype == kRowSparseStorage) &&
-        in_stype == kRowSparseStorage) {
-      dispatched = storage_type_assign(&grad_stype, kRowSparseStorage,
-                                       dispatch_mode, DispatchMode::kFComputeEx);
-    }
+  if (!dispatched && (ograd_stype == kDefaultStorage || ograd_stype == kRowSparseStorage) &&
+      in_stype == kRowSparseStorage) {
+    dispatched = storage_type_assign(&grad_stype, kRowSparseStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
   }
   if (!dispatched) {
     // nothing to fallback on
@@ -359,6 +353,25 @@ void SquareSumRspImpl(const nnvm::NodeAttrs& attrs,
   }
 }
 
+/*!
+ * \brief check the indices of ograd and input are the same.
+ */
+struct CheckSameIdxKernel {
+  template<typename IType>
+  MSHADOW_XINLINE static void Map(int i, IType* ograd_idx,
+                                  IType* in_idx, int32_t* is_diff) {
+    if (ograd_idx[i] != in_idx[i]){
+      *is_diff = 1;
+    }
+  }
+};
+
+
+template<typename xpu>
+void CheckSameIdx(const OpContext& ctx,
+                  const TBlob& ograd_row_idx,
+                  const TBlob& in_row_idx);
+
 /*!\brief
  * This function only supports the following three situations:
  * 1. ograd is a dns and input is an rsp
@@ -367,7 +380,7 @@ void SquareSumRspImpl(const nnvm::NodeAttrs& attrs,
  */
 template<typename xpu>
 void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
-                          mshadow::Stream<xpu>* s,
+                          const OpContext& ctx,
                           const NDArray& ograd,
                           const NDArray& input,
                           const OpReqType req,
@@ -381,6 +394,7 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(input.storage_type(), kRowSparseStorage);
   CHECK_EQ(igrad->storage_type(), kRowSparseStorage);
   CHECK_EQ(req, kWriteTo);
+  mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
   if (!input.storage_initialized()
       || (ograd.storage_type() == kRowSparseStorage && !ograd.storage_initialized())) {
     FillZerosRspImpl(s, *igrad);
@@ -429,28 +443,16 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
     const TBlob& igrad_data = igrad->data();
     const TBlob igrad_row_idx = igrad->aux_data(rowsparse::kIdx);
     MSHADOW_IDX_TYPE_SWITCH(igrad_row_idx.type_flag_, IType, {
-      if (std::is_same<xpu, cpu>::value) {
-        // when ograd_row_idx and in_row_idx have the same size and input is not a full rsp
-        // ograd_row_idx and in_row_idx are expected to have the same elements
-        if (in_row_idx.Size() != input.shape()[0]) {  // if input data is not a full rsp
-          CHECK_EQ(ograd_row_idx.Size(), in_row_idx.Size()) << "SquareSumRspGradImpl only supports"
-                                                               " equal ograd_row_idx and"
-                                                               " input_row_idx when ograd and"
-                                                               " input are both row-sparse and"
-                                                               " input data is not a full"
-                                                               " row-sparse matrix";
-          const IType* first1 = ograd_row_idx.dptr<IType>();
-          const IType* last1 = first1 + ograd_row_idx.Size();
-          const IType* first2 = in_row_idx.dptr<IType>();
-          CHECK(std::equal(first1, last1, first2)) << "SquareSumRspGradImpl only supports"
-                                                      " equal ograd_row_idx and input_row_idx"
-                                                      " when ograd and input are both"
-                                                      " row-sparse and input data is not a full"
-                                                      " row-sparse matrix";
-        }
-      } else {
-        LOG(FATAL) << "SquareSumRspGradImpl has not implemented GPU version when"
-                      " ograd and input are both row-sparse";
+      // when ograd_row_idx and in_row_idx have the same size and input is not a full rsp
+      // ograd_row_idx and in_row_idx are expected to have the same elements
+      if (in_row_idx.Size() != input.shape()[0]) {  // if input data is not a full rsp
+        CHECK_EQ(ograd_row_idx.Size(), in_row_idx.Size()) << "SquareSumRspGradImpl only supports"
+                                                             " equal ograd_row_idx and"
+                                                             " input_row_idx when ograd and"
+                                                             " input are both row-sparse and"
+                                                             " input data is not a full"
+                                                             " row-sparse matrix";
+        CheckSameIdx<xpu>(ctx, ograd_row_idx, in_row_idx);
       }
       MSHADOW_TYPE_SWITCH(igrad_data.type_flag_, DType, {
         MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
@@ -504,7 +506,6 @@ void SquareSumOpBackwardEx(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(inputs.size(), 2U);
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
-  mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
   const NDArrayStorageType ograd_stype = inputs[0].storage_type();
   const NDArrayStorageType input_stype = inputs[1].storage_type();
   if (input_stype == kRowSparseStorage &&
@@ -512,7 +513,7 @@ void SquareSumOpBackwardEx(const nnvm::NodeAttrs& attrs,
     CHECK_EQ(inputs[1].shape().ndim(), 2U) << "_square_sum op only supports"
                                               " 2D ndarray as input";
     NDArray output = outputs[0];
-    SquareSumRspGradImpl(attrs, s, inputs[0], inputs[1], req[0], &output);
+    SquareSumRspGradImpl<xpu>(attrs, ctx, inputs[0], inputs[1], req[0], &output);
   } else {
     LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
   }
diff --git a/src/operator/tensor/square_sum.cc b/src/operator/tensor/square_sum.cc
index e4b49d7f7f..af365bae05 100644
--- a/src/operator/tensor/square_sum.cc
+++ b/src/operator/tensor/square_sum.cc
@@ -25,6 +25,28 @@
 
 namespace mxnet {
 namespace op {
+
+template<>
+void CheckSameIdx<cpu>(const OpContext& ctx,
+                       const TBlob& ograd_row_idx,
+                       const TBlob& in_row_idx) {
+  MSHADOW_IDX_TYPE_SWITCH(ograd_row_idx.type_flag_, IType, {
+    mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+    const IType* ograd_idx = ograd_row_idx.dptr<IType>();
+    const IType* in_idx = in_row_idx.dptr<IType>();
+    const nnvm::dim_t idx_size = ograd_row_idx.Size();
+    int32_t is_different = 0;
+    mxnet_op::Kernel<CheckSameIdxKernel, cpu>::Launch(s, idx_size,
+      ograd_idx, in_idx, &is_different);
+    CHECK_EQ(is_different, 0) << "SquareSumRspGradImpl only supports"
+                                 " equal ograd_row_idx and input_row_idx"
+                                 " when ograd and input are both"
+                                 " row-sparse and input data is not a full"
+                                 " row-sparse matrix";
+  })
+}
+
+
 MXNET_OPERATOR_REGISTER_REDUCE(_square_sum)
 .describe(R"code(Computes the square sum of array elements over a given axis
 for row-sparse matrix. This is a temporary solution for fusing ops square and
@@ -45,6 +67,10 @@ Example::
 
 MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_square_sum)
 .set_num_inputs(2)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
 .set_attr<FInferStorageType>("FInferStorageType", SquareSumBackwardInferStorageType)
 .set_attr<FComputeEx>("FComputeEx<cpu>", SquareSumOpBackwardEx<cpu>);
 
diff --git a/src/operator/tensor/square_sum.cu b/src/operator/tensor/square_sum.cu
new file mode 100644
index 0000000000..0b40786dbd
--- /dev/null
+++ b/src/operator/tensor/square_sum.cu
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file square_sum.cu
+ * \brief GPU Implementation of square_sum op.
+ */
+#include "./square_sum-inl.h"
+
+namespace mxnet {
+namespace op {
+
+template<>
+void CheckSameIdx<gpu>(const OpContext& ctx,
+                       const TBlob& ograd_row_idx,
+                       const TBlob& in_row_idx) {
+  MSHADOW_IDX_TYPE_SWITCH(ograd_row_idx.type_flag_, IType, {
+    mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+    const IType* ograd_idx = ograd_row_idx.dptr<IType>();
+    const IType* in_idx = in_row_idx.dptr<IType>();
+    const nnvm::dim_t idx_size = ograd_row_idx.Size();
+    int32_t is_diff = 0;
+    mshadow::Tensor<gpu, 1, char> workspace = ctx.requested[0]
+        .get_space_typed<gpu, 1, char>(mshadow::Shape1(sizeof(int32_t)), s);
+    int32_t* is_diff_ptr = reinterpret_cast<int32_t*>(workspace.dptr_);
+    mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(s, 1, is_diff_ptr);
+    mxnet_op::Kernel<CheckSameIdxKernel, gpu>::Launch(s, idx_size,
+      ograd_idx, in_idx, is_diff_ptr);
+    CUDA_CALL(cudaMemcpy(&is_diff, is_diff_ptr, sizeof(int32_t), cudaMemcpyDeviceToHost));
+    CHECK_EQ(is_diff, 0) << "SquareSumRspGradImpl only supports"
+                            " equal ograd_row_idx and input_row_idx"
+                            " when ograd and input are both"
+                            " row-sparse and input data is not a full"
+                            " row-sparse matrix";
+  })
+}
+
+
+NNVM_REGISTER_OP(_square_sum)
+.set_attr<FComputeEx>("FComputeEx<gpu>", SquareSumOpForwardEx<gpu>);
+
+NNVM_REGISTER_OP(_backward_square_sum)
+.set_attr<FComputeEx>("FComputeEx<gpu>", SquareSumOpBackwardEx<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index a08b6187bc..a56677c5b0 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1429,63 +1429,62 @@ def test_fallback(func_name, axis=0, keepdims=True, exclude=True):
 
 
 def test_sparse_square_sum():
-    if default_context().device_type == 'cpu':
-        dim0 = 30
-        dim1 = 30
-        axes = [0, 1]
-        keepdims = [False, True]
-        densities = [0, 0.01, 0.2, 0.5, 1.0]
-        for density in densities:
-            shape = rand_shape_2d(dim0, dim1)
-            rsp = rand_ndarray(shape, 'row_sparse', density)
-            dns = rsp.tostype('default')
-            for axis in axes:
-                for keepdim in keepdims:
-                    ret = mx.nd._internal._square_sum(rsp, axis=axis, keepdims=keepdim)
-                    if axis == 1 and keepdim:
-                        assert ret.stype == 'row_sparse'
-                    else:
-                        assert ret.stype == 'default'
-                    ret_expected = mx.nd.sum(dns*dns, axis=axis, keepdims=keepdim)
-                    # check forward result
-                    assert_almost_equal(ret.asnumpy(), ret_expected.asnumpy())
-
-                    rsp_data = mx.sym.Variable('data', stype='row_sparse')
-                    test = mx.symbol._internal._square_sum(rsp_data, axis=axis, keepdims=keepdim)
-
-                    # check symbolic backward since ograd can be an rsp
-                    # and cannot be checked through check_numeric_gradient
-                    # because it will add a loss layer as the output layer
-                    # which makes ograd of the square_sum dense
-                    if axis == 1 and keepdim:
-                        dns_data = mx.sym.Variable('data')
-                        baseline = mx.sym.sum(mx.sym.square(dns_data), axis=axis, keepdims=keepdim)
-                        igrad_expected = mx.nd.empty(dns.shape)
-                        baseline_exec = baseline.bind(default_context(), args=[dns],
-                                                      args_grad=[igrad_expected])
-                        baseline_exec.forward(is_train=True)
-                        baseline_exec.backward([ret_expected])
-                        # check backward when ograd is row sparse
-                        check_symbolic_backward(test, [rsp], [ret_expected.tostype('row_sparse')],
-                                                [igrad_expected.asnumpy()], grad_stypes={'data': 'row_sparse'})
-
-                        # check backward when ograd is dense
-                        # the stype of output of the square_sum is deteremined in symbol binding stage.
-                        # The ograd stype of the last layer is the same as the output stype of the last layer.
-                        # Need to add one more layer after square_sum to trigger the kernel for ograd
-                        # with default stype in square_sum op.
-                        baseline1 = baseline + 1
-                        baseline_exec1 = baseline1.bind(default_context(), args=[dns],
-                                                        args_grad=[igrad_expected])
-                        baseline_exec1.forward(is_train=True)
-                        baseline_exec1.backward([ret_expected])
-                        test1 = test + 1
-                        check_symbolic_backward(test1, [rsp], [ret_expected], [igrad_expected.asnumpy()],
-                                                grad_stypes={'data': 'row_sparse'})
-
-                    # check numeric gradient
-                    check_numeric_gradient(test, [rsp], grad_stype_dict={'data': 'row_sparse'},
-                                           atol=1e-2, rtol=0.1)
+    dim0 = 30
+    dim1 = 30
+    axes = [0, 1]
+    keepdims = [False, True]
+    densities = [0, 0.01, 0.2, 0.5, 1.0]
+    for density in densities:
+        shape = rand_shape_2d(dim0, dim1)
+        rsp = rand_ndarray(shape, 'row_sparse', density)
+        dns = rsp.tostype('default')
+        for axis in axes:
+            for keepdim in keepdims:
+                ret = mx.nd._internal._square_sum(rsp, axis=axis, keepdims=keepdim)
+                if axis == 1 and keepdim:
+                    assert ret.stype == 'row_sparse'
+                else:
+                    assert ret.stype == 'default'
+                ret_expected = mx.nd.sum(dns*dns, axis=axis, keepdims=keepdim)
+                # check forward result
+                assert_almost_equal(ret.asnumpy(), ret_expected.asnumpy())
+
+                rsp_data = mx.sym.Variable('data', stype='row_sparse')
+                test = mx.symbol._internal._square_sum(rsp_data, axis=axis, keepdims=keepdim)
+
+                # check symbolic backward since ograd can be an rsp
+                # and cannot be checked through check_numeric_gradient
+                # because it will add a loss layer as the output layer
+                # which makes ograd of the square_sum dense
+                if axis == 1 and keepdim:
+                    dns_data = mx.sym.Variable('data')
+                    baseline = mx.sym.sum(mx.sym.square(dns_data), axis=axis, keepdims=keepdim)
+                    igrad_expected = mx.nd.empty(dns.shape)
+                    baseline_exec = baseline.bind(default_context(), args=[dns],
+                                                  args_grad=[igrad_expected])
+                    baseline_exec.forward(is_train=True)
+                    baseline_exec.backward([ret_expected])
+                    # check backward when ograd is row sparse
+                    check_symbolic_backward(test, [rsp], [ret_expected.tostype('row_sparse')],
+                                            [igrad_expected.asnumpy()], grad_stypes={'data': 'row_sparse'})
+
+                    # check backward when ograd is dense
+                    # the stype of output of the square_sum is deteremined in symbol binding stage.
+                    # The ograd stype of the last layer is the same as the output stype of the last layer.
+                    # Need to add one more layer after square_sum to trigger the kernel for ograd
+                    # with default stype in square_sum op.
+                    baseline1 = baseline + 1
+                    baseline_exec1 = baseline1.bind(default_context(), args=[dns],
+                                                    args_grad=[igrad_expected])
+                    baseline_exec1.forward(is_train=True)
+                    baseline_exec1.backward([ret_expected])
+                    test1 = test + 1
+                    check_symbolic_backward(test1, [rsp], [ret_expected], [igrad_expected.asnumpy()],
+                                            grad_stypes={'data': 'row_sparse'})
+
+                # check numeric gradient
+                check_numeric_gradient(test, [rsp], grad_stype_dict={'data': 'row_sparse'},
+                                       atol=1e-2, rtol=0.1)
 
 
 def test_sparse_storage_fallback():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services