You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/06/25 05:51:52 UTC
[incubator-mxnet] branch master updated: Enable support for dense
weight and sparse grad Adagrad updates (#11355)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9b27262 Enable support for dense weight and sparse grad Adagrad updates (#11355)
9b27262 is described below
commit 9b27262580b6813b4e4fc0ee1bd66a3d4bce005d
Author: Leonard Lausen <le...@lausen.nl>
AuthorDate: Mon Jun 25 05:51:46 2018 +0000
Enable support for dense weight and sparse grad Adagrad updates (#11355)
* Support dense weight and sparse grad AdagradUpdate
* Simplify AdagradStorageType
* Add test
---
python/mxnet/optimizer.py | 2 +-
src/operator/optimizer_op-inl.h | 34 +++++++++++++++++++++++++--------
tests/python/unittest/test_optimizer.py | 2 ++
3 files changed, 29 insertions(+), 9 deletions(-)
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 0c3fc90..267a402 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -1107,7 +1107,7 @@ class AdaGrad(Optimizer):
lr = self._get_lr(index)
wd = self._get_wd(index)
- is_sparse = weight.stype == 'row_sparse' and grad.stype == 'row_sparse'
+ is_sparse = grad.stype == 'row_sparse'
history = state
if is_sparse:
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 28b382c..9251b86 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -1663,16 +1663,20 @@ inline bool AdagradStorageType(const nnvm::NodeAttrs& attrs,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
+ const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
- const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
+ const int weight_stype = in_attrs->at(0);
+ const int grad_stype = in_attrs->at(1);
+ const int state_stype = in_attrs->at(2);
bool dispatched = false;
- if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) &&
- common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) &&
- param.wd == 0.0f) {
- // rsp, rsp, rsp -> rsp with wd = 0.0
- dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
- dispatch_mode, DispatchMode::kFComputeEx);
+ if (!dispatched && grad_stype == kRowSparseStorage &&
+ (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
+ state_stype == weight_stype && param.wd == 0.0f) {
+ // weight and state share stype, grad's stype = rsp
+ dispatched = storage_type_assign(
+ out_attrs, static_cast<NDArrayStorageType>(weight_stype), dispatch_mode,
+ DispatchMode::kFComputeEx);
}
return dispatched;
}
@@ -1802,10 +1806,24 @@ inline void AdagradUpdateEx(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &outputs) {
using namespace mxnet_op;
const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
+
+ const auto weight_stype = inputs[0].storage_type();
+ const auto grad_stype = inputs[1].storage_type();
+ const auto state_stype = inputs[2].storage_type();
+ const auto output_stype = outputs[0].storage_type();
+
if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
common::ContainsOnlyStorage(outputs, kRowSparseStorage)) {
NDArray out = outputs[0];
- AdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2], req[0], &out);
+ AdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
+ req[0], &out);
+ } else if (state_stype == weight_stype && output_stype == weight_stype &&
+ weight_stype == kDefaultStorage &&
+ grad_stype == kRowSparseStorage) {
+ TBlob out_blob = outputs[0].data();
+ AdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs[0].data(), inputs[1],
+ inputs[2].data(), req[0],
+ &out_blob);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index fba10fb..a5b3d40 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -1034,6 +1034,8 @@ def test_adagrad():
if wd_option.get('wd', 0.0) == 0.0:
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
w_stype='row_sparse', g_stype='row_sparse')
+ compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
+ g_stype='row_sparse')