You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/05/01 05:51:21 UTC

[incubator-mxnet] branch master updated: [MXNET-358] support dense weight & sparse grad for adam, sgd and sgd_momentum (#10664)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9f8f042  [MXNET-358] support dense weight & sparse grad for adam, sgd and sgd_momentum (#10664)
9f8f042 is described below

commit 9f8f042d6a603d7d5119de2aeab726bcdcfde78c
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Mon Apr 30 22:51:15 2018 -0700

    [MXNET-358] support dense weight & sparse grad for adam, sgd and sgd_momentum (#10664)
    
    * + support for dense weight with sparse grad for adam & sgd
    mom
    
    * fix test
    
    * sgd passes
    
    * fix typo
    
    * support adam
    
    * update doc
---
 python/mxnet/optimizer.py               |  10 +-
 src/operator/operator_common.h          |  18 +-
 src/operator/optimizer_op-inl.h         | 333 ++++++++++++++++++++------------
 src/operator/optimizer_op.cc            |  71 +++++--
 tests/python/unittest/test_optimizer.py |  36 ++--
 5 files changed, 301 insertions(+), 167 deletions(-)

diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 65540a9..1d2fd2e 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -434,7 +434,7 @@ register = Optimizer.register   # pylint: disable=invalid-name
 class SGD(Optimizer):
     """The SGD optimizer with momentum and weight decay.
 
-    If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, \
+    If the storage types of grad is ``row_sparse`` and ``lazy_update`` is True, \
     **lazy updates** are applied by::
 
         for row in grad.indices:
@@ -494,8 +494,8 @@ class SGD(Optimizer):
 
     def create_state(self, index, weight):
         momentum = None
-        stype = weight.stype if self.lazy_update else 'default'
         if self.momentum != 0.0:
+            stype = weight.stype if self.lazy_update else 'default'
             momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype)
         return momentum
 
@@ -515,7 +515,7 @@ class SGD(Optimizer):
         if not multi_precision:
             if state is not None:
                 sgd_mom_update(weight, grad, state, out=weight,
-                               lr=lr, wd=wd, **kwargs)
+                               lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs)
             else:
                 sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update,
                            lr=lr, wd=wd, **kwargs)
@@ -986,7 +986,7 @@ class Adam(Optimizer):
     This class implements the optimizer described in *Adam: A Method for
     Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980.
 
-    If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, \
+    If the storage types of grad is ``row_sparse``, and ``lazy_update`` is True, \
     **lazy updates** are applied by::
 
         for row in grad.indices:
@@ -1059,7 +1059,7 @@ class Adam(Optimizer):
 
         mean, var = state
         adam_update(weight, grad, mean, var, out=weight,
-                    lr=lr, wd=wd, **kwargs)
+                    lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs)
 
 @register
 class AdaGrad(Optimizer):
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index a629ba5..0a9cd08 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -471,14 +471,18 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) {
   attrs->parsed = std::move(param);
 }
 
-#define CHECK_RSP_ALL_ROWS_NON_ZERO(rsp, func, param)                              \
-  {                                                                                \
-    CHECK(rsp.storage_shape()[0] == rsp.shape()[0]) << func                        \
-          << " for RowSparse " << param << " is only implemented for "             \
-          << "RowSparse " << param << " with all rows containing non-zeros. "      \
-          << "Expects " << param << ".values.shape[0] (" << rsp.storage_shape()[0] \
-          << ") == " << param << ".shape[0] (" << rsp.shape()[0] << ").";          \
+inline void CheckAllRowsPresent(const NDArray& arr, const std::string& func,
+                                const std::string& param) {
+  if (arr.storage_type() == kRowSparseStorage) {
+    CHECK(arr.storage_shape()[0] == arr.shape()[0]) << func
+          << " for RowSparse " << param << " is only implemented for "
+          << "RowSparse " << param << " with all rows containing non-zeros. "
+          << "Expects " << param << ".data.shape[0] (" << arr.storage_shape()[0]
+          << ") == " << param << ".shape[0] (" << arr.shape()[0] << ").";
+  } else {
+    CHECK(arr.storage_type() == kDefaultStorage);
   }
+}
 
 inline void LogUnimplementedOp(const nnvm::NodeAttrs& attrs,
                                const OpContext &ctx,
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index dfc7bef..28b382c 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -42,6 +42,18 @@
 
 namespace mxnet {
 namespace op {
+
+/*
+ * \brief log message for optimizers with lazy update.
+ */
+inline void LogLazyUpdate() {
+  common::LogOnce("Optimizer with lazy_update = True detected. "
+                  "Be aware that lazy update with row_sparse gradient is different from "
+                  "standard update, and may lead to different empirical results. See "
+                  "https://mxnet.incubator.apache.org/api/python/optimization/optimization.html "
+                  "for more details.");
+}
+
 struct SGDParam : public dmlc::Parameter<SGDParam> {
   float lr;
   float wd;
@@ -66,7 +78,7 @@ struct SGDParam : public dmlc::Parameter<SGDParam> {
               "grad = max(min(grad, clip_gradient), -clip_gradient).");
     DMLC_DECLARE_FIELD(lazy_update)
     .set_default(true)
-    .describe("If true, lazy updates are applied.");
+    .describe("If true, lazy updates are applied if gradient's stype is row_sparse.");
   }
 };
 
@@ -167,6 +179,10 @@ struct SGDDnsRspKernel<req, cpu> {
   }
 };
 
+/*
+ * \brief SGD update implementation for dense weight and row_sparse grad.
+ *        Both standard update and lazy update are supported.
+ */
 template<typename xpu>
 inline void SGDUpdateDnsRspImpl(const SGDParam& param,
                                 const OpContext &ctx,
@@ -190,6 +206,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
       MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
         DType* weight_data = weight.dptr<DType>();
         float wd = param.wd;
+        // apply standard weight decay if not lazy update
         if (!param.lazy_update) {
           Kernel<op_with_req<mshadow_op::mul, req_type>, xpu>::Launch(s, weight.Size(),
             weight_data, weight_data, static_cast<DType>(1 - param.lr * param.wd));
@@ -214,14 +231,18 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
   });
 }
 
+/*
+ * \brief SGD update implementation for row_sparse grad.
+ *        Both standard update and lazy update are supported.
+ */
 template<typename xpu>
-inline void SGDUpdateRspRspImpl(const SGDParam& param,
-                                const OpContext& ctx,
-                                const NDArray& weight,
-                                const NDArray& grad,
-                                const OpReqType& req,
-                                NDArray *out) {
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights");
+inline void SGDUpdateRspImpl(const SGDParam& param,
+                             const OpContext& ctx,
+                             const NDArray& weight,
+                             const NDArray& grad,
+                             const OpReqType& req,
+                             NDArray *out) {
+  CheckAllRowsPresent(weight, "SGDUpdate", "weights");
   // reuse dns rsp implementation when storage_shape == shape
   TBlob out_blob = out->data();
   SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, req, &out_blob);
@@ -233,15 +254,15 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
                         const std::vector<NDArray> &inputs,
                         const std::vector<OpReqType> &req,
                         const std::vector<NDArray> &outputs) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-  using namespace mshadow_op;
   const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
-  auto out_stype = outputs[0].storage_type();
-  if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
-      out_stype == kRowSparseStorage) {
+  const auto w_stype = inputs[0].storage_type();
+  const auto g_stype = inputs[1].storage_type();
+  const auto o_stype = outputs[0].storage_type();
+  if (o_stype == w_stype && g_stype == kRowSparseStorage &&
+      (w_stype == kDefaultStorage || w_stype == kRowSparseStorage)) {
     NDArray out = outputs[0];
-    SGDUpdateRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out);
+    // std update and lazy update with rsp grad
+    SGDUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out);
   } else {
     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
   }
@@ -253,6 +274,7 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
   float wd;
   float rescale_grad;
   float clip_gradient;
+  bool lazy_update;
   DMLC_DECLARE_PARAMETER(SGDMomParam) {
     DMLC_DECLARE_FIELD(lr)
     .describe("Learning rate");
@@ -272,6 +294,10 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
               "If clip_gradient <= 0, gradient clipping is turned off. "
               "grad = max(min(grad, clip_gradient), -clip_gradient).");
+    DMLC_DECLARE_FIELD(lazy_update)
+    .set_default(true)
+    .describe("If true, lazy updates are applied if gradient's stype is row_sparse "
+              "and both weight and momentum have the same stype");
   }
 };
 
@@ -478,14 +504,17 @@ struct SGDMomDnsRspDnsKernel<req, gpu> {
   }
 };
 
+/*
+ * \brief sgd mom lazy update for dense weight, row_sparse grad, dense state.
+ */
 template<typename xpu>
-inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
-                                      const OpContext& ctx,
-                                      const TBlob& weight,
-                                      const NDArray& grad,
-                                      const TBlob& mom,
-                                      const OpReqType& req,
-                                      TBlob *out) {
+inline void SGDMomLazyUpdateDnsRspDnsImpl(const SGDMomParam& param,
+                                          const OpContext& ctx,
+                                          const TBlob& weight,
+                                          const NDArray& grad,
+                                          const TBlob& mom,
+                                          const OpReqType& req,
+                                          TBlob *out) {
   using namespace mxnet_op;
   using namespace rowsparse;
   Stream<xpu>* s = ctx.get_stream<xpu>();
@@ -518,69 +547,78 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
   });
 }
 
+/*
+ * \brief sgd momentum lazy update for row_sparse grad.
+ */
 template<typename xpu>
-inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
-                                      const OpContext& ctx,
-                                      const NDArray& weight,
-                                      const NDArray& grad,
-                                      const NDArray& mom,
-                                      const OpReqType& req,
-                                      NDArray *out) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
+inline void SGDMomLazyUpdateRspImpl(const SGDMomParam& param,
+                                    const OpContext& ctx,
+                                    const NDArray& weight,
+                                    const NDArray& grad,
+                                    const NDArray& mom,
+                                    const OpReqType& req,
+                                    NDArray *out) {
   using namespace mxnet_op;
   using namespace rowsparse;
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
+  CheckAllRowsPresent(weight, "SGDMomUpdate", "weights");
   Stream<xpu>* s = ctx.get_stream<xpu>();
-  // fill mom with zero values in order to reuse the sgd mom dns impl
-  if (!mom.storage_initialized()) {
+  // fill mom with zero values (if it's in rsp storage)
+  // in order to reuse the sgd mom dns impl
+  if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) {
     NDArray mom_zeros = mom;
     FillDnsZerosRspImpl(s, &mom_zeros);
   }
   TBlob out_blob = out->data();
   // reuse dns rsp implementation when storage_shape == shape
-  SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
-                                 mom.data(), req, &out_blob);
+  SGDMomLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
+                                     mom.data(), req, &out_blob);
 }
 
 /*!
- * \brief Storge type inference function in optimizer.
- * \param n_rsp     The number of inputs that should be of row_sparse storage type
- *                  if kFComputeEx is dispatched
- * \param n_rsp_dns The number of inputs that should be of row_sparse or default storage type
- *                  if kFComputeEx is dispatched
+ * \brief Storge type inference function for optimizers which support both
+ *        lazy update and standard update, with states (e.g. 2nd order moment)
+ * \param num_states The number of states that could be row_sparse or dense
  */
-template<int n_rsp, int n_rsp_dns>
+template<size_t num_states, typename ParamType>
 inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs,
                               const int dev_mask,
                               DispatchMode* dispatch_mode,
                               std::vector<int>* in_attrs,
                               std::vector<int>* out_attrs) {
   using namespace common;
-  CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_rsp + n_rsp_dns));
+  const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
+  // weight, grad, state 0, state 1, ... -> weight
+  CHECK_EQ(in_attrs->size(), 2 + num_states);
   CHECK_EQ(out_attrs->size(), 1U);
+  const int weight_stype = in_attrs->at(0);
+  const int grad_stype = in_attrs->at(1);
+  const int state_stype = in_attrs->at(2);
+  // the storage type of all states should be the same
+  for (size_t i = 3; i <  2 + num_states; i++) {
+    CHECK_EQ(state_stype, in_attrs->at(i))
+      << "Inconsistent storage types detected in state " << i;
+  }
   bool dispatched = false;
   if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
     // dns, ... -> dns
     dispatched = storage_type_assign(out_attrs, kDefaultStorage,
                                      dispatch_mode, DispatchMode::kFCompute);
   }
-  const std::vector<int> rsp_stypes(in_attrs->begin(), in_attrs->begin() + n_rsp);
-  const std::vector<int> rsp_dns_stypes(in_attrs->begin() + n_rsp, in_attrs->end());
-  if (!dispatched && ContainsOnlyStorage(rsp_stypes, kRowSparseStorage) &&
-      (ContainsOnlyStorage(rsp_dns_stypes, kRowSparseStorage) ||
-       ContainsOnlyStorage(rsp_dns_stypes, kDefaultStorage))) {
-    // rsp, ..., rsp/dns, ... -> rsp
-    dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
+  if (!dispatched && grad_stype == kRowSparseStorage &&
+      (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
+      state_stype == weight_stype) {
+    // weight and state share stype, grad's stype = rsp
+    dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
                                      dispatch_mode, DispatchMode::kFComputeEx);
     // warn users if lazy_update is turned on
-    if (dispatched && ContainsOnlyStorage(rsp_dns_stypes, kRowSparseStorage)) {
-      LogOnce("Optimizer with lazy_update = True detected. "
-      "Be aware that lazy update is different from standard update, "
-      "and may lead to different empirical results. See "
-      "https://mxnet.incubator.apache.org/api/python/optimization/optimization.html "
-      "for more details.");
-    }
+    if (dispatched && param.lazy_update) LogLazyUpdate();
+  }
+  if (!dispatched && grad_stype == kRowSparseStorage &&
+      weight_stype == kRowSparseStorage && state_stype == kDefaultStorage) {
+    // weight,  grad, state, ...  -> weight
+    // rsp,     rsp,  dns,   ...  -> rsp, standard update
+    dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
+                                     dispatch_mode, DispatchMode::kFComputeEx);
   }
   if (!dispatched) {
     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
@@ -588,10 +626,16 @@ inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs,
   return dispatched;
 }
 
+/*
+ * \brief kernel for standard momentum update for dense weight, sparse grad and dense state.
+ */
 template<int req, typename xpu>
 struct SGDMomStdDnsRspDnsKernel;
 
 
+/*
+ * \brief standard momentum update for dense weight, row_sparse grad and dense states.
+ */
 template<typename xpu>
 void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param,
                                   const OpContext& ctx,
@@ -601,19 +645,28 @@ void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param,
                                   const OpReqType& req,
                                   TBlob *out);
 
+/*
+ * \brief standard momentum update for row_sparse grad.
+ *        both row_sparse and dense weight are supported.
+ */
 template<typename xpu>
-inline void SGDMomStdUpdateRspRspDnsImpl(const SGDMomParam& param,
-                                         const OpContext& ctx,
-                                         const NDArray& weight,
-                                         const NDArray& grad,
-                                         const NDArray& mom,
-                                         const OpReqType& req,
-                                         NDArray *out) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
+inline void SGDMomStdUpdateRspImpl(const SGDMomParam& param,
+                                   const OpContext& ctx,
+                                   const NDArray& weight,
+                                   const NDArray& grad,
+                                   const NDArray& mom,
+                                   const OpReqType& req,
+                                   NDArray *out) {
   using namespace mxnet_op;
   using namespace rowsparse;
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
+  CheckAllRowsPresent(weight, "SGDMomUpdate", "weights");
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  // fill mom with zero values (if it's in rsp storage)
+  // in order to reuse the sgd mom dns impl
+  if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) {
+    NDArray mom_zeros = mom;
+    FillDnsZerosRspImpl(s, &mom_zeros);
+  }
   TBlob out_blob = out->data();
   SGDMomStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
                                     mom.data(), req, &out_blob);
@@ -630,16 +683,25 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
   auto &weight = inputs[0];
   auto &grad = inputs[1];
   auto &mom = inputs[2];
+  const auto w_stype = weight.storage_type();
+  const auto m_stype = mom.storage_type();
   const auto out_stype = outputs[0].storage_type();
   NDArray out = outputs[0];
-  if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
-      out_stype == kRowSparseStorage) {
-    SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
-  } else if (weight.storage_type() == kRowSparseStorage &&
-             grad.storage_type() == kRowSparseStorage &&
-             mom.storage_type() == kDefaultStorage &&
-             out_stype == kRowSparseStorage) {
-    SGDMomStdUpdateRspRspDnsImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
+  const bool valid_weight = w_stype == kDefaultStorage || w_stype == kRowSparseStorage;
+  const bool valid_grad = grad.storage_type() == kRowSparseStorage;
+  const bool lazy_update = param.lazy_update;
+  CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype";
+  if (valid_weight && valid_grad && m_stype == w_stype) {
+    if (lazy_update) {
+      // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update
+      SGDMomLazyUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
+    } else {
+      // rsp grad && m.stype = w.stype && lazy_update = false -> std update
+      SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
+    }
+  } else if (w_stype == kRowSparseStorage && valid_grad && m_stype == kDefaultStorage) {
+    // rsp weight, rsp grad, dns state -> std update
+    SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
   } else {
     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
   }
@@ -742,6 +804,7 @@ struct AdamParam : public dmlc::Parameter<AdamParam> {
   float wd;
   float rescale_grad;
   float clip_gradient;
+  bool lazy_update;
   DMLC_DECLARE_PARAMETER(AdamParam) {
     DMLC_DECLARE_FIELD(lr)
     .describe("Learning rate");
@@ -767,6 +830,10 @@ struct AdamParam : public dmlc::Parameter<AdamParam> {
     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
               "If clip_gradient <= 0, gradient clipping is turned off. "
               "grad = max(min(grad, clip_gradient), -clip_gradient).");
+    DMLC_DECLARE_FIELD(lazy_update)
+    .set_default(true)
+    .describe("If true, lazy updates are applied if gradient's stype is row_sparse "
+              "and all of w, m and v have the same stype");
   }
 };
 
@@ -876,15 +943,18 @@ struct AdamDnsRspDnsKernel<req, gpu> {
   }
 };
 
+/*
+ * \brief lazy adam update for dense weight, dense states and rsp grad.
+ */
 template<typename xpu>
-inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param,
-                                    const OpContext& ctx,
-                                    const TBlob& weight,
-                                    const NDArray& grad,
-                                    const TBlob& mean,
-                                    const TBlob& var,
-                                    const OpReqType& req,
-                                    TBlob *out) {
+inline void AdamLazyUpdateDnsRspDnsImpl(const AdamParam& param,
+                                        const OpContext& ctx,
+                                        const TBlob& weight,
+                                        const NDArray& grad,
+                                        const TBlob& mean,
+                                        const TBlob& var,
+                                        const OpReqType& req,
+                                        TBlob *out) {
   using namespace mxnet_op;
   using namespace rowsparse;
   Stream<xpu>* s = ctx.get_stream<xpu>();
@@ -920,39 +990,47 @@ inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param,
   });
 }
 
+/*
+ * \brief lazy adam update for both row_sparse and dense weight.
+ *        grad is expected to be row_sparse.
+ */
 template<typename xpu>
-inline void AdamUpdateRspRspRspImpl(const AdamParam& param,
-                                    const OpContext& ctx,
-                                    const NDArray& weight,
-                                    const NDArray& grad,
-                                    const NDArray& mean,
-                                    const NDArray& var,
-                                    const OpReqType& req,
-                                    NDArray *out) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
+inline void AdamLazyUpdateRspImpl(const AdamParam& param,
+                                  const OpContext& ctx,
+                                  const NDArray& weight,
+                                  const NDArray& grad,
+                                  const NDArray& mean,
+                                  const NDArray& var,
+                                  const OpReqType& req,
+                                  NDArray *out) {
   using namespace mxnet_op;
   using namespace rowsparse;
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamUpdate", "weights");
+  CheckAllRowsPresent(weight, "AdamUpdate", "weights");
   Stream<xpu>* s = ctx.get_stream<xpu>();
   // fill mean and variance with zero values in order to reuse the sgd mom dns impl
-  if (!mean.storage_initialized()) {
+  if (mean.storage_type() == kRowSparseStorage && !mean.storage_initialized()) {
     NDArray mean_zeros = mean;
     FillDnsZerosRspImpl(s, &mean_zeros);
   }
-  if (!var.storage_initialized()) {
+  if (var.storage_type() == kRowSparseStorage && !var.storage_initialized()) {
     NDArray var_zeros = var;
     FillDnsZerosRspImpl(s, &var_zeros);
   }
   TBlob out_blob = out->data();
   // reuse dns rsp implementation when storage_shape == shape
-  AdamUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
-                               var.data(), req, &out_blob);
+  AdamLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
+                                   var.data(), req, &out_blob);
 }
 
+/*
+ * \brief kernel for standard adam update for dense weight, row_sparse grad and dense states.
+ */
 template<int req, typename xpu>
 struct AdamStdDnsRspDnsKernel;
 
+/*
+ * \brief standard adam update for dense weight, row_sparse grad and dense states.
+ */
 template<typename xpu>
 void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param,
                                 const OpContext& ctx,
@@ -963,18 +1041,22 @@ void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param,
                                 const OpReqType& req,
                                 TBlob *out);
 
+/*
+ * \brief standard adam update for both row_sparse and dense weight.
+ *        states are expected to be dense, while grad is expected to be row_sparse.
+ */
 template<typename xpu>
-inline void AdamStdUpdateRspRspRspImpl(const AdamParam& param,
-                                       const OpContext& ctx,
-                                       const NDArray& weight,
-                                       const NDArray& grad,
-                                       const NDArray& mean,
-                                       const NDArray& var,
-                                       const OpReqType& req,
-                                       NDArray *out) {
+inline void AdamStdUpdateRspImpl(const AdamParam& param,
+                                 const OpContext& ctx,
+                                 const NDArray& weight,
+                                 const NDArray& grad,
+                                 const NDArray& mean,
+                                 const NDArray& var,
+                                 const OpReqType& req,
+                                 NDArray *out) {
   using namespace mxnet_op;
   using namespace rowsparse;
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamStdUpdate", "weights");
+  CheckAllRowsPresent(weight, "AdamStdUpdate", "weights");
   TBlob out_blob = out->data();
   // reuse dns rsp implementation when storage_shape == shape
   AdamStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
@@ -988,21 +1070,30 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
                          const std::vector<OpReqType> &req,
                          const std::vector<NDArray> &outputs) {
   const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
-  const auto weight_stype = inputs[0].storage_type();
-  const auto grad_stype = inputs[1].storage_type();
-  const auto mean_stype = inputs[2].storage_type();
-  const auto var_stype = inputs[3].storage_type();
+  const auto w_stype = inputs[0].storage_type();
+  const auto g_stype = inputs[1].storage_type();
+  const auto m_stype = inputs[2].storage_type();
+  const auto v_stype = inputs[3].storage_type();
   const auto out_stype = outputs[0].storage_type();
   NDArray out = outputs[0];
-  if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
-      out_stype == kRowSparseStorage) {
-     AdamUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
+  const bool valid_weight = w_stype == kDefaultStorage || w_stype == kRowSparseStorage;
+  CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype";
+  CHECK(m_stype == v_stype) << "Inconsistent mean stype and var stype";
+  if (valid_weight && g_stype == kRowSparseStorage && m_stype == w_stype) {
+     if (param.lazy_update) {
+       // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update
+       AdamLazyUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
                                   inputs[3], req[0], &out);
-  } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage &&
-             mean_stype == kDefaultStorage && var_stype == kDefaultStorage &&
-             out_stype == kRowSparseStorage) {
-     AdamStdUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
-                                     inputs[3], req[0], &out);
+     } else {
+       // rsp grad && m.stype = w.stype && lazy_update = false -> std update
+       AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
+                                 inputs[3], req[0], &out);
+     }
+  } else if (w_stype == kRowSparseStorage && g_stype == kRowSparseStorage &&
+             m_stype == kDefaultStorage) {
+     // rsp, rsp, dns, dns -> rsp, standard update
+     AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
+                               inputs[3], req[0], &out);
   } else {
     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
   }
@@ -1361,7 +1452,7 @@ inline void FtrlUpdateRspRspRspImpl(const FtrlParam& param,
   using namespace mshadow::expr;
   using namespace mxnet_op;
   using namespace rowsparse;
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "FtrlUpdate", "weights");
+  CheckAllRowsPresent(weight, "FtrlUpdate", "weights");
   Stream<xpu>* s = ctx.get_stream<xpu>();
   // fill z and n with zero values in order to reuse the sgd mom dns impl
   if (!z.storage_initialized()) {
@@ -1690,7 +1781,7 @@ inline void AdagradUpdateRspRspRspImpl(const AdagradParam& param,
   using namespace mshadow;
   using namespace mxnet_op;
   using namespace rowsparse;
-  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdagradUpdate", "weights");
+  CheckAllRowsPresent(weight, "AdagradUpdate", "weights");
   Stream<xpu>* s = ctx.get_stream<xpu>();
   // fill history with zero values
   if (!state.storage_initialized()) {
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 0e73c1d..935c92a 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -132,6 +132,10 @@ struct SGDMomStdDnsRspDnsKernel<req, cpu> {
   }
 };
 
+/*
+ * \brief standard momentum update for dense weight on cpu.
+ *        state is expected to be dense, while grad is expected to be row_sparse.
+ */
 template<>
 void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
                                        const OpContext& ctx,
@@ -152,12 +156,12 @@ void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
     MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
       MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
         DType* weight_data = weight.dptr<DType>();
-        IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
-        DType* grad_val = grad.data().dptr<DType>();
+        const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+        const DType* grad_val = grad.data().dptr<DType>();
         DType* mom_data = mom.dptr<DType>();
         DType* out_data = out->dptr<DType>();
-        nnvm::dim_t num_rows = weight.shape_[0];
-        auto row_length = weight.shape_.ProdShape(1, weight.ndim());
+        const nnvm::dim_t num_rows = weight.shape_[0];
+        const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
         Tensor<cpu, 1, char> workspace = ctx.requested[0]
           .get_space_typed<cpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t)), s);
 
@@ -275,6 +279,40 @@ void AdamStdUpdateDnsRspDnsImpl<cpu>(const AdamParam& param,
   });
 }
 
+/*!
+ * \brief Storge type inference function for SGD.
+ */
+inline bool SGDStorageType(const nnvm::NodeAttrs& attrs,
+                           const int dev_mask,
+                           DispatchMode* dispatch_mode,
+                           std::vector<int>* in_attrs,
+                           std::vector<int>* out_attrs) {
+  using namespace common;
+  const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const int weight_stype = in_attrs->at(0);
+  const int grad_stype = in_attrs->at(1);
+  bool dispatched = false;
+  if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+    // dns, ... -> dns
+    dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  if (!dispatched && grad_stype == kRowSparseStorage &&
+      (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage)) {
+    // grad's stype = rsp
+    dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+    // warn users if lazy_update is turned on
+    if (dispatched && param.lazy_update) LogLazyUpdate();
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  return dispatched;
+}
+
 
 NNVM_REGISTER_OP(sgd_update)
 MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
@@ -282,13 +320,13 @@ MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
 
 It updates the weights using::
 
- weight = weight - learning_rate * gradient
+ weight = weight - learning_rate * (gradient + wd * weight)
 
-If weight is of ``row_sparse`` storage type,
+However, if gradient is of ``row_sparse`` storage type and ``lazy_update`` is True,
 only the row slices whose indices appear in grad.indices are updated::
 
  for row in gradient.indices:
-     weight[row] = weight[row] - learning_rate * gradient[row]
+     weight[row] = weight[row] - learning_rate * (gradient[row] + wd * weight[row])
 
 )code" ADD_FILELINE)
 .set_num_inputs(2)
@@ -296,7 +334,7 @@ only the row slices whose indices appear in grad.indices are updated::
 .set_attr_parser(ParamParser<SGDParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<2, 1, false, true, false>)
+.set_attr<FInferStorageType>("FInferStorageType", SGDStorageType)
 .set_attr<FCompute>("FCompute<cpu>", SGDUpdate<cpu>)
 .set_attr<FComputeEx>("FComputeEx<cpu>", SGDUpdateEx<cpu>)
 .add_argument("weight", "NDArray-or-Symbol", "Weight")
@@ -305,7 +343,7 @@ only the row slices whose indices appear in grad.indices are updated::
 
 NNVM_REGISTER_OP(sgd_mom_update)
 MXNET_ADD_SPARSE_OP_ALIAS(sgd_mom_update)
-.describe(R"code(Momentum update function for Stochastic Gradient Descent (SDG) optimizer.
+.describe(R"code(Momentum update function for Stochastic Gradient Descent (SGD) optimizer.
 
 Momentum update has better convergence rates on neural networks. Mathematically it looks
 like below:
@@ -323,10 +361,8 @@ It updates the weights using::
 
 Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
 
-If weight and grad are both of ``row_sparse`` storage type and momentum is of ``default`` storage type,
-standard update is applied.
-
-If weight, grad and momentum are all of ``row_sparse`` storage type,
+However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and weight's storage
+type is the same as momentum's storage type,
 only the row slices whose indices appear in grad.indices are updated (for both weight and momentum)::
 
   for row in gradient.indices:
@@ -339,7 +375,7 @@ only the row slices whose indices appear in grad.indices are updated (for both w
 .set_attr_parser(ParamParser<SGDMomParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 1>)
+.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<1, SGDMomParam>)
 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
   [](const nnvm::NodeAttrs& attrs) {
     return std::vector<uint32_t>{2};
@@ -424,7 +460,7 @@ available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.
 .add_argument("d", "NDArray-or-Symbol", "Internal state ``d_t``")
 .add_argument("v", "NDArray-or-Symbol", "Internal state ``v_t``")
 .add_argument("z", "NDArray-or-Symbol", "Internal state ``z_t``")
-.add_arguments(AdamParam::__FIELDS__());
+.add_arguments(FTMLParam::__FIELDS__());
 
 NNVM_REGISTER_OP(adam_update)
 MXNET_ADD_SPARSE_OP_ALIAS(adam_update)
@@ -447,7 +483,8 @@ It updates the weights using::
  v = beta2*v + (1-beta2)*(grad**2)
  w += - learning_rate * m / (sqrt(v) + epsilon)
 
-If w, m and v are all of ``row_sparse`` storage type,
+However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and the storage
+type of weight is the same as those of m and v,
 only the row slices whose indices appear in grad.indices are updated (for w, m and v)::
 
  for row in grad.indices:
@@ -465,7 +502,7 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
   })
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 2>)
+.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, AdamParam>)
 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
   [](const nnvm::NodeAttrs& attrs) {
     return std::vector<uint32_t>{2, 3};
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index d1dc31a..90762f7 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -204,7 +204,6 @@ class PySGD(mx.optimizer.Optimizer):
     def update_multi_precision(self, index, weight, grad, state):
         self.update(index, weight, grad, state)
 
-@unittest.skip("Test fails intermittently. Temporarily disabled until fixed. Tracked at https://github.com/apache/incubator-mxnet/issues/9000")
 @with_seed()
 def test_sgd():
     opt1 = PySGD
@@ -233,16 +232,9 @@ def test_sgd():
                                 continue
                             compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
                             # test operator fallback on cpu
-                            if (default_context() == mx.cpu()):
-                                compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
-                                                  g_stype='row_sparse')
-                                if dtype != np.float16:
-                                    compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape[:2],
-                                                      dtype, w_stype='csr', g_stype='csr')
-    # test optimizer with a big shape
-    big_shape = (54686454, 1)
-    kwarg = {'momentum': 0.9, 'wd': 0.05}
-    compare_optimizer(opt1(**kwarg), opt2(**kwarg), big_shape, np.float32)
+                            if dtype != np.float16:
+                                compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape[:2],
+                                                  dtype, w_stype='csr', g_stype='csr')
 
 class PySparseSGD(mx.optimizer.Optimizer):
     """python reference implemenation of sgd"""
@@ -337,9 +329,11 @@ def test_sparse_sgd():
                             kwarg.update(mp_option)
                             compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
                                               w_stype='row_sparse', g_stype='row_sparse')
+                            compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
+                                              w_stype='default', g_stype='row_sparse')
 
 
-@with_seed(0)
+@with_seed()
 def test_std_sparse_sgd():
     opt1 = PySGD
     opt2 = mx.optimizer.SGD
@@ -360,6 +354,8 @@ def test_std_sparse_sgd():
                         kwarg.update(wd_option)
                         compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape, dtype,
                                           w_stype='row_sparse', g_stype='row_sparse')
+                        compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape, dtype,
+                                          w_stype='default', g_stype='row_sparse')
 
 
 class PyNAG(PySGD):
@@ -543,7 +539,7 @@ def test_ftml():
 class PyAdam(mx.optimizer.Optimizer):
     """python reference implemenation of adam"""
     def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
-                 decay_factor=(1 - 1e-8), lazy_update=False, **kwargs):
+                 decay_factor=(1 - 1e-8), lazy_update=True, **kwargs):
         super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs)
         self.beta1 = beta1
         self.beta2 = beta2
@@ -594,7 +590,7 @@ class PyAdam(mx.optimizer.Optimizer):
         for row in range(num_rows):
             # check row slices of all zeros
             all_zeros = mx.test_utils.almost_equal(grad[row].asnumpy(), np.zeros_like(grad[row].asnumpy()))
-            # skip zeros during sparse update
+            # skip zeros during lazy update
             if all_zeros and self.lazy_update:
                 continue
             grad[row] = grad[row] * self.rescale_grad + wd * weight[row]
@@ -635,15 +631,21 @@ def test_adam():
                                     not kwarg['multi_precision'])):
                             continue
                         # atol 2e-5 needed to pass with seed 1248389097
-                        compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
+                        compare_optimizer(opt1(lazy_update=False, **kwarg), opt2(**kwarg), shape, dtype,
                                           rtol=1e-4, atol=2e-5)
                         # atol 2e-5 needed to pass with seed 781809840
-                        compare_optimizer(opt1(lazy_update=True, **kwarg), opt2(**kwarg), shape,
+                        compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape,
                                           dtype, w_stype='row_sparse', g_stype='row_sparse',
                                           rtol=1e-4, atol=2e-5)
-                        compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape,
+                        compare_optimizer(opt1(lazy_update=False, **kwarg), opt2(lazy_update=False, **kwarg), shape,
                                           dtype, w_stype='row_sparse', g_stype='row_sparse',
                                           rtol=1e-4, atol=2e-5)
+                        compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape,
+                                          dtype, w_stype='default', g_stype='row_sparse',
+                                          rtol=1e-4, atol=2e-5)
+                        compare_optimizer(opt1(lazy_update=False, **kwarg), opt2(lazy_update=False, **kwarg), shape,
+                                          dtype, w_stype='default', g_stype='row_sparse',
+                                          rtol=1e-4, atol=2e-5)
 
 # Signum
 class PySignum(mx.optimizer.Optimizer):

-- 
To stop receiving notification emails like this one, please contact
haibin@apache.org.