You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/06/25 05:51:52 UTC

[incubator-mxnet] branch master updated: Enable support for dense weight and sparse grad Adagrad updates (#11355)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9b27262  Enable support for dense weight and sparse grad Adagrad updates (#11355)
9b27262 is described below

commit 9b27262580b6813b4e4fc0ee1bd66a3d4bce005d
Author: Leonard Lausen <le...@lausen.nl>
AuthorDate: Mon Jun 25 05:51:46 2018 +0000

    Enable support for dense weight and sparse grad Adagrad updates (#11355)
    
    * Support dense weight and sparse grad AdagradUpdate
    
    * Simplify AdagradStorageType
    
    * Add test
---
 python/mxnet/optimizer.py               |  2 +-
 src/operator/optimizer_op-inl.h         | 34 +++++++++++++++++++++++++--------
 tests/python/unittest/test_optimizer.py |  2 ++
 3 files changed, 29 insertions(+), 9 deletions(-)

diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 0c3fc90..267a402 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -1107,7 +1107,7 @@ class AdaGrad(Optimizer):
         lr = self._get_lr(index)
         wd = self._get_wd(index)
 
-        is_sparse = weight.stype == 'row_sparse' and grad.stype == 'row_sparse'
+        is_sparse = grad.stype == 'row_sparse'
         history = state
 
         if is_sparse:
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 28b382c..9251b86 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -1663,16 +1663,20 @@ inline bool AdagradStorageType(const nnvm::NodeAttrs& attrs,
                                DispatchMode* dispatch_mode,
                                std::vector<int>* in_attrs,
                                std::vector<int>* out_attrs) {
+  const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
   CHECK_EQ(in_attrs->size(), 3U);
   CHECK_EQ(out_attrs->size(), 1U);
-  const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
+  const int weight_stype = in_attrs->at(0);
+  const int grad_stype = in_attrs->at(1);
+  const int state_stype = in_attrs->at(2);
   bool dispatched = false;
-  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) &&
-      common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) &&
-      param.wd == 0.0f) {
-    // rsp, rsp, rsp -> rsp with wd = 0.0
-    dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
-                                     dispatch_mode, DispatchMode::kFComputeEx);
+  if (!dispatched && grad_stype == kRowSparseStorage &&
+      (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
+      state_stype == weight_stype && param.wd == 0.0f) {
+    // weight and state share stype, grad's stype = rsp
+    dispatched = storage_type_assign(
+        out_attrs, static_cast<NDArrayStorageType>(weight_stype), dispatch_mode,
+        DispatchMode::kFComputeEx);
   }
   return dispatched;
 }
@@ -1802,10 +1806,24 @@ inline void AdagradUpdateEx(const nnvm::NodeAttrs& attrs,
                             const std::vector<NDArray> &outputs) {
   using namespace mxnet_op;
   const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
+
+  const auto weight_stype = inputs[0].storage_type();
+  const auto grad_stype = inputs[1].storage_type();
+  const auto state_stype = inputs[2].storage_type();
+  const auto output_stype = outputs[0].storage_type();
+
   if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
       common::ContainsOnlyStorage(outputs, kRowSparseStorage)) {
     NDArray out = outputs[0];
-    AdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2], req[0], &out);
+    AdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
+                                    req[0], &out);
+  } else if (state_stype == weight_stype && output_stype == weight_stype &&
+             weight_stype == kDefaultStorage &&
+             grad_stype == kRowSparseStorage) {
+    TBlob out_blob = outputs[0].data();
+    AdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs[0].data(), inputs[1],
+                                    inputs[2].data(), req[0],
+                                    &out_blob);
   } else {
     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
   }
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index fba10fb..a5b3d40 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -1034,6 +1034,8 @@ def test_adagrad():
                         if wd_option.get('wd', 0.0) == 0.0:
                             compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
                                               w_stype='row_sparse', g_stype='row_sparse')
+                            compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
+                                              g_stype='row_sparse')