You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2015/07/18 10:38:40 UTC

[2/3] incubator-singa git commit: SINGA-32 Implement Synchronous training frameworks

SINGA-32 Implement Synchronous training frameworks

For the synchronous training frameworks, one worker group and one server group are launched.
Gradients for the same Param are aggregated locally at each process's stub.
The server conducts update until receive all gradients for the same Param (slice).
After udpate, the server sends back new Param (slice) values to every process who has sent update request.
The worker_shard_ and server_shard consist of ParamEntrys, each of which stores the information of one unique Param (slice), e.g.,
the number of shares of each Param (slice), and the local shares for each Param (slice).

The Msg class is improved to have clean/simple API. The msg header now includes a src (int), a dst (int) and a trgt (int value and int version),
representing the source addr, destination addr and target of the msg. The address is constructed by the
entity who creates the msg. Any addr is valid as long as it is unique for one entity.
Function Addr(int grp, int id_or_proc, int type) is provided to construct the addr using
group ID, worker/server ID (or procs ID) and entity type (kServer, kStub, etc.). Functions are also provided to extract
the group, worker/server ID from the addr (int). Similarly, the target field can be constructed using ParamTrgt function
which wraps the Param ID and Slice ID into a target value (int). ParamID() and SliceID() are to extract the info from target value.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/585e275f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/585e275f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/585e275f

Branch: refs/heads/master
Commit: 585e275fdf050db25eb9c583fb54ae39714d9b20
Parents: 7954a87
Author: wang wei <wa...@comp.nus.edu.sg>
Authored: Tue Jul 14 14:21:46 2015 +0800
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Fri Jul 17 12:02:40 2015 +0800

----------------------------------------------------------------------
 include/communication/msg.h    | 238 +++++++---
 include/communication/socket.h |   1 +
 include/neuralnet/base_layer.h |  47 +-
 include/neuralnet/neuralnet.h  |   2 +-
 include/trainer/server.h       |  40 +-
 include/trainer/trainer.h      | 179 ++++----
 include/trainer/worker.h       | 223 +++++----
 include/utils/common.h         |  24 +-
 include/utils/param.h          | 323 ++++++++-----
 include/utils/updater.h        |   1 +
 src/communication/msg.cc       | 170 ++++++-
 src/communication/socket.cc    |  16 +-
 src/neuralnet/layer.cc         |   4 +-
 src/neuralnet/neuralnet.cc     |  33 +-
 src/proto/model.proto          |   6 +-
 src/test/test_paramslicer.cc   |   2 +
 src/test/test_shard.cc         |   8 +-
 src/trainer/server.cc          | 306 +++++++------
 src/trainer/trainer.cc         | 890 ++++++++++++++++--------------------
 src/trainer/worker.cc          | 418 +++++++++--------
 src/utils/cluster.cc           |   6 +-
 src/utils/common.cc            |  93 +++-
 src/utils/param.cc             | 205 +++++----
 23 files changed, 1879 insertions(+), 1356 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/include/communication/msg.h
----------------------------------------------------------------------
diff --git a/include/communication/msg.h b/include/communication/msg.h
index 11b6012..1570936 100644
--- a/include/communication/msg.h
+++ b/include/communication/msg.h
@@ -4,7 +4,6 @@
 // TODO(wangwei): make it a compiler argument
 #define USE_ZMQ
 
-#include <string>
 #include <utility>
 
 #ifdef USE_ZMQ
