You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2015/05/27 16:39:08 UTC
[07/22] incubator-singa git commit: detect worker status through
ClusterRuntime which calls zookeeper. stop servers when all workers finish;
stop stub when all workers and servers finish
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index cd2bc02..5d530da 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -21,6 +21,7 @@ void Server::Setup(const UpdaterProto& proto,
shard_=shard;
}
+
void Server::Run(){
dealer_=std::make_shared<Dealer>(2*thread_id_);
dealer_->Connect(kInprocRouterEndpoint);
@@ -38,7 +39,12 @@ void Server::Run(){
break;
Msg* response=nullptr;
int type=msg->type();
- if (type==kConnect){
+ if (type== kStop){
+ msg->set_src(group_id_, server_id_, kServer);
+ msg->set_dst(-1,-1, kStub);
+ dealer_->Send(&msg);
+ break;
+ }else if (type==kConnect){
// TODO remove receiving pong msg
string pong((char*)msg->frame_data(), msg->frame_size());
CHECK_STREQ("PONG", pong.c_str());
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index bc6867d..37e9883 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -36,6 +36,20 @@ void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){
"Updater", CreateInstance(singa::SGDUpdater, singa::Updater));
}
+typedef struct HandleContext_{
+ shared_ptr<Dealer> dealer;
+ int group_id, id;
+} HandleContext;
+
+void HandleWorkerFinish(void * ctx){
+ HandleContext* hctx=static_cast<HandleContext*> (ctx);
+ Msg* msg=new Msg();
+ msg->set_src(-1,-1, kRuntime);
+ msg->set_dst(hctx->group_id, hctx->id, kServer);
+ msg->set_type(kStop);
+ hctx->dealer->Send(&msg);
+}
+
void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
int procs_id){
procs_id_=procs_id;
@@ -44,6 +58,7 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
auto cluster=Cluster::Get(cproto, procs_id);
// create servers
vector<shared_ptr<Server>> servers;
+ vector<HandleContext> ctx;
int nthreads=1; // the first socket is the router
if(cluster->has_server()){
int pid=cluster->procs_id();
@@ -54,10 +69,21 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
int end=start+cluster->nservers_per_group();
// the ParamShard for servers consists of a dictionary of Param objects
auto shard=make_shared<Server::ParamShard>();
- for(int sid=start;sid<end;sid++){
- auto server=make_shared<Server>(nthreads++, gid, sid);
- server->Setup(mproto.updater(), shard);
- servers.push_back(server);
+ if(start<end){
+ auto dealer=make_shared<Dealer>();
+ dealer->Connect(kInprocRouterEndpoint);
+ for(int sid=start;sid<end;sid++){
+ auto server=make_shared<Server>(nthreads++, gid, sid);
+ server->Setup(mproto.updater(), shard);
+ servers.push_back(server);
+ HandleContext hc;
+ hc.dealer=dealer;
+ hc.group_id=gid;
+ hc.id=sid;
+ ctx.push_back(hc);
+ cluster->runtime()->sWatchSGroup(gid, sid, HandleWorkerFinish,
+ &ctx.back());
+ }
}
}
@@ -152,12 +178,13 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
threads.push_back(std::thread(&Server::Run,server.get()));
for(auto worker: workers)
threads.push_back(std::thread(&Worker::Run,worker.get()));
- Run(shards);
+ Run(servers.size(), workers.size(), shards);
for(auto& thread: threads)
thread.join();
}
-void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){
+void Trainer::Run(int nworkers, int nservers,
+ const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){
auto cluster=Cluster::Get();
procs_id_=cluster->procs_id();
auto router=make_shared<Router>();
@@ -166,7 +193,8 @@ void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){
router->Bind(cluster->endpoint());
map<int, shared_ptr<Dealer>> interprocs_dealers;
- while(true){
+ bool stop=false;
+ while(!stop){
Msg* msg=router->Receive();
if(msg==nullptr){
LOG(ERROR)<<"Connection broken!";
@@ -179,6 +207,18 @@ void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){
if(dst_flag == kStub&&(dst_procs==procs_id_||dst_procs==-1)){
if(type==kConnect){
msg =HandleConnect(&msg);
+ }else if(type==kStop){
+ if(msg->src_flag()==kServer)
+ nworkers--;
+ else if (msg->src_flag()==kWorkerParam)
+ nservers--;
+ delete msg;
+ msg=nullptr;
+ if(nworkers==0&&nservers==0){
+ stop=true;
+ break;
+ }
+ LOG(ERROR)<<"Stub recv Stop";
}else{
int group_id=msg->src_first();
int paramid=msg->target_first();
@@ -223,6 +263,7 @@ void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){
}
}
}
+ LOG(ERROR)<<"Stub finishes";
}
Msg* Trainer::HandleConnect(Msg** msg){
string ping((char*)(*msg)->frame_data(), (*msg)->frame_size());
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
index 9ef47a6..f0b54ea 100644
--- a/src/trainer/worker.cc
+++ b/src/trainer/worker.cc
@@ -12,16 +12,20 @@ using std::thread;
namespace singa {
Worker::Worker(int thread_id, int group_id, int worker_id):
thread_id_(thread_id), group_id_(group_id), worker_id_(worker_id){
- }
+
+}
void Worker::Setup(const ModelProto& model,
shared_ptr<NeuralNet> train_net){
train_net_=train_net;
modelproto_=model;
+ auto cluster=Cluster::Get();
+ int sgid=group_id_/cluster->nworker_groups_per_server_group();
+ cluster->runtime()->wJoinSGroup(group_id_, worker_id_, sgid);
}
void Worker::Run(){
- param_dealer_=make_shared<Dealer>(2*thread_id_);
+ param_dealer_=make_shared<Dealer>(2*thread_id_);
param_dealer_->Connect(kInprocRouterEndpoint);
param_poller_.Add(param_dealer_.get());
layer_dealer_=make_shared<Dealer>(2*thread_id_+1);
@@ -87,6 +91,19 @@ void Worker::Run(){
RunOneBatch(step_, &perf);
step_++;
}
+
+ Stop();
+}
+
+void Worker::Stop(){
+ auto cluster=Cluster::Get();
+ int sgid=group_id_/cluster->nworker_groups_per_server_group();
+ cluster->runtime()->wLeaveSGroup(group_id_, worker_id_, sgid);
+ Msg* msg=new Msg();
+ msg->set_src(group_id_, worker_id_, kWorkerParam);
+ msg->set_dst(-1,-1, kStub);
+ msg->set_type(kStop);
+ param_dealer_->Send(&msg);
}
int Worker::Put(shared_ptr<Param> param, int step){
Msg* msg=new Msg();
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/utils/cluster.cc
----------------------------------------------------------------------
diff --git a/src/utils/cluster.cc b/src/utils/cluster.cc
index 66c4ac8..b00a3cd 100644
--- a/src/utils/cluster.cc
+++ b/src/utils/cluster.cc
@@ -30,6 +30,9 @@ Cluster::Cluster(const ClusterProto &cluster, int procs_id) {
}
CHECK_EQ(endpoints_.size(), nprocs);
}
+ auto rt=new ZKClusterRT(cluster_.zookeeper_host());
+ rt->Init();
+ cluster_rt_=shared_ptr<ClusterRuntime>(static_cast<ClusterRuntime*>(rt));
}
void Cluster::SetupFolders(const ClusterProto &cluster){
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/utils/cluster_rt.cc
----------------------------------------------------------------------
diff --git a/src/utils/cluster_rt.cc b/src/utils/cluster_rt.cc
index d88ab46..433623d 100644
--- a/src/utils/cluster_rt.cc
+++ b/src/utils/cluster_rt.cc
@@ -39,7 +39,7 @@ bool ZKClusterRT::Init(){
}
bool ZKClusterRT::sWatchSGroup(int gid, int sid, rt_callback fn, void *ctx){
-
+
string path = getSGroupPath(gid);
struct Stat stat;
@@ -49,7 +49,7 @@ bool ZKClusterRT::sWatchSGroup(int gid, int sid, rt_callback fn, void *ctx){
if (ret == ZOK) ;
//need to create zk node first
else if (ret == ZNONODE){
- char buf[MAX_BUF_LEN];
+ char buf[MAX_BUF_LEN];
ret = zoo_create(zkhandle_, path.c_str(), NULL, -1, &ZOO_OPEN_ACL_UNSAFE, 0, buf, MAX_BUF_LEN);
if (ret == ZOK){
LOG(INFO) << "zookeeper node " << buf << " created";
@@ -77,13 +77,13 @@ bool ZKClusterRT::sWatchSGroup(int gid, int sid, rt_callback fn, void *ctx){
}
bool ZKClusterRT::wJoinSGroup(int gid, int wid, int s_group){
-
+
string path = getSGroupPath(s_group) + getWorkerPath(gid, wid);
- char buf[MAX_BUF_LEN];
-
+ char buf[MAX_BUF_LEN];
+
int ret = zoo_create(zkhandle_, path.c_str(), NULL, -1, &ZOO_OPEN_ACL_UNSAFE, ZOO_EPHEMERAL, buf, MAX_BUF_LEN);
if (ret == ZOK){
- LOG(INFO) << "zookeeper node " << buf << " created";
+ LOG(ERROR) << "zookeeper node " << buf << " created";
return true;
}
else if (ret == ZNODEEXISTS){
@@ -94,18 +94,18 @@ bool ZKClusterRT::wJoinSGroup(int gid, int wid, int s_group){
LOG(ERROR) << "zookeeper parent node " << getSGroupPath(s_group) << " not exist";
return false;
}
-
+
LOG(ERROR) << "Unhandled ZK error code: " << ret << " (zoo_create)";
return false;
}
bool ZKClusterRT::wLeaveSGroup(int gid, int wid, int s_group){
-
+
string path = getSGroupPath(s_group) + getWorkerPath(gid, wid);
-
+
int ret = zoo_delete(zkhandle_, path.c_str(), -1);
if (ret == ZOK){
- LOG(INFO) << "zookeeper node " << path << " deleted";
+ LOG(ERROR) << "zookeeper node " << path << " deleted";
return true;
}
else if (ret == ZNONODE){