You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/05 05:36:20 UTC

[GitHub] eric-haibin-lin closed pull request #9189: standard update for sparse sgd_mom_update

eric-haibin-lin closed pull request #9189: standard update for sparse sgd_mom_update
URL: https://github.com/apache/incubator-mxnet/pull/9189
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 59898c9448..feff87e0ba 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -433,14 +433,8 @@ def _get_wd(self, index):
 class SGD(Optimizer):
     """The SGD optimizer with momentum and weight decay.
 
-    The optimizer updates the weight by::
-
-        rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
-        state = momentum * state + rescaled_grad
-        weight = weight - state
-
-    If the storage types of weight, state and grad are all ``row_sparse``, \
-    **sparse updates** are applied by::
+    If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, \
+    **lazy updates** are applied by::
 
         for row in grad.indices:
             rescaled_grad[row] = lr * rescale_grad * clip(grad[row], clip_gradient) + wd * weight[row]
@@ -454,6 +448,12 @@ class SGD(Optimizer):
     provides slightly different semantics than the original update, and
     may lead to different empirical results.
 
+    Otherwise, **standard updates** are applied by::
+
+        rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
+        state = momentum * state + rescaled_grad
+        weight = weight - state
+
     For details of the update algorithm see
     :class:`~mxnet.ndarray.sgd_update` and :class:`~mxnet.ndarray.sgd_mom_update`.
 
@@ -464,6 +464,9 @@ class SGD(Optimizer):
     ----------
     momentum : float, optional
        The momentum value.
+    lazy_update : bool, optional
+       Default is True. If True, lazy updates are applied \
+       if the storage types of weight and grad are both ``row_sparse``.
     multi_precision: bool, optional
        Flag to control the internal precision of the optimizer.
        ``False`` results in using the same precision as the weights (default),
@@ -471,9 +474,10 @@ class SGD(Optimizer):
                 in 32-bit precision even if actual weights used in the model have lower precision.\
                 Turning this on can improve convergence and accuracy when training with float16.
     """
-    def __init__(self, momentum=0.0, **kwargs):
+    def __init__(self, momentum=0.0, lazy_update=True, **kwargs):
         super(SGD, self).__init__(**kwargs)
         self.momentum = momentum
+        self.lazy_update = lazy_update
 
     def create_state_multi_precision(self, index, weight):
         weight_master_copy = None
@@ -489,8 +493,9 @@ def create_state_multi_precision(self, index, weight):
 
     def create_state(self, index, weight):
         momentum = None
+        stype = weight.stype if self.lazy_update else 'default'
         if self.momentum != 0.0:
-            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
+            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype)
         return momentum
 
     def _update_impl(self, index, weight, grad, state, multi_precision=False):
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index a6b32b129e..33b7dd5fe5 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -38,6 +38,7 @@
 #include "./elemwise_op_common.h"
 #include "mxnet_op.h"
 #include "./tensor/init_op.h"
+#include "./tensor/util/tensor_util-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -460,6 +461,106 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
                                  mom.data(), req, &out_blob);
 }
 
+/*! 
+ * \brief Storge type inference function in optimizer.
+ * \param n_rsp     The number of inputs that should be of row_sparse storage type
+ *                  if kFComputeEx is dispatched
+ * \param n_rsp_dns The number of inputs that should be of row_sparse or default storage type
+ *                  if kFComputeEx is dispatched
+ */
+template<int n_rsp, int n_rsp_dns>
+inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs,
+                              const int dev_mask,
+                              DispatchMode* dispatch_mode,
+                              std::vector<int>* in_attrs,
+                              std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_rsp + n_rsp_dns));
+  CHECK_EQ(out_attrs->size(), 1U);
+  bool dispatched = false;
+
+  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+    // dns, ... -> dns
+    dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  const std::vector<int> rsp_stypes(in_attrs->begin(), in_attrs->begin() + n_rsp);
+  const std::vector<int> rsp_dns_stypes(in_attrs->begin() + n_rsp, in_attrs->end());
+  if (!dispatched && common::ContainsOnlyStorage(rsp_stypes, kRowSparseStorage) &&
+      (common::ContainsOnlyStorage(rsp_dns_stypes, kRowSparseStorage) ||
+       common::ContainsOnlyStorage(rsp_dns_stypes, kDefaultStorage))) {
+    // rsp, ..., rsp/dns, ... -> rsp
+    dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+
+  if (!dispatched) {
+    dispatch_fallback(out_attrs, dispatch_mode);
+    LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs);
+  }
+  return true;
+}
+
+template<int req>
+struct SGDMomStdDnsRspDnsKernel {
+  template<typename DType, typename IType, typename RType>
+  MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data,
+    DType* mom_data, const DType* weight_data, const IType* grad_idx,
+    const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
+    const DType momentum, const DType lr, const DType wd, const DType rescale_grad) {
+    const DType rate = lr * wd;
+    const bool non_zero = (i == 0) ? prefix_sum[0] > 0
+                                   : prefix_sum[i] > prefix_sum[i-1];
+
+    const index_t row_i = i * row_length;
+    const RType grad_i = (prefix_sum[i]-1) * row_length;
+    for (index_t j = 0; j < row_length; j++) {
+      const index_t data_i = row_i + j;
+      const DType grad = non_zero ? grad_data[grad_i + j]
+                                  : static_cast<DType>(0);
+      if (clip_gradient >= 0.0f) {
+        mom_data[data_i] = momentum * mom_data[data_i]
+                - rate * weight_data[data_i]
+                - lr *
+                mshadow_op::clip::Map(rescale_grad * grad,
+                                      clip_gradient);
+      } else {
+        mom_data[data_i] = momentum * mom_data[data_i]
+                  - rate * weight_data[data_i]
+                  - lr * rescale_grad * grad;
+      }
+      KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]);
+    }
+  }
+};
+
+template<typename xpu>
+void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param,
+                                  const OpContext& ctx,
+                                  const TBlob& weight,
+                                  const NDArray& grad,
+                                  const TBlob& mom,
+                                  const OpReqType& req,
+                                  TBlob *out);
+
+template<typename xpu>
+inline void SGDMomStdUpdateRspRspDnsImpl(const SGDMomParam& param,
+                                         const OpContext& ctx,
+                                         const NDArray& weight,
+                                         const NDArray& grad,
+                                         const NDArray& mom,
+                                         const OpReqType& req,
+                                         NDArray *out) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  TBlob out_blob = out->data();
+  SGDMomStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
+                                    mom.data(), req, &out_blob);
+}
+
 template<typename xpu>
 inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
                            const OpContext &ctx,
