You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/01/18 04:13:35 UTC

[incubator-mxnet] branch master updated: standard adam update for sparse tensor (#9468)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9ae63d1  standard adam update for sparse tensor (#9468)
9ae63d1 is described below

commit 9ae63d16b600b52d7ec6c83cdffe427a27583fee
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Thu Jan 18 12:13:26 2018 +0800

    standard adam update for sparse tensor (#9468)
---
 python/mxnet/optimizer.py               | 31 +++++++------
 src/operator/optimizer_op-inl.h         | 77 ++++++++++++++++++++++++++++++---
 src/operator/optimizer_op.cc            | 62 +++++++++++++++++++++++++-
 src/operator/optimizer_op.cu            | 68 +++++++++++++++++++++++++++++
 tests/python/unittest/test_optimizer.py |  6 +--
 5 files changed, 222 insertions(+), 22 deletions(-)

diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 4285aec..57340be 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -780,15 +780,8 @@ class Adam(Optimizer):
     This class implements the optimizer described in *Adam: A Method for
     Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980.
 
-    The optimizer updates the weight by::
-
-        rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
-        m = beta1 * m + (1 - beta1) * rescaled_grad
-        v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
-        w = w - learning_rate * m / (sqrt(v) + epsilon)
-
-    If the storage types of weight, state and grad are all ``row_sparse``, \
-    **sparse updates** are applied by::
+    If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, \
+    **lazy updates** are applied by::
 
         for row in grad.indices:
             rescaled_grad[row] = clip(grad[row] * rescale_grad + wd * weight[row], clip_gradient)
@@ -796,12 +789,19 @@ class Adam(Optimizer):
             v[row] = beta2 * v[row] + (1 - beta2) * (rescaled_grad[row]**2)
             w[row] = w[row] - learning_rate * m[row] / (sqrt(v[row]) + epsilon)
 
-    The sparse update only updates the mean and var for the weights whose row_sparse
+    The lazy update only updates the mean and var for the weights whose row_sparse
     gradient indices appear in the current batch, rather than updating it for all indices.
     Compared with the original update, it can provide large improvements in model training
     throughput for some applications. However, it provides slightly different semantics than
     the original update, and may lead to different empirical results.
 
+    Otherwise, **standard updates** are applied by::
+
+        rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
+        m = beta1 * m + (1 - beta1) * rescaled_grad
+        v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
+        w = w - learning_rate * m / (sqrt(v) + epsilon)
+
     This optimizer accepts the following parameters in addition to those accepted
     by :class:`.Optimizer`.
 
@@ -815,19 +815,24 @@ class Adam(Optimizer):
         Exponential decay rate for the second moment estimates.
     epsilon : float, optional
         Small value to avoid division by 0.
+    lazy_update : bool, optional
+       Default is True. If True, lazy updates are applied \
+       if the storage types of weight and grad are both ``row_sparse``.
     """
     def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
-                 **kwargs):
+                 lazy_update=True, **kwargs):
         super(Adam, self).__init__(learning_rate=learning_rate, **kwargs)
         self.beta1 = beta1
         self.beta2 = beta2
         self.epsilon = epsilon
+        self.lazy_update = lazy_update
 
     def create_state(self, index, weight):
+        stype = weight.stype if self.lazy_update else 'default'
         return (zeros(weight.shape, weight.context, dtype=weight.dtype,
-                      stype=weight.stype),  # mean
+                      stype=stype),  # mean
                 zeros(weight.shape, weight.context, dtype=weight.dtype,
-                      stype=weight.stype))  # variance
+                      stype=stype))  # variance
 
     def update(self, index, weight, grad, state):
         assert(isinstance(weight, NDArray))
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 42721a9..0e0dd7f 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -859,6 +859,71 @@ inline void AdamUpdateRspRspRspImpl(const AdamParam& param,
                                var.data(), req, &out_blob);
 }
 
+template<int req>
+struct AdamStdDnsRspDnsKernel {
+  template<typename DType, typename IType, typename RType>
+  MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
+    DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
+    const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
+    const DType beta1, const DType beta2, const DType lr, const DType wd,
+    const DType epsilon, const DType rescale_grad) {
+    using namespace mshadow_op;
+    const bool non_zero = (i == 0) ? prefix_sum[0] > 0
+                                   : prefix_sum[i] > prefix_sum[i-1];
+
+    const index_t row_i = i * row_length;
+    const RType grad_i = (prefix_sum[i]-1) * row_length;
+    for (index_t j = 0; j < row_length; j++) {
+      const index_t data_i = row_i + j;
+      const DType grad_rescaled = non_zero ? static_cast<DType>(
+                                               grad_data[grad_i + j] * rescale_grad +
+                                               weight_data[data_i] * wd)
+                                           : static_cast<DType>(weight_data[data_i] * wd);
+      if (clip_gradient >= 0.0f) {
+        mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) *
+                            clip::Map(grad_rescaled, clip_gradient);
+        var_data[data_i] =  beta2 * var_data[data_i] + (1.f - beta2) * square::Map(
+                            clip::Map(grad_rescaled, clip_gradient));
+      } else {
+        mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
+        var_data[data_i] = beta2 * var_data[data_i] +
+                           (1.f - beta2) * square::Map(grad_rescaled);
+      }
+      KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
+                    (square_root::Map(var_data[data_i]) + epsilon));
+    }
+  }
+};
+
+
+template<typename xpu>
+void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param,
+                                const OpContext& ctx,
+                                const TBlob& weight,
+                                const NDArray& grad,
+                                const TBlob& mean,
+                                const TBlob& var,
+                                const OpReqType& req,
+                                TBlob *out);
+
+template<typename xpu>
+inline void AdamStdUpdateRspRspRspImpl(const AdamParam& param,
+                                       const OpContext& ctx,
+                                       const NDArray& weight,
+                                       const NDArray& grad,
+                                       const NDArray& mean,
+                                       const NDArray& var,
+                                       const OpReqType& req,
+                                       NDArray *out) {
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamStdUpdate", "weights");
+  mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
+  TBlob out_blob = out->data();
+  // reuse dns rsp implementation when storage_shape == shape
+  AdamStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
+                                  var.data(), req, &out_blob);
+}
 
 template<typename xpu>
 inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
@@ -868,18 +933,20 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
                          const std::vector<NDArray> &outputs) {
   const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
   const auto weight_stype = inputs[0].storage_type();
+  const auto grad_stype = inputs[1].storage_type();
   const auto mean_stype = inputs[2].storage_type();
   const auto var_stype = inputs[3].storage_type();
   const auto out_stype = outputs[0].storage_type();
-  CHECK_EQ(mean_stype, weight_stype) << "Inconsistent storage type detected between "
-           << " mean.stype = " << mean_stype << " and weight.stype = " << weight_stype;
-  CHECK_EQ(var_stype, weight_stype) << "Inconsistent storage type detected between "
-           << " var.stype = " << var_stype << " and weight.stype = " << weight_stype;
+  NDArray out = outputs[0];
   if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
       out_stype == kRowSparseStorage) {
-     NDArray out = outputs[0];
      AdamUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
                                   inputs[3], req[0], &out);
+  } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage &&
+             mean_stype == kDefaultStorage && var_stype == kDefaultStorage &&
+             out_stype == kRowSparseStorage) {
+     AdamStdUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
+                                     inputs[3], req[0], &out);
   } else {
     LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
   }
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 8760fe9..136769a 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -148,6 +148,62 @@ void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
   });
 }
 
+template<>
+void AdamStdUpdateDnsRspDnsImpl<cpu>(const AdamParam& param,
+                                     const OpContext& ctx,
+                                     const TBlob& weight,
+                                     const NDArray& grad,
+                                     const TBlob& mean,
+                                     const TBlob& var,
+                                     const OpReqType& req,
+                                     TBlob *out) {
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  using namespace mshadow;
+  Stream<cpu>* s = ctx.get_stream<cpu>();
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adam_update";
+  CHECK_GT(weight.shape_.Size(), 0);
+  CHECK_GT(mean.shape_.Size(), 0);
+  CHECK_GT(var.shape_.Size(), 0);
+
+  MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+        const DType* weight_data = weight.dptr<DType>();
+        const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+        const DType* grad_val = grad.data().dptr<DType>();
+        DType* mean_data = mean.dptr<DType>();
+        DType* var_data = var.dptr<DType>();
+        DType* out_data = out->dptr<DType>();
+        nnvm::dim_t num_rows = weight.shape_[0];
+        nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());
+        Tensor<cpu, 1, char> workspace = ctx.requested[0]
+          .get_space_typed<cpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t)), s);
+
+        nnvm::dim_t* prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
+        // mark row flags
+        Kernel<set_zero, cpu>::Launch(s, num_rows, prefix_sum);
+        if (grad.storage_initialized()) {
+          Kernel<MarkRowFlgKernel, cpu>::Launch(s, grad.aux_shape(kIdx)[0],
+            prefix_sum, grad_idx);
+          // calculate inclusive prefix sum
+          for (nnvm::dim_t i = 1; i < num_rows; i++) {
+            prefix_sum[i] += prefix_sum[i - 1];
+          }
+        }
+
+        Kernel<AdamStdDnsRspDnsKernel<req_type>, cpu>::Launch(s, num_rows, row_length,
+          out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum,
+          static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
+          static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
+          static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
+          static_cast<DType>(param.rescale_grad));
+      });
+    });
+  });
+}
+
 
 NNVM_REGISTER_OP(sgd_update)
 MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
@@ -329,8 +385,12 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<AdamParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<4, 1, false, true, false>)
+.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 2>)
 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
   [](const nnvm::NodeAttrs& attrs) {
     return std::vector<uint32_t>{2, 3};
diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu
index 891f24f..1bd6117 100644
--- a/src/operator/optimizer_op.cu
+++ b/src/operator/optimizer_op.cu
@@ -94,6 +94,74 @@ void SGDMomStdUpdateDnsRspDnsImpl<gpu>(const SGDMomParam& param,
   });
 }
 
+template<>
+void AdamStdUpdateDnsRspDnsImpl<gpu>(const AdamParam& param,
+                                     const OpContext& ctx,
+                                     const TBlob& weight,
+                                     const NDArray& grad,
+                                     const TBlob& mean,
+                                     const TBlob& var,
+                                     const OpReqType& req,
+                                     TBlob *out) {
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  using namespace mshadow;
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adam_update";
+  CHECK_GT(weight.shape_.Size(), 0);
+  CHECK_GT(mean.shape_.Size(), 0);
+  CHECK_GT(var.shape_.Size(), 0);
+
+  MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+        const DType* weight_data = weight.dptr<DType>();
+        const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+        const DType* grad_val = grad.data().dptr<DType>();
+        DType* mean_data = mean.dptr<DType>();
+        DType* var_data = var.dptr<DType>();
+        DType* out_data = out->dptr<DType>();
+        nnvm::dim_t num_rows = weight.shape_[0];
+        nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());
+        nnvm::dim_t* prefix_sum = NULL;
+        void* d_temp_storage = NULL;
+        size_t temp_storage_bytes = 0;
+        cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                      temp_storage_bytes,
+                                      prefix_sum,
+                                      prefix_sum,
+                                      num_rows,
+                                      Stream<gpu>::GetStream(s));
+        Tensor<gpu, 1, char> workspace = ctx.requested[0]
+          .get_space_typed<gpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t) +
+                                         temp_storage_bytes), s);
+        prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
+        d_temp_storage = workspace.dptr_ + num_rows*sizeof(nnvm::dim_t);
+        // mark row flags
+        Fill<false>(s, TBlob(prefix_sum, Shape1(num_rows), gpu::kDevMask), kWriteTo, 0);
+        if (grad.storage_initialized()) {
+          Kernel<MarkRowFlgKernel, gpu>::Launch(s, grad.aux_shape(kIdx)[0],
+            prefix_sum, grad_idx);
+          // calculate inclusive prefix sum
+          cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                        temp_storage_bytes,
+                                        prefix_sum,
+                                        prefix_sum,
+                                        num_rows,
+                                        Stream<gpu>::GetStream(s));
+        }
+
+        Kernel<AdamStdDnsRspDnsKernel<req_type>, gpu>::Launch(s, num_rows, row_length,
+          out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum,
+          static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
+          static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
+          static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
+          static_cast<DType>(param.rescale_grad));
+      });
+    });
+  });
+}
 
 NNVM_REGISTER_OP(signsgd_update)
 .set_attr<FCompute>("FCompute<gpu>", SignSGDUpdate<gpu>);
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index 2d22391..26ff48b 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -520,10 +520,10 @@ def test_adam():
                                     not kwarg['multi_precision'])):
                             continue
                         compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
-                        if (default_context() == mx.cpu()):
-                            compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape,
+                        compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape,
+                                          dtype, w_stype='row_sparse', g_stype='row_sparse')
+                        compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape,
                                           dtype, w_stype='row_sparse', g_stype='row_sparse')
-
 
 # Signum
 class PySignum(mx.optimizer.Optimizer):

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].