You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/12 22:19:58 UTC

[GitHub] piiswrong closed pull request #8373: distribute training in fp16

piiswrong closed pull request #8373: distribute training in fp16
URL: https://github.com/apache/incubator-mxnet/pull/8373
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/R-package/src/kvstore.cc b/R-package/src/kvstore.cc
index b15106b1dd..c7351709f9 100644
--- a/R-package/src/kvstore.cc
+++ b/R-package/src/kvstore.cc
@@ -161,9 +161,9 @@ void KVStore::Update(int index, const NDArray& grad, NDArray *weight) {
 }
 
 
-Rcpp::RObject KVStore::Create(const char *type) {
+Rcpp::RObject KVStore::Create(const char *type, const char *data_type) {
   KVStoreHandle handle;
-  MX_CALL(MXKVStoreCreate(type, &handle));
+  MX_CALL(MXKVStoreCreate(type, data_type, &handle));
   return Rcpp::internal::make_new_object(new KVStore(handle));
 }
 
diff --git a/R-package/src/kvstore.h b/R-package/src/kvstore.h
index f936130428..6ac71bcdce 100644
--- a/R-package/src/kvstore.h
+++ b/R-package/src/kvstore.h
@@ -65,7 +65,7 @@ class KVStore {
    * \brief create a KVStore
    * \return the created KVStore
    */
-  static Rcpp::RObject Create(const char *type);
+  static Rcpp::RObject Create(const char *type, const char *data_type = "float32");
   /*! \brief initialize the R cpp Module */
   static void InitRcppModule();
   // destructor
diff --git a/cpp-package/include/mxnet-cpp/kvstore.h b/cpp-package/include/mxnet-cpp/kvstore.h
index 9c3c81f37f..eda9c62e7c 100644
--- a/cpp-package/include/mxnet-cpp/kvstore.h
+++ b/cpp-package/include/mxnet-cpp/kvstore.h
@@ -35,7 +35,7 @@ namespace cpp {
 
 class KVStore {
  public:
-  static void SetType(const std::string& type);
+  static void SetType(const std::string& type, const std::string& data_type = "float32");
   static void RunServer();
   static void Init(int key, const NDArray& val);
   static void Init(const std::vector<int>& keys, const std::vector<NDArray>& vals);
diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp b/cpp-package/include/mxnet-cpp/kvstore.hpp
index f2b5e74990..76313f0220 100644
--- a/cpp-package/include/mxnet-cpp/kvstore.hpp
+++ b/cpp-package/include/mxnet-cpp/kvstore.hpp
@@ -73,8 +73,8 @@ inline KVStore*& KVStore::get_kvstore() {
 
 inline KVStore::KVStore() {}
 
-inline void KVStore::SetType(const std::string& type) {
-  CHECK_EQ(MXKVStoreCreate(type.c_str(), &(get_kvstore()->get_handle())), 0);
+inline void KVStore::SetType(const std::string& type, const std::string& data_type) {
+  CHECK_EQ(MXKVStoreCreate(type.c_str(), data_type.c_str(), &(get_kvstore()->get_handle())), 0);
 }
 
 inline void KVStore::RunServer() {
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 55b840dd2c..d4871034e2 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1529,6 +1529,7 @@ MXNET_DLL int MXInitPSEnv(mx_uint num_vars,
  * \return 0 when success, -1 when failure happens
  */
 MXNET_DLL int MXKVStoreCreate(const char *type,
+                              const char *data_type,
                               KVStoreHandle *out);
 /*!
  * \brief Delete a KVStore handle.
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index ddaa207dab..255c94bc46 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -57,7 +57,7 @@ class KVStore {
    *   - 'dist_*' : multi-machines
    * \return a new created KVStore.
    */
-  static KVStore *Create(const char *type = "local");
+  static KVStore *Create(const char *type = "local", const char *data_type = "float32");
 
   /**
    * \brief return the type
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/KVStore.pm b/perl-package/AI-MXNet/lib/AI/MXNet/KVStore.pm
index 84a890dcc9..674ce83bcb 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/KVStore.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/KVStore.pm
@@ -452,9 +452,9 @@ method _send_command_to_servers(Int $head, Str $body)
         The created AI::MXNet::KVStore
 =cut
 
-method create(Str $name='local')
+method create(Str $name='local', Str $data_type='float32')
 {
-    my $handle = check_call(AI::MXNetCAPI::KVStoreCreate($name));
+    my $handle = check_call(AI::MXNetCAPI::KVStoreCreate($name, $data_type));
     return __PACKAGE__->new(handle => $handle);
 }
 
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index adfef9a949..d88a26ff82 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -531,7 +531,7 @@ def _send_command_to_servers(self, head, body):
         check_call(_LIB.MXKVStoreSendCommmandToServers(
             self.handle, mx_uint(head), c_str(body)))
 
-def create(name='local'):
+def create(name='local', data_type='flaot32'):
     """Creates a new KVStore.
 
     For single machine training, there are two commonly used types:
@@ -570,5 +570,6 @@ def create(name='local'):
         raise TypeError('name must be a string')
     handle = KVStoreHandle()
     check_call(_LIB.MXKVStoreCreate(c_str(name),
+                                    c_str(data_type),
                                     ctypes.byref(handle)))
     return KVStore(handle)
diff --git a/python/mxnet/kvstore_server.py b/python/mxnet/kvstore_server.py
index 2504b4674a..e643db6044 100644
--- a/python/mxnet/kvstore_server.py
+++ b/python/mxnet/kvstore_server.py
@@ -72,12 +72,12 @@ def run(self):
         _ctrl_proto = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)
         check_call(_LIB.MXKVStoreRunServer(self.handle, _ctrl_proto(self._controller()), None))
 
-def _init_kvstore_server_module():
+def _init_kvstore_server_module(data_type='float32'):
     """Start server/scheduler."""
     is_worker = ctypes.c_int()
     check_call(_LIB.MXKVStoreIsWorkerNode(ctypes.byref(is_worker)))
     if is_worker.value == 0:
-        kvstore = create('dist')
+        kvstore = create('dist', data_type)
         server = KVStoreServer(kvstore)
         server.run()
         sys.exit()
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
index 94dd25497a..30bcc0b0b5 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
@@ -46,9 +46,9 @@ object KVStore {
    *     - dist works for multi-machines (multiple processes)
    * @return The created KVStore
    */
-  def create(name: String = "local"): KVStore = {
+  def create(name: String = "local", data_type: String = "float32"): KVStore = {
     val handle = new KVStoreHandleRef
-    checkCall(_LIB.mxKVStoreCreate(name, handle))
+    checkCall(_LIB.mxKVStoreCreate(name, data_type, handle))
     new KVStore(handle.value)
   }
 }
diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
index db0f11e27f..3eac6ffeca 100644
--- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
@@ -646,13 +646,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreIsWorkerNode
 }
 
 JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreCreate
-  (JNIEnv *env, jobject obj, jstring name, jobject kvStoreHandle) {
+  (JNIEnv *env, jobject obj, jstring name, jstring data_type_name, jobject kvStoreHandle) {
   jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
   jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
 
   KVStoreHandle out;
   const char *type = env->GetStringUTFChars(name, 0);
-  int ret = MXKVStoreCreate(type, &out);
+  const char *data_type = env->GetStringUTFChars(data_type_name, 0);
+  int ret = MXKVStoreCreate(type, data_type, &out);
   env->ReleaseStringUTFChars(name, type);
 
   env->SetLongField(kvStoreHandle, refLongFid, reinterpret_cast<jlong>(out));
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 1d348a5b40..7eb08d6b36 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -727,9 +727,10 @@ int MXDataIterGetPadNum(DataIterHandle handle, int *pad) {
 }
 
 int MXKVStoreCreate(const char *type,
+                    const char *data_type,
                     KVStoreHandle *out) {
   API_BEGIN();
-  *out = KVStore::Create(type);
+  *out = KVStore::Create(type, data_type);
   API_END();
 }
 
diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc
index a288676102..bd3d93b7a4 100644
--- a/src/kvstore/kvstore.cc
+++ b/src/kvstore/kvstore.cc
@@ -31,7 +31,7 @@
 
 namespace mxnet {
 
-KVStore* KVStore::Create(const char *type_name) {
+KVStore* KVStore::Create(const char *type_name, const char *data_type) {
   std::string tname = type_name;
   std::transform(tname.begin(), tname.end(), tname.begin(), ::tolower);
   KVStore* kv = nullptr;
@@ -45,7 +45,11 @@ KVStore* KVStore::Create(const char *type_name) {
 
   if (has("dist")) {
 #if MXNET_USE_DIST_KVSTORE
-    kv = new kvstore::KVStoreDist(use_device_comm);
+    if (strcmp(data_type, "float16") == 0) {
+      kv = new kvstore::KVStoreDist<mshadow::half::half_t>(use_device_comm);
+    } else {
+      kv = new kvstore::KVStoreDist<real_t>(use_device_comm);
+    }
     if (!has("_async") && kv->IsWorkerNode() && kv->get_rank() == 0) {
       // configure the server to be the sync mode
       kv->SendCommandToServers(kvstore::kSyncMode, "");
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 5e62be8c4c..c67395bf76 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -45,12 +45,14 @@ namespace kvstore {
  * it's the server node's job to control the data consistency among all
  * workers. see details on \ref ServerHandle::Start
  */
+
+template <typename DType>
 class KVStoreDist : public KVStoreLocal {
  public:
   explicit KVStoreDist(bool use_device_comm)
       : KVStoreLocal(use_device_comm), ps_worker_(nullptr), server_(nullptr) {
     if (IsWorkerNode()) {
-      ps_worker_ = new ps::KVWorker<real_t>(0);
+      ps_worker_ = new ps::KVWorker<DType>(0);
       ps::StartAsync("mxnet\0");
       if (!ps::Postoffice::Get()->is_recovery()) {
         ps::Postoffice::Get()->Barrier(
@@ -113,7 +115,7 @@ class KVStoreDist : public KVStoreLocal {
   void RunServer(const Controller& controller) override {
     CHECK(!IsWorkerNode());
     if (IsServerNode()) {
-      server_ = new KVStoreDistServer();
+      server_ = new KVStoreDistServer<DType>();
       server_->set_controller(controller);
     }
 
@@ -175,7 +177,11 @@ class KVStoreDist : public KVStoreLocal {
       if (recv_buf.is_none()) {
         // it may happen for the first time a no-rank-0 worker pull the weight.
         recv_buf = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_,
-                           true, grouped_vals[i][0]->dtype());
+                           true, mshadow::DataType<DType>::kFlag);
+      }
+      auto& tmp = tmp_buf_[key];
+      if (tmp.is_none()) {
+        tmp = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_, true, grouped_vals[i][0]->dtype());
       }
       auto pull_from_servers = [this, key, recv_buf](
           RunContext rctx, Engine::CallbackOnComplete cb) {
@@ -185,9 +191,9 @@ class KVStoreDist : public KVStoreLocal {
 #if MKL_EXPERIMENTAL == 1
         mkl_set_tblob_eager_mode(recv_buf.data());
 #endif
-        real_t* data = recv_buf.data().dptr<real_t>();
+        DType* data = recv_buf.data().dptr<DType>();
         // false means not to delete data when SArray is deleted
-        auto vals = new ps::SArray<real_t>(data, size, false);
+        auto vals = new ps::SArray<DType>(data, size, false);
         // issue pull
         CHECK_NOTNULL(ps_worker_)->ZPull(
           pskv.keys, vals, &pskv.lens, kDefaultPushPull, [vals, cb](){ delete vals; cb(); });
@@ -201,8 +207,12 @@ class KVStoreDist : public KVStoreLocal {
           FnProperty::kNormal,
           priority,
           PROFILER_MESSAGE("KVStoreDistDefaultPull"));
-
-      comm_->Broadcast(key, recv_buf, grouped_vals[i], priority);
+      if (grouped_vals[i][0]->dtype() != mshadow::DataType<DType>::kFlag) {
+        CopyFromTo(recv_buf, &tmp, 0);
+        comm_->Broadcast(key, tmp, grouped_vals[i], priority);
+      } else {
+        comm_->Broadcast(key, recv_buf, grouped_vals[i], priority);
+      }
     }
   }
 
@@ -272,16 +282,35 @@ class KVStoreDist : public KVStoreLocal {
         // This shouldn't affect training of networks though because training involves
         // a sequence of push, pull, then push. This imposes ordering that the
         // second push happens after the first pull, and the pull happens after first push.
-        send_buf = merged;  // avoid memory copy
+        if (send_buf.is_none()) {
+          if (storage_type == kDefaultStorage) {
+            send_buf = NDArray(merged.shape(), pinned_ctx_, true, mshadow::DataType<DType>::kFlag);
+          } else {
+            send_buf = NDArray(storage_type, merged.shape(), pinned_ctx_,
+                               true, mshadow::DataType<DType>::kFlag);
+          }
+        }
+        if (merged.dtype() == mshadow::DataType<DType>::kFlag) {
+          send_buf = merged;
+        } else {
+          CopyFromTo(merged, &send_buf);
+        }
       } else {
         if (send_buf.is_none()) {
           if (storage_type == kDefaultStorage) {
-            send_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype());
+            send_buf = NDArray(merged.shape(), pinned_ctx_, true, mshadow::DataType<DType>::kFlag);
           } else {
-            send_buf = NDArray(storage_type, merged.shape(), pinned_ctx_, true, merged.dtype());
+            send_buf = NDArray(storage_type, merged.shape(), pinned_ctx_,
+                               true, mshadow::DataType<DType>::kFlag);
           }
         }
-        CopyFromTo(merged, &send_buf);
+        if (merged.dtype() == mshadow::DataType<DType>::kFlag) {
+          CopyFromTo(merged, &send_buf);
+        } else {
+          NDArray tmp = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype());
+          CopyFromTo(merged, &tmp);
+          CopyFromTo(tmp, &send_buf);
+        }
       }
 
       // push to servers
@@ -295,9 +324,9 @@ class KVStoreDist : public KVStoreLocal {
 #if MKL_EXPERIMENTAL == 1
           mkl_set_tblob_eager_mode(send_buf.data());
 #endif
-          real_t* data = send_buf.data().dptr<real_t>();
+          DType* data = send_buf.data().dptr<DType>();
           // do push. false means no delete
-          ps::SArray<real_t> vals(data, size, false);
+          ps::SArray<DType> vals(data, size, false);
           CHECK_NOTNULL(ps_worker_)->ZPush(
               pskv.keys, vals, pskv.lens, 0, [cb]() { cb(); });
         };
@@ -329,7 +358,7 @@ class KVStoreDist : public KVStoreLocal {
 #if MKL_EXPERIMENTAL == 1
       mkl_set_tblob_eager_mode(recv_buf.data());
 #endif
-      real_t* data = recv_buf.data().dptr<real_t>();
+      DType* data = recv_buf.data().dptr<DType>();
       const auto offsets = indices.data().dptr<int64_t>();
       const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim());
       const int64_t size = num_rows * unit_len;
@@ -340,7 +369,7 @@ class KVStoreDist : public KVStoreLocal {
         LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: "
                   << pskv.keys << " size: " << size;
       }
-      auto vals = new ps::SArray<real_t>(data, size, false);
+      auto vals = new ps::SArray<DType>(data, size, false);
       // copy indices to recv_buf. this needs to be done before ZPull
       // because after pull is done, the callback function returns and locks are released.
       // at this point, later functions may access the indices variable while copy happens
@@ -367,7 +396,7 @@ class KVStoreDist : public KVStoreLocal {
 #if MKL_EXPERIMENTAL == 1
       mkl_set_tblob_eager_mode(send_buf.data());
 #endif
-      real_t* data = send_buf.data().dptr<real_t>();
+      DType* data = send_buf.data().dptr<DType>();
       const int64_t num_rows = send_buf.aux_shape(kIdx)[0];
       const auto offsets = send_buf.aux_data(kIdx).dptr<int64_t>();
       const auto unit_len = send_buf.shape().ProdShape(1, send_buf.shape().ndim());
@@ -380,7 +409,7 @@ class KVStoreDist : public KVStoreLocal {
         LOG(INFO) << "worker " << get_rank() << " push lens: " << pskv.lens << " keys: "
                   << pskv.keys << " size: " << size;
       }
-      ps::SArray<real_t> vals(data, size, false);
+      ps::SArray<DType> vals(data, size, false);
       CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, kRowSparsePushPull, [cb]() {
         cb();
       });
@@ -531,17 +560,18 @@ class KVStoreDist : public KVStoreLocal {
   /**
    * \brief for worker to push and pull data
    */
-  ps::KVWorker<real_t>* ps_worker_;
+  ps::KVWorker<DType>* ps_worker_;
   /**
    * \brief the server handle
    */
-  KVStoreDistServer* server_;
+  KVStoreDistServer<DType>* server_;
   /**
    * \brief threshold for partition
    */
   size_t bigarray_bound_;
   /// \brief send & recver buffer
   std::unordered_map<int, NDArray> comm_buf_;
+  std::unordered_map<int, NDArray> tmp_buf_;
   bool log_verbose_;
 };
 
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index bedb5398a0..57532bbb67 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -106,11 +106,12 @@ class Executor {
   std::condition_variable cond_;
 };
 
+template <typename DType>
 class KVStoreDistServer {
  public:
   KVStoreDistServer() {
     using namespace std::placeholders;
-    ps_server_ = new ps::KVServer<float>(0);
+    ps_server_ = new ps::KVServer<DType>(0);
     static_cast<ps::SimpleApp*>(ps_server_)->set_request_handle(
         std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2));
     ps_server_->set_request_handle(
@@ -162,8 +163,8 @@ class KVStoreDistServer {
   }
 
   void DataHandleEx(const ps::KVMeta& req_meta,
-                    const ps::KVPairs<real_t>& req_data,
-                    ps::KVServer<real_t>* server) {
+                    const ps::KVPairs<DType>& req_data,
+                    ps::KVServer<DType>* server) {
     if (req_meta.cmd == kRowSparsePushPull) {
       DataHandleRowSparse(req_meta, req_data, server);
     } else {
@@ -173,7 +174,7 @@ class KVStoreDistServer {
   }
 
   inline void ApplyUpdates(const int key, MergeBuf *merged, NDArray *stored,
-                           ps::KVServer<real_t>* server) {
+                           ps::KVServer<DType>* server) {
     if (merged->request.size() == (size_t) ps::NumWorkers()) {
       // let the main thread to execute updater_, which is necessary for python
       if (updater_) {
@@ -209,15 +210,15 @@ class KVStoreDistServer {
   }
 
   void DataHandleRowSparse(const ps::KVMeta& req_meta,
-                       const ps::KVPairs<real_t>& req_data,
-                       ps::KVServer<real_t>* server) {
+                       const ps::KVPairs<DType>& req_data,
+                       ps::KVServer<DType>* server) {
     int master_key = DecodeKey(req_data.keys[0]);
     auto num_rows = req_data.keys.size() - 1;
     auto& stored = store_[master_key];
     if (req_meta.push) {
       CHECK_GT(req_data.lens.size(), 0) << "req_data.lens cannot be empty";
       CHECK_EQ(req_data.lens[0], 0);
-      real_t* data = req_data.vals.data();
+      DType* data = req_data.vals.data();
       if (stored.is_none()) {
         if (log_verbose_) LOG(INFO) << "initial push: " << master_key;
         // initialization
@@ -235,8 +236,8 @@ class KVStoreDistServer {
             stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])});
             mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
             op::PopulateFullIdxRspImpl(s, &rsp);
-            mshadow::Copy(rsp.data().FlatTo1D<cpu, float>(),
-                          recved.data().FlatTo1D<cpu, float>(), s);
+            mshadow::Copy(rsp.data().FlatTo1D<cpu, DType>(),
+                          recved.data().FlatTo1D<cpu, DType>(), s);
           }, recved.ctx(), {recved.var()}, {stored.var()},
           FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
         stored.WaitToRead();
@@ -325,7 +326,7 @@ class KVStoreDistServer {
     } else {
       // pull
       if (log_verbose_) LOG(INFO) << "pull: " << master_key;
-      ps::KVPairs<real_t> response;
+      ps::KVPairs<DType> response;
       if (num_rows == 0) {
         std::vector<int> lens(req_data.keys.size(), 0);
         response.keys = req_data.keys;
@@ -336,7 +337,7 @@ class KVStoreDistServer {
       CHECK(!stored.is_none()) << "init " << master_key << " first";
       auto shape = stored.shape();
       auto unit_len = shape.ProdShape(1, shape.ndim());
-      const float* data = stored.data().dptr<float>();
+      const DType* data = stored.data().dptr<DType>();
       auto len = unit_len * num_rows;
       // concat values
       response.vals.resize(len);
@@ -359,8 +360,8 @@ class KVStoreDistServer {
   }
 
   void DataHandleDefault(const ps::KVMeta& req_meta,
-                         const ps::KVPairs<real_t> &req_data,
-                         ps::KVServer<real_t>* server) {
+                         const ps::KVPairs<DType> &req_data,
+                         ps::KVServer<DType>* server) {
     CHECK_EQ(req_meta.cmd, kDefaultPushPull);
     // do some check
     CHECK_EQ(req_data.keys.size(), (size_t)1);
@@ -378,25 +379,46 @@ class KVStoreDistServer {
     if (req_meta.push) {
       size_t ds[] = {(size_t)req_data.lens[0]};
       TShape dshape(ds, ds + 1);
-      TBlob recv_blob((real_t*)req_data.vals.data(), // NOLINT(*)
+      TBlob recv_blob((DType*)req_data.vals.data(), // NOLINT(*)
                       dshape, cpu::kDevMask);
       NDArray recved = NDArray(recv_blob, 0);
+      NDArray recved_tmp;
+      if (recved.dtype() != mshadow::DataType<real_t>::kFlag) {
+        recved.WaitToRead();
+        recved_tmp = NDArray(dshape, Context::CPU(0), false, mshadow::DataType<real_t>::kFlag);
+        CopyFromTo(recved, &recved_tmp, 0);
+        recved_tmp.WaitToRead();
+      }
       if (stored.is_none()) {
         // initialization
-        stored = NDArray(dshape, Context());
-        CopyFromTo(recved, &stored, 0);
+        if (recved.dtype() != mshadow::DataType<real_t>::kFlag) {
+          stored = NDArray(dshape, Context::CPU(0), false, mshadow::DataType<real_t>::kFlag);
+          CopyFromTo(recved_tmp, &stored, 0);
+        } else {
+          stored = NDArray(dshape, Context());
+          CopyFromTo(recved, &stored, 0);
+        }
         server->Response(req_meta);
         stored.WaitToRead();
       } else if (sync_mode_) {
         // synced push
         auto& merged = merge_buf_[key];
         if (merged.array.is_none()) {
-          merged.array = NDArray(dshape, Context());
+          if (recved.dtype() != mshadow::DataType<real_t>::kFlag) {
+            merged.array = NDArray(dshape, Context::CPU(0),
+                                   false, mshadow::DataType<real_t>::kFlag);
+          } else {
+            merged.array = NDArray(dshape, Context());
+          }
         }
         if (merged.request.size() == 0) {
           CopyFromTo(recved, &merged.array, 0);
         } else {
-          merged.array += recved;
+          if (recved.dtype() != mshadow::DataType<real_t>::kFlag) {
+            merged.array += recved_tmp;
+          } else {
+            merged.array += recved;
+          }
         }
         merged.request.push_back(req_meta);
         ApplyUpdates(key, &merged, &stored, server);
@@ -411,13 +433,23 @@ class KVStoreDistServer {
       }
     } else {
       // pull
-      ps::KVPairs<real_t> response;
+      ps::KVPairs<DType> response;
       CHECK(!stored.is_none()) << "init " << key << " first";
       auto len = stored.shape().Size();
       response.keys = req_data.keys;
       response.lens = {len};
       // TODO(mli) try to remove this CopyFrom
-      response.vals.CopyFrom(static_cast<const float*>(stored.data().dptr_), len);
+      if (stored.dtype() != mshadow::DataType<DType>::kFlag ||
+          stored.ctx().dev_mask() != cpu::kDevMask) {
+        stored.WaitToRead();
+        NDArray tmp = NDArray(stored.shape(), Context::CPU(0),
+                              false, mshadow::DataType<DType>::kFlag);
+        CopyFromTo(stored, &tmp, 0);
+        tmp.WaitToRead();
+        response.vals.CopyFrom(static_cast<const DType*>(tmp.data().dptr_), len);
+      } else {
+        response.vals.CopyFrom(static_cast<const DType*>(stored.data().dptr_), len);
+      }
       server->Response(req_meta, response);
     }
   }
@@ -438,7 +470,7 @@ class KVStoreDistServer {
   std::unordered_map<int, MergeBuf> merge_buf_;
 
   Executor exec_;
-  ps::KVServer<float>* ps_server_;
+  ps::KVServer<DType>* ps_server_;
 
   // whether to LOG verbose information
   bool log_verbose_;


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services