You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cj...@apache.org on 2017/11/05 15:36:47 UTC

[incubator-mxnet] 03/04: Misc fixes for sparse distributed training (#8345)

This is an automated email from the ASF dual-hosted git repository.

cjolivier01 pushed a commit to branch v0.12.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 0d3c77f343c7fc667cb92580e13fdfcc6c25b650
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Sat Oct 21 00:23:20 2017 -0700

    Misc fixes for sparse distributed training (#8345)
    
    * remove mshadow::range in init_op.h
    
    * add unit test
    
    * remove pass by ptr, add unit test for pull empty wieghts
    
    * fix range in key partition
    
    * remove wrong comment
    
    * remove change for partition
    
    * remove unused var
    
    * add int64 to arange. add checkpointing example
---
 example/sparse/linear_classification.py |  7 ++++++
 src/kvstore/kvstore_dist.h              | 41 +++++++++++++++------------------
 src/operator/tensor/init_op.h           | 19 +++++++--------
 tests/nightly/dist_sync_kvstore.py      | 31 ++++++++++++++-----------
 tests/python/unittest/test_ndarray.py   |  2 ++
 tests/python/unittest/test_optimizer.py |  4 ++++
 6 files changed, 58 insertions(+), 46 deletions(-)

diff --git a/example/sparse/linear_classification.py b/example/sparse/linear_classification.py
index b173d04..70f8963 100644
--- a/example/sparse/linear_classification.py
+++ b/example/sparse/linear_classification.py
@@ -96,6 +96,7 @@ if __name__ == '__main__':
     # get the sparse weight parameter
     weight_index = mod._exec_group.param_names.index('weight')
     weight_param = mod._exec_group.param_arrays[weight_index]
+    all_row_ids = mx.nd.arange(0, num_features, dtype='int64')
     speedometer = mx.callback.Speedometer(batch_size, 100)
 
     logging.info('Training started ...')
@@ -118,9 +119,15 @@ if __name__ == '__main__':
             speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
                                                        eval_metric=metric, locals=locals())
             speedometer(speedometer_param)
+        # pull all rows before making a checkpoint
+        if kv:
+            kv.row_sparse_pull('weight', weight_param, row_ids=[all_row_ids],
+                               priority=-weight_index)
         # evaluate metric on validation dataset
         score = mod.score(eval_data, ['nll_loss'])
         logging.info('epoch %d, eval nll = %s ' % (epoch, score[0][1]))
+        save_optimizer_states = 'dist' not in kv.type
+        mod.save_checkpoint("checkpoint", epoch, save_optimizer_states=False)
         # reset the iterator for next pass of data
         data_iter.reset()
     logging.info('Training completed.')
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 2d5e52f..5e62be8 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -42,10 +42,6 @@ namespace kvstore {
 /**
  * \brief distributed kvstore
  *
- * for a worker node, it always guarantees that all push and pull issued from
- * this worker on the same key are serialized. namely push(3) and then pull(3),
- * then the data pulled is always containing the modification from the push(3).
- *
  * it's the server node's job to control the data consistency among all
  * workers. see details on \ref ServerHandle::Start
  */
@@ -248,7 +244,7 @@ class KVStoreDist : public KVStoreLocal {
         LOG(FATAL) << "RowSparsePull with multiple values is not implemented yet";
       } else {
         auto& indices = target_val_rowids[0].second;
-        PullRowSparse_(key, &recv_buf, indices, priority);
+        PullRowSparse_(key, recv_buf, indices, priority);
         comm_->BroadcastRowSparse(key, recv_buf, grouped_val_rowid, num_vals == 1, priority);
       }
     }
@@ -322,24 +318,24 @@ class KVStoreDist : public KVStoreLocal {
   }
 
   // pull row sparse weight into `recv_buf` based on indices given by `indices`
-  void PullRowSparse_(const int key, NDArray *recv_buf, const NDArray& indices, int priority) {
+  void PullRowSparse_(const int key, const NDArray& recv_buf,
+                      const NDArray& indices, int priority) {
     using namespace rowsparse;
     auto pull_from_servers = [this, key, recv_buf, indices]
                              (RunContext rctx, Engine::CallbackOnComplete cb) {
       // allocate memory for the buffer
       size_t num_rows = indices.shape().Size();
-      recv_buf->CheckAndAlloc({mshadow::Shape1(num_rows)});
+      recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
 #if MKL_EXPERIMENTAL == 1
-      mkl_set_tblob_eager_mode(recv_buf->data());
+      mkl_set_tblob_eager_mode(recv_buf.data());
 #endif
-      real_t* data = recv_buf->data().dptr<real_t>();
-      auto indices_data = indices.data();
-      const auto offsets = indices_data.dptr<int64_t>();
-      const auto unit_len = recv_buf->shape().ProdShape(1, recv_buf->shape().ndim());
+      real_t* data = recv_buf.data().dptr<real_t>();
+      const auto offsets = indices.data().dptr<int64_t>();
+      const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim());
       const int64_t size = num_rows * unit_len;
        // convert to ps keys in row sparse format
       PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
-                                      unit_len, recv_buf->shape()[0]);
+                                      unit_len, recv_buf.shape()[0]);
       if (this->log_verbose_) {
         LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: "
                   << pskv.keys << " size: " << size;
@@ -348,8 +344,8 @@ class KVStoreDist : public KVStoreLocal {
       // copy indices to recv_buf. this needs to be done before ZPull
       // because after pull is done, the callback function returns and locks are released.
       // at this point, later functions may access the indices variable while copy happens
-      mshadow::Copy(recv_buf->aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
-                    indices_data.FlatTo1D<cpu, int64_t>());
+      mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
+                    indices.data().FlatTo1D<cpu, int64_t>());
       CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, kRowSparsePushPull,
         [vals, cb]() { delete vals; cb(); });
     };
