You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/21 23:18:28 UTC

[GitHub] rahul003 closed pull request #10183: [MXNET-120] Float16 support for distributed training

rahul003 closed pull request #10183: [MXNET-120] Float16 support for distributed training
URL: https://github.com/apache/incubator-mxnet/pull/10183
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 7ab5783f7fc..4c00553db74 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -47,7 +47,7 @@ class KVStoreDist : public KVStoreLocal {
       : KVStoreLocal(use_device_comm), ps_worker_(nullptr), server_(nullptr) {
     if (IsWorkerNode()) {
       int new_customer_id = GetNewCustomerId();
-      ps_worker_ = new ps::KVWorker<real_t>(0, new_customer_id);
+      ps_worker_ = new ps::KVWorker<char>(0, new_customer_id);
       ps::StartAsync(new_customer_id, "mxnet\0");
       if (!ps::Postoffice::Get()->is_recovery()) {
         ps::Postoffice::Get()->Barrier(
@@ -228,17 +228,18 @@ class KVStoreDist : public KVStoreLocal {
           RunContext rctx, Engine::CallbackOnComplete cb) {
         // convert to ps keys
         size_t size = recv_buf.shape().Size();
-
+        int dtype = recv_buf.dtype();
+        int num_bytes = mshadow::mshadow_sizeof(dtype);
         PSKV& pskv = (gradient_compression_->get_type() == CompressionType::kNone) ?
-                      EncodeDefaultKey(key, size, false) :
-                      EncodeCompressedKey(key, size, false);
-        real_t* data = recv_buf.data().dptr<real_t>();
+                      EncodeDefaultKey(key, size, false, num_bytes) :
+                      EncodeCompressedKey(key, size, false, num_bytes);
+        char* data = static_cast<char*> (recv_buf.data().dptr_);
         // false means not to delete data when SArray is deleted
-        auto vals = new ps::SArray<real_t>(data, size, false);
+        auto vals = new ps::SArray<char>(data, size * num_bytes, false);
         // issue pull
-        int cmd = (gradient_compression_->get_type() != CompressionType::kNone) ?
-                  static_cast<int>(DataHandleType::kCompressedPushPull) :
-                  static_cast<int>(DataHandleType::kDefaultPushPull);
+        RequestType mode = (gradient_compression_->get_type() != CompressionType::kNone) ?
+                  RequestType::kCompressedPushPull : RequestType::kDefaultPushPull;
+        int cmd = GetCommandType(mode, dtype);
         CHECK_NOTNULL(ps_worker_)->ZPull(
           pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
       };
@@ -329,18 +330,21 @@ class KVStoreDist : public KVStoreLocal {
         }
         CopyFromTo(merged, &comm_buf);
       }
-
+      int dtype = merged.dtype();
+      int num_bytes = mshadow::mshadow_sizeof(dtype);
       // push to servers
       if (storage_type == kDefaultStorage) {
         if (gradient_compression_->get_type() == CompressionType::kNone) {
-          PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), true);
+          PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), true, num_bytes);
           PushDefault(key, comm_buf, pskv, priority);
         } else {
+          CHECK_EQ(dtype, mshadow::kFloat32) << "Gradient compression is only supported for "
+                                             << "float32 type of parameters";
           // Note: gradient compression uses `do_merge` as proxy to
           // detect whether the push is initialization of a key or not.
           // is_active is false when push is initialization of key
           bool is_active = do_merge;
-          PSKV &pskv = EncodeCompressedKey(key, comm_buf.shape().Size(), is_active);
+          PSKV &pskv = EncodeCompressedKey(key, comm_buf.shape().Size(), is_active, num_bytes);
           // Returns push_pskv if active, else pull_pskv
           // we want inactive gc to send uncompressed gradients,
           // but sharded in the same way as later pushes would when gc becomes active
@@ -364,24 +368,23 @@ class KVStoreDist : public KVStoreLocal {
     auto &small_buf = compr_buf_[key];
     auto &res_buf = residual_[key];
     size_t original_size = comm_buf.shape().Size();
+    int dtype = comm_buf.dtype();
 
     // Init the small buffer and residual_ buffer for quantize
     if (small_buf.is_none()) {
-      small_buf = NDArray(TShape{pskv.size}, comm_buf.ctx(), false, comm_buf.dtype());
-      res_buf = NDArray(TShape{(int64_t) original_size}, comm_buf.ctx(),
-                        false, comm_buf.dtype());
+      small_buf = NDArray(TShape{pskv.size}, comm_buf.ctx(), false, dtype);
+      res_buf = NDArray(TShape{(int64_t) original_size}, comm_buf.ctx(), false, dtype);
       res_buf = 0;
     }
     gradient_compression_->Quantize(comm_buf, &small_buf, &res_buf, priority);
     auto push_to_servers =
-      [this, key, pskv, small_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
-        size_t size = small_buf.shape().Size();
-        real_t* data = small_buf.data().dptr<real_t>();
+      [this, key, dtype, pskv, small_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
+        size_t size = small_buf.shape().Size() * mshadow::mshadow_sizeof(dtype);
+        char* data = static_cast<char *> (small_buf.data().dptr_);
         // do push. false means no delete
-        ps::SArray<real_t> vals(data, size, false);
-        CHECK_NOTNULL(ps_worker_)->ZPush(
-          pskv.keys, vals, pskv.lens,
-          static_cast<int>(DataHandleType::kCompressedPushPull), [cb]() { cb(); });
+        ps::SArray<char> vals(data, size, false);
+        int cmd = GetCommandType(RequestType::kCompressedPushPull, dtype);
+        CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, [cb]() { cb(); });
       };
     // acquire locks on both comm_buf and small_buf so that
     // pull (which uses comm_buf) for the same key waits till push finishes
@@ -398,14 +401,17 @@ class KVStoreDist : public KVStoreLocal {
   void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) {
     auto push_to_servers =
         [this, key, pskv, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
+          int dtype = send_buf.dtype();
+          int num_bytes = mshadow::mshadow_sizeof(dtype);
           // convert to ps keys
-          size_t size = send_buf.shape().Size();
-          real_t* data = send_buf.data().dptr<real_t>();
+          size_t size = send_buf.shape().Size() * num_bytes;
+          char* data = static_cast<char *>(send_buf.data().dptr_);
           // do push. false means no delete
-          ps::SArray<real_t> vals(data, size, false);
+          ps::SArray<char> vals(data, size, false);
+          int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
           CHECK_NOTNULL(ps_worker_)->ZPush(
               pskv.keys, vals, pskv.lens,
-              static_cast<int>(DataHandleType::kDefaultPushPull), [cb]() { cb(); });
+              cmd, [cb]() { cb(); });
         };
     Engine::Get()->PushAsync(
         push_to_servers,
@@ -422,23 +428,22 @@ class KVStoreDist : public KVStoreLocal {
     using namespace rowsparse;
     auto push_to_servers = [this, key, send_buf]
                            (RunContext rctx, Engine::CallbackOnComplete cb) {
-      real_t* data = send_buf.data().dptr<real_t>();
+      char* data = static_cast<char *>(send_buf.data().dptr_);
       const int64_t num_rows = send_buf.aux_shape(kIdx)[0];
       const auto offsets = send_buf.aux_data(kIdx).dptr<int64_t>();
       const auto unit_len = send_buf.shape().ProdShape(1, send_buf.shape().ndim());
+      int num_bytes = mshadow::mshadow_sizeof(send_buf.dtype());
       const int64_t size = num_rows * unit_len;
-
        // convert to ps keys in row sparse format
       PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
-                                      unit_len, send_buf.shape()[0]);
+                                      unit_len, send_buf.shape()[0], num_bytes);
       if (this->log_verbose_) {
         LOG(INFO) << "worker " << get_rank() << " push lens: " << pskv.lens << " keys: "
                   << pskv.keys << " size: " << size;
       }
-      ps::SArray<real_t> vals(data, size, false);
-      CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens,
-                                       static_cast<int>(DataHandleType::kRowSparsePushPull),
-                                       [cb]() { cb(); });
+      ps::SArray<char> vals(data, size * num_bytes, false);
+      int cmd = GetCommandType(RequestType::kRowSparsePushPull, send_buf.dtype());
+      CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, [cb]() { cb(); });
     };
     Engine::Get()->PushAsync(
         push_to_servers,
@@ -462,25 +467,29 @@ class KVStoreDist : public KVStoreLocal {
       const TBlob idx_data = indices.data();
       size_t num_rows = idx_data.shape_.Size();
       recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
-      real_t* data = recv_buf.data().dptr<real_t>();
+      int dtype = recv_buf.dtype();
+      char* data = static_cast<char *>(recv_buf.data().dptr_);
       const auto offsets = idx_data.dptr<int64_t>();
       const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim());
       const int64_t size = num_rows * unit_len;
+      int num_bytes = mshadow::mshadow_sizeof(dtype);
       // convert to ps keys in row sparse format
       PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
-                                      unit_len, recv_buf.shape()[0]);
+                                      unit_len, recv_buf.shape()[0],
+                                      num_bytes);
       if (this->log_verbose_) {
         LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: "
                   << pskv.keys << " size: " << size;
       }
-      auto vals = new ps::SArray<real_t>(data, size, false);
+      auto vals = new ps::SArray<char>(data, size * num_bytes, false);
+      int cmd = GetCommandType(RequestType::kRowSparsePushPull, recv_buf.dtype());
       // copy indices to recv_buf. this needs to be done before ZPull
       // because after pull is done, the callback function returns and locks are released.
       // at this point, later functions may access the indices variable while copy happens
       mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
                     idx_data.FlatTo1D<cpu, int64_t>());
       CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens,
