You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2015/05/27 16:39:08 UTC

[07/22] incubator-singa git commit: detect worker status through ClusterRuntime which calls zookeeper. stop servers when all workers finish; stop stub when all workers and servers finish

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index cd2bc02..5d530da 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -21,6 +21,7 @@ void Server::Setup(const UpdaterProto& proto,
   shard_=shard;
 }
 
+
 void Server::Run(){
   dealer_=std::make_shared<Dealer>(2*thread_id_);
   dealer_->Connect(kInprocRouterEndpoint);
@@ -38,7 +39,12 @@ void Server::Run(){
       break;
     Msg* response=nullptr;
     int type=msg->type();
-    if (type==kConnect){
+    if (type== kStop){
+      msg->set_src(group_id_, server_id_, kServer);
+      msg->set_dst(-1,-1, kStub);
+      dealer_->Send(&msg);
+      break;
+    }else if (type==kConnect){
       // TODO remove receiving pong msg
       string pong((char*)msg->frame_data(), msg->frame_size());
       CHECK_STREQ("PONG", pong.c_str());

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index bc6867d..37e9883 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -36,6 +36,20 @@ void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){
       "Updater", CreateInstance(singa::SGDUpdater, singa::Updater));
 }
 
+typedef struct HandleContext_{
+  shared_ptr<Dealer> dealer;
+  int group_id, id;
+} HandleContext;
+
+void HandleWorkerFinish(void * ctx){
+  HandleContext* hctx=static_cast<HandleContext*> (ctx);
+  Msg* msg=new Msg();
+  msg->set_src(-1,-1, kRuntime);
+  msg->set_dst(hctx->group_id, hctx->id, kServer);
+  msg->set_type(kStop);
+  hctx->dealer->Send(&msg);
+}
+
 void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
     int procs_id){
   procs_id_=procs_id;
@@ -44,6 +58,7 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
   auto cluster=Cluster::Get(cproto, procs_id);
   // create servers
   vector<shared_ptr<Server>> servers;
+  vector<HandleContext> ctx;
   int nthreads=1; // the first socket is the router
   if(cluster->has_server()){
     int pid=cluster->procs_id();
@@ -54,10 +69,21 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
     int end=start+cluster->nservers_per_group();
     // the ParamShard for servers consists of a dictionary of Param objects
     auto shard=make_shared<Server::ParamShard>();
-    for(int sid=start;sid<end;sid++){
-      auto server=make_shared<Server>(nthreads++, gid, sid);
-      server->Setup(mproto.updater(), shard);
-      servers.push_back(server);
+    if(start<end){
+      auto dealer=make_shared<Dealer>();
+      dealer->Connect(kInprocRouterEndpoint);
+      for(int sid=start;sid<end;sid++){
+        auto server=make_shared<Server>(nthreads++, gid, sid);
+        server->Setup(mproto.updater(), shard);
+        servers.push_back(server);
+        HandleContext hc;
+        hc.dealer=dealer;
+        hc.group_id=gid;
+        hc.id=sid;
+        ctx.push_back(hc);
+        cluster->runtime()->sWatchSGroup(gid, sid, HandleWorkerFinish,
+            &ctx.back());
+      }
     }
   }
 
@@ -152,12 +178,13 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
     threads.push_back(std::thread(&Server::Run,server.get()));
   for(auto worker: workers)
     threads.push_back(std::thread(&Worker::Run,worker.get()));
-  Run(shards);
+  Run(servers.size(), workers.size(), shards);
   for(auto& thread: threads)
     thread.join();
 }
 
-void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){
+void Trainer::Run(int nworkers, int nservers,
+    const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){
   auto cluster=Cluster::Get();
   procs_id_=cluster->procs_id();
   auto router=make_shared<Router>();
@@ -166,7 +193,8 @@ void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){
     router->Bind(cluster->endpoint());
 
   map<int, shared_ptr<Dealer>> interprocs_dealers;
-  while(true){
+  bool stop=false;
+  while(!stop){
     Msg* msg=router->Receive();
     if(msg==nullptr){
       LOG(ERROR)<<"Connection broken!";
@@ -179,6 +207,18 @@ void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){
       if(dst_flag == kStub&&(dst_procs==procs_id_||dst_procs==-1)){
         if(type==kConnect){
           msg =HandleConnect(&msg);
+        }else if(type==kStop){
+          if(msg->src_flag()==kServer)
+            nworkers--;
+          else if (msg->src_flag()==kWorkerParam)
+            nservers--;
+          delete msg;
+          msg=nullptr;
+          if(nworkers==0&&nservers==0){
+            stop=true;
+            break;
+          }
+          LOG(ERROR)<<"Stub recv Stop";
         }else{
           int group_id=msg->src_first();
           int paramid=msg->target_first();
@@ -223,6 +263,7 @@ void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){
       }
     }
   }
+  LOG(ERROR)<<"Stub finishes";
 }
 Msg* Trainer::HandleConnect(Msg** msg){
   string ping((char*)(*msg)->frame_data(), (*msg)->frame_size());

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
index 9ef47a6..f0b54ea 100644
--- a/src/trainer/worker.cc
+++ b/src/trainer/worker.cc
@@ -12,16 +12,20 @@ using std::thread;
 namespace singa {
 Worker::Worker(int thread_id, int group_id, int worker_id):
   thread_id_(thread_id), group_id_(group_id), worker_id_(worker_id){
-  }
+
+}
 
 void Worker::Setup(const ModelProto& model,
     shared_ptr<NeuralNet> train_net){
   train_net_=train_net;
   modelproto_=model;
+  auto cluster=Cluster::Get();
+  int sgid=group_id_/cluster->nworker_groups_per_server_group();
+  cluster->runtime()->wJoinSGroup(group_id_, worker_id_, sgid);
 }
 
 void Worker::Run(){
-  param_dealer_=make_shared<Dealer>(2*thread_id_);
+    param_dealer_=make_shared<Dealer>(2*thread_id_);
   param_dealer_->Connect(kInprocRouterEndpoint);
   param_poller_.Add(param_dealer_.get());
   layer_dealer_=make_shared<Dealer>(2*thread_id_+1);
@@ -87,6 +91,19 @@ void Worker::Run(){
     RunOneBatch(step_, &perf);
     step_++;
   }
+
+  Stop();
+}
+
+void Worker::Stop(){
+  auto cluster=Cluster::Get();
+  int sgid=group_id_/cluster->nworker_groups_per_server_group();
+  cluster->runtime()->wLeaveSGroup(group_id_, worker_id_, sgid);
+  Msg* msg=new Msg();
+  msg->set_src(group_id_, worker_id_, kWorkerParam);
+  msg->set_dst(-1,-1, kStub);
+  msg->set_type(kStop);
+  param_dealer_->Send(&msg);
 }
 int Worker::Put(shared_ptr<Param> param, int step){
   Msg* msg=new Msg();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/utils/cluster.cc
----------------------------------------------------------------------
diff --git a/src/utils/cluster.cc b/src/utils/cluster.cc
index 66c4ac8..b00a3cd 100644
--- a/src/utils/cluster.cc
+++ b/src/utils/cluster.cc
@@ -30,6 +30,9 @@ Cluster::Cluster(const ClusterProto &cluster, int procs_id) {
     }
     CHECK_EQ(endpoints_.size(), nprocs);
   }
+  auto rt=new ZKClusterRT(cluster_.zookeeper_host());
+  rt->Init();
+  cluster_rt_=shared_ptr<ClusterRuntime>(static_cast<ClusterRuntime*>(rt));
 }
 
 void Cluster::SetupFolders(const ClusterProto &cluster){

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/utils/cluster_rt.cc
----------------------------------------------------------------------
diff --git a/src/utils/cluster_rt.cc b/src/utils/cluster_rt.cc
index d88ab46..433623d 100644
--- a/src/utils/cluster_rt.cc
+++ b/src/utils/cluster_rt.cc
@@ -39,7 +39,7 @@ bool ZKClusterRT::Init(){
 }
 
 bool ZKClusterRT::sWatchSGroup(int gid, int sid, rt_callback fn, void *ctx){
- 
+
   string path = getSGroupPath(gid);
   struct Stat stat;
 
@@ -49,7 +49,7 @@ bool ZKClusterRT::sWatchSGroup(int gid, int sid, rt_callback fn, void *ctx){
   if (ret == ZOK) ;
   //need to create zk node first
   else if (ret == ZNONODE){
-    char buf[MAX_BUF_LEN]; 
+    char buf[MAX_BUF_LEN];
     ret = zoo_create(zkhandle_, path.c_str(), NULL, -1, &ZOO_OPEN_ACL_UNSAFE, 0, buf, MAX_BUF_LEN);
     if (ret == ZOK){
       LOG(INFO) << "zookeeper node " << buf << " created";
@@ -77,13 +77,13 @@ bool ZKClusterRT::sWatchSGroup(int gid, int sid, rt_callback fn, void *ctx){
 }
 
 bool ZKClusterRT::wJoinSGroup(int gid, int wid, int s_group){
-  
+
   string path = getSGroupPath(s_group) + getWorkerPath(gid, wid);
-  char buf[MAX_BUF_LEN]; 
-  
+  char buf[MAX_BUF_LEN];
+
   int ret = zoo_create(zkhandle_, path.c_str(), NULL, -1, &ZOO_OPEN_ACL_UNSAFE, ZOO_EPHEMERAL, buf, MAX_BUF_LEN);
   if (ret == ZOK){
-    LOG(INFO) << "zookeeper node " << buf << " created";
+    LOG(ERROR) << "zookeeper node " << buf << " created";
     return true;
   }
   else if (ret == ZNODEEXISTS){
@@ -94,18 +94,18 @@ bool ZKClusterRT::wJoinSGroup(int gid, int wid, int s_group){
     LOG(ERROR) << "zookeeper parent node " << getSGroupPath(s_group) << " not exist";
     return false;
   }
-  
+
   LOG(ERROR) << "Unhandled ZK error code: " << ret << " (zoo_create)";
   return false;
 }
 
 bool ZKClusterRT::wLeaveSGroup(int gid, int wid, int s_group){
-  
+
   string path = getSGroupPath(s_group) + getWorkerPath(gid, wid);
-  
+
   int ret = zoo_delete(zkhandle_, path.c_str(), -1);
   if (ret == ZOK){
-    LOG(INFO) << "zookeeper node " << path << " deleted";
+    LOG(ERROR) << "zookeeper node " << path << " deleted";
     return true;
   }
   else if (ret == ZNONODE){