@@ -357,7 +353,7 @@ class KVStoreDist : public KVStoreLocal {
         pull_from_servers,
         pinned_ctx_,
         {indices.var()},
-        {recv_buf->var()},
+        {recv_buf.var()},
         FnProperty::kNormal,
         priority,
         PROFILER_MESSAGE("KVStoreDistRowSparsePull"));
@@ -366,15 +362,14 @@ class KVStoreDist : public KVStoreLocal {
   // push row sparse gradient
   void PushRowSparse(int key, const NDArray &send_buf, int priority) {
     using namespace rowsparse;
-    auto push_to_servers = [this, key, &send_buf]
+    auto push_to_servers = [this, key, send_buf]
                            (RunContext rctx, Engine::CallbackOnComplete cb) {
 #if MKL_EXPERIMENTAL == 1
       mkl_set_tblob_eager_mode(send_buf.data());
 #endif
       real_t* data = send_buf.data().dptr<real_t>();
-      bool init = send_buf.storage_initialized();
-      const int64_t num_rows = init ? send_buf.aux_shape(kIdx)[0] : 0;
-      const auto offsets = init ? send_buf.aux_data(kIdx).dptr<int64_t>() : nullptr;
+      const int64_t num_rows = send_buf.aux_shape(kIdx)[0];
+      const auto offsets = send_buf.aux_data(kIdx).dptr<int64_t>();
       const auto unit_len = send_buf.shape().ProdShape(1, send_buf.shape().ndim());
       const int64_t size = num_rows * unit_len;
 
@@ -472,7 +467,7 @@ class KVStoreDist : public KVStoreLocal {
     return pskv;
   }
 
-  // TODO(haibin) this encoding method for row sparse keys doesn't allow cross-layer batching
+  // Note: this encoding method for row sparse keys doesn't allow cross-layer batching
   inline PSKV& EncodeRowSparseKey(const int key, const int64_t size, const int64_t num_rows,
                                   const int64_t *offsets, const size_t unit_len,
                                   const int64_t total_num_rows) {
@@ -495,15 +490,15 @@ class KVStoreDist : public KVStoreLocal {
         ps::Key master_key = krs[i].begin() + key;
         pskv.keys.push_back(master_key);
         pskv.lens.push_back(0);
-        if (offsets) {
+        if (offsets && size > 0) {
           // calculate partition ranges
           int64_t part_num_rows =
             llround(static_cast<double>(total_num_rows) / num_servers * (i + 1)) -
             llround(static_cast<double>(total_num_rows) / num_servers * i);
           auto end_row = start_row + part_num_rows;
+          // search for offsets in [start_row, end_row)
           auto lb = std::lower_bound(offsets, offsets + num_rows, start_row);
           auto ub = std::upper_bound(offsets, offsets + num_rows, end_row - 1);
-
           for (auto offset = lb; offset < ub; offset++) {
             ps::Key ps_key = krs[i].begin() + key + (*offset - start_row);
             CHECK_LT(ps_key, krs[i].end());
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 97bda90..ea4243e 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -93,6 +93,7 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
     .add_enum("float16", mshadow::kFloat16)
     .add_enum("uint8", mshadow::kUint8)
     .add_enum("int32", mshadow::kInt32)
+    .add_enum("int64", mshadow::kInt64)
     .describe("Target data type.");
   }
 };
@@ -179,6 +180,13 @@ void FillCompute(const nnvm::NodeAttrs& attrs,
   });
 }
 
+struct PopulateFullIdxRspKernel {
+  template<typename IType>
+  MSHADOW_XINLINE static void Map(int i, IType* out) {
+    KERNEL_ASSIGN(out[i], kWriteTo, i);
+  }
+};
+
 // Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray,
 // instead of the usual compact representation.
 template<typename xpu>
@@ -192,21 +200,14 @@ inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
     MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
       auto num_rows = dst->shape()[0];
       dst->CheckAndAlloc({Shape1(num_rows)});
-      auto idx = dst->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
+      auto idx = dst->aux_data(kIdx);
       auto val = dst->data();
       Kernel<set_zero, xpu>::Launch(s, val.Size(), val.dptr<DType>());
-      ASSIGN_DISPATCH(idx, kWriteTo, range<IType>(0, num_rows, 1, 1));
+      Kernel<PopulateFullIdxRspKernel, xpu>::Launch(s, num_rows, idx.dptr<IType>());
     });
   });
 }
 