-                                       static_cast<int>(DataHandleType::kRowSparsePushPull),
+                                       cmd,
                                        [vals, cb]() { delete vals; cb(); });
     };
     CHECK_NOTNULL(Engine::Get())->PushAsync(
@@ -506,12 +515,13 @@ class KVStoreDist : public KVStoreLocal {
   /**
    * \brief convert to keys in ps
    */
-  inline PSKV& EncodeDefaultKey(int key, size_t size, bool is_push) {
+  inline PSKV& EncodeDefaultKey(int key, size_t size, bool is_push, int num_bytes) {
     mu_.lock();
     PSKV& pskv = ps_kv_[key];
     mu_.unlock();
     if (!pskv.keys.empty()) {
-      CHECK_EQ(static_cast<size_t>(pskv.size), size) << "The value size cannot be changed";
+      CHECK_EQ(static_cast<size_t>(pskv.size), size * num_bytes)
+        << "The value size cannot be changed " << size * num_bytes;
     } else {
       auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
       int num_servers = krs.size();
@@ -524,8 +534,8 @@ class KVStoreDist : public KVStoreLocal {
         ps::Key ps_key = krs[server].begin() + key;
         CHECK_LT(ps_key, krs[server].end());
         pskv.keys.push_back(ps_key);
-        pskv.lens.push_back(size);
-        pskv.size = size;
+        pskv.lens.push_back(size * num_bytes);
+        pskv.size = size * num_bytes;
       } else {
         // parition it to all servers
         pskv.size = 0;
@@ -536,10 +546,10 @@ class KVStoreDist : public KVStoreLocal {
           ps::Key ps_key = krs[i].begin() + key;
           CHECK_LT(ps_key, krs[i].end());
           pskv.keys.push_back(ps_key);
-          pskv.lens.push_back(part_size);
-          pskv.size += part_size;
+          pskv.lens.push_back(part_size * num_bytes);
+          pskv.size += part_size * num_bytes;
         }
-        CHECK_EQ(static_cast<size_t>(pskv.size), size);
+        CHECK_EQ(static_cast<size_t>(pskv.size), size * num_bytes);
       }
     }
     return pskv;
@@ -550,21 +560,21 @@ class KVStoreDist : public KVStoreLocal {
    * Divides original array into equal parts for each server
    * Populates both push and pull pskv on first call
    */
-  inline PSKV& EncodeCompressedKey(int key, size_t original_size, bool is_push) {
+  inline PSKV& EncodeCompressedKey(int key, size_t original_size, bool is_push, int num_bytes) {
     auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
     int num_servers = krs.size();
     CHECK_GT(num_servers, 0);
 
     // represents size of data to be sent
     size_t compr_size = gradient_compression_->GetCompressedSize(original_size);
-
     mu_.lock();
     PSKV& pskv = (is_push) ? compr_ps_kv_[key].push : compr_ps_kv_[key].pull;
     mu_.unlock();
 
     if (!pskv.keys.empty()) {
       size_t size = (is_push) ? compr_size : original_size;
-      CHECK_EQ(static_cast<size_t >(pskv.size), size)<< "The value size can't be changed";
+      CHECK_EQ(static_cast<size_t >(pskv.size), size * num_bytes)
+        << "The value size can't be changed";
     } else {
       // populate both pull and push pskvs
       // push pskv has sizes corresponding to compressed data
@@ -586,10 +596,10 @@ class KVStoreDist : public KVStoreLocal {
         // data
         push_pskv.keys.push_back(ps_key);
         pull_pskv.keys.push_back(ps_key);
-        push_pskv.lens.push_back(compr_size);
-        pull_pskv.lens.push_back(original_size);
-        push_pskv.size = compr_size;
-        pull_pskv.size = original_size;
+        push_pskv.lens.push_back(compr_size * num_bytes);
+        pull_pskv.lens.push_back(original_size * num_bytes);
+        push_pskv.size = compr_size * num_bytes;
+        pull_pskv.size = original_size * num_bytes;
       } else {
         // partition it to all servers
         push_pskv.size = 0;
@@ -619,15 +629,17 @@ class KVStoreDist : public KVStoreLocal {
           push_pskv.keys.push_back(ps_key);
           pull_pskv.keys.push_back(ps_key);
           // push_pskv stores lengths of compressed blocks
-          push_pskv.lens.push_back(part_compr);
+          push_pskv.lens.push_back(part_compr * num_bytes);
           // pull_pskv stores lengths of original data
-          pull_pskv.lens.push_back(part_orig);
+          pull_pskv.lens.push_back(part_orig * num_bytes);
           push_pskv.size += part_compr;
           pull_pskv.size += part_orig;
         }
-        CHECK_EQ(static_cast<size_t>(push_pskv.size), compr_size);
-        CHECK_EQ(static_cast<size_t>(pull_pskv.size), original_size);
-        CHECK_EQ(push_pskv.lens.size(), num_servers*2);
+        push_pskv.size *= num_bytes;
+        pull_pskv.size *= num_bytes;
+        CHECK_EQ(static_cast<size_t>(push_pskv.size), compr_size * num_bytes);
+        CHECK_EQ(static_cast<size_t>(pull_pskv.size), original_size * num_bytes);
+        CHECK_EQ(push_pskv.lens.size(), num_servers * 2);
         }
       }
     return pskv;
@@ -636,7 +648,7 @@ class KVStoreDist : public KVStoreLocal {
   // Note: this encoding method for row sparse keys doesn't allow cross-layer batching
   inline PSKV& EncodeRowSparseKey(const int key, const int64_t size, const int64_t num_rows,
                                   const int64_t *offsets, const size_t unit_len,
-                                  const int64_t total_num_rows) {
+                                  const int64_t total_num_rows, int num_bytes) {
     using namespace common;
     mu_.lock();
     PSKV& pskv = ps_kv_[key];
@@ -669,13 +681,13 @@ class KVStoreDist : public KVStoreLocal {
             ps::Key ps_key = krs[i].begin() + key + (*offset - start_row);
             CHECK_LT(ps_key, krs[i].end());
             pskv.keys.push_back(ps_key);
-            pskv.lens.push_back(unit_len);
-            pskv.size += unit_len;
+            pskv.lens.push_back(unit_len * num_bytes);
+            pskv.size += (unit_len * num_bytes);
           }
           start_row = end_row;
         }
       }
-      CHECK_EQ(static_cast<size_t>(pskv.size), size);
+      CHECK_EQ(static_cast<size_t>(pskv.size), size * num_bytes);
     } else {
       // send it to a single random picked server
       int server = (key * 9973) % num_servers;
@@ -686,9 +698,9 @@ class KVStoreDist : public KVStoreLocal {
         ps::Key ps_key = krs[server].begin() + key + offsets[i];
         CHECK_LT(ps_key, krs[server].end());
         pskv.keys.push_back(ps_key);
-        pskv.lens.push_back(unit_len);
+        pskv.lens.push_back(unit_len * num_bytes);
       }
-      pskv.size = size;
+      pskv.size = size * num_bytes;
     }
     return pskv;
   }
@@ -696,7 +708,7 @@ class KVStoreDist : public KVStoreLocal {
   /**
    * \brief for worker to push and pull data
    */
-  ps::KVWorker<real_t>* ps_worker_;
+  ps::KVWorker<char>* ps_worker_;
   /**
    * \brief the server handle
    */
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index f1637c4e57d..3accb0945b7 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -44,10 +44,47 @@ enum class CommandType {
   kController, kStopServer, kSyncMode, kSetGradientCompression
 };
 
-enum class DataHandleType {
-  kDefaultPushPull, kCompressedPushPull, kRowSparsePushPull
+enum class RequestType {
+  kDefaultPushPull, kRowSparsePushPull, kCompressedPushPull
 };
 
+struct DataHandleType {
+  RequestType requestType;
+  int dtype;
+};
+
+/*!
+ * Uses Cantor pairing function to generate a unique number given two numbers.
+ * This number can also be inverted to find the unique pair whose Cantor value is this number.
+ * Ref: https://en.wikipedia.org/wiki/Pairing_function#Cantor_pairing_function
+ * \param requestType RequestType
+ * \param dtype integer
+ * \return Cantor value of arguments
+ */
+static int GetCommandType(RequestType requestType, int d) {
+  int m = static_cast<int>(requestType);
+  return (((m + d) * (m + d + 1)) / 2) + d;
+}
+
+/*!
+ * Unpairs Cantor value and finds the two integers used to pair.
+ * Then returns DataHandleType object with those numbers.
+ * \param cmd DataHandleCommand generated by GetCommandType function
+ * \return DataHandleType
+ */
+static DataHandleType DepairDataHandleType(int cmd) {
+  int w = std::floor((std::sqrt(8 * cmd + 1) - 1)/2);
+  int t = ((w * w) + w) / 2;
+  int y = cmd - t;
+  int x = w - y;
+  CHECK_GE(x, 0);
+  CHECK_GE(y, 0);
+  DataHandleType type;
+  type.requestType = static_cast<RequestType>(x);
+  type.dtype = y;
+  return type;
+}
+
 /**
  * \brief executor runs a function using the thread called \ref Start
  */
@@ -114,7 +151,7 @@ class KVStoreDistServer {
  public:
   KVStoreDistServer() {
     using namespace std::placeholders;
-    ps_server_ = new ps::KVServer<float>(0);
+    ps_server_ = new ps::KVServer<char>(0);
     static_cast<ps::SimpleApp*>(ps_server_)->set_request_handle(
         std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2));
     ps_server_->set_request_handle(
@@ -149,6 +186,8 @@ class KVStoreDistServer {
   struct MergeBuf {
     std::vector<ps::KVMeta> request;
     NDArray array;
+    // temp_array is used to cast received values as float32 for computation if required
+    NDArray temp_array;
   };
 
   void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
@@ -171,32 +210,41 @@ class KVStoreDistServer {
   }
 
   void DataHandleEx(const ps::KVMeta& req_meta,
-                    const ps::KVPairs<real_t>& req_data,
-                    ps::KVServer<real_t>* server) {
-    DataHandleType recved_type = static_cast<DataHandleType>(req_meta.cmd);
-    if (recved_type == DataHandleType::kRowSparsePushPull) {
-      DataHandleRowSparse(req_meta, req_data, server);
-    } else if (recved_type == DataHandleType::kCompressedPushPull) {
-      DataHandleCompressed(req_meta, req_data, server);
-    } else {
-      DataHandleDefault(req_meta, req_data, server);
+                    const ps::KVPairs<char>& req_data,
+                    ps::KVServer<char>* server) {
+    DataHandleType type = DepairDataHandleType(req_meta.cmd);
+    switch (type.requestType) {
+      case RequestType::kRowSparsePushPull:
+        DataHandleRowSparse(type, req_meta, req_data, server);
+        break;
+      case RequestType::kCompressedPushPull:
+        DataHandleCompressed(type, req_meta, req_data, server);
+        break;
+      case RequestType::kDefaultPushPull:
+        DataHandleDefault(type, req_meta, req_data, server);
+        break;
     }
-    return;
   }
 
-  inline void ApplyUpdates(const int key, MergeBuf *merged, NDArray *stored,
-                           ps::KVServer<real_t>* server) {
+  inline void ApplyUpdates(const int key, const int dtype, MergeBuf *merged, NDArray *stored,
+                           ps::KVServer<char>* server) {
     if (merged->request.size() == (size_t) ps::NumWorkers()) {
       // let the main thread to execute updater_, which is necessary for python
       if (updater_) {
         exec_.Exec([this, key, merged, stored](){
-            CHECK(updater_);
-            updater_(key, merged->array, stored);
-          });
+          CHECK(updater_);
+          updater_(key, merged->array, stored);
+        });
       } else {
         // if no updater, just copy
         CopyFromTo(merged->array, stored);
       }
+      // better to cast once and store, than once for each pull
+      // we don't need to wait on this because unlike recvd, stored wont go out of scope
+      if (dtype != mshadow::kFloat32) {
+        auto& stored_dtype = store_[key].arr_dtype;
+        CopyFromTo(*stored, &stored_dtype, 0);
+      }
       if (log_verbose_)  {
         LOG(INFO) << "sync response to " << merged->request.size() << " workers";
       }
@@ -220,46 +268,137 @@ class KVStoreDistServer {
     }
   }
 
-  void DataHandleRowSparse(const ps::KVMeta& req_meta,
-                       const ps::KVPairs<real_t>& req_data,
-                       ps::KVServer<real_t>* server) {
+  void AccumulateRowSparseGrads(const NDArray& recved_realt, MergeBuf* merged) {
+    NDArray out(kRowSparseStorage, merged->array.shape(), Context());
+    std::vector<Engine::VarHandle> const_vars;
+    const_vars.push_back(recved_realt.var());
+    const_vars.push_back(merged->array.var());
+    // accumulate row_sparse gradients
+    // TODO(haibin) override + operator for row_sparse NDArray
+    // instead of calling BinaryComputeRspRsp directly
+    using namespace mshadow;
+    Engine::Get()->PushAsync(
+    [recved_realt, merged, out](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+      op::ElemwiseBinaryOp::ComputeEx<cpu, op::mshadow_op::plus>(
+      {}, {}, {recved_realt, merged->array}, {kWriteTo}, {out});
+      on_complete();
+    }, recved_realt.ctx(), const_vars, {out.var()},
+    FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
+    CopyFromTo(out, &(merged->array), 0);
+  }
+
+  void RowSparsePullResponse(int master_key, int dtype, size_t num_rows,
+                             const ps::KVMeta& req_meta,
+                             const ps::KVPairs<char>& req_data,
+                             ps::KVServer<char>* server) {
+    if (log_verbose_) LOG(INFO) << "pull: " << master_key;
+    ps::KVPairs<char> response;
+    if (num_rows == 0) {
+      std::vector<int> lens(req_data.keys.size(), 0);
+      response.keys = req_data.keys;
+      response.lens.CopyFrom(lens.begin(), lens.end());
+      server->Response(req_meta, response);
+      return;
+    }
+    const NDArray& stored = (dtype == mshadow::kFloat32) ? store_[master_key].arr_fp32 :
+                                                           store_[master_key].arr_dtype;
+    CHECK(!stored.is_none()) << "init " << master_key << " first";
+    if (dtype != mshadow::kFloat32) {
+      stored.WaitToRead();
+    }
+    auto shape = stored.shape();
+    auto unit_len = shape.ProdShape(1, shape.ndim());
+    int num_bytes = mshadow::mshadow_sizeof(dtype);
+    const char* data = static_cast<char *> (stored.data().dptr_);
+    auto len = unit_len * num_rows * num_bytes;
+    // concat values
+    response.vals.resize(len);
+    #pragma omp parallel for
+    for (size_t i = 1; i <= num_rows; i++) {
+      int key = DecodeKey(req_data.keys[i]);
+      int64_t row_id = key - master_key;
+      const auto src = data + row_id * unit_len * num_bytes;
+      auto begin = (i - 1) * unit_len * num_bytes;
+      auto end = i * unit_len * num_bytes;
+      response.vals.segment(begin, end).CopyFrom(src, unit_len * num_bytes);
+    }
+    // setup response
+    response.keys = req_data.keys;
+    std::vector<int> lens(req_data.keys.size(), unit_len);
+    lens[0] = 0;
+    response.lens.CopyFrom(lens.begin(), lens.end());
+    server->Response(req_meta, response);
+  }
+
+  void InitRowSparseStored(DataHandleType type,
+                           int master_key,
+                           size_t num_rows,
+                           const ps::KVMeta& req_meta,
+                           const ps::KVPairs<char>& req_data,
+                           ps::KVServer<char>* server) {
+    auto& stored = store_[master_key].arr_fp32;
+    auto& stored_dtype = store_[master_key].arr_dtype;
+    int num_bytes = mshadow::mshadow_sizeof(type.dtype);
+    auto unit_len = req_data.lens[1] / num_bytes;
+    CHECK_GT(unit_len, 0);
+    size_t ds[] = {num_rows, (size_t) unit_len};
+    TShape dshape(ds, ds + 2);
+    CHECK_EQ(req_data.vals.size(), num_rows * unit_len * num_bytes);
+
+    TBlob recv_blob;
+    MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
+      recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()), dshape, cpu::kDevMask);
+    })
+    NDArray recved = NDArray(recv_blob, 0);
+    stored = NDArray(kRowSparseStorage, dshape, Context(), false, mshadow::kFloat32);
+    if (type.dtype != mshadow::kFloat32) {
+      stored_dtype = NDArray(kRowSparseStorage, dshape, Context(), false,
+                             type.dtype);
+    }
+    Engine::Get()->PushAsync(
+    [recved, stored](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+      NDArray rsp = stored;
+      stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])});
+      mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+      using namespace mxnet::op;
+      nnvm::dim_t nnr = rsp.shape()[0];
+      MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
+        IType* idx = rsp.aux_data(rowsparse::kIdx).dptr<IType>();
+        mxnet_op::Kernel<PopulateFullIdxRspKernel, cpu>::Launch(s, nnr, idx);
+      });
+      if (recved.data().type_flag_ != mshadow::kFloat32) {
+        MSHADOW_TYPE_SWITCH(recved.data().type_flag_, SrcDType, {
+          rsp.data().FlatTo1D<cpu, float>() =
+          mshadow::expr::tcast<float>(recved.data().FlatTo1D<cpu, SrcDType>());
+        });
+      } else {
+        mshadow::Copy(rsp.data().FlatTo1D<cpu, float>(),
+                      recved.data().FlatTo1D<cpu, float>(), s);
+      }
+      on_complete();
+    }, recved.ctx(), {recved.var()}, {stored.var()},
+    FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
+    if (type.dtype != mshadow::kFloat32) {
+      CopyFromTo(stored, stored_dtype);
+    }
+    stored.WaitToRead();
+    server->Response(req_meta);
+  }
+
+  void DataHandleRowSparse(DataHandleType type, const ps::KVMeta& req_meta,
+                           const ps::KVPairs<char>& req_data,
+                           ps::KVServer<char>* server) {
     int master_key = DecodeKey(req_data.keys[0]);
     auto num_rows = req_data.keys.size() - 1;
-    auto& stored = store_[master_key];
+    auto& stored = store_[master_key].arr_fp32;
     if (req_meta.push) {
       CHECK_GT(req_data.lens.size(), 0) << "req_data.lens cannot be empty";
       CHECK_EQ(req_data.lens[0], 0);
-      real_t* data = req_data.vals.data();
       if (stored.is_none()) {
         if (log_verbose_) LOG(INFO) << "initial push: " << master_key;
         // initialization
         CHECK_GT(num_rows, 0) << "init with empty data is not supported";
-        auto unit_len = req_data.lens[1];
-        CHECK_GT(unit_len, 0);
-        size_t ds[] = {num_rows, (size_t) unit_len};
-        TShape dshape(ds, ds + 2);
-        CHECK_EQ(req_data.vals.size(), num_rows * unit_len);
-        TBlob recv_blob(data, dshape, cpu::kDevMask);  // NOLINT(*)
-        NDArray recved = NDArray(recv_blob, 0);
-        stored = NDArray(kRowSparseStorage, dshape, Context());
-        Engine::Get()->PushAsync(
-          [recved, stored](RunContext ctx, Engine::CallbackOnComplete on_complete) {
-            NDArray rsp = stored;
-            stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])});
-            mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
-            using namespace mxnet::op;
-            nnvm::dim_t nnr = rsp.shape()[0];
-            MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
-              IType* idx = rsp.aux_data(rowsparse::kIdx).dptr<IType>();
-              mxnet_op::Kernel<PopulateFullIdxRspKernel, cpu>::Launch(s, nnr, idx);
-            });
-            mshadow::Copy(rsp.data().FlatTo1D<cpu, float>(),
-                          recved.data().FlatTo1D<cpu, float>(), s);
-            on_complete();
-          }, recved.ctx(), {recved.var()}, {stored.var()},
-          FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
-        stored.WaitToRead();
-        server->Response(req_meta);
+        InitRowSparseStored(type, master_key, num_rows, req_meta, req_data, server);
         return;
       }
       // synced push
@@ -268,53 +407,53 @@ class KVStoreDistServer {
         auto& merged = merge_buf_[master_key];
         if (merged.array.is_none()) {
           merged.array = NDArray(kRowSparseStorage, stored.shape(), Context());
+          merged.temp_array = NDArray(kRowSparseStorage, stored.shape(), Context());
         }
         if (num_rows == 0) {
           // reset to zeros
-          if (merged.request.size() == 0) {
+          if (merged.request.empty()) {
             merged.array = NDArray(kRowSparseStorage, stored.shape(), Context());
           } else {
             // nothing to aggregate
           }
           merged.request.push_back(req_meta);
-          ApplyUpdates(master_key, &merged,  &stored, server);
+          ApplyUpdates(master_key, type.dtype, &merged,  &stored, server);
           return;
-        }
-        auto unit_len = req_data.lens[1];
-        CHECK_GT(unit_len, 0);
-        // indices
-        std::vector<int64_t> indices(num_rows);
-        DecodeRowIds(req_data.keys, indices.data(), master_key, num_rows);
-        // data
-        TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), cpu::kDevMask);
-        size_t ds[] = {(size_t) num_rows, (size_t) unit_len};
-        TShape dshape(ds, ds + 2);
-        TBlob recv_blob(data, dshape, cpu::kDevMask); // NOLINT(*)
-        // row_sparse NDArray
-        NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, {idx_blob}, 0);
-
-        if (merged.request.size() == 0) {
-          CopyFromTo(recved, &merged.array, 0);
         } else {
-          NDArray out(kRowSparseStorage, stored.shape(), Context());
-          std::vector<Engine::VarHandle> const_vars;
-          const_vars.push_back(recved.var());
-          const_vars.push_back(merged.array.var());
-          // accumulate row_sparse gradients
-          // TODO(haibin) override + operator for row_sparse NDArray
-          // instead of calling BinaryComputeRspRsp directly
-          using namespace mshadow;
-          Engine::Get()->PushAsync(
-            [recved, merged, out](RunContext ctx, Engine::CallbackOnComplete on_complete) {
-              op::ElemwiseBinaryOp::ComputeEx<cpu, op::mshadow_op::plus>(
-                {}, {}, {recved, merged.array}, {kWriteTo}, {out});
-              on_complete();
-            }, recved.ctx(), const_vars, {out.var()},
-            FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
-          CopyFromTo(out, &merged.array, 0);
+          int num_bytes = mshadow::mshadow_sizeof(type.dtype);
+          auto unit_len = req_data.lens[1] / num_bytes;
+          CHECK_GT(unit_len, 0);
+          // indices
+          std::vector<int64_t> indices(num_rows);
+          DecodeRowIds(req_data.keys, indices.data(), master_key, num_rows);
+
+          // data
+          TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), cpu::kDevMask);
+          size_t ds[] = {(size_t) num_rows, (size_t) unit_len};
+          TShape dshape(ds, ds + 2);
+          TBlob recv_blob;
+          MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
+            recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()),
+                              dshape, cpu::kDevMask);
+          })
+
+          // row_sparse NDArray
+          NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, {idx_blob}, 0);
+
+          if (merged.request.empty()) {
+            CopyFromTo(recved, &merged.array, 0);
+            merged.array.WaitToRead();
+          } else {
+            if (type.dtype != mshadow::kFloat32) {
+              CopyFromTo(recved, merged.temp_array);
+              AccumulateRowSparseGrads(merged.temp_array, &merged);
+            } else {
+              AccumulateRowSparseGrads(recved, &merged);
+            }
+          }
+          merged.request.push_back(req_meta);
+          ApplyUpdates(master_key, type.dtype, &merged,  &stored, server);
         }