@@ -474,12 +575,15 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
   const auto weight_stype = weight.storage_type();
   const auto mom_stype = mom.storage_type();
   const auto out_stype = outputs[0].storage_type();
-  CHECK_EQ(weight_stype, mom_stype) << "Inconsistent storage type detected between mom.stype = "
-           << mom_stype << " and weight.stype = " << weight_stype;
+  NDArray out = outputs[0];
   if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
       out_stype == kRowSparseStorage) {
-     NDArray out = outputs[0];
-     SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
+    SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
+  } else if (weight.storage_type() == kRowSparseStorage &&
+             grad.storage_type() == kRowSparseStorage &&
+             mom.storage_type() == kDefaultStorage &&
+             out_stype == kRowSparseStorage) {
+    SGDMomStdUpdateRspRspDnsImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
   } else {
     LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
   }
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 4de94e52ce..dda809255d 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -37,6 +37,57 @@ DMLC_REGISTER_PARAMETER(RMSPropParam);
 DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
 DMLC_REGISTER_PARAMETER(FtrlParam);
 
+template<>
+void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
+                                       const OpContext& ctx,
+                                       const TBlob& weight,
+                                       const NDArray& grad,
+                                       const TBlob& mom,
+                                       const OpReqType& req,
+                                       TBlob *out) {
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  using namespace mshadow;
+  Stream<cpu>* s = ctx.get_stream<cpu>();
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
+  CHECK_GT(weight.shape_.Size(), 0);
+  CHECK_GT(mom.shape_.Size(), 0);
+  MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+        DType* weight_data = weight.dptr<DType>();
+        IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+        DType* grad_val = grad.data().dptr<DType>();
+        DType* mom_data = mom.dptr<DType>();
+        DType* out_data = out->dptr<DType>();
+        nnvm::dim_t num_rows = weight.shape_[0];
+        auto row_length = weight.shape_.ProdShape(1, weight.ndim());
+        Tensor<cpu, 1, char> workspace = ctx.requested[0]
+          .get_space_typed<cpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t)), s);
+
+        nnvm::dim_t* prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
+        // mark row flags
+        Kernel<set_zero, cpu>::Launch(s, num_rows, prefix_sum);
+        if (grad.storage_initialized()) {
+          Kernel<MarkRowFlgKernel, cpu>::Launch(s, grad.aux_shape(kIdx)[0],
+            prefix_sum, grad_idx);
+          // calculate inclusive prefix sum
+          for (nnvm::dim_t i = 1; i < num_rows; i++) {
+            prefix_sum[i] += prefix_sum[i - 1];
+          }
+        }
+        Kernel<SGDMomStdDnsRspDnsKernel<req_type>, cpu>::Launch(s, num_rows, row_length,
+          out_data, mom_data, weight_data, grad_idx, grad_val, prefix_sum,
+          static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
+          static_cast<DType>(param.lr), static_cast<DType>(param.wd),
+          static_cast<DType>(param.rescale_grad));
+      });
+    });
+  });
+}
+
+
 NNVM_REGISTER_OP(sgd_update)
 MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
 .describe(R"code(Update function for Stochastic Gradient Descent (SDG) optimizer.
@@ -84,7 +135,10 @@ It updates the weights using::
 
 Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
 
-If weight and momentum are both of ``row_sparse`` storage type,
+If weight and grad are both of ``row_sparse`` storage type and momentum is of ``default`` storage type,
+standard update is applied.
+
+If weight, grad and momentum are all of ``row_sparse`` storage type,
 only the row slices whose indices appear in grad.indices are updated (for both weight and momentum)::
 
   for row in gradient.indices:
@@ -97,11 +151,15 @@ only the row slices whose indices appear in grad.indices are updated (for both w
 .set_attr_parser(ParamParser<SGDMomParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<3, 1, false, true, false>)
+.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 1>)
 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
   [](const nnvm::NodeAttrs& attrs) {
     return std::vector<uint32_t>{2};
   })
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
 .set_attr<FCompute>("FCompute<cpu>", SGDMomUpdate<cpu>)
 .set_attr<FComputeEx>("FComputeEx<cpu>", SGDMomUpdateEx<cpu>)
 .add_argument("weight", "NDArray-or-Symbol", "Weight")
diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu
index 4306b32e6e..9512e92a80 100644
--- a/src/operator/optimizer_op.cu
+++ b/src/operator/optimizer_op.cu
@@ -24,10 +24,76 @@
  * \author Junyuan Xie
  */
 #include "./optimizer_op-inl.h"
+#include <cub/cub.cuh>
 
 namespace mxnet {
 namespace op {
 
+template<>
+void SGDMomStdUpdateDnsRspDnsImpl<gpu>(const SGDMomParam& param,
+                                       const OpContext& ctx,
+                                       const TBlob& weight,
+                                       const NDArray& grad,
+                                       const TBlob& mom,
+                                       const OpReqType& req,
+                                       TBlob *out) {
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  using namespace mshadow;
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
+  CHECK_GT(weight.shape_.Size(), 0);
+  CHECK_GT(mom.shape_.Size(), 0);
+
+  MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+        DType* weight_data = weight.dptr<DType>();
+        IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+        DType* grad_val = grad.data().dptr<DType>();
+        DType* mom_data = mom.dptr<DType>();
+        DType* out_data = out->dptr<DType>();
+        nnvm::dim_t num_rows = weight.shape_[0];
+        nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());
+
+        nnvm::dim_t* prefix_sum = NULL;
+        void* d_temp_storage = NULL;
+        size_t temp_storage_bytes = 0;
+        cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                      temp_storage_bytes,
+                                      prefix_sum,
+                                      prefix_sum,
+                                      num_rows,
+                                      Stream<gpu>::GetStream(s));
+        Tensor<gpu, 1, char> workspace = ctx.requested[0]
+          .get_space_typed<gpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t) +
+                                         temp_storage_bytes), s);
+        prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
+        d_temp_storage = workspace.dptr_ + num_rows*sizeof(nnvm::dim_t);
+        // mark row flags
+        Fill<false>(s, TBlob(prefix_sum, Shape1(num_rows), gpu::kDevMask), kWriteTo, 0);
+        if (grad.storage_initialized()) {
+          Kernel<MarkRowFlgKernel, gpu>::Launch(s, grad.aux_shape(kIdx)[0],
+            prefix_sum, grad_idx);
+          // calculate inclusive prefix sum
+          cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                        temp_storage_bytes,
+                                        prefix_sum,
+                                        prefix_sum,
+                                        num_rows,
+                                        mshadow::Stream<gpu>::GetStream(s));
+        }
+        Kernel<SGDMomStdDnsRspDnsKernel<req_type>, gpu>::Launch(s, num_rows, row_length,
+          out_data, mom_data, weight_data, grad_idx, grad_val, prefix_sum,
+          static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
+          static_cast<DType>(param.lr), static_cast<DType>(param.wd),
+          static_cast<DType>(param.rescale_grad));
+      });
+    });
+  });
+}
+
 NNVM_REGISTER_OP(sgd_update)
 .set_attr<FCompute>("FCompute<gpu>", SGDUpdate<gpu>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", SGDUpdateEx<gpu>);
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index 655e157ddf..ae248b0d0b 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -334,6 +334,29 @@ def test_sparse_sgd():
                                               w_stype='row_sparse', g_stype='row_sparse')
 
 
+def test_std_sparse_sgd():
+    mx.random.seed(0)
+    opt1 = PySGD
+    opt2 = mx.optimizer.SGD
+    shape = (3, 4, 5)
+    mom_options = [{'momentum': 0.9}]
+    cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+    rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+    wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+    for dtype in [np.float32]:
+        for mom_option in mom_options:
+            for cg_option in cg_options:
+                for rg_option in rg_options:
+                    for wd_option in wd_options:
+                        kwarg = {}
+                        kwarg.update(mom_option)
+                        kwarg.update(cg_option)
+                        kwarg.update(rg_option)
+                        kwarg.update(wd_option)
+                        compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape, dtype,
+                                          w_stype='row_sparse', g_stype='row_sparse')
+
+
 # FTML
 
 class PyFTML(mx.optimizer.Optimizer):
@@ -400,7 +423,6 @@ def test_ftml():
                             compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
 
 
-
 # ADAM
 
 class PyAdam(mx.optimizer.Optimizer):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services