You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/05/01 05:51:21 UTC
[incubator-mxnet] branch master updated: [MXNET-358] support dense
weight & sparse grad for adam, sgd and sgd_momentum (#10664)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9f8f042 [MXNET-358] support dense weight & sparse grad for adam, sgd and sgd_momentum (#10664)
9f8f042 is described below
commit 9f8f042d6a603d7d5119de2aeab726bcdcfde78c
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Mon Apr 30 22:51:15 2018 -0700
[MXNET-358] support dense weight & sparse grad for adam, sgd and sgd_momentum (#10664)
* + support for dense weight with sparse grad for adam & sgd
mom
* fix test
* sgd passes
* fix typo
* support adam
* update doc
---
python/mxnet/optimizer.py | 10 +-
src/operator/operator_common.h | 18 +-
src/operator/optimizer_op-inl.h | 333 ++++++++++++++++++++------------
src/operator/optimizer_op.cc | 71 +++++--
tests/python/unittest/test_optimizer.py | 36 ++--
5 files changed, 301 insertions(+), 167 deletions(-)
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 65540a9..1d2fd2e 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -434,7 +434,7 @@ register = Optimizer.register # pylint: disable=invalid-name
class SGD(Optimizer):
"""The SGD optimizer with momentum and weight decay.
- If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, \
+ If the storage types of grad is ``row_sparse`` and ``lazy_update`` is True, \
**lazy updates** are applied by::
for row in grad.indices:
@@ -494,8 +494,8 @@ class SGD(Optimizer):
def create_state(self, index, weight):
momentum = None
- stype = weight.stype if self.lazy_update else 'default'
if self.momentum != 0.0:
+ stype = weight.stype if self.lazy_update else 'default'
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype)
return momentum
@@ -515,7 +515,7 @@ class SGD(Optimizer):
if not multi_precision:
if state is not None:
sgd_mom_update(weight, grad, state, out=weight,
- lr=lr, wd=wd, **kwargs)
+ lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs)
else:
sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update,
lr=lr, wd=wd, **kwargs)
@@ -986,7 +986,7 @@ class Adam(Optimizer):
This class implements the optimizer described in *Adam: A Method for
Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980.
- If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, \
+ If the storage types of grad is ``row_sparse``, and ``lazy_update`` is True, \
**lazy updates** are applied by::
for row in grad.indices:
@@ -1059,7 +1059,7 @@ class Adam(Optimizer):
mean, var = state
adam_update(weight, grad, mean, var, out=weight,
- lr=lr, wd=wd, **kwargs)
+ lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs)
@register
class AdaGrad(Optimizer):
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index a629ba5..0a9cd08 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -471,14 +471,18 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) {
attrs->parsed = std::move(param);
}
-#define CHECK_RSP_ALL_ROWS_NON_ZERO(rsp, func, param) \
- { \
- CHECK(rsp.storage_shape()[0] == rsp.shape()[0]) << func \
- << " for RowSparse " << param << " is only implemented for " \
- << "RowSparse " << param << " with all rows containing non-zeros. " \
- << "Expects " << param << ".values.shape[0] (" << rsp.storage_shape()[0] \
- << ") == " << param << ".shape[0] (" << rsp.shape()[0] << ")."; \
+inline void CheckAllRowsPresent(const NDArray& arr, const std::string& func,
+ const std::string& param) {
+ if (arr.storage_type() == kRowSparseStorage) {
+ CHECK(arr.storage_shape()[0] == arr.shape()[0]) << func
+ << " for RowSparse " << param << " is only implemented for "
+ << "RowSparse " << param << " with all rows containing non-zeros. "
+ << "Expects " << param << ".data.shape[0] (" << arr.storage_shape()[0]
+ << ") == " << param << ".shape[0] (" << arr.shape()[0] << ").";
+ } else {
+ CHECK(arr.storage_type() == kDefaultStorage);
}
+}
inline void LogUnimplementedOp(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index dfc7bef..28b382c 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -42,6 +42,18 @@
namespace mxnet {
namespace op {
+
+/*
+ * \brief log message for optimizers with lazy update.
+ */
+inline void LogLazyUpdate() {
+ common::LogOnce("Optimizer with lazy_update = True detected. "
+ "Be aware that lazy update with row_sparse gradient is different from "
+ "standard update, and may lead to different empirical results. See "
+ "https://mxnet.incubator.apache.org/api/python/optimization/optimization.html "
+ "for more details.");
+}
+
struct SGDParam : public dmlc::Parameter<SGDParam> {
float lr;
float wd;
@@ -66,7 +78,7 @@ struct SGDParam : public dmlc::Parameter<SGDParam> {
"grad = max(min(grad, clip_gradient), -clip_gradient).");
DMLC_DECLARE_FIELD(lazy_update)
.set_default(true)
- .describe("If true, lazy updates are applied.");
+ .describe("If true, lazy updates are applied if gradient's stype is row_sparse.");
}
};
@@ -167,6 +179,10 @@ struct SGDDnsRspKernel<req, cpu> {
}
};
+/*
+ * \brief SGD update implementation for dense weight and row_sparse grad.
+ * Both standard update and lazy update are supported.
+ */
template<typename xpu>
inline void SGDUpdateDnsRspImpl(const SGDParam& param,
const OpContext &ctx,
@@ -190,6 +206,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
DType* weight_data = weight.dptr<DType>();
float wd = param.wd;
+ // apply standard weight decay if not lazy update
if (!param.lazy_update) {
Kernel<op_with_req<mshadow_op::mul, req_type>, xpu>::Launch(s, weight.Size(),
weight_data, weight_data, static_cast<DType>(1 - param.lr * param.wd));
@@ -214,14 +231,18 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
});
}
+/*
+ * \brief SGD update implementation for row_sparse grad.
+ * Both standard update and lazy update are supported.
+ */
template<typename xpu>
-inline void SGDUpdateRspRspImpl(const SGDParam& param,
- const OpContext& ctx,
- const NDArray& weight,
- const NDArray& grad,
- const OpReqType& req,
- NDArray *out) {
- CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights");
+inline void SGDUpdateRspImpl(const SGDParam& param,
+ const OpContext& ctx,
+ const NDArray& weight,
+ const NDArray& grad,
+ const OpReqType& req,
+ NDArray *out) {
+ CheckAllRowsPresent(weight, "SGDUpdate", "weights");
// reuse dns rsp implementation when storage_shape == shape
TBlob out_blob = out->data();
SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, req, &out_blob);
@@ -233,15 +254,15 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
- using namespace mshadow;
- using namespace mshadow::expr;
- using namespace mshadow_op;
const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
- auto out_stype = outputs[0].storage_type();
- if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
- out_stype == kRowSparseStorage) {
+ const auto w_stype = inputs[0].storage_type();
+ const auto g_stype = inputs[1].storage_type();
+ const auto o_stype = outputs[0].storage_type();
+ if (o_stype == w_stype && g_stype == kRowSparseStorage &&
+ (w_stype == kDefaultStorage || w_stype == kRowSparseStorage)) {
NDArray out = outputs[0];
- SGDUpdateRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out);
+ // std update and lazy update with rsp grad
+ SGDUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
@@ -253,6 +274,7 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
float wd;
float rescale_grad;
float clip_gradient;
+ bool lazy_update;
DMLC_DECLARE_PARAMETER(SGDMomParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
@@ -272,6 +294,10 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
+ DMLC_DECLARE_FIELD(lazy_update)
+ .set_default(true)
+ .describe("If true, lazy updates are applied if gradient's stype is row_sparse "
+ "and both weight and momentum have the same stype");
}
};
@@ -478,14 +504,17 @@ struct SGDMomDnsRspDnsKernel<req, gpu> {
}
};
+/*
+ * \brief sgd mom lazy update for dense weight, row_sparse grad, dense state.
+ */
template<typename xpu>
-inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
- const OpContext& ctx,
- const TBlob& weight,
- const NDArray& grad,
- const TBlob& mom,
- const OpReqType& req,
- TBlob *out) {
+inline void SGDMomLazyUpdateDnsRspDnsImpl(const SGDMomParam& param,
+ const OpContext& ctx,
+ const TBlob& weight,
+ const NDArray& grad,
+ const TBlob& mom,
+ const OpReqType& req,
+ TBlob *out) {
using namespace mxnet_op;
using namespace rowsparse;
Stream<xpu>* s = ctx.get_stream<xpu>();
@@ -518,69 +547,78 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
});
}
+/*
+ * \brief sgd momentum lazy update for row_sparse grad.
+ */
template<typename xpu>
-inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
- const OpContext& ctx,
- const NDArray& weight,
- const NDArray& grad,
- const NDArray& mom,
- const OpReqType& req,
- NDArray *out) {
- using namespace mshadow;
- using namespace mshadow::expr;
+inline void SGDMomLazyUpdateRspImpl(const SGDMomParam& param,
+ const OpContext& ctx,
+ const NDArray& weight,
+ const NDArray& grad,
+ const NDArray& mom,
+ const OpReqType& req,
+ NDArray *out) {
using namespace mxnet_op;
using namespace rowsparse;
- CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
+ CheckAllRowsPresent(weight, "SGDMomUpdate", "weights");
Stream<xpu>* s = ctx.get_stream<xpu>();
- // fill mom with zero values in order to reuse the sgd mom dns impl
- if (!mom.storage_initialized()) {
+ // fill mom with zero values (if it's in rsp storage)
+ // in order to reuse the sgd mom dns impl
+ if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) {
NDArray mom_zeros = mom;
FillDnsZerosRspImpl(s, &mom_zeros);
}
TBlob out_blob = out->data();
// reuse dns rsp implementation when storage_shape == shape
- SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
- mom.data(), req, &out_blob);
+ SGDMomLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
+ mom.data(), req, &out_blob);
}
/*!
- * \brief Storge type inference function in optimizer.
- * \param n_rsp The number of inputs that should be of row_sparse storage type
- * if kFComputeEx is dispatched
- * \param n_rsp_dns The number of inputs that should be of row_sparse or default storage type
- * if kFComputeEx is dispatched
+ * \brief Storge type inference function for optimizers which support both
+ * lazy update and standard update, with states (e.g. 2nd order moment)
+ * \param num_states The number of states that could be row_sparse or dense
*/
-template<int n_rsp, int n_rsp_dns>
+template<size_t num_states, typename ParamType>
inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
using namespace common;
- CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_rsp + n_rsp_dns));
+ const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
+ // weight, grad, state 0, state 1, ... -> weight
+ CHECK_EQ(in_attrs->size(), 2 + num_states);
CHECK_EQ(out_attrs->size(), 1U);
+ const int weight_stype = in_attrs->at(0);
+ const int grad_stype = in_attrs->at(1);
+ const int state_stype = in_attrs->at(2);
+ // the storage type of all states should be the same
+ for (size_t i = 3; i < 2 + num_states; i++) {
+ CHECK_EQ(state_stype, in_attrs->at(i))
+ << "Inconsistent storage types detected in state " << i;
+ }
bool dispatched = false;
if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
// dns, ... -> dns
dispatched = storage_type_assign(out_attrs, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
- const std::vector<int> rsp_stypes(in_attrs->begin(), in_attrs->begin() + n_rsp);
- const std::vector<int> rsp_dns_stypes(in_attrs->begin() + n_rsp, in_attrs->end());
- if (!dispatched && ContainsOnlyStorage(rsp_stypes, kRowSparseStorage) &&
- (ContainsOnlyStorage(rsp_dns_stypes, kRowSparseStorage) ||
- ContainsOnlyStorage(rsp_dns_stypes, kDefaultStorage))) {
- // rsp, ..., rsp/dns, ... -> rsp
- dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
+ if (!dispatched && grad_stype == kRowSparseStorage &&
+ (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
+ state_stype == weight_stype) {
+ // weight and state share stype, grad's stype = rsp
+ dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
dispatch_mode, DispatchMode::kFComputeEx);
// warn users if lazy_update is turned on
- if (dispatched && ContainsOnlyStorage(rsp_dns_stypes, kRowSparseStorage)) {
- LogOnce("Optimizer with lazy_update = True detected. "
- "Be aware that lazy update is different from standard update, "
- "and may lead to different empirical results. See "
- "https://mxnet.incubator.apache.org/api/python/optimization/optimization.html "
- "for more details.");
- }
+ if (dispatched && param.lazy_update) LogLazyUpdate();
+ }
+ if (!dispatched && grad_stype == kRowSparseStorage &&
+ weight_stype == kRowSparseStorage && state_stype == kDefaultStorage) {
+ // weight, grad, state, ... -> weight
+ // rsp, rsp, dns, ... -> rsp, standard update
+ dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
+ dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
@@ -588,10 +626,16 @@ inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs,
return dispatched;
}
+/*
+ * \brief kernel for standard momentum update for dense weight, sparse grad and dense state.
+ */
template<int req, typename xpu>
struct SGDMomStdDnsRspDnsKernel;
+/*
+ * \brief standard momentum update for dense weight, row_sparse grad and dense states.
+ */
template<typename xpu>
void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param,
const OpContext& ctx,
@@ -601,19 +645,28 @@ void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param,
const OpReqType& req,
TBlob *out);
+/*
+ * \brief standard momentum update for row_sparse grad.
+ * both row_sparse and dense weight are supported.
+ */
template<typename xpu>
-inline void SGDMomStdUpdateRspRspDnsImpl(const SGDMomParam& param,
- const OpContext& ctx,
- const NDArray& weight,
- const NDArray& grad,
- const NDArray& mom,
- const OpReqType& req,
- NDArray *out) {
- using namespace mshadow;
- using namespace mshadow::expr;
+inline void SGDMomStdUpdateRspImpl(const SGDMomParam& param,
+ const OpContext& ctx,
+ const NDArray& weight,
+ const NDArray& grad,
+ const NDArray& mom,
+ const OpReqType& req,
+ NDArray *out) {
using namespace mxnet_op;
using namespace rowsparse;
- CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
+ CheckAllRowsPresent(weight, "SGDMomUpdate", "weights");
+ Stream<xpu>* s = ctx.get_stream<xpu>();
+ // fill mom with zero values (if it's in rsp storage)
+ // in order to reuse the sgd mom dns impl
+ if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) {
+ NDArray mom_zeros = mom;
+ FillDnsZerosRspImpl(s, &mom_zeros);
+ }
TBlob out_blob = out->data();
SGDMomStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
mom.data(), req, &out_blob);
@@ -630,16 +683,25 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
auto &weight = inputs[0];
auto &grad = inputs[1];
auto &mom = inputs[2];
+ const auto w_stype = weight.storage_type();
+ const auto m_stype = mom.storage_type();
const auto out_stype = outputs[0].storage_type();
NDArray out = outputs[0];
- if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
- out_stype == kRowSparseStorage) {
- SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
- } else if (weight.storage_type() == kRowSparseStorage &&
- grad.storage_type() == kRowSparseStorage &&
- mom.storage_type() == kDefaultStorage &&
- out_stype == kRowSparseStorage) {
- SGDMomStdUpdateRspRspDnsImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
+ const bool valid_weight = w_stype == kDefaultStorage || w_stype == kRowSparseStorage;
+ const bool valid_grad = grad.storage_type() == kRowSparseStorage;
+ const bool lazy_update = param.lazy_update;
+ CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype";
+ if (valid_weight && valid_grad && m_stype == w_stype) {
+ if (lazy_update) {
+ // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update
+ SGDMomLazyUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
+ } else {
+ // rsp grad && m.stype = w.stype && lazy_update = false -> std update
+ SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
+ }
+ } else if (w_stype == kRowSparseStorage && valid_grad && m_stype == kDefaultStorage) {
+ // rsp weight, rsp grad, dns state -> std update
+ SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
@@ -742,6 +804,7 @@ struct AdamParam : public dmlc::Parameter<AdamParam> {
float wd;
float rescale_grad;
float clip_gradient;
+ bool lazy_update;
DMLC_DECLARE_PARAMETER(AdamParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
@@ -767,6 +830,10 @@ struct AdamParam : public dmlc::Parameter<AdamParam> {
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
+ DMLC_DECLARE_FIELD(lazy_update)
+ .set_default(true)
+ .describe("If true, lazy updates are applied if gradient's stype is row_sparse "
+ "and all of w, m and v have the same stype");
}
};
@@ -876,15 +943,18 @@ struct AdamDnsRspDnsKernel<req, gpu> {
}
};
+/*
+ * \brief lazy adam update for dense weight, dense states and rsp grad.
+ */
template<typename xpu>
-inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param,
- const OpContext& ctx,
- const TBlob& weight,
- const NDArray& grad,
- const TBlob& mean,
- const TBlob& var,
- const OpReqType& req,
- TBlob *out) {
+inline void AdamLazyUpdateDnsRspDnsImpl(const AdamParam& param,
+ const OpContext& ctx,
+ const TBlob& weight,
+ const NDArray& grad,
+ const TBlob& mean,
+ const TBlob& var,
+ const OpReqType& req,
+ TBlob *out) {
using namespace mxnet_op;
using namespace rowsparse;
Stream<xpu>* s = ctx.get_stream<xpu>();
@@ -920,39 +990,47 @@ inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param,
});
}
+/*
+ * \brief lazy adam update for both row_sparse and dense weight.
+ * grad is expected to be row_sparse.
+ */
template<typename xpu>
-inline void AdamUpdateRspRspRspImpl(const AdamParam& param,
- const OpContext& ctx,
- const NDArray& weight,
- const NDArray& grad,
- const NDArray& mean,
- const NDArray& var,
- const OpReqType& req,
- NDArray *out) {
- using namespace mshadow;
- using namespace mshadow::expr;
+inline void AdamLazyUpdateRspImpl(const AdamParam& param,
+ const OpContext& ctx,
+ const NDArray& weight,
+ const NDArray& grad,
+ const NDArray& mean,
+ const NDArray& var,
+ const OpReqType& req,
+ NDArray *out) {
using namespace mxnet_op;
using namespace rowsparse;
- CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamUpdate", "weights");
+ CheckAllRowsPresent(weight, "AdamUpdate", "weights");
Stream<xpu>* s = ctx.get_stream<xpu>();
// fill mean and variance with zero values in order to reuse the sgd mom dns impl
- if (!mean.storage_initialized()) {
+ if (mean.storage_type() == kRowSparseStorage && !mean.storage_initialized()) {
NDArray mean_zeros = mean;
FillDnsZerosRspImpl(s, &mean_zeros);
}
- if (!var.storage_initialized()) {
+ if (var.storage_type() == kRowSparseStorage && !var.storage_initialized()) {
NDArray var_zeros = var;
FillDnsZerosRspImpl(s, &var_zeros);
}
TBlob out_blob = out->data();
// reuse dns rsp implementation when storage_shape == shape
- AdamUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
- var.data(), req, &out_blob);
+ AdamLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
+ var.data(), req, &out_blob);
}
+/*
+ * \brief kernel for standard adam update for dense weight, row_sparse grad and dense states.
+ */
template<int req, typename xpu>
struct AdamStdDnsRspDnsKernel;
+/*
+ * \brief standard adam update for dense weight, row_sparse grad and dense states.
+ */
template<typename xpu>
void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param,
const OpContext& ctx,
@@ -963,18 +1041,22 @@ void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param,
const OpReqType& req,
TBlob *out);
+/*
+ * \brief standard adam update for both row_sparse and dense weight.
+ * states are expected to be dense, while grad is expected to be row_sparse.
+ */
template<typename xpu>
-inline void AdamStdUpdateRspRspRspImpl(const AdamParam& param,
- const OpContext& ctx,
- const NDArray& weight,
- const NDArray& grad,
- const NDArray& mean,
- const NDArray& var,
- const OpReqType& req,
- NDArray *out) {
+inline void AdamStdUpdateRspImpl(const AdamParam& param,
+ const OpContext& ctx,
+ const NDArray& weight,
+ const NDArray& grad,
+ const NDArray& mean,
+ const NDArray& var,
+ const OpReqType& req,
+ NDArray *out) {
using namespace mxnet_op;
using namespace rowsparse;
- CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamStdUpdate", "weights");
+ CheckAllRowsPresent(weight, "AdamStdUpdate", "weights");
TBlob out_blob = out->data();
// reuse dns rsp implementation when storage_shape == shape
AdamStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
@@ -988,21 +1070,30 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
- const auto weight_stype = inputs[0].storage_type();
- const auto grad_stype = inputs[1].storage_type();
- const auto mean_stype = inputs[2].storage_type();
- const auto var_stype = inputs[3].storage_type();
+ const auto w_stype = inputs[0].storage_type();
+ const auto g_stype = inputs[1].storage_type();
+ const auto m_stype = inputs[2].storage_type();
+ const auto v_stype = inputs[3].storage_type();
const auto out_stype = outputs[0].storage_type();
NDArray out = outputs[0];
- if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
- out_stype == kRowSparseStorage) {
- AdamUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
+ const bool valid_weight = w_stype == kDefaultStorage || w_stype == kRowSparseStorage;
+ CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype";
+ CHECK(m_stype == v_stype) << "Inconsistent mean stype and var stype";
+ if (valid_weight && g_stype == kRowSparseStorage && m_stype == w_stype) {
+ if (param.lazy_update) {
+ // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update
+ AdamLazyUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
inputs[3], req[0], &out);
- } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage &&
- mean_stype == kDefaultStorage && var_stype == kDefaultStorage &&
- out_stype == kRowSparseStorage) {
- AdamStdUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
- inputs[3], req[0], &out);
+ } else {
+ // rsp grad && m.stype = w.stype && lazy_update = false -> std update
+ AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
+ inputs[3], req[0], &out);
+ }
+ } else if (w_stype == kRowSparseStorage && g_stype == kRowSparseStorage &&
+ m_stype == kDefaultStorage) {
+ // rsp, rsp, dns, dns -> rsp, standard update
+ AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
+ inputs[3], req[0], &out);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
@@ -1361,7 +1452,7 @@ inline void FtrlUpdateRspRspRspImpl(const FtrlParam& param,
using namespace mshadow::expr;
using namespace mxnet_op;
using namespace rowsparse;
- CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "FtrlUpdate", "weights");
+ CheckAllRowsPresent(weight, "FtrlUpdate", "weights");
Stream<xpu>* s = ctx.get_stream<xpu>();
// fill z and n with zero values in order to reuse the sgd mom dns impl
if (!z.storage_initialized()) {
@@ -1690,7 +1781,7 @@ inline void AdagradUpdateRspRspRspImpl(const AdagradParam& param,
using namespace mshadow;
using namespace mxnet_op;
using namespace rowsparse;
- CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdagradUpdate", "weights");
+ CheckAllRowsPresent(weight, "AdagradUpdate", "weights");
Stream<xpu>* s = ctx.get_stream<xpu>();
// fill history with zero values
if (!state.storage_initialized()) {
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 0e73c1d..935c92a 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -132,6 +132,10 @@ struct SGDMomStdDnsRspDnsKernel<req, cpu> {
}
};
+/*
+ * \brief standard momentum update for dense weight on cpu.
+ * state is expected to be dense, while grad is expected to be row_sparse.
+ */
template<>
void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
const OpContext& ctx,
@@ -152,12 +156,12 @@ void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
DType* weight_data = weight.dptr<DType>();
- IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
- DType* grad_val = grad.data().dptr<DType>();
+ const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+ const DType* grad_val = grad.data().dptr<DType>();
DType* mom_data = mom.dptr<DType>();
DType* out_data = out->dptr<DType>();
- nnvm::dim_t num_rows = weight.shape_[0];
- auto row_length = weight.shape_.ProdShape(1, weight.ndim());
+ const nnvm::dim_t num_rows = weight.shape_[0];
+ const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
Tensor<cpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<cpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t)), s);
@@ -275,6 +279,40 @@ void AdamStdUpdateDnsRspDnsImpl<cpu>(const AdamParam& param,
});
}
+/*!
+ * \brief Storge type inference function for SGD.
+ */
+inline bool SGDStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ using namespace common;
+ const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), 2U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ const int weight_stype = in_attrs->at(0);
+ const int grad_stype = in_attrs->at(1);
+ bool dispatched = false;
+ if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+ // dns, ... -> dns
+ dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+ dispatch_mode, DispatchMode::kFCompute);
+ }
+ if (!dispatched && grad_stype == kRowSparseStorage &&
+ (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage)) {
+ // grad's stype = rsp
+ dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
+ dispatch_mode, DispatchMode::kFComputeEx);
+ // warn users if lazy_update is turned on
+ if (dispatched && param.lazy_update) LogLazyUpdate();
+ }
+ if (!dispatched) {
+ dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+ }
+ return dispatched;
+}
+
NNVM_REGISTER_OP(sgd_update)
MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
@@ -282,13 +320,13 @@ MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
It updates the weights using::
- weight = weight - learning_rate * gradient
+ weight = weight - learning_rate * (gradient + wd * weight)
-If weight is of ``row_sparse`` storage type,
+However, if gradient is of ``row_sparse`` storage type and ``lazy_update`` is True,
only the row slices whose indices appear in grad.indices are updated::
for row in gradient.indices:
- weight[row] = weight[row] - learning_rate * gradient[row]
+ weight[row] = weight[row] - learning_rate * (gradient[row] + wd * weight[row])
)code" ADD_FILELINE)
.set_num_inputs(2)
@@ -296,7 +334,7 @@ only the row slices whose indices appear in grad.indices are updated::
.set_attr_parser(ParamParser<SGDParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<2, 1, false, true, false>)
+.set_attr<FInferStorageType>("FInferStorageType", SGDStorageType)
.set_attr<FCompute>("FCompute<cpu>", SGDUpdate<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", SGDUpdateEx<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
@@ -305,7 +343,7 @@ only the row slices whose indices appear in grad.indices are updated::
NNVM_REGISTER_OP(sgd_mom_update)
MXNET_ADD_SPARSE_OP_ALIAS(sgd_mom_update)
-.describe(R"code(Momentum update function for Stochastic Gradient Descent (SDG) optimizer.
+.describe(R"code(Momentum update function for Stochastic Gradient Descent (SGD) optimizer.
Momentum update has better convergence rates on neural networks. Mathematically it looks
like below:
@@ -323,10 +361,8 @@ It updates the weights using::
Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
-If weight and grad are both of ``row_sparse`` storage type and momentum is of ``default`` storage type,
-standard update is applied.
-
-If weight, grad and momentum are all of ``row_sparse`` storage type,
+However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and weight's storage
+type is the same as momentum's storage type,
only the row slices whose indices appear in grad.indices are updated (for both weight and momentum)::
for row in gradient.indices:
@@ -339,7 +375,7 @@ only the row slices whose indices appear in grad.indices are updated (for both w
.set_attr_parser(ParamParser<SGDMomParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 1>)
+.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<1, SGDMomParam>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2};
@@ -424,7 +460,7 @@ available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.
.add_argument("d", "NDArray-or-Symbol", "Internal state ``d_t``")
.add_argument("v", "NDArray-or-Symbol", "Internal state ``v_t``")
.add_argument("z", "NDArray-or-Symbol", "Internal state ``z_t``")
-.add_arguments(AdamParam::__FIELDS__());
+.add_arguments(FTMLParam::__FIELDS__());
NNVM_REGISTER_OP(adam_update)
MXNET_ADD_SPARSE_OP_ALIAS(adam_update)
@@ -447,7 +483,8 @@ It updates the weights using::
v = beta2*v + (1-beta2)*(grad**2)
w += - learning_rate * m / (sqrt(v) + epsilon)
-If w, m and v are all of ``row_sparse`` storage type,
+However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and the storage
+type of weight is the same as those of m and v,
only the row slices whose indices appear in grad.indices are updated (for w, m and v)::
for row in grad.indices:
@@ -465,7 +502,7 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 2>)
+.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, AdamParam>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index d1dc31a..90762f7 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -204,7 +204,6 @@ class PySGD(mx.optimizer.Optimizer):
def update_multi_precision(self, index, weight, grad, state):
self.update(index, weight, grad, state)
-@unittest.skip("Test fails intermittently. Temporarily disabled until fixed. Tracked at https://github.com/apache/incubator-mxnet/issues/9000")
@with_seed()
def test_sgd():
opt1 = PySGD
@@ -233,16 +232,9 @@ def test_sgd():
continue
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
# test operator fallback on cpu
- if (default_context() == mx.cpu()):
- compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
- g_stype='row_sparse')
- if dtype != np.float16:
- compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape[:2],
- dtype, w_stype='csr', g_stype='csr')
- # test optimizer with a big shape
- big_shape = (54686454, 1)
- kwarg = {'momentum': 0.9, 'wd': 0.05}
- compare_optimizer(opt1(**kwarg), opt2(**kwarg), big_shape, np.float32)
+ if dtype != np.float16:
+ compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape[:2],
+ dtype, w_stype='csr', g_stype='csr')
class PySparseSGD(mx.optimizer.Optimizer):
"""python reference implemenation of sgd"""
@@ -337,9 +329,11 @@ def test_sparse_sgd():
kwarg.update(mp_option)
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
w_stype='row_sparse', g_stype='row_sparse')
+ compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
+ w_stype='default', g_stype='row_sparse')
-@with_seed(0)
+@with_seed()
def test_std_sparse_sgd():
opt1 = PySGD
opt2 = mx.optimizer.SGD
@@ -360,6 +354,8 @@ def test_std_sparse_sgd():
kwarg.update(wd_option)
compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape, dtype,
w_stype='row_sparse', g_stype='row_sparse')
+ compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape, dtype,
+ w_stype='default', g_stype='row_sparse')
class PyNAG(PySGD):
@@ -543,7 +539,7 @@ def test_ftml():
class PyAdam(mx.optimizer.Optimizer):
"""python reference implemenation of adam"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
- decay_factor=(1 - 1e-8), lazy_update=False, **kwargs):
+ decay_factor=(1 - 1e-8), lazy_update=True, **kwargs):
super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
@@ -594,7 +590,7 @@ class PyAdam(mx.optimizer.Optimizer):
for row in range(num_rows):
# check row slices of all zeros
all_zeros = mx.test_utils.almost_equal(grad[row].asnumpy(), np.zeros_like(grad[row].asnumpy()))
- # skip zeros during sparse update
+ # skip zeros during lazy update
if all_zeros and self.lazy_update:
continue
grad[row] = grad[row] * self.rescale_grad + wd * weight[row]
@@ -635,15 +631,21 @@ def test_adam():
not kwarg['multi_precision'])):
continue
# atol 2e-5 needed to pass with seed 1248389097
- compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
+ compare_optimizer(opt1(lazy_update=False, **kwarg), opt2(**kwarg), shape, dtype,
rtol=1e-4, atol=2e-5)
# atol 2e-5 needed to pass with seed 781809840
- compare_optimizer(opt1(lazy_update=True, **kwarg), opt2(**kwarg), shape,
+ compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape,
dtype, w_stype='row_sparse', g_stype='row_sparse',
rtol=1e-4, atol=2e-5)
- compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape,
+ compare_optimizer(opt1(lazy_update=False, **kwarg), opt2(lazy_update=False, **kwarg), shape,
dtype, w_stype='row_sparse', g_stype='row_sparse',
rtol=1e-4, atol=2e-5)
+ compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape,
+ dtype, w_stype='default', g_stype='row_sparse',
+ rtol=1e-4, atol=2e-5)
+ compare_optimizer(opt1(lazy_update=False, **kwarg), opt2(lazy_update=False, **kwarg), shape,
+ dtype, w_stype='default', g_stype='row_sparse',
+ rtol=1e-4, atol=2e-5)
# Signum
class PySignum(mx.optimizer.Optimizer):
--
To stop receiving notification emails like this one, please contact
haibin@apache.org.