You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2015/06/25 15:45:00 UTC
[2/6] incubator-singa git commit: SINGA-8 Implement distributed
Hogwild The program is simply tested using two processes. TODO 1. read
process endpoints from the zookeeper instead of hard-coding them. 2. split
large parameters to avoid load-balance issue
SINGA-8 Implement distributed Hogwild
The program is simply tested using two processes.
TODO
1. read process endpoints from the zookeeper instead of hard-coding them.
2. split large parameters to avoid load-balance issue among server groups.
currently, server groups are assigned (almost) equal number of param objects.
but these objects may be quite different in terms of memory space.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/f4370118
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/f4370118
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/f4370118
Branch: refs/heads/master
Commit: f4370118c91f688fdc8c84d0d590096f2e93586c
Parents: a019958
Author: wang wei <wa...@comp.nus.edu.sg>
Authored: Wed Jun 17 16:17:19 2015 +0800
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Thu Jun 25 11:49:32 2015 +0800
----------------------------------------------------------------------
examples/cifar10/cluster-dist.conf | 8 +++
examples/cifar10/hostfile | 22 +-----
include/communication/msg.h | 2 +-
include/communication/socket.h | 11 ++-
include/trainer/server.h | 14 ++--
include/trainer/trainer.h | 5 +-
include/utils/cluster.h | 10 ++-
include/utils/param.h | 3 +-
src/communication/socket.cc | 8 ++-
src/proto/cluster.proto | 6 ++
src/test/test_paramslicer.cc | 47 +++++++++++++
src/trainer/server.cc | 121 ++++++++++++++++++++++++++------
src/trainer/trainer.cc | 62 +++++++++++++++-
src/trainer/worker.cc | 8 ++-
src/utils/param.cc | 15 ++--
15 files changed, 272 insertions(+), 70 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/examples/cifar10/cluster-dist.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/cluster-dist.conf b/examples/cifar10/cluster-dist.conf
new file mode 100644
index 0000000..1a4e2c2
--- /dev/null
+++ b/examples/cifar10/cluster-dist.conf
@@ -0,0 +1,8 @@
+nworker_groups: 2
+nserver_groups: 2
+nservers_per_group: 1
+nworkers_per_group: 1
+nworkers_per_procs: 1
+workspace: "examples/cifar10/"
+hostfile: "examples/cifar10/hostfile"
+poll_time: 100
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/examples/cifar10/hostfile
----------------------------------------------------------------------
diff --git a/examples/cifar10/hostfile b/examples/cifar10/hostfile
index 83e06e5..eda7414 100644
--- a/examples/cifar10/hostfile
+++ b/examples/cifar10/hostfile
@@ -1,20 +1,2 @@
-awan-2-26-0
-awan-2-27-0
-awan-2-28-0
-awan-2-29-0
-awan-2-30-0
-awan-2-31-0
-awan-2-32-0
-awan-2-33-0
-awan-2-34-0
-awan-2-35-0
-awan-2-36-0
-awan-2-37-0
-awan-2-38-0
-awan-2-39-0
-awan-2-40-0
-awan-2-41-0
-awan-2-42-0
-awan-2-43-0
-awan-2-44-0
-awan-2-45-0
+localhost:9733
+localhost:9734
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/include/communication/msg.h
----------------------------------------------------------------------
diff --git a/include/communication/msg.h b/include/communication/msg.h
index c3ef1c7..60a359a 100644
--- a/include/communication/msg.h
+++ b/include/communication/msg.h
@@ -23,6 +23,7 @@ class Msg {
* @param second worker/server id within the group
* @param flag 0 for server, 1 for worker, 2 for stub
*/
+<<<<<<< HEAD
inline void set_src(int first, int second, int flag) {
src_ = (first << kOff1) | (second << kOff2) | flag;
}
@@ -78,7 +79,6 @@ class Msg {
void ParseFromZmsg(zmsg_t* msg);
zmsg_t* DumpToZmsg();
#endif
-
protected:
static const unsigned int kOff1 = 16;
static const unsigned int kOff2 = 4;
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/include/communication/socket.h
----------------------------------------------------------------------
diff --git a/include/communication/socket.h b/include/communication/socket.h
index d1cb400..b98656e 100644
--- a/include/communication/socket.h
+++ b/include/communication/socket.h
@@ -19,10 +19,10 @@ class SocketInterface {
public:
virtual ~SocketInterface() {}
/**
- * Send a message to connected socket(s), non-blocking. The message
- * will be deallocated after sending, thus should not be used after
+ * Send a message to connected socket(s), non-blocking. The message
+ * will be deallocated after sending, thus should not be used after
* calling Send();
- *
+ *
* @param msg The message to be sent
* @return 1 for success queuing the message for sending, 0 for failure
*/
@@ -56,6 +56,11 @@ class Poller {
*/
SocketInterface* Wait(int duration);
+ /**
+ * @return true if the poller is terminated due to process interupt
+ */
+ virtual bool Terminated()=0;
+
protected:
#ifdef USE_ZMQ
zpoller_t *poller_;
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/include/trainer/server.h
----------------------------------------------------------------------
diff --git a/include/trainer/server.h b/include/trainer/server.h
index b07741f..a8995fb 100644
--- a/include/trainer/server.h
+++ b/include/trainer/server.h
@@ -27,6 +27,12 @@ class Server{
void Setup(const UpdaterProto& proto, shared_ptr<ServerShard> shard,
const vector<int>& slice2group);
void Run();
+ const int group_id() const {
+ return group_id_;
+ }
+ const int server_id() const {
+ return server_id_;
+ }
protected:
@@ -50,24 +56,20 @@ class Server{
* @return the original message or response message. If we don't want need to
* acknowledge the put request, then return nullptr.
*/
- virtual void HandlePut(shared_ptr<Param> param, Msg **msg);
+ virtual Msg* HandlePut(Msg **msg);
/**
* TODO Process SYNC request.
*/
virtual Msg* HandleSyncRequest(shared_ptr<Param> param, Msg** msg);
- /**
- * TODO Process SYNC response.
- virtual int HandleSyncResponse(shared_ptr<Param> param, Msg** msg);
- */
-
protected:
int thread_id_,group_id_, server_id_;
shared_ptr<Dealer> dealer_;
shared_ptr<Updater> updater_;
shared_ptr<ServerShard> shard_;
vector<int> slice2group_;
+ std::map<int, shared_ptr<Blob<float>>> last_data_;
};
} /* Server */
#endif //INCLUDE_TRAINER_SERVER_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/include/trainer/trainer.h
----------------------------------------------------------------------
diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h
index ed93374..fbbfd0b 100644
--- a/include/trainer/trainer.h
+++ b/include/trainer/trainer.h
@@ -95,13 +95,14 @@ class Trainer{
// point.
protected:
-
vector<shared_ptr<Server>> CreateServers(int nthread, const ModelProto& mproto,
const vector<int> slices, vector<HandleContext*>* ctx);
vector<shared_ptr<Worker>> CreateWorkers(int nthread,
const ModelProto& mproto, vector<int> *slice_size);
- void Run(int nworkers, int nservers);
+ void Run(const vector<shared_ptr<Worker>>& workers,
+ const vector<shared_ptr<Server>>& servers,
+ const std::map<int, shared_ptr<ParamShard>>& shards);
/**
* Register default implementations for all base classes used in the system,
* e.g., the Updater, BaseMsg, etc.
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/include/utils/cluster.h
----------------------------------------------------------------------
diff --git a/include/utils/cluster.h b/include/utils/cluster.h
index 9648bfe..0eeb808 100644
--- a/include/utils/cluster.h
+++ b/include/utils/cluster.h
@@ -112,11 +112,15 @@ class Cluster {
}
/**
- * bandwidth MB/s
- float bandwidth() const {
+ * bandwidth Bytes/s
+ */
+ const int bandwidth() const {
return cluster_.bandwidth();
}
- */
+
+ const int poll_time() const {
+ return cluster_.poll_time();
+ }
shared_ptr<ClusterRuntime> runtime() const {
return cluster_rt_;
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/include/utils/param.h
----------------------------------------------------------------------
diff --git a/include/utils/param.h b/include/utils/param.h
index 897c97a..d449fba 100644
--- a/include/utils/param.h
+++ b/include/utils/param.h
@@ -71,6 +71,7 @@ class Param {
*/
virtual Msg* HandleSyncMsg(Msg** msg);
+<<<<<<< HEAD
/**
* Server parses update request message.
*
@@ -105,6 +106,7 @@ class Param {
* @param shape
*/
virtual void Setup(const ParamProto& proto, const std::vector<int>& shape);
+ virtual void Setup(const vector<int>& shape);
/*
* Fill the values according to initmethod, e.g., gaussian distribution
*
@@ -238,7 +240,6 @@ class Param {
ParamProto proto_;
int local_version_;
};
-
} // namespace singa
#endif // INCLUDE_UTILS_PARAM_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/communication/socket.cc
----------------------------------------------------------------------
diff --git a/src/communication/socket.cc b/src/communication/socket.cc
index 5321724..38c0d79 100644
--- a/src/communication/socket.cc
+++ b/src/communication/socket.cc
@@ -19,9 +19,14 @@ SocketInterface* Poller::Wait(int timeout) {
zsock_t* sock = static_cast<zsock_t*>(zpoller_wait(poller_, timeout));
if (sock != nullptr)
return zsock2Socket_[sock];
- return nullptr;
+ else
+ return nullptr;
+}
+bool Poller::Terminated(){
+ return zpoller_terminated(poller_);
}
+
Dealer::Dealer() : Dealer(-1) {}
Dealer::Dealer(int id) : id_(id) {
@@ -31,6 +36,7 @@ Dealer::Dealer(int id) : id_(id) {
CHECK_NOTNULL(poller_);
}
+<<<<<<< HEAD
Dealer::~Dealer() {
zsock_destroy(&dealer_);
}
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/proto/cluster.proto
----------------------------------------------------------------------
diff --git a/src/proto/cluster.proto b/src/proto/cluster.proto
index 4f7e661..3317f2a 100644
--- a/src/proto/cluster.proto
+++ b/src/proto/cluster.proto
@@ -38,6 +38,12 @@ message ClusterProto {
optional bool server_update = 40 [default = true];
// share memory space between worker groups in one procs
optional bool share_memory = 41 [default = true];
+
+ // bandwidth of ethernet, Bytes per second, default is 1 Gbps
+ optional int32 bandwidth=50 [default=134217728];
+ // poll time in milliseconds
+ optional int32 poll_time=51 [default =100];
+>>>>>>> SINGA-8 Implement distributed Hogwild
}
message ServerTopology {
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/test/test_paramslicer.cc
----------------------------------------------------------------------
diff --git a/src/test/test_paramslicer.cc b/src/test/test_paramslicer.cc
new file mode 100644
index 0000000..bbff616
--- /dev/null
+++ b/src/test/test_paramslicer.cc
@@ -0,0 +1,47 @@
+#include "utils/param.h"
+#include "gtest/gtest.h"
+
+
+using namespace singa;
+
+const int param_size[]={2400,32,25600,32, 51200,64,57600,10};
+
+class ParamSlicerTest : public ::testing::Test {
+ public:
+ ParamSlicerTest() {
+ ParamProto proto;
+ int nparams=sizeof(param_size)/sizeof(int);
+ for(int i=0;i<nparams;i++){
+ vector<int> shape{param_size[i]};
+ auto param=std::make_shared<Param>();
+ param->Setup(proto, shape);
+ param->set_id(i);
+ params.push_back(param);
+ }
+ }
+ protected:
+ vector<shared_ptr<Param>> params;
+};
+
+// all params are stored in one box, no need to split
+TEST_F(ParamSlicerTest, OneBox){
+ int nparams=sizeof(param_size)/sizeof(int);
+ ParamSlicer slicer;
+ int num=1;
+ auto slices=slicer.Slice(num, params);
+ ASSERT_EQ(slices.size(),nparams);
+ ASSERT_EQ(slicer.Get(1).size(),1);
+ ASSERT_EQ(slicer.Get(2).size(),1);
+ ASSERT_EQ(slicer.Get(nparams-1).back(), slices.size()-1);
+}
+
+// there are multiple boxes
+TEST_F(ParamSlicerTest, MultipleBox){
+ int nparams=sizeof(param_size)/sizeof(int);
+ ParamSlicer slicer;
+ int num=4;
+ auto slices=slicer.Slice(num, params);
+ ASSERT_EQ(slicer.Get(1).size(),1);
+ ASSERT_EQ(slicer.Get(3).size(),1);
+ ASSERT_EQ(slicer.Get(nparams-1).back(), slices.size()-1);
+}
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index 04f6040..5185c51 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -1,13 +1,14 @@
#include <list>
#include <tuple>
#include <queue>
+#include "mshadow/tensor.h"
#include "trainer/server.h"
#include "utils/param.h"
#include "utils/singleton.h"
#include "utils/factory.h"
#include "utils/cluster.h"
-
+using namespace mshadow;
namespace singa {
Server::Server(int thread_id,int group_id, int server_id):
thread_id_(thread_id),group_id_(group_id), server_id_(server_id){}
@@ -23,21 +24,23 @@ void Server::Setup(const UpdaterProto& proto,
}
void Server::Run(){
+ LOG(INFO)<<"Server (group_id= "<<group_id_<<", id="<<server_id_<<") starts";
dealer_=std::make_shared<Dealer>(2*thread_id_);
dealer_->Connect(kInprocRouterEndpoint);
-
+ auto cluster=Cluster::Get();
Msg* ping=new Msg();
ping->set_src(group_id_, server_id_, kServer);
ping->set_dst(-1,-1,kStub);
ping->add_frame("PING", 4);
ping->set_type(kConnect);
dealer_->Send(&ping);
+ int syncEntry=0;
//start recv loop and process requests
while (true){
Msg* msg=dealer_->Receive();
if (msg==nullptr)
break;
- Msg* response=nullptr;
+ Msg* response=nullptr, *sync=nullptr;
int type=msg->type();
if (type== kStop){
msg->set_src(group_id_, server_id_, kServer);
@@ -48,26 +51,47 @@ void Server::Run(){
// TODO remove receiving pong msg
string pong((char*)msg->frame_data(), msg->frame_size());
CHECK_STREQ("PONG", pong.c_str());
- delete msg;
+ DeleteMsg(&msg);
}else if(type==kPut){
- int pid=msg->trgt_second();
- shared_ptr<Param> param=nullptr;
- if(shard_->find(pid)!=shard_->end()){
- LOG(ERROR)<<"Param ("<<pid<<") is put more than once";
- param=shard_->at(pid);
- }else{
- param=shared_ptr<Param>(Singleton<Factory<Param>>::Instance()
- ->Create("Param"));
- param->set_id(pid);
- (*shard_)[pid]=param;
- }
- HandlePut(param, &msg);
+ response = HandlePut(&msg);
}else{
int pid=msg->trgt_second();
if(shard_->find(pid)==shard_->end()){
// delay the processing by re-queue the msg.
response=msg;
DLOG(ERROR)<<"Requeue msg";
+ }else if(type==kSyncReminder){
+ DeleteMsg(&msg);
+ unsigned nchecks=0, nparams=shard_->size();
+ while(nchecks<nparams
+ &&group_locator_->at(shard_->at(syncEntry))!=group_id_){
+ syncEntry=(syncEntry+1)%nparams;
+ nchecks++;
+ }
+ if(nchecks==nparams) continue;
+ auto param=shard_->at(syncEntry);
+ if(param->local_version()!=param->version()){
+ sync=param->GenSyncMsg(true);
+ for(int i=0;i<cluster->nserver_groups();i++){
+ if(i!=group_id_) {
+ Msg* tmp=sync;
+ if(i<cluster->nserver_groups()-1)
+ tmp= new Msg(sync);
+ tmp->set_dst(i, server_locator_->at(param), kServer);
+ tmp->set_src(group_id_, server_id_, kServer);
+ dealer_->Send(&tmp);
+ param->set_version(param->local_version());
+ //DLOG(ERROR)<<"sync";
+ }
+ }
+ }
+ }else {
+ int pid=msg->target_first();
+ if(shard_->find(pid)==shard_->end()){
+ // delay the processing by re-queue the msg.
+ response=msg;
+ LOG(ERROR)<<"Requeue";
+>>>>>>> SINGA-8 Implement distributed Hogwild
} else{
auto param=shard_->at(pid);
switch (type){
@@ -80,20 +104,42 @@ void Server::Run(){
case kSyncRequest:
response = HandleSyncRequest(param, &msg);
break;
- }
- if (response!=nullptr){
- dealer_->Send(&response);
+ default:
+ LOG(ERROR)<<"Unknown message type "<<type;
+ break;
}
}
}
+ if (response!=nullptr)
+ dealer_->Send(&response);
}
+ LOG(INFO)<<"Server (group_id= "<<group_id_<<", id="<<server_id_<<") stops";
}
-void Server::HandlePut(shared_ptr<Param> param, Msg **msg){
+Msg* Server::HandlePut(Msg **msg){
int version=(*msg)->trgt_third();
- param->HandlePutMsg(msg);
+ int pid=(*msg)->target_first();
+ shared_ptr<Param> param=nullptr;
+ if(shard_->find(pid)!=shard_->end()){
+ LOG(ERROR)<<"Param ("<<pid<<") is put more than once";
+ param=shard_->at(pid);
+ }else{
+ auto factory=Singleton<Factory<Param>>::Instance();
+ param=shared_ptr<Param>(factory ->Create("Param"));
+ param->set_id(pid);
+ (*shard_)[pid]=param;
+ }
+ auto response=param->HandlePutMsg(msg);
// must set version after HandlePutMsg which allocates the memory
param->set_version(version);
+ if(Cluster::Get()->nserver_groups()>1 &&
+ group_locator_->at(param)!=group_id_){
+ last_data_[pid]=std::make_shared<Blob<float>>();
+ last_data_[pid]->ReshapeLike(param->data());
+ last_data_[pid]->CopyFrom(param->data());
+ }
+ LOG(INFO)<<"Server put param "<<pid<<" size="<<param->size()<<" Bytes";
+ return response;
}
Msg* Server::HandleGet(shared_ptr<Param> param, Msg **msg){
@@ -124,7 +170,36 @@ Msg* Server::HandleUpdate(shared_ptr<Param> param, Msg **msg) {
}
Msg* Server::HandleSyncRequest(shared_ptr<Param> param, Msg **msg){
- return param->HandleSyncMsg(msg);
+ Msg* response=nullptr;
+ auto shape=Shape1(param->size());
+ CHECK_EQ((*msg)->frame_size(), param->size()*sizeof(float));
+ Tensor<cpu, 1> tmp(static_cast<float*>((*msg)->frame_data()), shape);
+ Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape);
+ if(group_locator_->at(param)==group_id_){
+ cur+=tmp;
+ param->set_local_version(param->local_version()+1);
+ }else{
+ TensorContainer<cpu, 1> diff(shape);
+ Tensor<cpu, 1> prev(last_data_[param->id()]->mutable_cpu_data(), shape);
+ diff=cur-prev;
+ (*msg)->next_frame();
+ int bandwidth;
+ sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &bandwidth);
+ if(bandwidth>0){
+ response=new Msg();
+ response->set_type(kSyncRequest);
+ response->set_target(param->id(), param->version());
+ response->add_frame(diff.dptr, param->size()*sizeof(float));
+ (*msg)->SwapAddr();
+ response->SetAddr(*msg);
+ prev=diff+tmp;
+ Copy(cur, prev);
+ }else{
+ Copy(prev, tmp);
+ cur=tmp+diff;
+ }
+ }
+ DeleteMsg(msg);
+ return response;
}
-
} /* singa */
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index 6c08a3a..bdc1416 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -2,12 +2,16 @@
#include <vector>
#include <map>
#include <queue>
+#include <chrono>
#include <glog/logging.h>
#include "proto/common.pb.h"
#include "trainer/trainer.h"
#include "mshadow/tensor.h"
using std::vector;
using std::map;
+using namespace std::chrono;
+
+typedef std::chrono::milliseconds TimeT;
namespace singa {
@@ -21,14 +25,17 @@ void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){
}
void HandleWorkerFinish(void * ctx){
+ /*
HandleContext* hctx=static_cast<HandleContext*> (ctx);
Msg* msg=new Msg();
msg->set_src(-1,-1, kRuntime);
msg->set_dst(hctx->group_id, hctx->id, kServer);
msg->set_type(kStop);
hctx->dealer->Send(&msg);
+ */
}
+<<<<<<< HEAD
const std::unordered_map<int, vector<std::pair<int, int>>> SliceParams(int num,
const vector<shared_ptr<Param>>& params){
std::unordered_map<int, vector<std::pair<int, int>>> paramid2slices;
@@ -276,20 +283,51 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
threads.push_back(std::thread(&Server::Run,server.get()));
for(auto worker: workers)
threads.push_back(std::thread(&Worker::Run,worker.get()));
- Run(workers.size(), servers.size());
+ Run(workers, servers, shards);
for(auto& thread: threads)
thread.join();
for(auto x: ctx)
delete x;
}
-void Trainer::Run(int nworkers, int nservers){
+inline int bandwidth(int bytes, system_clock::time_point start){
+ auto now=system_clock::now();
+ auto duration=duration_cast<TimeT> (now - start);
+ return static_cast<int>(bytes*1000.f/duration.count());
+}
+void Trainer::Run(const vector<shared_ptr<Worker>>& workers,
+ const vector<shared_ptr<Server>>& servers,
+ const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){
auto cluster=Cluster::Get();
procs_id_=cluster->procs_id();
+ LOG(INFO)<<"Stub in process "<<procs_id_<<" starts";
map<int, shared_ptr<Dealer>> interprocs_dealers;
std::queue<Msg*> msg_queue;
bool stop=false;
+ auto start=std::chrono::system_clock::now();
+ float amount=0.f;
+ Poller poll;
+ poll.Add(router_.get());
+ int sync_server=0, nworkers=workers.size(), nservers=servers.size();
while(!stop){
+ Socket *sock=poll.Wait(cluster->poll_time());
+ if(poll.Terminated()){
+ LOG(ERROR)<<"Connection broken!";
+ exit(0);
+ }else if(sock==nullptr){
+ if(cluster->nserver_groups()>1&&
+ bandwidth(amount, start)<cluster->bandwidth()){
+ Msg* msg=new Msg();
+ msg->set_src(-1,-1, kStub);
+ msg->set_dst(servers[sync_server]->group_id(),
+ servers[sync_server]->server_id(), kServer);
+ msg->set_type(kSyncReminder);
+ sync_server=(sync_server+1)%servers.size();
+ router_->Send(&msg);
+ //LOG(ERROR)<<"Reminder";
+ }
+ continue;
+ }
Msg* msg=router_->Receive();
if(msg==nullptr){
LOG(ERROR)<<"Connection broken!";
@@ -360,6 +398,7 @@ void Trainer::Run(int nworkers, int nservers){
msg_queue.push(x);
break;
default:
+ LOG(ERROR)<<"Unknow message type:"<<type;
break;
}
}else{
@@ -374,12 +413,30 @@ void Trainer::Run(int nworkers, int nservers){
msg->dst_second(), msg->dst_flag());
}
if(dst_procs_id!=procs_id_){
+ // forward to other procs
+ if (interprocs_dealers.find(dst_procs_id)==interprocs_dealers.end()){
+ auto dealer=make_shared<Dealer>();
+ interprocs_dealers[dst_procs_id]=dealer;
+ dealer->Connect("tcp://"+cluster->endpoint(dst_procs_id));
+ }
+ if(bandwidth(amount, start) <=cluster->bandwidth()){
+ start=std::chrono::system_clock::now();
+ amount=0;
+ }
+ amount+=msg->size();
+ interprocs_dealers[dst_procs_id]->Send(&msg);
}else{
+ if(type==kSyncRequest){
+ char buf[32];
+ sprintf(buf, "%d", cluster->bandwidth()-bandwidth(amount, start));
+ msg->add_frame(buf, strlen(buf));
+ }
router_->Send(&msg);
}
}
}
}
+ LOG(INFO)<<"Stub in process "<<procs_id_<<" stops";
}
Msg* Trainer::HandleConnect(Msg** msg){
string ping((char*)(*msg)->frame_data(), (*msg)->frame_size());
@@ -394,7 +451,6 @@ Msg* Trainer::HandleConnect(Msg** msg){
*msg=NULL;
return reply;
}
-
const vector<Msg*> Trainer::HandleGet(shared_ptr<ParamInfo> pi, Msg** msg){
Msg* msgg=*msg;
vector<Msg*> replies;
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
index 788e77c..37acb14 100644
--- a/src/trainer/worker.cc
+++ b/src/trainer/worker.cc
@@ -47,6 +47,7 @@ void Worker::ConnectStub(shared_ptr<Dealer> dealer, EntityType type){
}
void Worker::Run(){
+ LOG(INFO)<<"Worker (group_id= "<<group_id_<<", id="<<worker_id_<<") starts";
dealer_=make_shared<Dealer>(2*thread_id_);
ConnectStub(dealer_, kWorkerParam);
for(auto layer: train_net_->layers())
@@ -61,8 +62,10 @@ void Worker::Run(){
for(auto layer: train_net_->layers()){
if(layer->partitionid()==worker_id_)
for(auto param: layer->GetParams()){
+ // only owners fill the memory of parameter values.
+ // others share the memory with owners hence do not need to put/get.
if(param->owner() == param->id()){
- if(group_id_==0)
+ if(group_id_%Cluster::Get()->nworker_groups_per_server_group()==0)
param->InitValues(0);
else
Get(param, modelproto_.warmup_steps());
@@ -70,7 +73,7 @@ void Worker::Run(){
}
}
Metric perf;
- if(group_id_==0){
+ if(group_id_%Cluster::Get()->nworker_groups_per_server_group()==0){
for(step_=0;step_<modelproto_.warmup_steps();step_++)
RunOneBatch(step_, &perf);
for(auto layer: train_net_->layers()){
@@ -86,6 +89,7 @@ void Worker::Run(){
}
Stop();
+ LOG(INFO)<<"Worker (group_id= "<<group_id_<<", id="<<worker_id_<<") stops";
}
void Worker::Stop(){
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index deff6f4..4ad17ce 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -133,8 +133,12 @@ Msg* Param::GenUpdateMsg(bool copy, int idx){
return msg;
}
-Msg* Param::GenSyncMsg(){
- return nullptr;
+Msg* Param::GenSyncMsg(bool copy, int v){
+ Msg* msg=new Msg();
+ msg->set_type(kSyncRequest);
+ msg->set_target(id(), local_version());
+ msg->add_frame(mutable_cpu_data(), size()*sizeof(float));
+ return msg;
}
Msg* Param::HandlePutMsg(Msg** msg){
@@ -146,9 +150,9 @@ Msg* Param::HandlePutMsg(Msg** msg){
proto_.set_learning_rate_multiplier(lr);
proto_.set_weight_decay_multiplier(wc);
vector<int> shape{size};
- grad_.Reshape(shape);
- history_.Reshape(shape);
- data_=std::make_shared<Blob<float>>(shape);
+ Setup(shape);
+ set_local_version((*msg)->target_second());
+ set_version((*msg)->target_second());
if(ptr==nullptr){
CHECK((*msg)->next_frame());
CHECK_EQ(size* sizeof(float), (*msg)->frame_size());
@@ -201,6 +205,7 @@ Msg* Param::HandleSyncMsg(Msg** msg){
return nullptr;
}
+<<<<<<< HEAD
int Param::ParseSyncResponseMsg(Msg** msg, int slice_idx){
DeleteMsg(msg);
return 1;