@@ -12,86 +11,199 @@
 #endif
 
 namespace singa {
+/**
+ * Wrapper to generate message address
+ * @param grp worker/server group id
+ * @param id_or_proc worker/server id or procs id
+ * @param type msg type
+ */
+inline int Addr(int grp, int id_or_proc, int type) {
+  return (grp << 16) | (id_or_proc << 8) | type;
+}
+
+/**
+ * Parse group id from addr.
+ *
+ * @return group id
+ */
+inline int AddrGrp(int addr) {
+  return addr >> 16;
+}
+/**
+ * Parse worker/server id from addr.
+ *
+ * @return id
+ */
+inline int AddrID(int addr) {
+  static const int mask = (1 << 8) - 1;
+  return (addr >> 8) & mask;
+}
 
+/**
+ * Parse worker/server procs from addr.
+ *
+ * @return procs id
+ */
+inline int AddrProc(int addr) {
+  return AddrID(addr);
+}
+/**
+ * Parse msg type from addr
+ * @return msg type
+ */
+inline int AddrType(int addr) {
+  static const int mask = (1 << 8) -1;
+  return addr & mask;
+}
+
+/**
+ * Msg used to transfer Param info (gradient or value), feature blob, etc
+ * between workers, stubs and servers.
+ *
+ * Each msg has a source addr and dest addr identified by a unique integer.
+ * It is also associated with a target field (value and version) for ease of
+ * getting some meta info (e.g., parameter id) from the msg.
+ *
+ * Other data is added into the message as frames.
+ */
 class Msg {
  public:
+  ~Msg();
   Msg();
+  /**
+   * Construct the msg providing source and destination addr.
+   */
+  Msg(int src, int dst);
+  /**
+   * Copy constructor.
+   */
   Msg(const Msg& msg);
-  ~Msg();
-  int size() const;
-
   /**
-    * @param first worker/server group id
-    * @param second worker/server id within the group
-    * @param flag 0 for server, 1 for worker, 2 for stub
-    */
-  inline void set_src(int first, int second, int flag) {
-    src_ = (first << kOff1) | (second << kOff2) | flag;
-  }
-  inline void set_dst(int first, int second, int flag) {
-    dst_ = (first << kOff1) | (second << kOff2) | flag;
-  }
-  inline void set_src(int procs_id, int flag) { set_src(procs_id, 0, flag); }
-  inline void set_dst(int procs_id, int flag) { set_dst(procs_id, 0, flag); }
-  inline int src() const { return src_; }
-  inline int dst() const { return dst_; }
-  inline int src_first() const { return src_ >> kOff1; }
-  inline int dst_first() const { return dst_ >> kOff1; }
-  inline int src_second() const { return (src_ & kMask1) >> kOff2; }
-  inline int dst_second() const { return (dst_ & kMask1) >> kOff2; }
-  inline int src_flag() const { return src_&kMask2; }
-  inline int dst_flag() const { return dst_&kMask2; }
-  inline void SwapAddr() { std::swap(src_, dst_); }
-  inline void set_type(int type) { type_ = type; }
-  inline int type() const { return type_; }
-  inline void set_trgt(int first, int second, int third) {
-    trgt_first_ = first;
-    trgt_second_ = second;
-    trgt_third_ = third;
-  }
-  inline int trgt_first() const { return trgt_first_; }
-  inline int trgt_second() const { return trgt_second_; }
-  inline int trgt_third() const { return trgt_third_; }
- /**
-   * Copy src and dst address, including first, id, flag
-   */
-  inline Msg* CopyAddr() {
-    Msg* msg = new Msg();
-    msg->src_ = src_;
-    msg->dst_ = dst_;
-    return msg;
-  }
-  inline void SetAddr(Msg* msg) {
-    src_ = msg->src_;
-    dst_ = msg->dst_;
-  }
+   * Swap the src/dst addr
+   */
+  void SwapAddr();
   /**
    * Add a frame (a chunck of bytes) into the message
    */
-  void add_frame(const void* addr, int nBytes);
-  int frame_size();
-  void* frame_data();
+  void AddFrame(const void* addr, int nBytes);
+  /**
+   * @return num of bytes of the current frame.
+   */
+  int FrameSize();
+  /**
+   * @return the pointer to the current frame data.
+   */
+  void* FrameData();
+  /**
+   * @return the data of the current frame as c string
+   */
+  char* FrameStr();
+  /**
+   * Move the cursor to the first frame.
+   */
+  void FirstFrame();
+  /**
+   * Move the cursor to the last frame.
+   */
+  void LastFrame();
+  /**
+   * Move the cursor to the next frame
+   * @return true if the next frame is not NULL; otherwise false
+   */
+  bool NextFrame();
+  /**
+   *  Add a 'format' frame to the msg (like CZMQ's zsock_send).
+   *
+   *  The format is a string that defines the type of each field.
+   *  The format can contain any of these characters, each corresponding to
+   *  one or two arguments:
+   *  i = int (signed)
+   *  1 = uint8_t
+   *  2 = uint16_t
+   *  4 = uint32_t
+   *  8 = uint64_t
+   *  p = void * (sends the pointer value, only meaningful over inproc)
+   *  s = char**
+   *
+   *  Returns size of the added content.
+   */
+  int AddFormatFrame(const char *format, ...);
   /**
-    * Move the cursor to the next frame
-    * @return true if the next frame is not NULL; otherwise false
-    */
-  bool next_frame();
+   *  Parse the current frame added using AddFormatFrame(const char*, ...).
+   *
+   *  The format is a string that defines the type of each field.
+   *  The format can contain any of these characters, each corresponding to
+   *  one or two arguments:
+   *  i = int (signed)
+   *  1 = uint8_t
+   *  2 = uint16_t
+   *  4 = uint32_t
+   *  8 = uint64_t
+   *  p = void * (sends the pointer value, only meaningful over inproc)
+   *  s = char**
+   *
+   *  Returns size of the parsed content.
+   */
+  int ParseFormatFrame(const char* format, ...);
+
 #ifdef USE_ZMQ
   void ParseFromZmsg(zmsg_t* msg);
   zmsg_t* DumpToZmsg();
 #endif
- protected:
-  static const unsigned int kOff1 = 16;
-  static const unsigned int kOff2 = 4;
-  static const unsigned int kMask1 = (1 << kOff1) - 1;
-  static const unsigned int kMask2 = (1 << kOff2) - 1;
 
+  /**
+   * @return msg size in terms of bytes, ignore meta info.
+   */
+  int size() const;
+  /**
+   * Set source addr.
+   * @param addr unique identify one worker/server/stub in the current job
+   */
+  void set_src(int addr) { src_ = addr; }
+  /**
+   * @return source addr.
+   */
+  int src() const { return src_; }
+  /**
+   * Set destination addr.
+   * @param addr unique identify one worker/server/stub in the current job
+   */
+  void set_dst(int addr) { dst_ = addr; }
+  /**
+   * @return dst addr.
+   */
+  int dst() const { return dst_; }
+  /**
+   * Set msg type, e.g., kPut, kGet, kUpdate, kRequest
+   */
+  void set_type(int type) { type_ = type; }
+  /**
+   * @return msg type.
+   */
+  int type() const { return type_; }
+  /**
+   * Set msg target.
+   *
+   * One msg has a target to identify some entity in worker/server/stub.
+   * The target is associated with a version, e.g., Param version.
+   */
+  void set_trgt(int val, int version) {
+    trgt_val_ = val;
+    trgt_version_ = version;
+  }
+  int trgt_val() const {
+    return trgt_val_;
+  }
+  int trgt_version() const {
+    return trgt_version_;
+  }
+
+ protected:
   int src_ = 0;
   int dst_ = 0;
   int type_ = 0;
-  int trgt_first_ = 0;
-  int trgt_second_ = 0;
-  int trgt_third_ = 0;
+  int trgt_val_ = 0;
+  int trgt_version_ = 0;
 #ifdef USE_ZMQ
   zmsg_t* msg_ = nullptr;
   zframe_t *frame_ = nullptr;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/include/communication/socket.h
----------------------------------------------------------------------
diff --git a/include/communication/socket.h b/include/communication/socket.h
index 5a9598c..fe06aad 100644
--- a/include/communication/socket.h
+++ b/include/communication/socket.h
@@ -43,6 +43,7 @@ class SocketInterface {
 class Poller {
  public:
   Poller();
+  Poller(SocketInterface* socket);
   /**
     * Add a socket for polling; Multiple sockets can be polled together by
     * adding them into the same poller.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/include/neuralnet/base_layer.h
----------------------------------------------------------------------
diff --git a/include/neuralnet/base_layer.h b/include/neuralnet/base_layer.h
index 047e43d..6ae7c50 100644
--- a/include/neuralnet/base_layer.h
+++ b/include/neuralnet/base_layer.h
@@ -204,6 +204,9 @@ class Layer {
   virtual bool is_bridgedstlayer() const {
     return false;
   }
+  virtual bool is_bridgelayer() const {
+    return false;
+  }
 
  protected:
   LayerProto layer_proto_;
@@ -211,19 +214,30 @@ class Layer {
   vector<Layer*> srclayers_, dstlayers_;
 };
 
+class BridgeLayer : public Layer {
+ public:
+  void set_ready(bool a) {
+    ready_ = a;
+  }
+  bool ready() const {
+    return ready_;
+  }
+  bool is_bridgelayer() const override {
+    return true;
+  }
+
+ protected:
+  //!< true if received grad from BridgeDstLayer
+  bool ready_;
+};
 /**
  * For sending data to layer on other threads which may resident on other nodes
  * due to layer/data partition.
  */
-class BridgeSrcLayer: public Layer {
+class BridgeSrcLayer: public BridgeLayer {
  public:
   using Layer::ComputeFeature;
   using Layer::ComputeGradient;
-  using Layer::data;
-  using Layer::mutable_data;
-  using Layer::grad;
-  using Layer::mutable_grad;
-  using Layer::is_bridgesrclayer;
 
   void ComputeFeature(Phase phase, Metric* perf) override {}
   void ComputeGradient(Phase phase) override {
@@ -246,22 +260,12 @@ class BridgeSrcLayer: public Layer {
   bool is_bridgesrclayer() const override {
     return true;
   }
-  void set_ready(bool a) {
-    ready_ = a;
-  }
-  bool ready() const {
-    return ready_;
-  }
-
- protected:
-  //!< true if received grad from BridgeDstLayer
-  bool ready_;
 };
 /**
  * For recv data from layer on other threads which may resident on other nodes
  * due to layer/data partiton
  */
-class BridgeDstLayer: public Layer {
+class BridgeDstLayer: public BridgeLayer {
  public:
   using Layer::ComputeFeature;
   using Layer::ComputeGradient;
@@ -275,15 +279,6 @@ class BridgeDstLayer: public Layer {
   bool is_bridgedstlayer() const {
     return true;
   }
-  void set_ready(bool ready) {
-    ready_ = ready;
-  }
-  bool ready() const {
-    return ready_;
-  }
- protected:
-  //!< true if received data from BridgeSrcLayer
-  bool ready_;
 };
 
 /**

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/include/neuralnet/neuralnet.h
----------------------------------------------------------------------
diff --git a/include/neuralnet/neuralnet.h b/include/neuralnet/neuralnet.h
index 2e19d0c..6aec88e 100644
--- a/include/neuralnet/neuralnet.h
+++ b/include/neuralnet/neuralnet.h
@@ -58,7 +58,7 @@ class NeuralNet {
   /**
    * Share memory of parameter values from other neuralnet
    */
-  void ShareParams(shared_ptr<NeuralNet> other);
+  void ShareParamsFrom(shared_ptr<NeuralNet> other);
 
   const std::vector<Layer*>& layers() {
     return layers_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/include/trainer/server.h
----------------------------------------------------------------------
diff --git a/include/trainer/server.h b/include/trainer/server.h
index 96a1437..7fb60c4 100644
--- a/include/trainer/server.h
+++ b/include/trainer/server.h
@@ -7,9 +7,7 @@
 #include "proto/model.pb.h"
 #include "communication/socket.h"
 
-using std::shared_ptr;
 namespace singa {
-typedef std::unordered_map<int, Param*> ServerShard;
 /* Repsond to worker's get/put/udpate request, and periodically syncing with
   * other servers.
   *
@@ -22,17 +20,17 @@ typedef std::unordered_map<int, Param*> ServerShard;
   */
 class Server{
  public:
-
   Server(int thread_id, int group_id, int server_id);
-  virtual ~Server() {};
-  void Setup(const UpdaterProto& proto, shared_ptr<ServerShard> shard,
+  virtual ~Server();
+  void Setup(const UpdaterProto& proto,
+      std::unordered_map<int, ParamEntry*>* shard,
       const vector<int>& slice2group);
   void Run();
-  const int group_id() const {
-    return group_id_;
+  const int grp_id() const {
+    return grp_id_;
   }
-  const int server_id() const {
-    return server_id_;
+  const int id() const {
+    return id_;
   }
 
  protected:
@@ -42,14 +40,14 @@ class Server{
    *
    * @return the orignal message or response message
    */
-	virtual Msg* HandleGet(Param* param, Msg** msg);
+	virtual Msg* HandleGet(Msg** msg);
 
 	/**
 	 * Process Update request.
    *
    * @return the orignal message or response message
    */
-	virtual Msg* HandleUpdate(Param* param, Msg** msg);
+  const vector<Msg*> HandleUpdate(Msg **msg);
 
 	/**
 	 * Process PUT request.
@@ -62,15 +60,23 @@ class Server{
 	/**
    * TODO Process SYNC request.
 	 */
-	virtual Msg* HandleSyncRequest(Param* param, Msg** msg);
+	virtual Msg* HandleSyncRequest(Msg** msg);
+
+  /**
+   * Generate sync message which sends local mastered Param slice to other
+   * server groups
+   * @param param slice to be sync with others
+   * @return sync messages
+   */
+  const vector<Msg*> GenSyncMsgs(Param* param);
 
  protected:
-  int thread_id_,group_id_, server_id_;
-  shared_ptr<Dealer> dealer_;
-  shared_ptr<Updater> updater_;
-  shared_ptr<ServerShard> shard_;
+  int thread_id_,grp_id_, id_;
+  Updater* updater_;
+  std::unordered_map<int, ParamEntry*> *shard_;
   vector<int> slice2group_;
-  std::map<int, shared_ptr<Blob<float>>> last_data_;
+  std::unordered_map<int, shared_ptr<Blob<float>>> last_data_;
+  std::unordered_map<int, vector<Msg*>> buffer_requests_;
 };
 } /* Server */
 #endif //INCLUDE_TRAINER_SERVER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/include/trainer/trainer.h
----------------------------------------------------------------------
diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h
index bc81e72..50526ae 100644
--- a/include/trainer/trainer.h
+++ b/include/trainer/trainer.h
@@ -1,9 +1,9 @@
 #ifndef INCLUDE_TRAINER_TRAINER_H_
 #define INCLUDE_TRAINER_TRAINER_H_
 #include <unordered_map>
+#include <queue>
 #include "proto/cluster.pb.h"
 #include "proto/model.pb.h"
-#include "utils/updater.h"
 #include "utils/param.h"
 #include "utils/singleton.h"
 #include "utils/factory.h"
@@ -14,66 +14,6 @@
 
 namespace singa {
 /**
- * Callback function for zookeeper
- */
-void HandleWorkerFinish(void * ctx);
-/**
- * Zookeeper handler context used by HandleWorkerFinish(void*)function.
- */
-typedef struct HandleContext_{
-  shared_ptr<Dealer> dealer;
-  int group_id, id;
-} HandleContext;
-/**
-  * ParamInfo is used to construct a parameter shard.
-  *
-  * For each worker group:
-  *   Every unique Param object is associated with a ParamCounter object whose
-  *   param field points the to Param object itself.
-  *
-  *   Param objects sharing the same values (due to data parallelism) are
-  *   associated with the same ParamCounter whose param field also shares the
-  *   same values.
-  *
-  *   Usage: we need to aggregate gradients from all workers for the shared
-  *   parameters before sending the update request. The nUpdate counter counts
-  *   the number.
-  *
-  * TODO test with different physical architectures.
-  */
-class ParamInfo{
-   public:
-  ParamInfo(Param* p,int local, int owner):
-    num_update(0), next_version(-1),num_local(local), num_total(1),
-    owner_procs(owner){
-      shares.push_back(p);
-    }
-
-  /**
-    * Associate the counter to a Param object.
-    *
-    * @param p
-    * @param local 1 if this Param object is used by workers in this procs, 0
-    *  otherwise
-    * @param owner the procs id of the worker who ownes this Param object
-    */
-  void AddParam(Param* p, bool local){
-    num_local+=local;
-    num_total+=1;
-    if(local)
-      shares.push_back(p);
-  }
-  int num_update, next_version; //!< all counters are atomic
-
-  int num_local; //!< # local workers uses the shared parameter
-  int num_total; //!< # total workers uses the shared parameter
-  int owner_procs; //!< the procs id of the worker that owns the parameter
-  vector<Param*> shares;
-};
-
-typedef std::map<int, shared_ptr<ParamInfo>> WorkerShard;
-
-/**
  * Every running process has a training object which launches one or more
  * worker (and server) threads.
  *
@@ -82,23 +22,56 @@ typedef std::map<int, shared_ptr<ParamInfo>> WorkerShard;
 
 class Trainer{
  public:
+  ~Trainer();
   /**
-   * Start the training in one process
+   * Entrance function which construct the workers and servers, and luanch
+   * one thread per worker/server. TODO rename variables about cluster config,
+   * job config, etc.
    *
-   * @param modelproto
-   * @param clusterproto
+   * @param mconf model configuration
+   * @param globalconf global singa configuration
+   * @param cconf cluster configuration
+   * @param job job ID
    */
-  void Start(const ModelProto& modelproto, const GlobalProto& globalproto, 
-             const ClusterProto& clusterproto, const int procs_id);
+  void Start(const ModelProto& mconf, const GlobalProto& globalconf,
+             const ClusterProto& cconf, const int job);
 
   // TODO add Resume() function to continue training from a previously stopped
   // point.
-
  protected:
-  vector<Server*> CreateServers(int nthread, const ModelProto& mproto,
-      const vector<int> slices, vector<HandleContext*>* ctx);
-  vector<Worker*> CreateWorkers(int nthread, const ModelProto& mproto,
-      vector<int> *slice_size);
+  /**
+   * Create server instances.
+   * @param nthread total num of threads in current procs which is used to
+   * assign each thread a local thread ID. The number of workers is extracted
+   * from Cluster
+   * @param model_conf
+   * @return server instances
+   */
+  vector<Server*> CreateServers(int nthread, const ModelProto& mproto);
+  /**
+   * Create workers instances.
+   * @param nthread total num of threads in current procs which is used to
+   * assign each thread a local thread ID. The number of workers is extracted
+   * from Cluster
+   * @param model_conf
+   * @return worker instances
+   */
+  vector<Worker*> CreateWorkers(int nthread, const ModelProto& mproto);
+
+  /**
+   * Setup workers and servers.
+   *
+   * For each worker, create and assign a neuralnet to it.
+   * For each server, create and assign the param shard to it.
+   * Create the partition map from slice ID to server
+   * @param model_conf
+   * @param workers
+   * @param servers
+   */
+  void SetupWorkerServer(
+    const ModelProto& model_conf,
+    const vector<Worker*>& workers,
+    const vector<Server*>& servers);
 
   void Run(const vector<Worker*>& workers, const vector<Server*>& servers);
   /**
@@ -111,37 +84,73 @@ class Trainer{
    * implementation class as the value, e.g., <"Updater" SGDUpdater>.
    */
   void RegisterDefaultClasses(const singa::ModelProto& proto);
-
   /**
-   * Workers from the same group resident in the same process share the same
-   * WorkerShard which contains ParamCounters for Param objects used/updated by
-   * these worekrs. Shared Param objects are associated with the same
-   * ParamCounter.
+   * Generate msg to trigger synchronization with other server groups.
+   *
+   * @param server the local server index whom the message is sent to
+   * @param servers all local servers
+   * @return sync msg
+   */
+  Msg* GenSyncReminderMsg(int server, const vector<Server*>& servers);
+  /**
+   * Display metrics to log (standard output)
+   */
+  void DisplayMetric(Msg** msg);
+  /**
+   * Create a socket to send msg to the specified process
+   * @param dst_procs the dst process (logical) ID
+   * @return the newly created socket
    */
+  Dealer* CreateInterProcsDealer(int dst_procs);
+  /**
+   * Handle messages to local servers and local stub
+   */
+  void HandleLocalMsg(std::queue<Msg*>* msg_queue, Msg** msg);
 
 	/**
 	 * Generate a request message to Get the parameter object.
 	 */
-	virtual const vector<Msg*> HandleGet(shared_ptr<ParamInfo>counter, Msg** msg);
-	virtual void HandleGetResponse(shared_ptr<ParamInfo>counter, Msg** msg);
+	const vector<Msg*> HandleGet(ParamEntry* entry, Msg** msg);
+	void HandleGetResponse(ParamEntry* entry, Msg** msg);
 
 	/**
 	 * Generate a request message to Update the parameter object.
 	 */
-	virtual const vector<Msg*> HandleUpdate(shared_ptr<ParamInfo>counter, Msg** msg);
-  virtual void HandleUpdateResponse(shared_ptr<ParamInfo>counter, Msg** msg);
+	const vector<Msg*> HandleUpdate(ParamEntry* entry, Msg** msg);
+  void HandleUpdateResponse(ParamEntry* entry, Msg** msg);
 
   /**
 	 * Generate a request message to Put the parameter object.
 	 */
-	virtual const vector<Msg*> HandlePut(shared_ptr<ParamInfo>counter, Msg** msg);
-	virtual Msg* HandleConnect(Msg** msg);
+	const vector<Msg*> HandlePut(ParamEntry* entry, Msg** msg);
+
+  /**
+   * Called by HandlePut, HandleUpdate and HandleGet functions
+   * @param type message type
+   * @param version param version
+   * @param entry
+   * @param msg
+   * @param ret generated messages
+   */
+  void GenMsgs(int type, int version, ParamEntry* entry,
+    Msg* msg, vector<Msg*> *ret);
+  /**
+   * Get a hash id for a Param object from a group.
+   *
+   * Simple multiple group_id with a large prime number 997 (assuming there are
+   * no more than 997 worker groups) and plus owner param id.
+   */
+  inline int Hash(int grp_id, int param_id) {
+    return grp_id * 997 + param_id;
+  }
 
  protected:
   int procs_id_;
-  shared_ptr<Router> router_;
-  std::unordered_map<int, shared_ptr<WorkerShard>> worker_shards_;
-  shared_ptr<ServerShard> server_shard_;
+  Router *router_;
+  std::unordered_map<int, ParamEntry*> worker_shard_;
+  //!< map from slice ID to slice, used by servers and deleted in the destructor
+  std::unordered_map<int, ParamEntry*> server_shard_;
+  //!< map from slice to the server that updates it
   vector<int> slice2server_;
 };
 } /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/include/trainer/worker.h
----------------------------------------------------------------------
diff --git a/include/trainer/worker.h b/include/trainer/worker.h
index 3283ee9..ad04c1b 100644
--- a/include/trainer/worker.h
+++ b/include/trainer/worker.h
@@ -1,46 +1,70 @@
-#ifndef INCLUDE_TRAINER_WORKER_H_
-#define INCLUDE_TRAINER_WORKER_H_
-#include <map>
-#include <exception>
+#ifndef SINGA_TRAINER_WORKER_H_
+#define SINGA_TRAINER_WORKER_H_
 #include "neuralnet/neuralnet.h"
 #include "proto/model.pb.h"
-#include "utils/cluster.h"
 #include "utils/updater.h"
 #include "communication/socket.h"
-#include "communication/msg.h"
 
 namespace singa {
-const int kCollectSleepTime=5;//milliseconds;
+//!< sleep 5 milliseconds if the Param is not updated to the expected version
+const int kCollectSleepTime=5;
 /**
  * The Worker class which runs the training algorithm.
  * The first worker group will initialize parameters of the Net,
  * and put them into the distributed memory/table.
+ * The virtual function TrainOneBatch and TestOneBatch implement the
+ * training and test algorithm for one mini-batch data.
+ *
+ * Child workers override the two functions to implement their training
+ * algorithms, e.g., the BPWorker/CDWorker/BPTTWorker implements the BP/CD/BPTT
+ * algorithm respectively.
  */
 class Worker {
  public:
-  Worker(int thread_id, int group_id, int worker_id);
-  virtual ~Worker(){}
-  void Setup(const ModelProto& model, shared_ptr<NeuralNet> train_net);
-  void set_test_net(shared_ptr<NeuralNet> test_net){
-    test_net_=test_net;
-  }
-  void set_validation_net(shared_ptr<NeuralNet> val_net){
-    validation_net_=val_net;
-  }
-
-  void Stop();
-  int Put(Param* param, int step);
-  int Get(Param* param, int step);
-  int Update(Param* param, int step);
-  int Collect(Param* param, int step);
-  int CollectAll(shared_ptr<NeuralNet> net, int step);
   /**
-    * check validation/test firstly, then TrainOneBatch
-    * Performance collects performance for the whole neuralnet.
-    * Hence, no need to collect performance in every thread.
-    * Only the main thread will pass none null perf.
+   * @param thread_id local thread index within the procs
+   * @param grp_id global worker group ID
+   * @param id worker ID within the group
+   */
+  Worker(int thread_id, int grp_id, int id);
+  virtual ~Worker();
+  /**
+   * Setup members
+   */
+  void Setup(const ModelProto& model, shared_ptr<NeuralNet> train_net,
+      shared_ptr<NeuralNet> valid_net, shared_ptr<NeuralNet> test_net);
+  /**
+    * Main function of Worker.
+    *
+    * Train the neuralnet step by step, test/validation is done periodically.
     */
-  void RunOneBatch(int step, Metric* perf=nullptr);
+  void Run();
+  /**
+   * Resume from snapshot
+   */
+  void Resume();
+  /**
+   * Init all local params (i.e., params from layers resident in this worker).
+   *
+   * If the param is owned by the worker, then init it and put it to servers.
+   * Otherwise call Get() to get the param. The Get may not send get request.
+   * Because the param's own is in the same procs. Once the owner initializes
+   * the param, its version is visiable to all shares.
+   * If the training starts from scrath, the params are initialzed using random
+   * distributions, e.g., Gaussian distribution. After that, the worker may
+   * train for a couple of steps to warmup the params before put
+   * them to servers (warmup of ModelProto controls this).
+   *
+   * TODO(wangwei) If the worker is resumed from checkpoint, the owner param's
+   * values are parsed from the checkpoint file instead of randomly initialized.
+   */
+  void InitLocalParams();
+  /**
+    * Test the perforance of the learned model on validation or test dataset.
+    * Test is done by the first group.
+    * @param net, neural network
+    */
+  void Test(int nsteps, Phase phase, shared_ptr<NeuralNet> net);
   /**
     * Train one mini-batch.
     * Test/Validation is done before training.
@@ -52,96 +76,104 @@ class Worker {
   virtual void TestOneBatch(int step, Phase phase, shared_ptr<NeuralNet> net,
       Metric* perf)=0;
   /**
-    * Test the perforance of the learned model on validation or test dataset.
-    * Test is done by the first group.
-    * @param net, neural network
-    */
-  void Test(int nsteps, Phase phase, shared_ptr<NeuralNet> net);
+   * Report performance to the stub.
+   *
+   * @param prefix display prefix, e.g., 'Train', 'Test'
+   * @param perf
+   */
+  void Report(const string& prefix, const Metric & perf);
 
   /**
-    * Main function of Worker.
-    * 1. Train the neuralnet step by step, test/validation is done periodically.
-    * 2. TODO Communicate with others, e.g., zookeeper, after every step.
-    */
-  virtual void Run();
+   * Put Param to server.
+   * @param param
+   * @param step used as current param version for the put request
+   */
+  int Put(Param* param, int step);
+  /**
+   * Get Param with specific version from server
+   * If the current version >= the requested version, then return.
+   * Otherwise send a get request to stub who would forwards it to servers.
+   * @param param
+   * @param step requested param version
+   */
+  int Get(Param* param, int step);
+  /**
+   * Update Param
+   * @param param
+   * @param step training step used for updating (e.g., deciding learning rate)
+   */
+  int Update(Param* param, int step);
+  /**
+   * Block until the param is updated since sending the update request
+   *
+   * @param param
+   * @param step not used
+   */
+  int Collect(Param* param, int step);
+  /**
+   * Call Collect for every param of net
+   */
+  int CollectAll(shared_ptr<NeuralNet> net, int step);
+  /**
+   * Receive blobs from other workers due to model partitions.
+   */
+  void ReceiveBlobs(
+    bool data, bool grad, BridgeLayer* layer, shared_ptr<NeuralNet> net);
+  /**
+   * Send blobs to other workers due to model partitions.
+   */
+  void SendBlobs(
+    bool data, bool grad, BridgeLayer* layer, shared_ptr<NeuralNet> net);
 
   /**
    * Check is it time to display training info, e.g., loss and precison.
    */
-  const bool DisplayNow(const int step) const {
-    return (modelproto_.display_frequency() > 0
-        && step >= modelproto_.display_after_steps()
-        && ((step - modelproto_.display_after_steps())
-          % modelproto_.display_frequency() == 0));
-  }
-
-  const bool DisplayDebugInfo(const int step) const {
-    return DisplayNow(step)&&modelproto_.debug()&&group_id_==0;
-  }
-  void DisplayPerformance(const string& prefix, const Metric & perf);
-
+  inline bool DisplayNow(int step) const;
+  /**
+   * Check is it time to display training info, e.g., loss and precison.
+   */
+  inline bool DisplayDebugInfo(int step) const;
   /**
-   * return true if the stop condition is satisfied, e.g., the maximum number
-   * of steps have been reached.
+   * Check is it time to stop
    */
-  const bool StopNow(const int step) const{
-    return (step >= modelproto_.train_steps());
-  }
+  inline bool StopNow(int step) const;
   /**
    * Check is it time to do checkpoint.
-   * @param step the ::Train() has been called this num times.
    */
-  const bool CheckpointNow(const int step) const{
-    return (group_id_==0
-        && modelproto_.checkpoint_frequency() > 0
-        && step >= modelproto_.checkpoint_after_steps()
-        && ((step - modelproto_.checkpoint_after_steps())
-          % modelproto_.checkpoint_frequency() == 0));
-  }
+  inline bool CheckpointNow(int step) const;
   /**
    * Check is it time to do test.
    * @param step the ::Train() has been called this num times.
    */
-  const bool TestNow(const int step) const{
-    return (group_id_==0
-        && modelproto_.test_frequency() > 0
-        && modelproto_.test_steps() > 0
-        && step >= modelproto_.test_after_steps()
-        && ((step - modelproto_.test_after_steps())
-          % modelproto_.test_frequency() == 0));
-  }
+  inline bool TestNow(int step) const;
   /**
    * Check is it time to do validation.
    * @param step the ::Train() has been called step times.
    */
-  const bool ValidateNow(const int step) {
-    return (group_id_==0
-        && modelproto_.validation_frequency() > 0
-        && modelproto_.validation_steps() > 0
-        && step >= modelproto_.validation_after_steps()
-        && ((step - modelproto_.validation_after_steps())
-          % modelproto_.validation_frequency() == 0));
-  }
+  inline bool ValidateNow(int step) const;
 
   /**
-   * TODO Resume from snapshot
-  void Resume();
+   * @return group ID
    */
-  void ReceiveBlobs(shared_ptr<NeuralNet> net);
-  void SendBlob();
-  void ConnectStub(shared_ptr<Dealer> dealer, EntityType type);
+  int grp_id() const { return grp_id_;}
+
+  /**
+   * @reutrn worker ID within the worker group.
+   */
+  int id() const { return id_;}
+
  protected:
-  int thread_id_, group_id_, worker_id_;
+  int thread_id_, grp_id_, id_;
   int step_;
   ModelProto modelproto_;
   shared_ptr<NeuralNet> train_net_, test_net_, validation_net_;
-  shared_ptr<Dealer> layer_dealer_, dealer_;
-  shared_ptr<Updater> updater_;
+  Dealer* layer_dealer_, *dealer_;
+  Updater* updater_;
 };
 
 class BPWorker: public Worker{
  public:
-  BPWorker(int thread_id, int group_id, int worker_id);
+  BPWorker(int thread_id, int grp_id, int id);
   ~BPWorker(){}
   void TrainOneBatch(int step, Metric* perf) override;
   void TestOneBatch(int step, Phase phase, shared_ptr<NeuralNet> net,
@@ -150,6 +182,19 @@ class BPWorker: public Worker{
   void Forward(int step, Phase phase, shared_ptr<NeuralNet> net, Metric* perf);
   void Backward(int step, shared_ptr<NeuralNet> net);
 };
+
+inline int BlobTrgt(int grp, int layer) {
+  return (grp << 16) | layer;
+}
+
+inline int BlobGrp(int blob_trgt) {
+  return blob_trgt >> 16;
+}
+
+inline int BlobLayer(int blob_trgt) {
+  static int mask = (1 << 16) -1;
+  return blob_trgt & mask;
+}
 }  // namespace singa
 
-#endif  // INCLUDE_TRAINER_WORKER_H_
+#endif  // SINGA_TRAINER_WORKER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/include/utils/common.h
----------------------------------------------------------------------
diff --git a/include/utils/common.h b/include/utils/common.h
index 022a1dd..ef83031 100644
--- a/include/utils/common.h
+++ b/include/utils/common.h
@@ -10,6 +10,7 @@
 #include "proto/common.pb.h"
 
 namespace singa {
+using std::vector;
 
 std::string IntVecToString(const std::vector<int>& vec);
 std::string VStringPrintf(std::string fmt, va_list l);
@@ -25,6 +26,24 @@ void WriteProtoToBinaryFile(const google::protobuf::Message& proto,
 
 const std::string CurrentDateTime();
 void  CreateFolder(const std::string name);
+/**
+ * Slice a set of large Params into small pieces such that they can be roughtly
+ * equally partitioned into a fixed number of boxes.
+ *
+ * @param num total number of boxes to store the small pieces
+ * @param sizes size of all Params
+ * @return all slices for each Param
+ */
+const vector<vector<int>> Slice(int num, const vector<int>& sizes);
+/**
+ * Partition slices into boxes.
+ *
+ * @param num number of boxes
+ * @param slices slice sizes
+ * @return box id for each slice
+ */
+const vector<int> PartitionSlices(int num, const vector<int>& slices);
+
 /*
 inline void Sleep(int millisec=1){
   std::this_thread::sleep_for(std::chrono::milliseconds(millisec));
@@ -46,6 +65,8 @@ void SetupLog(const std::string& workspace, const std::string& model);
  */
 class Metric {
  public:
+  Metric() {}
+  explicit Metric(const std::string& str);
   /**
    * Add one metric.
    *
@@ -60,7 +81,7 @@ class Metric {
    */
   void Reset();
   /**
-   * Generate a one line string for logging
+   * Generate a one-line string for logging
    */
   const std::string ToLogString() const;
   /**
@@ -71,6 +92,7 @@ class Metric {
    * Parse the metric from a string
    */
   void ParseFrom(const std::string& msg);
+
  private:
   std::unordered_map<std::string, std::pair<int, float>> entry_;
 };

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/include/utils/param.h
----------------------------------------------------------------------
diff --git a/include/utils/param.h b/include/utils/param.h
index 781fdb6..ed30d40 100644
--- a/include/utils/param.h
+++ b/include/utils/param.h
@@ -2,110 +2,42 @@
 #define INCLUDE_UTILS_PARAM_H_
 #include <vector>
 #include <string>
-#include <map>
-#include <functional>
 #include "proto/model.pb.h"
 #include "utils/blob.h"
 #include "communication/msg.h"
-// Base paramter class.
+
+/**
+ * Base paramter class.
+ *
+ * The Param object is a set of parameters, e.g., the (sub) weight matrix or
+ * (sub) bias vector.
+ *
+ * It has at a gradient Blob and data Blob for gradients and parameter values.
+ * Since some layers (or neuralnet) share parameter values, the data Blob is a
+ * shared pointer which can be assigned to many Param objects' data field.
+ *
+ * It provides access methods like data(), grad(). It also provides functions
+ * for generating messages and parsing messages to transferring the Param
+ * objects among worker-worker, worker-server and server-server.
+ *
+ * Param objects are of different sizes, which makes it hard to acheive
+ * load-balance among servers. Hence, we slice large Param objects into small
+ * pieces. At the server side, one slice is a Param object.
+ */
 namespace singa {
 class Param {
  public:
   Param();
   virtual ~Param(){ }
   /**
-   * Generate the message for a get request, i.e., get parameters from a server
-   *
-   * This function is called at worker/stub side.
-   * @param copy decides whether to copy the parameter values from the server.
-   * @param slice_idx index of the slice from which the message is generated.
-   * @return generated message without setting src, dst, target fields.
-   */
-  virtual Msg* GenGetMsg(bool copy, int slice_idx);
-  /**
-   * Generate the message for a put request, i.e., put parameters to a server.
-   * \copydetails GenGetMsg(bool, int);
-   */
-  virtual Msg* GenPutMsg(bool copy, int slice_idx);
-  /**
-   * Generate the message for a update request, i.e., pass info to server for
-   * parameter update.
-   * \copydetails GenGetMsg(bool, int);
-   */
-  virtual Msg* GenUpdateMsg(bool copy, int slice_idx);
-  /**
-   * Generate the message for a synchronization request between server groups.
-   *
-   * This function is called at server side where the Param is actually a slice
-   * of an original Param object.
-   * */
-  virtual Msg* GenSyncMsg(int offset, int size);
-  /**
-   * Generate the message to response the update request.
-   *
-   * This function is called at the server side, where the Param is actually a slice
-   * of an original Param object.
-   * @param copy if true copy the parameter value into the message, otherwise
-   * only transfer the pointer of the parameter values.
-   * @return response message pointer
-   */
-  virtual Msg* GenUpdateResponseMsg(bool copy);
-
-  /**
-   * Server handling function for get request.
-   *
-   * @param msg  request message
-   * @return resposne message
-   */
-  virtual Msg* HandleGetMsg(Msg** msg);
-  /**
-   * Server handling function for put request.
-   *
-   * \copydetails HandleGetMsg(Msg**)
-   */
-  virtual Msg* HandlePutMsg(Msg** msg);
-  /**
-   * Server handling function for synchronization message
-   *
-   * \copydetails HandleGetMsg(Msg**)
-   */
-  virtual Msg* HandleSyncMsg(Msg** msg);
-  /**
-   * Server parses update request message.
-   *
-   * @param msg
-   * @return 1 for copy, 0 for no copy
-   */
-  virtual int ParseUpdateMsg(Msg** msg);
-  /**
-   * Worker/Stub parsing function for get response.
-   *
-   * @param msg
-   * @param slice_idx index for the slice
-   */
-  virtual int ParseGetResponseMsg(Msg** msg, int slice_idx);
-  /**
-   * Worker/Server parsing function for update response
-   *
-   * \copydetails ParseGetResponseMsg(Msg**, int);
-   */
-  virtual int ParseUpdateResponseMsg(Msg** msg, int slice_idx);
-  /**
-   * Server parsing function for synchronization response.
-   *
-   * \copydetails ParseGetResponseMsg(Msg** , int);
-   */
-  virtual int ParseSyncResponseMsg(Msg** msg, int slice_idx);
-
-  /**
    * Setup param object
    *
-   * @param proto includes learning rate/weight decay multipliers
-   * @param shape
+   * @param conf param configuration, include learning rate multiplier etc.
+   * @param shape one value per dimension
    */
-  virtual void Setup(const ParamProto& proto, const std::vector<int>& shape);
+  virtual void Setup(const ParamProto& conf, const std::vector<int>& shape);
   /*
-   * Fill the values according to initmethod, e.g., gaussian distribution
+   * Fill the values according to init method, e.g., gaussian distribution.
    *
    * @param version initial version
    */
@@ -115,56 +47,69 @@ class Param {
    *
    * @param other the Param object whose owner owns the data blob
    */
-  void ShareData(Param* other){
-    proto_.set_owner(other->owner());
-    if(data_!=nullptr)
-      CHECK(std::equal(data_->shape().begin(), data_->shape().end(),
-          other->data_->shape().begin()));
-    data_=other->data_;
-  }
+  void ShareFrom(const Param& other);
+
+  /**
+   * Scale the learning rate when updating parameters in the Param object
+   */
   float learning_rate_multiplier() {
     return proto_.learning_rate_multiplier();
   }
+  /**
+   * Scale the weight decay when updating parameters in the Param object
+   */
   float weight_decay_multiplier() {
     return proto_.weight_decay_multiplier();
   }
+  /**
+   * Parameter name used for Param re-use in other model or sharing between
+   * layers
+   */
   const std::string& name() {
     return proto_.name();
   }
   /**
-   * if the Param shares data with others, then owner is the id of that param.
+   * If it shares data from others, then owner is the id of that Param,
    * otherwise it is itself's id.
    */
-  const int owner() const{
+  const int owner() const {
     return proto_.owner();
   }
-  int id() const{
+  /**
+   * ID start from 0 and ordered for all Param from the same neuralnet
+   */
+  int id() const {
     return proto_.id();
   }
-  void set_id(int id){
+  /**
+   * Set ID
+   */
+  void set_id(int id) {
     proto_.set_id(id);
     proto_.set_owner(id);
   }
 
   /**
-   * return the version of the parameter value shared by multiple workers
+   * Param version is stored inside the data blob to enable all Param objs
+   * sharing the same values have the same version.
+   * @return the param version
    */
   int version() const {
     return data_->version();
   }
 
   void set_version(int v) {
-    data_->set_version(v); // TODO read version from data blob
+    data_->set_version(v);
   }
 
   /**
-   * return the version of the parameter value local to a worker
+   * @return the version of the parameter value local to a worker
    */
   int local_version() const {
     return local_version_;
   }
 
-  void set_local_version(int v){
+  void set_local_version(int v) {
     local_version_=v;
   }
    /**
@@ -173,9 +118,6 @@ class Param {
   int size() const {
     return data_->count();
   }
-  /**
-   * Return const mem address for the content of this parameter
-   */
   const Blob<float> &data() {
     return *data_;
   }
@@ -191,14 +133,6 @@ class Param {
   Blob<float> *mutable_grad() {
     return &grad_;
   }
-
-  const Blob<float> &history() {
-    return history_;
-  }
-  Blob<float> *mutable_history() {
-    return &history_;
-  }
-
   float* mutable_cpu_data(){
     return data_->mutable_cpu_data();
   }
@@ -208,6 +142,10 @@ class Param {
   float* mutable_cpu_history(){
     return history_.mutable_cpu_data();
   }
+
+  /**
+   * @return slice start ID
+   */
   int slice_start() const {
     return slice_start_;
   }
@@ -216,10 +154,108 @@ class Param {
     return num_slices_;
   }
 
+  /**
+   * Add a slice
+   *
+   * @param slice_id
+   * @param size num of floats for this slice
+   */
   void AddSlice(int slice_id, int size);
+  /**********************Msg related functions***************************/
+
+  /**
+   * Generate the message for a get request, i.e., get parameters from a server
+   *
+   * This function is called at worker/stub side.
+   * @param copy decides whether to copy the parameter values from the server.
+   * @param slice_idx index of the slice from which the message is generated.
+   * @return generated message without setting src, dst, target fields.
+   */
+  virtual Msg* GenGetMsg(bool copy, int slice_idx);
+  /**
+   * Generate the message for a put request, i.e., put parameters to a server.
+   * \copydetails GenGetMsg(bool, int);
+   */
+  virtual Msg* GenPutMsg(bool copy, int slice_idx);
+  /**
+   * Generate the message for a update request, i.e., pass info to server for
+   * parameter update.
+   * \copydetails GenGetMsg(bool, int);
+   */
+  virtual Msg* GenUpdateMsg(bool copy, int slice_idx);
+  /**
+   * Generate the message for a synchronization request between server groups.
+   *
+   * This function is called at server side where the Param is actually a slice
+   * of an original Param object.
+   * */
+  virtual Msg* GenSyncMsg(int offset, int size);
+  /**
+   * Generate the messages to response the update requests.
+   *
+   * This function is called at the server side, where the Param is actually a
+   * slice of an original Param object.
+   *
+   * @param msgs for synchronous training, there would be multiple procs in
+   * which workers sharing the same Param (slice) objects. Their update requests
+   * is bufferred and handled together. For asynchrnous training, there is only
+   * request in msgs.
+   * @return response messages
+   */
+  virtual const vector<Msg*> GenUpdateResponseMsgs(const vector<Msg*>& msgs);
+
+  /**
+   * Server handling function for get request.
+   *
+   * @param msg request
+   * @param reserve if true reserve the msg space for the calling function;
+   * otherwise the msg should be freed inside the function.
+   * @return resposne message
+   */
+  virtual Msg* HandleGetMsg(Msg** msg, bool reserve = false);
+  /**
+   * Server handling function for put request.
+   *
+   * \copydetails HandleGetMsg(Msg**, bool reserve)
+   */
+  virtual Msg* HandlePutMsg(Msg** msg, bool reserve = false);
+  /**
+   * Server handling function for synchronization message
+   *
+   * \copydetails HandleGetMsg(Msg**, bool reserve)
+   */
+  virtual Msg* HandleSyncMsg(Msg** msg, bool reserve = false);
+  /**
+   * Worker/Stub parsing function for get response.
+   *
+   * @param msg
+   * @param slice_idx index for the slice
+   */
+  virtual int ParseGetResponseMsg(Msg* msg, int slice_idx);
+  /**
+   * Worker/Server parsing function for update response
+   *
+   * \copydetails ParseGetResponseMsg(Msg**, int);
+   */
+  virtual int ParseUpdateResponseMsg(Msg* msg, int slice_idx);
+  /**
+   * Server parse update requests.
+   * \copydetails GenUpdateResponseMsgs(const vector<Msg*>& msgs);
+   */
+  virtual void ParseUpdateMsgs(const vector<Msg*>& msgs);
+  /**
+   * Server parsing function for synchronization response.
+   *
+   * \copydetails ParseGetResponseMsg(Msg** , int);
+   */
+  virtual int ParseSyncResponseMsg(Msg* msg, int slice_idx);
 
  protected:
-  void ParseResponseMsg(Msg** msg, int slice_idx);
+  /**
+   * Implement the common code of ParseGetResponseMsg and ParseUpdateResponseMsg
+   * \copydetails ParseSyncResponseMsg(Msg* msg, int slice_idx);
+   */
+  void ParseResponseMsg(Msg* msg, int slice_idx);
 
  protected:
 
@@ -227,16 +263,61 @@ class Param {
    * name of the parameter used to share wights between neuralnets
    */
   std::string name_;
-  shared_ptr<Blob<float>> data_;
-  int slice_start_, num_slices_;
+  int local_version_;
+  //!< the ID of the first slice
+  int slice_start_;
+  int num_slices_;
+  //!< offset and size of each slice
   vector<int> slice_offset_, slice_size_;
+
+  //!< for debug checking
   vector<bool> pending_put_,pending_get_, pending_update_;
   int num_pending_requests_;
+
+  shared_ptr<Blob<float>> data_;
   //! gradient, history gradient of this parameter
   Blob<float> grad_, history_;
   ParamProto proto_;
-  int local_version_;
 };
+
+/**
+ * ParamEntry is used for aggregating gradients of Params shared by workers from
+ * the same group.
+ *
+ * For each worker group, every unique Param object has a ParamEntry object.
+ * Param objects sharing the same values are associated with the same
+ * ParamEntry.
+ */
+class ParamEntry{
+ public:
+  ParamEntry();
+  ParamEntry(int total, Param* p);
+  /**
+   * Associate the counter to a Param object.
+   *
+   * @param p
+   * @param local 1 if it is used by workers in this procs, 0 otherwise
+   */
+  void AddParam(bool local, Param* p);
+  int num_update, next_version;
+  int num_local; //!< # local workers using the shared parameter
+  int num_total; //!< # total workers using the shared parameter
+  //!< Shares are deleted by neuralnet's destructor
+  vector<Param*> shares;
+};
+
+inline int ParamTrgt(int param_id, int slice_id) {
+  return (param_id << 16) | slice_id;
+}
+
+inline int ParamID(int param_trgt) {
+  return param_trgt >> 16;
+}
+
+inline int SliceID(int param_trgt) {
+  static int mask = (1 << 16) -1;
+  return param_trgt & mask;
+}
 }  // namespace singa
 
 #endif  // INCLUDE_UTILS_PARAM_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/include/utils/updater.h
----------------------------------------------------------------------
diff --git a/include/utils/updater.h b/include/utils/updater.h
index 0d408d8..ea6d74a 100644
--- a/include/utils/updater.h
+++ b/include/utils/updater.h
@@ -9,6 +9,7 @@ namespace singa{
  */
 class Updater{
  public:
+  virtual ~Updater() {}
   virtual void Init(const UpdaterProto &proto){
     proto_=proto;
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/communication/msg.cc
----------------------------------------------------------------------
diff --git a/src/communication/msg.cc b/src/communication/msg.cc
index 38512d2..ccf02c4 100644
--- a/src/communication/msg.cc
+++ b/src/communication/msg.cc
@@ -1,62 +1,194 @@
+#include <glog/logging.h>
 #include "communication/msg.h"
 
 namespace singa {
 
 #ifdef USE_ZMQ
+Msg::~Msg() {
+  if (msg_ != nullptr)
+    zmsg_destroy(&msg_);
+  frame_ = nullptr;
+}
+
 Msg::Msg() {
   msg_ = zmsg_new();
 }
 
-Msg::Msg(const Msg& msg){
-  src_=msg.src_;
-  dst_=msg.dst_;
-  type_=msg.type_;
-  trgt_first_=msg.trgt_first_;
-  trgt_second_=msg.trgt_second_;
+Msg::Msg(const Msg& msg) {
+  src_ = msg.src_;
+  dst_ = msg.dst_;
+  type_ = msg.type_;
+  trgt_val_ = msg.trgt_val_;
+  trgt_version_ = msg.trgt_version_;
   msg_=zmsg_dup(msg.msg_);
 }
 
-Msg::~Msg() {
-  if (msg_ != nullptr)
-    zmsg_destroy(&msg_);
+Msg::Msg(int src, int dst) {
+  src_ = src;
+  dst_ = dst;
+  msg_ = zmsg_new();
 }
 
-int Msg::size() const{
+void Msg::SwapAddr() {
+  std::swap(src_, dst_);
+}
+
+int Msg::size() const {
   return zmsg_content_size(msg_);
 }
 
-void Msg::add_frame(const void* addr, int nBytes) {
+void Msg::AddFrame(const void* addr, int nBytes) {
   zmsg_addmem(msg_, addr, nBytes);
 }
 
-int Msg::frame_size() {
+int Msg::FrameSize() {
   return zframe_size(frame_);
 }
 
-void* Msg::frame_data() {
+void* Msg::FrameData() {
   return zframe_data(frame_);
 }
 
-bool Msg::next_frame() {
+char* Msg::FrameStr() {
+  return zframe_strdup(frame_);
+}
+bool Msg::NextFrame() {
   frame_ = zmsg_next(msg_);
   return frame_ != nullptr;
 }
 
+void Msg::FirstFrame() {
+  frame_ = zmsg_first(msg_);
+}
+
+void Msg::LastFrame() {
+  frame_ = zmsg_last(msg_);
+}
+
 void Msg::ParseFromZmsg(zmsg_t* msg) {
   char* tmp = zmsg_popstr(msg);
-  sscanf(tmp, "%d %d %d %d %d %d",
-         &src_, &dst_, &type_, &trgt_first_, &trgt_second_, &trgt_third_);
-  frame_ = zmsg_next(msg);
+  sscanf(tmp, "%d %d %d %d %d",
+         &src_, &dst_, &type_, &trgt_val_, &trgt_version_);
+  frame_ = zmsg_first(msg);
   msg_ = msg;
 }
 
 zmsg_t* Msg::DumpToZmsg() {
-  zmsg_pushstrf(msg_, "%d %d %d %d %d %d",
-      src_, dst_, type_, trgt_first_, trgt_second_, trgt_third_);
+  zmsg_pushstrf(msg_, "%d %d %d %d %d",
+      src_, dst_, type_, trgt_val_, trgt_version_);
   zmsg_t *tmp = msg_;
   msg_ = nullptr;
   return tmp;
 }
+
+// frame marker indicating this frame is serialize like printf
+#define FMARKER "*singa*"
+
+#define kMaxFrameLen 2048
+
+int Msg::AddFormatFrame(const char *format, ...) {
+  va_list argptr;
+  va_start(argptr, format);
+  int size = strlen(FMARKER);
+  char dst[kMaxFrameLen];
+  memcpy(dst, FMARKER, size);
+  dst[size++] = 0;
+  while (*format) {
+    if (*format == 'i') {
+      int x = va_arg(argptr, int);
+      dst[size++] = 'i';
+      memcpy(dst + size, &x, sizeof(x));
+      size += sizeof(x);
+    } else if (*format == 'f') {
+      float x = static_cast<float> (va_arg(argptr, double));
+      dst[size++] = 'f';
+      memcpy(dst + size, &x, sizeof(x));
+      size += sizeof(x);
+    } else if (*format == '1') {
+      uint8_t x = va_arg(argptr, int);
+      memcpy(dst + size, &x, sizeof(x));
+      size += sizeof(x);
+    } else if (*format == '2') {
+      uint16_t x = va_arg(argptr, int);
+      memcpy(dst + size, &x, sizeof(x));
+      size += sizeof(x);
+    } else if (*format == '4') {
+      uint32_t x = va_arg(argptr, uint32_t);
+      memcpy(dst + size, &x, sizeof(x));
+      size += sizeof(x);
+    } else if (*format == 's') {
+      char* x = va_arg(argptr, char *);
+      dst[size++] = 's';
+      memcpy(dst + size, x, strlen(x));
+      size += strlen(x);
+      dst[size++] = 0;
+    } else if (*format == 'p') {
+      void* x = va_arg(argptr, void *);
+      dst[size++] = 'p';
+      memcpy(dst + size, &x, sizeof(x));
+      size += sizeof(x);
+    } else {
+      LOG(ERROR) << "Unknown format " << *format;
+    }
+    format++;
+    CHECK_LE(size, kMaxFrameLen);
+  }
+  va_end(argptr);
+  zmsg_addmem(msg_, dst, size);
+  return size;
+}
+
+int Msg::ParseFormatFrame(const char *format, ...) {
+  va_list argptr;
+  va_start(argptr, format);
+  char* src = zframe_strdup(frame_);
+  CHECK_STREQ(FMARKER, src);
+  int size = strlen(FMARKER) + 1;
+  while (*format) {
+    if (*format == 'i') {
+      int *x = va_arg(argptr, int *);
+      CHECK_EQ(src[size++], 'i');
+      memcpy(x, src + size, sizeof(*x));
+      size += sizeof(*x);
+    } else if (*format == 'f') {
+      float *x = va_arg(argptr, float *);
+      CHECK_EQ(src[size++], 'f');
+      memcpy(x, src + size, sizeof(*x));
+      size += sizeof(*x);
+    }else if (*format == '1') {
+      uint8_t *x = va_arg(argptr, uint8_t *);
+      memcpy(x, src + size, sizeof(*x));
+      size += sizeof(*x);
+    } else if (*format == '2') {
+      uint16_t *x = va_arg(argptr, uint16_t *);
+      memcpy(x, src + size, sizeof(*x));
+      size += sizeof(*x);
+    } else if (*format == '4') {
+      uint32_t *x = va_arg(argptr, uint32_t *);
+      memcpy(x, src + size, sizeof(*x));
+      size += sizeof(*x);
+    } else if (*format == 's') {
+      char* x = va_arg(argptr, char *);
+      int len = strlen(src + size);
+      CHECK_EQ(src[size++], 's');
+      memcpy(x, src + size, len);
+      x[len] = 0;
+      size += len + 1;
+    } else if (*format == 'p') {
+      void** x = va_arg(argptr, void **);
+      CHECK_EQ(src[size++], 'p');
+      memcpy(x, src + size, sizeof(*x));
+      size += sizeof(*x);
+    } else {
+      LOG(ERROR) << "Unknown format type " << *format;
+    }
+    format++;
+  }
+  va_end(argptr);
+  delete src;
+  return size;
+}
+
 #endif
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/communication/socket.cc
----------------------------------------------------------------------
diff --git a/src/communication/socket.cc b/src/communication/socket.cc
index 0cb0982..4e6ec2c 100644
--- a/src/communication/socket.cc
+++ b/src/communication/socket.cc
@@ -9,6 +9,11 @@ Poller::Poller() {
   poller_ = zpoller_new(nullptr);
 }
 
+Poller::Poller(SocketInterface* socket) {
+  poller_ = zpoller_new(nullptr);
+  Add(socket);
+}
+
 void Poller::Add(SocketInterface* socket) {
   zsock_t* zsock = static_cast<zsock_t*>(socket->InternalID());
   zpoller_add(poller_, zsock);
@@ -20,8 +25,9 @@ SocketInterface* Poller::Wait(int timeout) {
   if (sock != nullptr)
     return zsock2Socket_[sock];
   else
-    return nullptr;
+  return nullptr;
 }
+
 bool Poller::Terminated(){
   return zpoller_terminated(poller_);
 }
@@ -32,8 +38,6 @@ Dealer::Dealer() : Dealer(-1) {}
 Dealer::Dealer(int id) : id_(id) {
   dealer_ = zsock_new(ZMQ_DEALER);
   CHECK_NOTNULL(dealer_);
-  poller_ = zpoller_new(dealer_);
-  CHECK_NOTNULL(poller_);
 }
 
 Dealer::~Dealer() {
@@ -123,8 +127,10 @@ int Router::Send(Msg **msg) {
 
 Msg* Router::Receive() {
   zmsg_t* zmsg = zmsg_recv(router_);
-  if (zmsg == nullptr)
-    return nullptr;
+  if (zmsg == nullptr) {
+    LOG(ERROR)<<"Connection broken!";
+    exit(0);
+  }
   zframe_t* dealer = zmsg_pop(zmsg);
   Msg* msg = new Msg();
   msg->ParseFromZmsg(zmsg);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/neuralnet/layer.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/layer.cc b/src/neuralnet/layer.cc
index b40b676..0a1c665 100644
--- a/src/neuralnet/layer.cc
+++ b/src/neuralnet/layer.cc
@@ -235,7 +235,7 @@ void LabelLayer::ParseRecords(Phase phase, const vector<Record>& records,
 /*********************LMDBDataLayer**********************************/
 void LMDBDataLayer::ComputeFeature(Phase phase, Metric* perf){
   if(random_skip_){
-    int nskip=rand()%random_skip_;
+    int nskip = rand() % random_skip_;
     int n=0;
     CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_,
           &mdb_value_, MDB_FIRST), MDB_SUCCESS);
@@ -637,7 +637,7 @@ void RGBImageLayer::Setup(const LayerProto& proto, int npartitions) {
 /***************Implementation for ShardDataLayer**************************/
 void ShardDataLayer::ComputeFeature(Phase phase, Metric* perf){
   if(random_skip_){
-    int nskip=rand()%random_skip_;
+    int nskip = rand() % random_skip_;
     LOG(INFO)<<"Random Skip "<<nskip<<" records, there are "<<shard_->Count()
       <<" records in total";
     string key;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/neuralnet/neuralnet.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc
index 6d82734..25609ed 100644
--- a/src/neuralnet/neuralnet.cc
+++ b/src/neuralnet/neuralnet.cc
@@ -5,6 +5,7 @@
 #include "utils/singleton.h"
 
 namespace singa {
+// macros to shorten the code
 #define LayerT(x) LayerProto_LayerType_k##x
 
 #define RegisterLayer(factory, id) \
@@ -36,31 +37,31 @@ void NeuralNet::RegisterLayers() {
 }
 
 shared_ptr<NeuralNet> NeuralNet::Create(
-    const NetProto& conf,
+    const NetProto& net_conf,
     Phase phase,
     int npartitions) {
-  NetProto proto;
-  proto.CopyFrom(conf);
-  proto.clear_layer();
+  NetProto conf;
+  conf.CopyFrom(net_conf);
+  conf.clear_layer();
   // exclude layers according to phase
-  for (const auto& layer : conf.layer()) {
+  for (const auto& layer : net_conf.layer()) {
     bool include = true;
-    for (auto x : layer.exclude()) {
-      if (x == phase)
+    for (auto p : layer.exclude()) {
+      if (p == phase)
         include = false;
     }
     if (include) {
-      LayerProto* lp = proto.add_layer();
-      lp->CopyFrom(layer);
+      LayerProto* layer_conf = conf.add_layer();
+      layer_conf->CopyFrom(layer);
       // using net partition if layer partition is not set
-      if (!lp->has_partition_dim())
-        lp->set_partition_dim(proto.partition_dim());
+      if (!layer_conf->has_partition_dim())
+        layer_conf->set_partition_dim(net_conf.partition_dim());
     }
   }
-  LOG(INFO) << "NeuralNet config is\n" << proto.DebugString();
+  LOG(INFO) << "NeuralNet config is\n" << conf.DebugString();
 
   // TODO(wangwei) create net based on net type, e.g., directed, undirected, etc
-  auto net = std::make_shared<NeuralNet>(proto, npartitions);
+  auto net = std::make_shared<NeuralNet>(conf, npartitions);
   return net;
 }
 
@@ -120,7 +121,7 @@ void NeuralNet::CreateNetFromGraph(Graph* graph, int npartitions) {
       auto params = (*it)->GetParams();
       CHECK_EQ(params.size(), owner_params.size());
       for (size_t i = 0; i < params.size(); i++)
-        params.at(i)->ShareData(owner_params.at(i));
+        params.at(i)->ShareFrom(*owner_params.at(i));
     }
   }
 }
@@ -349,7 +350,7 @@ std::string NeuralNet::ToAdjacency() {
   return disp;
 }
 
-void NeuralNet::ShareParams(shared_ptr<NeuralNet> other) {
+void NeuralNet::ShareParamsFrom(shared_ptr<NeuralNet> other) {
   for (auto& layer : layers_) {
     auto otherlayer = other->name2layer(layer->name());
     if (otherlayer != nullptr) {
@@ -357,7 +358,7 @@ void NeuralNet::ShareParams(shared_ptr<NeuralNet> other) {
       const auto& params = layer->GetParams();
       CHECK_EQ(params.size(), otherparams.size());
       for (size_t i = 0; i < params.size(); i++) {
-        params[i]->ShareData(otherparams[i]);
+        params[i]->ShareFrom(*otherparams[i]);
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index a8de5d5..5b22b3f 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -42,6 +42,8 @@ message ModelProto {
   optional int32 checkpoint_frequency = 34 [default = 0];
   // send parameters to servers after training for this num of steps
   optional int32 warmup_steps = 35 [default = 0];
+  // checkpoint path
+  optional bool resume = 36 [default = false];
 
    // start display after this num steps
   optional int32 display_after_steps =  60[default = 0];
@@ -60,7 +62,7 @@ message ModelProto {
 message NetProto {
   repeated LayerProto layer = 1;
   // partitioning type for parallelism
-  optional int32 partition_dim = 2 [default = -1];
+  optional int32 partition_dim = 2 [default = 0];
 }
 
 // weight matrix should be defined before bias vector
@@ -209,7 +211,7 @@ message RGBImageProto {
   optional string meanfile = 4 [default = ""];
 }
 
-message PrefetchProto{
+message PrefetchProto {
   repeated LayerProto sublayers = 1;
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/test/test_paramslicer.cc
----------------------------------------------------------------------
diff --git a/src/test/test_paramslicer.cc b/src/test/test_paramslicer.cc
index bbff616..759c18b 100644
--- a/src/test/test_paramslicer.cc
+++ b/src/test/test_paramslicer.cc
@@ -6,6 +6,7 @@ using namespace singa;
 
 const int param_size[]={2400,32,25600,32, 51200,64,57600,10};
 
+/*
 class ParamSlicerTest : public ::testing::Test {
   public:
     ParamSlicerTest() {
@@ -45,3 +46,4 @@ TEST_F(ParamSlicerTest, MultipleBox){
   ASSERT_EQ(slicer.Get(3).size(),1);
   ASSERT_EQ(slicer.Get(nparams-1).back(), slices.size()-1);
 }
+*/

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/test/test_shard.cc
----------------------------------------------------------------------
diff --git a/src/test/test_shard.cc b/src/test/test_shard.cc
index 6fe478e..dcc3026 100644
--- a/src/test/test_shard.cc
+++ b/src/test/test_shard.cc
@@ -17,7 +17,7 @@ std::string tuple[] = {"firsttuple",
 using namespace singa;
 
 TEST(DataShardTest, CreateDataShard) {
-  std::string path = "src/test/data/shard_test";
+  std::string path = "src/test/shard_test";
   mkdir(path.c_str(), 0755);
   DataShard shard(path, DataShard::kCreate, 50);
   shard.Insert(key[0], tuple[0]);
@@ -27,7 +27,7 @@ TEST(DataShardTest, CreateDataShard) {
 }
 
 TEST(DataShardTest, AppendDataShard) {
-  std::string path = "src/test/data/shard_test";
+  std::string path = "src/test/shard_test";
   DataShard shard(path, DataShard::kAppend, 50);
   shard.Insert(key[3], tuple[3]);
   shard.Insert(key[4], tuple[4]);
@@ -35,14 +35,14 @@ TEST(DataShardTest, AppendDataShard) {
 }
 
 TEST(DataShardTest, CountDataShard) {
-  std::string path = "src/test/data/shard_test";
+  std::string path = "src/test/shard_test";
   DataShard shard(path, DataShard::kRead, 50);
   int count = shard.Count();
   ASSERT_EQ(5, count);
 }
 
 TEST(DataShardTest, ReadDataShard) {
-  std::string path = "src/test/data/shard_test";
+  std::string path = "src/test/shard_test";
   DataShard shard(path, DataShard::kRead, 50);
   std::string k, t;
   ASSERT_TRUE(shard.Next(&k, &t));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index cbb0ee1..8c97d5d 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -1,6 +1,5 @@
-#include <list>
-#include <tuple>
-#include <queue>
+#include <thread>
+#include <chrono>
 #include "mshadow/tensor.h"
 #include "trainer/server.h"
 #include "utils/param.h"
@@ -9,101 +8,93 @@
 #include "utils/cluster.h"
 #include "proto/common.pb.h"
 
-using namespace mshadow;
 namespace singa {
+using namespace mshadow;
+
 Server::Server(int thread_id,int group_id, int server_id):
-  thread_id_(thread_id),group_id_(group_id), server_id_(server_id){}
+  thread_id_(thread_id),grp_id_(group_id), id_(server_id){
+}
 
 void Server::Setup(const UpdaterProto& proto,
-    shared_ptr<ServerShard> shard, const vector<int>& slice2group){
-	//VLOG(3) << "Parsing config file for host "<<hosts[id_] << " server id = " <<id_;
-  updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance()
-      ->Create("Updater"));
+    std::unordered_map<int, ParamEntry*>* shard,
+    const vector<int>& slice2group) {
+  updater_ = Singleton<Factory<Updater>>::Instance()->Create("Updater");
   updater_->Init(proto);
-  shard_=shard;
-  slice2group_=slice2group;
+  shard_ = shard;
+  slice2group_ = slice2group;
+}
+
+Server::~Server() {
+  delete updater_;
 }
 
-void Server::Run(){
-  LOG(ERROR)<<"Server (group_id = "<<group_id_
-    <<", id = "<<server_id_<<") starts";
-  dealer_=std::make_shared<Dealer>(2*thread_id_);
-  dealer_->Connect(kInprocRouterEndpoint);
-  auto cluster=Cluster::Get();
-  Msg* ping=new Msg();
-  ping->set_src(group_id_, server_id_, kServer);
-  ping->set_dst(-1,-1,kStub);
-  ping->add_frame("PING", 4);
+void Stop(void * running) {
+  *static_cast<bool *>(running) = false;
+}
+
+void Server::Run() {
+  LOG(ERROR) << "Server (group = " << grp_id_ <<", id = " << id_ << ") start";
+  auto dealer = new Dealer(2*thread_id_);
+  CHECK(dealer->Connect(kInprocRouterEndpoint));
+  Msg* ping = new Msg(Addr(grp_id_, id_, kServer), Addr(-1, -1, kStub));
   ping->set_type(kConnect);
-  dealer_->Send(&ping);
+  dealer->Send(&ping);
+
+  auto cluster = Cluster::Get();
+  bool running = true;
+  CHECK(cluster->runtime()->WatchSGroup(grp_id_, id_, Stop, &running));
+
+  int nserver_grps = cluster->nserver_groups();
   vector<Param*> master_params;
   size_t syncEntry=0;
-  //start recv loop and process requests
-  while (true){
-    Msg* msg=dealer_->Receive();
-    if (msg==nullptr)
-      break;
-    Msg* response=nullptr, *sync=nullptr;
+  Poller poll(dealer);
+  // start recv loop and process requests
+  while (running) {
+    auto *sock = poll.Wait(cluster->poll_time());
+    if (poll.Terminated()) {
+      LOG(ERROR) << "Connection broken!";
+      exit(0);
+    } else if (sock == nullptr) {
+      continue;
+    }
+    Msg* msg=dealer->Receive();
+    if (msg==nullptr) break;
+    Msg* response=nullptr;
     int type=msg->type();
-    if (type== kStop){
-      msg->set_src(group_id_, server_id_, kServer);
-      msg->set_dst(-1,-1, kStub);
-      dealer_->Send(&msg);
-      break;
-    }else if (type==kConnect){
-      // TODO remove receiving pong msg
-      string pong((char*)msg->frame_data(), msg->frame_size());
-      CHECK_STREQ("PONG", pong.c_str());
-      DeleteMsg(&msg);
-    }else if(type==kPut){
-      int pid = msg->trgt_second();
+    int slice_id = SliceID(msg->trgt_val());
+    if (type == kPut) {
       response = HandlePut(&msg);
-      if(slice2group_[pid]==group_id_)
-        master_params.push_back(shard_->at(pid));
-    }else{
-      int pid=msg->trgt_second();
-      if(shard_->find(pid)==shard_->end()){
+      if(slice2group_[slice_id] == grp_id_)
+        master_params.push_back(shard_->at(slice_id)->shares.at(0));
+    } else {
+      if (shard_->find(slice_id) == shard_->end()) {
         // delay the processing by re-queue the msg.
-        response=msg;
-        //LOG(INFO)<<"Requeue msg"<<type;
-      }else if(type == kSyncReminder){
+        response = msg;
+      } else if (type == kSyncReminder) {
         DeleteMsg(&msg);
-        if(syncEntry>=master_params.size())
+        if(syncEntry >= master_params.size())
           continue;
-        auto param=master_params.at(syncEntry);
+        auto param = master_params.at(syncEntry);
         // control the frequency of synchronization
         // currently sync is triggerred only when the slice is updated
         // by local worker or other workers for at least nserver_groups times.
         // TODO may optimize the trigger condition.
-        if(abs(param->local_version()-param->version())>=cluster->nserver_groups()){
-          // TODO replace the argument (0,0) to sync a chunk instead of a slice
-          sync=param->GenSyncMsg(0,0);
-          for(int i=0;i<cluster->nserver_groups();i++){
-            if(i!=group_id_) {
-              Msg* tmp=sync;
-              if(i<cluster->nserver_groups()-1)
-                tmp= new Msg(*sync);
-              // assume only one server per group, TODO generalize it
-              tmp->set_dst(i, 0, kServer);
-              tmp->set_src(group_id_, server_id_, kServer);
-              dealer_->Send(&tmp);
-              param->set_version(param->local_version());
-              //LOG(ERROR)<<"sync slice="<<param->id()<<" to procs "<<i;
-            }
-          }
-          syncEntry=(syncEntry+1)%master_params.size();
+        if (abs(param->local_version() - param->version()) >= nserver_grps) {
+          for (auto msg : GenSyncMsgs(param))
+            dealer->Send(&msg);
+          syncEntry = (syncEntry+1) % master_params.size();
         }
-      }else{
-        auto param=shard_->at(pid);
-        switch (type){
+      } else {
+        switch (type) {
           case kGet:
-            response=HandleGet(param, &msg);
+            response = HandleGet(&msg);
             break;
           case kUpdate:
-            response = HandleUpdate(param, &msg);
+            for (auto reply : HandleUpdate(&msg))
+              dealer->Send(&reply);
             break;
           case kSyncRequest:
-            response = HandleSyncRequest(param, &msg);
+            response = HandleSyncRequest(&msg);
             break;
           default:
             LOG(ERROR)<<"Unknown message type "<<type;
@@ -111,96 +102,149 @@ void Server::Run(){
         }
       }
     }
-    if (response!=nullptr)
-      dealer_->Send(&response);
+    if (response != nullptr)
+      dealer->Send(&response);
   }
-  LOG(ERROR)<<"Server (group_id = "<<group_id_
-    <<", id = "<<server_id_<<") stops";
+
+  // send stop msg to stub
+  Msg* msg = new Msg(Addr(grp_id_, id_, kServer), Addr(-1, -1, kStub));
+  msg->set_type(kStop);
+  dealer->Send(&msg);
+  std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+
+  LOG(ERROR) << "Server (group = " << grp_id_ << ", id = " << id_ << ") stops";
+  delete dealer;
 }
 
-Msg* Server::HandlePut(Msg **msg){
-  int version=(*msg)->trgt_third();
-  int pid=(*msg)->trgt_second();
-  Param* param=nullptr;
-  if(shard_->find(pid)!=shard_->end()){
-    LOG(ERROR)<<"Param ("<<pid<<") is put more than once";
-    param=shard_->at(pid);
-  }else{
-    auto factory=Singleton<Factory<Param>>::Instance();
-    param=factory ->Create("Param");
-    (*shard_)[pid]=param;
+const vector<Msg*> Server::GenSyncMsgs(Param* param) {
+  vector<Msg*> ret;
+  // TODO replace the argument (0,0) to sync a chunk instead of a slice
+  auto msg = param->GenSyncMsg(0, 0);
+  auto cluster = Cluster::Get();
+  for (int i = 0; i < cluster->nserver_groups(); i++) {
+    if (i != grp_id_) {
+      Msg* tmp = msg;
+      if (i < cluster->nserver_groups() - 1)
+        tmp = new Msg(*msg);
+      // assume only one server per group, TODO generalize it
+      tmp->set_dst(Addr(i, 0, kServer));
+      tmp->set_src(Addr(grp_id_, id_, kServer));
+      ret.push_back(tmp);
+      param->set_version(param->local_version());
+      //LOG(ERROR)<<"sync slice="<<param->id()<<" to procs "<<i;
+    }
   }
-  auto response=param->HandlePutMsg(msg);
+  return ret;
+}
+
+Msg* Server::HandlePut(Msg **msg) {
+  int version = (*msg)->trgt_version();
+  int slice_id = SliceID((*msg)->trgt_val());
+  if (shard_->find(slice_id) != shard_->end())
+    LOG(FATAL) << "Param (" << slice_id << ") is put more than once";
+
+  auto  param = Singleton<Factory<Param>>::Instance()->Create("Param");
+  auto response = param->HandlePutMsg(msg, true);
+  // parse num of shares of this param from a worker group
+  int num_shares = 1;
+  if ((*msg)->NextFrame())
+    (*msg)->ParseFormatFrame("i", &num_shares);
+  DeleteMsg(msg);
+  (*shard_)[slice_id] = new ParamEntry(num_shares, param);
   // must set version after HandlePutMsg which allocates the memory
   param->set_version(version);
   param->set_local_version(version);
-  param->set_id(pid);
+  param->set_id(slice_id);
   //LOG(ERROR)<<"put norm "<<param->data().asum_data()<<", "<<pid;
-  if(Cluster::Get()->nserver_groups()>1 &&
-      slice2group_[pid]!=group_id_){
-    last_data_[pid]=std::make_shared<Blob<float>>();
-    last_data_[pid]->ReshapeLike(param->data());
-    last_data_[pid]->CopyFrom(param->data());
+  // allocate blob for param sync between groups.
+  if (Cluster::Get()->nserver_groups() > 1 && slice2group_[slice_id] != grp_id_) {
+    last_data_[slice_id] = std::make_shared<Blob<float>>();
+    last_data_[slice_id]->ReshapeLike(param->data());
+    last_data_[slice_id]->CopyFrom(param->data());
   }
-  LOG(INFO)<<"server ("<<group_id_<<", "<<server_id_
-    <<") put slice="<<pid<<" size="<<param->size();
+  LOG(INFO)<<"server (group = " << grp_id_ << ", id = " << id_ <<") put slice="
+    << slice_id << " size=" << param->size();
   return response;
 }
 
-Msg* Server::HandleGet(Param* param, Msg **msg){
-  if(param->version()<(*msg)->trgt_third())
+Msg* Server::HandleGet(Msg **msg) {
+  int val = (*msg)->trgt_val();
+  auto param = shard_->at(SliceID(val))->shares.at(0);
+  // re-queue the request if the param is not updated to the required version
+  if(param->version()<(*msg)->trgt_version())
     return *msg;
-  else{
-    auto reply= param->HandleGetMsg(msg);
-    int paramid=reply->trgt_first(), slice=reply->trgt_second();
-    reply->set_trgt(paramid, slice, param->version());
+  else {
+    // LOG(ERROR) << "get " << slice << " from "<<(*msg)->src_first();
+    auto reply = param->HandleGetMsg(msg);
+    reply->set_trgt(val, param->version());
     return reply;
   }
 }
 
-Msg* Server::HandleUpdate(Param* param, Msg **msg) {
-  auto* tmp=static_cast<Msg*>((*msg)->CopyAddr());
-  tmp->SwapAddr();
-  int paramid=(*msg)->trgt_first();
-  int sliceid=(*msg)->trgt_second();
-  int step=(*msg)->trgt_third();
-  bool copy=param->ParseUpdateMsg(msg);
-  updater_->Update(step, param);
-  param->set_local_version(param->local_version()+1);
-  auto response=param->GenUpdateResponseMsg(copy);
-  response->set_trgt(paramid, sliceid, param->local_version());
-  response->SetAddr(tmp);
-  delete tmp;
-  return response;
+const vector<Msg*> Server::HandleUpdate(Msg **msg) {
+  vector<Msg*> ret;
+  int sliceid = SliceID((*msg)->trgt_val());
+  auto entry = shard_->at(sliceid);
+  buffer_requests_[sliceid].push_back(*msg);
+  int num_update;
+  (*msg)->LastFrame();
+  (*msg)->ParseFormatFrame("i", &num_update);
+  (*msg)->FirstFrame();
+  entry->num_update += num_update;
+  // LOG(ERROR) << "update "<<sliceid<< " from "<<(*msg)->src_second()
+  //  << ", " << num_update << " total " << entry->num_total;
+  // do update until recv gradients from all shares of this param/slice
+  if (entry->num_update >= entry->num_total) {
+    CHECK_EQ(entry->num_update, entry->num_total);
+    auto& request = buffer_requests_.at(sliceid);
+    int step = (*msg)->trgt_version();
+    auto param = entry->shares.at(0);
+    // extract and aggregate gradients
+    param->ParseUpdateMsgs(request);
+    updater_->Update(step, param);
+    param->set_local_version(param->local_version() + 1);
+    // response to all shares of this param
+    for (auto response : param->GenUpdateResponseMsgs(request)) {
+      response->set_trgt((*msg)->trgt_val(), param->local_version());
+      ret.push_back(response);
+    }
+    request.clear();
+    entry->num_update = 0;
+  }
+  *msg = nullptr;
+  return ret;
 }
 
-Msg* Server::HandleSyncRequest(Param* param, Msg **msg){
+Msg* Server::HandleSyncRequest(Msg **msg) {
+  Msg* msgg = *msg;
+  int slice = SliceID(msgg->trgt_val());
+  auto param = shard_->at(slice)->shares.at(0);
   Msg* response=nullptr;
   auto shape=Shape1(param->size());
-  CHECK_EQ((*msg)->frame_size(), param->size()*sizeof(float));
-  Tensor<cpu, 1> tmp(static_cast<float*>((*msg)->frame_data()), shape);
+  CHECK_EQ(msgg->FrameSize(), param->size()*sizeof(float));
+  Tensor<cpu, 1> tmp(static_cast<float*>(msgg->FrameData()), shape);
   Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape);
   //LOG(ERROR)<<"Recv sync for "<<param->id();
-  if(slice2group_[param->id()]==group_id_){
+  if (slice2group_[slice] == grp_id_) {
+    // recv sync msg on slice I am mastering
     cur+=tmp;
     param->set_local_version(param->local_version()+1);
-  }else{
+  } else {  // recv sync msg on slice mastered by others
     TensorContainer<cpu, 1> diff(shape);
     Tensor<cpu, 1> prev(last_data_[param->id()]->mutable_cpu_data(), shape);
     diff=cur-prev;
-    (*msg)->next_frame();
+    msgg->NextFrame();
     int bandwidth;
-    sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &bandwidth);
-    if(bandwidth>0){
-      response=new Msg();
+    msgg->ParseFormatFrame("i", &bandwidth);
+    if (bandwidth > 0) {
+      // send back my updates to the server group mastering this param
+      response=new Msg(msgg->dst(), msgg->src());
       response->set_type(kSyncRequest);
-      response->set_trgt(-1, param->id(), param->version());
-      response->add_frame(diff.dptr, param->size()*sizeof(float));
-      (*msg)->SwapAddr();
-      response->SetAddr(*msg);
+      response->set_trgt(param->id(), param->version());
+      response->AddFrame(diff.dptr, param->size()*sizeof(float));
       prev=diff+tmp;
       Copy(cur, prev);
-    }else{
+    } else {  // no bandwidth, aggregate my updates for next sync
       Copy(prev, tmp);
       cur=tmp+diff;
     }