-        merged.request.push_back(req_meta);
-        ApplyUpdates(master_key, &merged,  &stored, server);
       } else {
         // async push
         if (log_verbose_) LOG(INFO) << "async push: " << master_key;
@@ -322,6 +461,7 @@ class KVStoreDistServer {
           server->Response(req_meta);
           return;
         }
+        auto& merged = merge_buf_[master_key];
         auto unit_len = req_data.lens[1];
         CHECK_GT(unit_len, 0);
         // indices
@@ -330,68 +470,59 @@ class KVStoreDistServer {
         TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), cpu::kDevMask);
         size_t ds[] = {(size_t) num_rows, (size_t) unit_len};
         TShape dshape(ds, ds + 2);
-        TBlob recv_blob(data, dshape, cpu::kDevMask); // NOLINT(*)
+        TBlob recv_blob;
+        MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
+          recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()), dshape, cpu::kDevMask);
+        })
         NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, {idx_blob}, 0);
-        exec_.Exec([this, master_key, &recved, &stored](){
+        if (type.dtype != mshadow::kFloat32) {
+          if (merged.temp_array.is_none()) {
+            merged.temp_array = NDArray(kRowSparseStorage, stored.shape(), Context());
+          }
+          CopyFromTo(recved, merged.temp_array);
+        }
+        const NDArray& recved_float = (type.dtype == mshadow::kFloat32) ? recved
+                                                                        : merged.temp_array;
+
+        exec_.Exec([this, master_key, &recved_float, &stored](){
             CHECK(updater_);
-            updater_(master_key, recved, &stored);
+            updater_(master_key, recved_float, &stored);
           });
         server->Response(req_meta);
         stored.WaitToRead();
       }
     } else {
-      // pull
-      if (log_verbose_) LOG(INFO) << "pull: " << master_key;
-      ps::KVPairs<real_t> response;
-      if (num_rows == 0) {
-        std::vector<int> lens(req_data.keys.size(), 0);
-        response.keys = req_data.keys;
-        response.lens.CopyFrom(lens.begin(), lens.end());
-        server->Response(req_meta, response);
-        return;
-      }
-      CHECK(!stored.is_none()) << "init " << master_key << " first";
-      auto shape = stored.shape();
-      auto unit_len = shape.ProdShape(1, shape.ndim());
-      const float* data = stored.data().dptr<float>();
-      auto len = unit_len * num_rows;
-      // concat values
-      response.vals.resize(len);
-      #pragma omp parallel for
-      for (size_t i = 1; i <= num_rows; i++) {
-        int key = DecodeKey(req_data.keys[i]);
-        int64_t row_id = key - master_key;
-        const auto src = data + row_id * unit_len;
-        auto begin = (i - 1) * unit_len;
-        auto end = i * unit_len;
-        response.vals.segment(begin, end).CopyFrom(src, unit_len);
-      }
-      // setup response
-      response.keys = req_data.keys;
-      std::vector<int> lens(req_data.keys.size(), unit_len);
-      lens[0] = 0;
-      response.lens.CopyFrom(lens.begin(), lens.end());
-      server->Response(req_meta, response);
+      RowSparsePullResponse(master_key, type.dtype, num_rows, req_meta, req_data, server);
     }
   }
 