-struct PopulateFullIdxRspKernel {
-  template<typename IType>
-  MSHADOW_XINLINE static void Map(int i, IType* out) {
-    KERNEL_ASSIGN(out[i], kWriteTo, i);
-  }
-};
-
 // Fill full indices NDArray with zeros by updating the aux shape.
 template<typename xpu>
 void PopulateFullIdxRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py
index 5f1b11f..900d6bb 100644
--- a/tests/nightly/dist_sync_kvstore.py
+++ b/tests/nightly/dist_sync_kvstore.py
@@ -39,7 +39,7 @@ init_test_keys_device_big = [str(i) for i in range(500,600)]
 
 rate = 2
 shape = (2, 3)
-big_shape = (1200, 1200)        # bigger than BIGARRAY_BOUND
+big_shape = (1200, 1200)        # bigger than MXNET_KVSTORE_BIGARRAY_BOUND
 
 kv = mx.kv.create('dist_sync')
 
@@ -104,24 +104,27 @@ def test_sync_push_pull():
     def check_row_sparse_keys_with_zeros(kv, my_rank, nworker):
         nrepeat = 3
         # prepare gradient
-        v = mx.nd.zeros(shape)
-        big_v = mx.nd.zeros(big_shape)
+        v = mx.nd.sparse.zeros('row_sparse', shape)
+        big_v = mx.nd.sparse.zeros('row_sparse', big_shape)
         # push
         for i in range(nrepeat):
-            kv.push('11', v.tostype('row_sparse'))
-            kv.push('100', big_v.tostype('row_sparse'))
-
+            kv.push('11', v)
+            kv.push('100', big_v)
             # pull a subset of rows this worker is interested in
             all_row_ids = np.arange(shape[0])
-            val = mx.nd.ones(shape).tostype('row_sparse')
-            big_val = mx.nd.ones(big_shape).tostype('row_sparse')
-            kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array(all_row_ids, dtype='int64'))
-            big_num_rows = shape[0]
+            val = mx.nd.sparse.zeros('row_sparse', shape)
+            big_val = mx.nd.sparse.zeros('row_sparse', big_shape)
+            kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array(all_row_ids))
             big_all_row_ids = np.arange(big_shape[0])
-            kv.row_sparse_pull('100', out=big_val, row_ids=mx.nd.array(big_all_row_ids, dtype='int64'))
+            kv.row_sparse_pull('100', out=big_val, row_ids=mx.nd.array(big_all_row_ids))
             # verify results
-            check_diff_to_scalar(val, mx.nd.ones(shape))
-            check_diff_to_scalar(big_val, mx.nd.ones(big_shape))
+            check_diff_to_scalar(val, 1)
+            check_diff_to_scalar(big_val, 1)
+            # pull empty weights
+            kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array([]))
+            kv.row_sparse_pull('100', out=big_val, row_ids=mx.nd.array([]))
+            check_diff_to_scalar(val, 0)
+            check_diff_to_scalar(big_val, 0)
 
     def check_big_row_sparse_keys(kv, my_rank, nworker):
         mx.random.seed(123)
@@ -154,7 +157,7 @@ def test_sync_push_pull():
             rnd.seed(my_rank)
             num_rows = big_shape[0]
             row_ids_np = np.random.randint(num_rows, size=num_rows)
-            row_ids = mx.nd.array(row_ids_np, dtype='int64')
+            row_ids = mx.nd.array(row_ids_np)
             # perform pull
             val = mx.nd.zeros(big_shape, stype='row_sparse')
             kv.row_sparse_pull('100', out=val, row_ids=row_ids)
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 576d963..fc8c350 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -734,6 +734,8 @@ def test_output():
     assert_almost_equal(out.asnumpy(), zeros.asnumpy())
     mx.nd.full(shape, 2, out=out)
     assert_almost_equal(out.asnumpy(), ones.asnumpy() * 2)
+    arange_out = mx.nd.arange(0, 20, dtype='int64')
+    assert_almost_equal(arange_out.asnumpy(), np.arange(0, 20))
 
 def test_ndarray_fluent():
     has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod',
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index 8666b9e..1a26434 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -232,6 +232,10 @@ def test_sgd():
                                 if dtype != np.float16:
                                     compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape[:2],
                                                       dtype, w_stype='csr', g_stype='csr')
+    # test optimizer with a big shape
+    big_shape = (54686454, 1)
+    kwarg = {'momentum': 0.9, 'wd': 0.05}
+    compare_optimizer(opt1(**kwarg), opt2(**kwarg), big_shape, np.float32)
 
 class PySparseSGD(mx.optimizer.Optimizer):
     """python reference implemenation of sgd"""

-- 
To stop receiving notification emails like this one, please contact
"commits@mxnet.apache.org" <co...@mxnet.apache.org>.