-  void DefaultStorageResponse(int key, const NDArray& stored,
+  void DefaultStorageResponse(int key,
+                              int dtype,
                               const ps::KVMeta& req_meta,
-                              const ps::KVPairs<real_t> &req_data,
-                              ps::KVServer<real_t>* server) {
-    ps::KVPairs<real_t> response;
+                              const ps::KVPairs<char> &req_data,
+                              ps::KVServer<char>* server) {
+    ps::KVPairs<char> response;
+    const NDArray& stored = (dtype == mshadow::kFloat32) ? store_[key].arr_fp32 :
+                                                           store_[key].arr_dtype;
     CHECK(!stored.is_none()) << "init " << key << " first";
-    auto len = stored.shape().Size();
+    if (dtype != mshadow::kFloat32) {
+      stored.WaitToRead();
+    }
+    int num_bytes = mshadow::mshadow_sizeof(dtype);
+    auto len = stored.shape().Size() * num_bytes;
     response.keys = req_data.keys;
     response.lens = {len};
     // TODO(mli) try to remove this CopyFrom
-    response.vals.CopyFrom(static_cast<const float*>(stored.data().dptr_), len);
+    response.vals.CopyFrom(static_cast<const char*>(stored.data().dptr_), len);
     server->Response(req_meta, response);
   }
 
-  void DataHandleCompressed(const ps::KVMeta& req_meta,
-                            const ps::KVPairs<real_t> &req_data,
-                            ps::KVServer<real_t>* server) {
+  void DataHandleCompressed(DataHandleType type,
+                            const ps::KVMeta& req_meta,
+                            const ps::KVPairs<char> &req_data,
+                            ps::KVServer<char>* server) {
+    CHECK_EQ(type.dtype, mshadow::kFloat32)
+      << "Gradient compression is currently supported for fp32 only";
     if (req_meta.push) {
       // there used several WaitToRead, this is because \a recved's memory
       // could be deallocated when this function returns. so we need to make sure
@@ -404,12 +535,11 @@ class KVStoreDistServer {
 
       int original_size = DecodeKey(req_data.keys[0]);
       int key = DecodeKey(req_data.keys[1]);
-      auto& stored = store_[key];
+      auto& stored = store_[key].arr_fp32;
 
-      size_t ds[] = {(size_t)req_data.lens[1]};
+      size_t ds[] = {(size_t)req_data.lens[1] / mshadow::mshadow_sizeof(type.dtype)};
       TShape dshape(ds, ds + 1);
-      TBlob recv_blob((real_t*) req_data.vals.data(), // NOLINT(*)
-                      dshape, cpu::kDevMask);
+      TBlob recv_blob(reinterpret_cast<real_t*>(req_data.vals.data()), dshape, cpu::kDevMask);
       NDArray recved = NDArray(recv_blob, 0);
 
       NDArray decomp_buf = decomp_buf_[key];
@@ -437,7 +567,7 @@ class KVStoreDistServer {
           merged.array += decomp_buf;
         }
         merged.request.push_back(req_meta);
-        ApplyUpdates(key, &merged, &stored, server);
+        ApplyUpdates(key, type.dtype, &merged, &stored, server);
       } else {
         // async push
         gradient_compression_->Dequantize(recved, &decomp_buf, 0);
@@ -452,63 +582,89 @@ class KVStoreDistServer {
       CHECK_EQ(req_data.keys.size(), (size_t)1);
       CHECK_EQ(req_data.lens.size(), (size_t)0);
       int key = DecodeKey(req_data.keys[0]);
-      DefaultStorageResponse(key, store_[key], req_meta, req_data, server);
+      DefaultStorageResponse(key, type.dtype, req_meta, req_data, server);
     }
   }
 
-  void DataHandleDefault(const ps::KVMeta& req_meta,
-                         const ps::KVPairs<real_t> &req_data,
-                         ps::KVServer<real_t>* server) {
-    CHECK_EQ(req_meta.cmd, static_cast<int>(DataHandleType::kDefaultPushPull));
+  void DataHandleDefault(DataHandleType type, const ps::KVMeta& req_meta,
+                         const ps::KVPairs<char> &req_data,
+                         ps::KVServer<char>* server) {
     // do some check
     CHECK_EQ(req_data.keys.size(), (size_t)1);
     if (req_meta.push) {
       CHECK_EQ(req_data.lens.size(), (size_t)1);
       CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]);
     }
-
     int key = DecodeKey(req_data.keys[0]);
-    auto& stored = store_[key];
-
+    auto& stored = store_[key].arr_fp32;
+    auto& stored_dtype = store_[key].arr_dtype;
     // there used several WaitToRead, this is because \a recved's memory
     // could be deallocated when this function returns. so we need to make sure
     // the operators with \a NDArray are actually finished
     if (req_meta.push) {
-      size_t ds[] = {(size_t)req_data.lens[0]};
+      size_t ds[] = {(size_t) req_data.lens[0] / mshadow::mshadow_sizeof(type.dtype)};
       TShape dshape(ds, ds + 1);
-      TBlob recv_blob((real_t*)req_data.vals.data(), // NOLINT(*)
-                      dshape, cpu::kDevMask);
+      TBlob recv_blob;
+      MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
+        recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()), dshape, cpu::kDevMask);
+      })
       NDArray recved = NDArray(recv_blob, 0);
       if (stored.is_none()) {
         // initialization
-        stored = NDArray(dshape, Context());
+        // stored is real_t
+        stored = NDArray(dshape, Context(), false, mshadow::kFloat32);
+        if (type.dtype != mshadow::kFloat32) {
+          stored_dtype = NDArray(dshape, Context(), false, type.dtype);
+          // no need to wait on stored_dtype because stored will be in scope
+        }
         CopyFromTo(recved, &stored, 0);
+        if (type.dtype != mshadow::kFloat32) {
+          CopyFromTo(stored, &stored_dtype, 0);
+        }
         server->Response(req_meta);
         stored.WaitToRead();
       } else if (sync_mode_) {
         // synced push
         auto& merged = merge_buf_[key];
         if (merged.array.is_none()) {
-          merged.array = NDArray(dshape, Context());
+          merged.array = NDArray(dshape, Context(), false, mshadow::kFloat32);
+          merged.temp_array = NDArray(dshape, Context(), false, mshadow::kFloat32);
         }
-        if (merged.request.size() == 0) {
-          CopyFromTo(recved, &merged.array, 0);
+        if (merged.request.empty()) {
+          CopyFromTo(recved, merged.array);
         } else {
-          merged.array += recved;
+          if (type.dtype == mshadow::kFloat32) {
+            merged.array += recved;
+          } else {
+            CopyFromTo(recved, merged.temp_array);
+            merged.array += merged.temp_array;
+          }
         }
         merged.request.push_back(req_meta);
-        ApplyUpdates(key, &merged, &stored, server);
+        ApplyUpdates(key, type.dtype, &merged, &stored, server);
       } else {
         // async push
-        exec_.Exec([this, key, &recved, &stored](){
+        auto& merged = merge_buf_[key];
+        if (type.dtype != mshadow::kFloat32) {
+          if (merged.temp_array.is_none()) {
+            merged.temp_array = NDArray(dshape, Context(), false, mshadow::kFloat32);
+          }
+          CopyFromTo(recved, merged.temp_array);
+        }
+        const NDArray& recved_float = (type.dtype == mshadow::kFloat32) ? recved
+                                                                        : merged.temp_array;
+        exec_.Exec([this, key, &recved_float, &stored](){
             CHECK(updater_);
-            updater_(key, recved, &stored);
+            updater_(key, recved_float, &stored);
           });
         server->Response(req_meta);
+        if (type.dtype != mshadow::kFloat32) {
+          CopyFromTo(stored, &stored_dtype, 0);
+        }
         stored.WaitToRead();
       }
     } else {
-      DefaultStorageResponse(key, stored, req_meta, req_data, server);
+      DefaultStorageResponse(key, type.dtype, req_meta, req_data, server);
     }
   }
 
@@ -525,10 +681,20 @@ class KVStoreDistServer {
   KVStore::Controller controller_;
   KVStore::Updater updater_;
 
+  /**
+   * \brief Server always works with float32 (realt) array as stored,
+   * but when datatype for a particular key is not float32, then server
+   * stores a cast of `arr_fp32` in `arr_dtype` so that pulls can be responded to without delay
+   */
+  struct StoredArr {
+    NDArray arr_fp32;
+    NDArray arr_dtype;
+  };
+
   /**
    * \brief store_ contains the value at kvstore for each key
    */
-  std::unordered_map<int, NDArray> store_;
+  std::unordered_map<int, StoredArr> store_;
 
   /**
    * \brief merge_buf_ is a buffer used if sync_mode is true. It represents
@@ -544,7 +710,7 @@ class KVStoreDistServer {
   std::unordered_map<int, NDArray> decomp_buf_;
 
   Executor exec_;
-  ps::KVServer<float>* ps_server_;
+  ps::KVServer<char>* ps_server_;
 
   // whether to LOG verbose information
   bool log_verbose_;
diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py
index 3a3c916d782..031f16df160 100644
--- a/tests/nightly/dist_sync_kvstore.py
+++ b/tests/nightly/dist_sync_kvstore.py
@@ -20,6 +20,7 @@
 # pylint: skip-file
 import sys
 sys.path.insert(0, "../../python/")
+import argparse
 import mxnet as mx
 import numpy as np
 import numpy.random as rnd
@@ -31,121 +32,153 @@ def check_diff_to_scalar(A, x, rank=None):
     assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x)
 
 # setup
-keys = ['3', '5', '7']
-rsp_keys = ['9', '11', '13']
-init_test_keys = [str(i) for i in range(200,300)]
-init_test_keys_big = [str(i) for i in range(300,400)]
-init_test_keys_device = [str(i) for i in range(400,500)]
-init_test_keys_device_big = [str(i) for i in range(500,600)]
-
-rate = 2
 shape = (2, 3)
 irregular_shape = (1211,1211)
 big_shape = (1200, 1200)        # bigger than MXNET_KVSTORE_BIGARRAY_BOUND
 
+keys_shape = ['3', '5', '7']
+keys_big_shape = ['99']
+fp16_keys_shape = ['4', '6', '8']
+fp16_keys_big_shape = ['100']
+
+rsp_keys_shape = ['9', '11', '13']
+rsp_keys_big_shape = ['97']
+fp16_rsp_keys_shape = ['10', '12', '14']
+fp16_rsp_keys_big_shape = ['98']
+
+keys_shapes = [(k, shape) for k in keys_shape] + [(k, big_shape) for k in keys_big_shape]
+fp16_keys_shapes = [(k, shape) for k in fp16_keys_shape] + [(k, big_shape) for k in fp16_keys_big_shape]
+
+init_test_keys = [str(i) for i in range(200, 300)]
+init_test_keys_big = [str(i) for i in range(300, 400)]
+init_test_keys_device = [str(i) for i in range(400, 500)]
+init_test_keys_device_big = [str(i) for i in range(500, 600)]
+
+compr_keys_shapes = [('1000', shape), ('1200', irregular_shape),('1300', big_shape)]
+compr_init_keys_shapes = [('1001', shape), ('1201', irregular_shape),('1301', big_shape)]
+compr_random_keys_shapes = [('1002', shape),('1202', irregular_shape),('1302', big_shape)]
+
+rate = 2
+
 kv = mx.kv.create('dist_sync')
 
+my_rank = kv.rank
+nworker = kv.num_workers
+
 def init_kv():
-    # init kv dns keys
-    kv.init(keys, [mx.nd.ones(shape)] * len(keys))
-    kv.init('99', mx.nd.ones(big_shape))
-    # init kv row_sparse keys
-    kv.init(rsp_keys, [mx.nd.ones(shape).tostype('row_sparse')] * len(rsp_keys))
-    kv.init('100', mx.nd.ones(big_shape).tostype('row_sparse'))
-    # worker info
-    my_rank = kv.rank
-    nworker = kv.num_workers
+    # # init kv dns keys
+    kv.init(keys_shape, [mx.nd.ones(shape)] * len(keys_shape))
+    kv.init(keys_big_shape, [mx.nd.ones(big_shape)] * len(keys_big_shape))
+    # # init kv row_sparse keys
+    kv.init(rsp_keys_shape, [mx.nd.ones(shape).tostype('row_sparse')] * len(rsp_keys_shape))
+    kv.init(rsp_keys_big_shape, [mx.nd.ones(big_shape).tostype('row_sparse')] * len(rsp_keys_big_shape))
+    # init fp16 dns keys
+    kv.init(fp16_keys_shape, [mx.nd.ones(shape, dtype='float16')] * len(keys_shape))
+    kv.init(fp16_keys_big_shape, [mx.nd.ones(big_shape, dtype='float16')] * len(keys_big_shape))
+    # init fp16 row_sparse keys
+    kv.init(fp16_rsp_keys_shape, [mx.nd.ones(shape, dtype='float16').tostype('row_sparse')] * len(rsp_keys_shape))
+    kv.init(fp16_rsp_keys_big_shape, [mx.nd.ones(big_shape, dtype='float16').tostype('row_sparse')] * len(rsp_keys_big_shape))
+
     # init updater on servers
     kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate))
-    return kv, my_rank, nworker
+    return kv
 
 def init_kv_compressed(kv):
     threshold = 0.5
-    kv.set_gradient_compression({'type': '2bit', 'threshold':threshold})
+    kv.set_gradient_compression({'type': '2bit', 'threshold': threshold})
     # init kv compression keys
-    kv.init('11221', mx.nd.zeros(big_shape))
-    kv.init('112221', mx.nd.zeros(irregular_shape))
-    kv.init('1121', mx.nd.zeros(shape))
+    for k, s in compr_keys_shapes:
+        kv.init(k, mx.nd.zeros(s))
     # to test inactive mode
-    kv.init('1122', mx.nd.ones(shape))
+    for k, s in compr_init_keys_shapes:
+        kv.init(k, mx.nd.ones(s))
     return kv, threshold
 
-def test_sync_push_pull():
-    kv, my_rank, nworker = init_kv()
-    def check_default_keys(kv, my_rank, nworker):
-        nrepeat = 3
+def test_sync_push_pull(nrepeat):
+    def check_default_keys(dtype, nrepeat):
         # checks pull after push in loop, because behavior during
         # consecutive pushes doesn't offer any guarantees
-        for i in range(nrepeat):
-            kv.push('3', mx.nd.ones(shape)*(my_rank+1))
-            kv.push('99', mx.nd.ones(big_shape)*(my_rank+1))
-            num = (nworker + 1) * nworker * rate / 2 * (i + 1) + 1
-            val = mx.nd.zeros(shape)
-            kv.pull('3', out=val)
-            check_diff_to_scalar(val, num)
-            val2 = mx.nd.zeros(big_shape)
-            kv.pull('99', out=val2)
-            check_diff_to_scalar(val2, num)
-
-    def check_row_sparse_keys(kv, my_rank, nworker):
-        nrepeat = 3
+        ks = keys_shapes if dtype == 'float32' else fp16_keys_shapes
+        for k, s in ks:
+            for i in range(nrepeat):
+                kv.push(k, mx.nd.ones(s, dtype=dtype)*(my_rank+1))
+                num = (nworker + 1) * nworker * rate / 2 * (i + 1) + 1
+                val = mx.nd.zeros(s, dtype=dtype)
+                kv.pull(k, out=val)
+                check_diff_to_scalar(val, num)
+
+    def check_row_sparse_keys(dtype, nrepeat):
         # prepare gradient
-        v = mx.nd.zeros(shape)
+        v = mx.nd.zeros(shape, dtype=dtype)
         my_row = my_rank % shape[0]
         v[my_row] = my_rank + 1
         # push
+        if dtype == 'float32':
+            k = rsp_keys_shape[0]
+        else:
+            k = fp16_rsp_keys_shape[0]
+        s = shape
         for i in range(nrepeat):
-            kv.push('9', v.tostype('row_sparse'))
+            kv.push(k, v.tostype('row_sparse'))
             # select a random subset of rows this worker is interested in
-            num_rows = shape[0]
+            num_rows = s[0]
             row_ids_np = np.random.randint(num_rows, size=num_rows)
-            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 2))
+            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 2)).astype(dtype)
             # perform pull
-            val = mx.nd.zeros(shape, stype='row_sparse')
-            kv.row_sparse_pull('9', out=val, row_ids=row_ids)
+            val = mx.nd.zeros(s, stype='row_sparse', dtype=dtype)
+            kv.row_sparse_pull(k, out=val, row_ids=row_ids)
             # prepare updated values
-            updated_val = mx.nd.ones(shape)
+            updated_val = mx.nd.ones(s, dtype=dtype)
             for rank in range(nworker):
-                row = rank % shape[0]
+                row = rank % s[0]
                 updated_val[row] += (rank + 1) * rate * (i+1)
             # verify subset of updated values
-            expected = mx.nd.zeros(shape)
+            expected = mx.nd.zeros(s, dtype=dtype)
             for row in row_ids_np:
                 expected[row] = updated_val[row]
-            check_diff_to_scalar(val, expected)
+            check_diff_to_scalar(val, expected, kv.rank)
 
-    def check_row_sparse_keys_with_zeros(kv, my_rank, nworker):
-        nrepeat = 3
+    def check_row_sparse_keys_with_zeros(dtype, nrepeat):
+        if dtype == 'float32':
+            k1 = rsp_keys_shape[1]
+            k2 = rsp_keys_big_shape[0]
+        else:
+            k1 = fp16_rsp_keys_shape[1]
+            k2 = fp16_rsp_keys_big_shape[0]
         # prepare gradient
-        v = mx.nd.sparse.zeros('row_sparse', shape)
-        big_v = mx.nd.sparse.zeros('row_sparse', big_shape)
+        v = mx.nd.sparse.zeros('row_sparse', shape, dtype=dtype)
+        big_v = mx.nd.sparse.zeros('row_sparse', big_shape, dtype=dtype)
         # push
         for i in range(nrepeat):
-            kv.push('11', v)
-            kv.push('100', big_v)
+            kv.push(k1, v)
+            kv.push(k2, big_v)
             # pull a subset of rows this worker is interested in
             all_row_ids = np.arange(shape[0])
             val = mx.nd.sparse.zeros('row_sparse', shape)
             big_val = mx.nd.sparse.zeros('row_sparse', big_shape)
-            kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array(all_row_ids))
+            kv.row_sparse_pull(k1, out=val, row_ids=mx.nd.array(all_row_ids))
             big_all_row_ids = np.arange(big_shape[0])
-            kv.row_sparse_pull('100', out=big_val, row_ids=mx.nd.array(big_all_row_ids))
+            kv.row_sparse_pull(k2, out=big_val, row_ids=mx.nd.array(big_all_row_ids))
             # verify results
             check_diff_to_scalar(val, 1)
             check_diff_to_scalar(big_val, 1)
             # pull empty weights
-            kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array([]))
-            kv.row_sparse_pull('100', out=big_val, row_ids=mx.nd.array([]))
+            kv.row_sparse_pull(k1, out=val, row_ids=mx.nd.array([]))
+            kv.row_sparse_pull(k2, out=big_val, row_ids=mx.nd.array([]))
             check_diff_to_scalar(val, 0)
             check_diff_to_scalar(big_val, 0)
 
-    def check_big_row_sparse_keys(kv, my_rank, nworker):
+    def check_big_row_sparse_keys(dtype, nrepeat):
+        if dtype == 'float32':
+            k = rsp_keys_big_shape[0]
+        else:
+            k = fp16_rsp_keys_big_shape[0]
+
         mx.random.seed(123)
         rnd.seed(123)
         density = 0.3
-        nrepeat = 3
         # prepare gradient
-        v = mx.nd.zeros(big_shape)
+        v = mx.nd.zeros(big_shape, dtype=dtype)
         idx_sample = rnd.rand(big_shape[0])
         indices = np.argwhere(idx_sample < density).flatten()
         # each worker chooses a subset of the indices to update
@@ -163,98 +196,103 @@ def check_big_row_sparse_keys(kv, my_rank, nworker):
             v[row] = my_rank + 1
         # push
         for i in range(nrepeat):
-            kv.push('100', v.tostype('row_sparse'))
+            kv.push(k, v.tostype('row_sparse'))
 
             # select a random subset of rows this worker is interested in
             mx.random.seed(my_rank)
             rnd.seed(my_rank)
             num_rows = big_shape[0]
             row_ids_np = np.random.randint(num_rows, size=num_rows)
-            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 2))
+            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 2)).astype(dtype)
             # perform pull
-            val = mx.nd.zeros(big_shape, stype='row_sparse')
-            kv.row_sparse_pull('100', out=val, row_ids=row_ids)
+            val = mx.nd.zeros(big_shape, stype='row_sparse', dtype=dtype)
+            kv.row_sparse_pull(k, out=val, row_ids=row_ids)
             # prepare expected result
-            updated_val = mx.nd.ones(big_shape)
+            updated_val = mx.nd.ones(big_shape, dtype=dtype)
             # apply updates from each worker
             for rank in range(nworker):
                 for row in update_rows[rank]:
                     updated_val[row] += (rank + 1) * rate * (i+1)
 
-            expected = mx.nd.zeros(big_shape)
+            expected = mx.nd.zeros(big_shape, dtype=dtype)
             for row in row_ids_np:
                 expected[row] = updated_val[row]
             check_diff_to_scalar(val, expected, rank=my_rank)
 
-    def check_compr_residual(kv, threshold, nworker):
-        for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]:
+    for dtype in ['float16', 'float32']:
+        check_default_keys(dtype, nrepeat)
+        check_row_sparse_keys(dtype, nrepeat)
+        check_row_sparse_keys_with_zeros(dtype, nrepeat)
+        check_big_row_sparse_keys(dtype, nrepeat)
+    print('worker ' + str(my_rank) + ' is done with non compression tests')
+
+def test_sync_2bit_compression(threshold, nrepeat):
+    def check_compr_residual(threshold):
+        for k, s in compr_keys_shapes:
             # doesn't meet threshold
-            kv.push(k, mx.nd.ones(s)*0.4)
-            val=mx.nd.zeros(s)
+            kv.push(k, mx.nd.ones(s) * 0.4)
+            val = mx.nd.zeros(s)
             kv.pull(k,val)
             check_diff_to_scalar(val, 0)
 
             # just meets threshold with residual
-            kv.push(k, mx.nd.ones(s)*(threshold - 0.4))
+            kv.push(k, mx.nd.ones(s) * (threshold - 0.4))
             val2 = mx.nd.zeros(s)
             kv.pull(k,val2)
             curval = threshold * rate * nworker
             check_diff_to_scalar(val2, curval)
 
             # doesn't meet threshold
-            kv.push(k, mx.nd.ones(s)*0.2)
-            val3= mx.nd.zeros(s)
+            kv.push(k, mx.nd.ones(s) * 0.2)
+            val3 = mx.nd.zeros(s)
             kv.pull(k, val3)
             check_diff_to_scalar(val3, curval)
 
             # exceeds again
-            kv.push(k, mx.nd.ones(s)*(threshold-0.2))
+            kv.push(k, mx.nd.ones(s) * (threshold-0.2))
             val4 = mx.nd.zeros(s)
-            kv.pull(k,val4)
-            curval += threshold*rate*nworker
+            kv.pull(k, val4)
+            curval += threshold * rate * nworker
             check_diff_to_scalar(val4, curval)
             # residual is 0 now
 
-    def check_compr_ones(kv, threshold, nworker):
-        for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]:
+    def check_compr_ones(threshold):
+        for k, s in compr_keys_shapes:
             val = mx.nd.zeros(s)
             kv.pull(k, val)
             curval = val[0][0].asnumpy()[0]
-            kv.push(k,mx.nd.ones(s)*threshold)
+            kv.push(k,mx.nd.ones(s) * threshold)
             val2 = mx.nd.zeros(s)
             kv.pull(k, val2)
-            newval = curval + rate*nworker*threshold
+            newval = curval + rate * nworker * threshold
             check_diff_to_scalar(val2, newval)
             # residual = 0  again
 
-    def check_compr_pull_before_push(kv):
-        for k,s in [('1121', shape),('112221',irregular_shape),
-                    ('11221', big_shape), ('1122',shape)]:
-            if k=='1122':
-                # tests that GC is not used for init of a key
-                val = mx.nd.zeros(s)
-                kv.pull(k, val)
-                check_diff_to_scalar(val, 1)
-            else:
-                val = mx.nd.ones(s)
-                kv.pull(k, val)
-                check_diff_to_scalar(val, 0)
+    def check_compr_pull_before_push():
+        for k,s in compr_keys_shapes:
+            val = mx.nd.ones(s)
+            kv.pull(k, val)
+            check_diff_to_scalar(val, 0)
+        for k, s in compr_init_keys_shapes:
+            # tests that GC is not used for init of a key
+            val = mx.nd.zeros(s)
+            kv.pull(k, val)
+            check_diff_to_scalar(val, 1)
 
-    def check_compr_zero(kv):
-        for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]:
+    def check_compr_zero():
+        for k,s in compr_keys_shapes:
             kv.push(k, mx.nd.zeros(s))
             # to check that all are set to 0s
             val = mx.nd.ones(s)
             kv.pull(k, val)
             check_diff_to_scalar(val, 0)
 
-    def check_compr_random(kv, threshold, nworker):
+    def check_compr_random(threshold, nrepeat):
         # set a seed so all workers generate same data. knowing this helps
         # calculate expected value after pull
         mx.random.seed(123)
         rnd.seed(123)
-        nrepeat = 5
-        compr_random_keys_shapes = [('2121', shape),('212221',irregular_shape),('21221', big_shape)]
+
         # use new keys so residual is 0 for calculation of expected
         for k,s in compr_random_keys_shapes:
             kv.init(k, mx.nd.zeros(s))
@@ -278,39 +316,49 @@ def check_compr_random(kv, threshold, nworker):
                 decompr *= nworker * rate
                 assert_almost_equal(diff.asnumpy(), decompr)
 
-    print ('worker '+str(my_rank)+' started with non compression tests')
-    check_default_keys(kv, my_rank, nworker)
-    check_row_sparse_keys(kv, my_rank, nworker)
-    check_row_sparse_keys_with_zeros(kv, my_rank, nworker)
-    check_big_row_sparse_keys(kv, my_rank, nworker)
-    print('worker ' + str(my_rank) + ' is done with non compression tests')
-
-    # don't run non compressed keys after this as kvstore now is set to compressed
-    print ('worker '+str(my_rank)+' started with compression tests')
-    kv, threshold = init_kv_compressed(kv)
-    check_compr_pull_before_push(kv)
-    check_compr_zero(kv)
-    check_compr_residual(kv, threshold, nworker)
-    check_compr_ones(kv, threshold, nworker)
-    check_compr_random(kv, threshold, nworker)
+    print ('worker ' + str(my_rank) + ' started with compression tests')
+    check_compr_pull_before_push()
+    check_compr_zero()
+    check_compr_residual(threshold)
+    check_compr_ones(threshold)
+    check_compr_random(threshold, nrepeat)
     print('worker ' + str(my_rank) + ' is done with compression tests')
 
-def test_sync_init():
+def test_sync_init(gpu_tests=False):
+    def get_dtype(idx, cur_keys):
+        if idx < len(cur_keys)/2:
+            dtype = 'float32'
+        else:
+            dtype = 'float16'
+        return dtype
+
     def check_init(kv, cur_keys, cur_shape, device=False):
         ctx = mx.gpu(0) if device else mx.cpu()
-        val = [mx.nd.zeros(cur_shape, ctx) for i in cur_keys]
+        val = [mx.nd.zeros(cur_shape, ctx=ctx, dtype=get_dtype(i, cur_keys)) for i in range(len(cur_keys))]
         for i in range(len(cur_keys)):
             expected = i
-            kv.init(cur_keys[i], [mx.nd.ones(cur_shape, ctx) * i])
+            kv.init(cur_keys[i], [mx.nd.ones(cur_shape, ctx=ctx, dtype=get_dtype(i, cur_keys)) * i])
             kv.pull(cur_keys[i], out=val[i])
             check_diff_to_scalar(val[i], expected)
     check_init(kv, init_test_keys, shape)
     check_init(kv, init_test_keys_big, big_shape)
-    check_init(kv, init_test_keys_device, shape, device=True)
-    check_init(kv, init_test_keys_device_big, big_shape, device=True)
-    my_rank = kv.rank
-    print('worker ' + str(my_rank) + ' is initialized')
+    if gpu_tests:
+        check_init(kv, init_test_keys_device, shape, device=True)
+        check_init(kv, init_test_keys_device_big, big_shape, device=True)
+    print('worker ' + str(kv.rank) + ' is initialized')
 
 if __name__ == "__main__":
-    test_sync_init()
-    test_sync_push_pull()
+    parser = argparse.ArgumentParser(description='test distributed kvstore in dist_sync mode')
+    parser.add_argument('--nrepeat', type=int, default=5)
+    parser.add_argument('--type', type=str, default='all')
+    parser.add_argument('--gpu', action='store_true')
+    opt = parser.parse_args()
+    if opt.type == 'all' or  opt.type == 'init':
+        test_sync_init(opt.gpu)
+    kv = init_kv()
+    if opt.type == 'all' or  opt.type == 'default':
+        test_sync_push_pull(opt.nrepeat)
+    # dont run non compressed tests after this as kvstore compression is set now
+    if opt.type == 'all' or  opt.type == 'compressed':
+        kv, threshold = init_kv_compressed(kv)
+        test_sync_2bit_compression(threshold, opt.nrepeat)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services