You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2015/09/23 17:07:30 UTC
[4/5] incubator-singa git commit: SINGA-21 Code review 5
SINGA-21 Code review 5
review trainer.cc/h, driver.cc/.h, singa.h, main.cc
- rewrite headers in driver.h
- move template impl from driver.h to driver.cc
- format code
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/366e6a82
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/366e6a82
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/366e6a82
Branch: refs/heads/master
Commit: 366e6a82684aff9c0b31e904e3c45dcca2163490
Parents: f50d293
Author: wang sheng <wa...@gmail.com>
Authored: Wed Sep 23 15:20:20 2015 +0800
Committer: wang sheng <wa...@gmail.com>
Committed: Wed Sep 23 15:28:43 2015 +0800
----------------------------------------------------------------------
include/driver.h | 45 +-------
include/trainer/trainer.h | 80 ++++++-------
src/driver.cc | 52 ++++++++-
src/main.cc | 10 +-
src/neuralnet/neuralnet.cc | 4 +-
src/trainer/trainer.cc | 250 +++++++++++++++++++---------------------
6 files changed, 211 insertions(+), 230 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/include/driver.h
----------------------------------------------------------------------
diff --git a/include/driver.h b/include/driver.h
index 7d15c98..563be77 100644
--- a/include/driver.h
+++ b/include/driver.h
@@ -22,7 +22,8 @@
#ifndef SINGA_DRIVER_H_
#define SINGA_DRIVER_H_
-#include "singa.h"
+#include "proto/job.pb.h"
+#include "proto/singa.pb.h"
namespace singa {
@@ -119,48 +120,6 @@ class Driver {
SingaProto singa_conf_;
};
-template<typename Subclass, typename Type>
-int Driver::RegisterLayer(const Type& type) {
- auto factory = Singleton<Factory<singa::Layer>>::Instance();
- factory->Register(type, CreateInstance(Subclass, Layer));
- return 1;
-}
-
-template<typename Subclass, typename Type>
-int Driver::RegisterParam(const Type& type) {
- auto factory = Singleton<Factory<singa::Param>>::Instance();
- factory->Register(type, CreateInstance(Subclass, Param));
- return 1;
-}
-
-template<typename Subclass, typename Type>
-int Driver::RegisterParamGenerator(const Type& type) {
- auto factory = Singleton<Factory<singa::ParamGenerator>>::Instance();
- factory->Register(type, CreateInstance(Subclass, ParamGenerator));
- return 1;
-}
-
-template<typename Subclass, typename Type>
-int Driver::RegisterUpdater(const Type& type) {
- auto factory = Singleton<Factory<singa::Updater>>::Instance();
- factory->Register(type, CreateInstance(Subclass, Updater));
- return 1;
-}
-
-template<typename Subclass, typename Type>
-int Driver::RegisterLRGenerator(const Type& type) {
- auto factory = Singleton<Factory<singa::LRGenerator>>::Instance();
- factory->Register(type, CreateInstance(Subclass, LRGenerator));
- return 1;
-}
-
-template<typename Subclass, typename Type>
-int Driver::RegisterWorker(const Type& type) {
- auto factory = Singleton<Factory<singa::Worker>>::Instance();
- factory->Register(type, CreateInstance(Subclass, Worker));
- return 1;
-}
-
} // namespace singa
#endif // SINGA_DRIVER_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/include/trainer/trainer.h
----------------------------------------------------------------------
diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h
index d3d332f..1c0e039 100644
--- a/include/trainer/trainer.h
+++ b/include/trainer/trainer.h
@@ -19,26 +19,24 @@
*
*************************************************************/
-#ifndef INCLUDE_TRAINER_TRAINER_H_
-#define INCLUDE_TRAINER_TRAINER_H_
+#ifndef SINGA_TRAINER_TRAINER_H_
+#define SINGA_TRAINER_TRAINER_H_
#include <queue>
-#include <vector>
#include <unordered_map>
+#include <vector>
+#include "communication/socket.h"
+#include "neuralnet/neuralnet.h"
#include "proto/job.pb.h"
#include "proto/singa.pb.h"
+#include "trainer/server.h"
+#include "trainer/worker.h"
+#include "utils/factory.h"
#include "utils/param.h"
#include "utils/singleton.h"
-#include "utils/factory.h"
-#include "neuralnet/neuralnet.h"
-#include "trainer/worker.h"
-#include "trainer/server.h"
-#include "communication/socket.h"
namespace singa {
-using std::vector;
-
/**
* Every running process has a training object which launches one or more
* worker (and server) threads.
@@ -77,7 +75,7 @@ class Trainer{
* @param jobConf
* @return server instances
*/
- vector<Server*> CreateServers(const JobProto& jobConf);
+ std::vector<Server*> CreateServers(const JobProto& jobConf);
/**
* Create workers instances.
* @param nthread total num of threads in current procs which is used to
@@ -86,8 +84,7 @@ class Trainer{
* @param jobConf
* @return worker instances
*/
- vector<Worker*> CreateWorkers(const JobProto& jobConf);
-
+ std::vector<Worker*> CreateWorkers(const JobProto& jobConf);
/**
* Setup workers and servers.
*
@@ -98,12 +95,11 @@ class Trainer{
* @param workers
* @param servers
*/
- void SetupWorkerServer(
- const JobProto& jobConf,
- const vector<Worker*>& workers,
- const vector<Server*>& servers);
-
- void Run(const vector<Worker*>& workers, const vector<Server*>& servers);
+ void SetupWorkerServer(const JobProto& jobConf,
+ const std::vector<Worker*>& workers,
+ const std::vector<Server*>& servers);
+ void Run(const std::vector<Worker*>& workers,
+ const std::vector<Server*>& servers);
/**
* Display metrics to log (standard output)
*/
@@ -118,24 +114,20 @@ class Trainer{
* Handle messages to local servers and local stub
*/
void HandleLocalMsg(std::queue<Msg*>* msg_queue, Msg** msg);
-
- /**
- * Generate a request message to Get the parameter object.
- */
- const vector<Msg*> HandleGet(ParamEntry* entry, Msg** msg);
- void HandleGetResponse(ParamEntry* entry, Msg** msg);
-
- /**
- * Generate a request message to Update the parameter object.
- */
- const vector<Msg*> HandleUpdate(ParamEntry* entry, Msg** msg);
+ /**
+ * Generate a request message to Get the parameter object.
+ */
+ const std::vector<Msg*> HandleGet(ParamEntry* entry, Msg** msg);
+ void HandleGetResponse(ParamEntry* entry, Msg** msg);
+ /**
+ * Generate a request message to Update the parameter object.
+ */
+ const std::vector<Msg*> HandleUpdate(ParamEntry* entry, Msg** msg);
void HandleUpdateResponse(ParamEntry* entry, Msg** msg);
-
/**
- * Generate a request message to Put the parameter object.
- */
- const vector<Msg*> HandlePut(ParamEntry* entry, Msg** msg);
-
+ * Generate a request message to Put the parameter object.
+ */
+ const std::vector<Msg*> HandlePut(ParamEntry* entry, Msg** msg);
/**
* Called by HandlePut, HandleUpdate and HandleGet functions
* @param type message type
@@ -145,7 +137,7 @@ class Trainer{
* @param ret generated messages
*/
void GenMsgs(int type, int version, ParamEntry* entry,
- Msg* msg, vector<Msg*> *ret);
+ Msg* msg, std::vector<Msg*> *ret);
/**
* Get a hash id for a Param object from a group.
*
@@ -157,13 +149,15 @@ class Trainer{
}
protected:
- int procs_id_;
- Router *router_;
+ int procs_id_ = -1;
+ Router *router_ = nullptr;
std::unordered_map<int, ParamEntry*> worker_shard_;
//!< map from slice to the server that updates it
- vector<int> slice2server_;
- //stub will destroy all neuralnets in the end
- vector<NeuralNet*> nets_;
+ std::vector<int> slice2server_;
+ // a buffer of created nets, will destroy them all in destructor
+ std::vector<NeuralNet*> nets_;
};
-} /* singa */
-#endif // INCLUDE_TRAINER_TRAINER_H_
+
+} // namespace singa
+
+#endif // SINGA_TRAINER_TRAINER_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
index 28d21c2..42a1330 100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@ -24,24 +24,27 @@
#include <cblas.h>
#include <glog/logging.h>
#include <string>
+#include "neuralnet/neuralnet.h"
+#include "neuralnet/layer.h"
+#include "trainer/trainer.h"
+#include "utils/common.h"
+#include "utils/factory.h"
+#include "utils/singleton.h"
#include "utils/tinydir.h"
namespace singa {
void Driver::Init(int argc, char **argv) {
google::InitGoogleLogging(argv[0]);
-
// unique job ID generated from singa-run.sh, passed in as "-singa_job <id>"
int arg_pos = ArgPos(argc, argv, "-singa_job");
job_id_ = (arg_pos != -1) ? atoi(argv[arg_pos+1]) : -1;
-
// global signa conf passed by singa-run.sh as "-singa_conf <path>"
arg_pos = ArgPos(argc, argv, "-singa_conf");
if (arg_pos != -1)
ReadProtoFromTextFile(argv[arg_pos+1], &singa_conf_);
else
ReadProtoFromTextFile("conf/singa.conf", &singa_conf_);
-
// job conf passed by users as "-conf <path>"
arg_pos = ArgPos(argc, argv, "-conf");
CHECK_NE(arg_pos, -1);
@@ -107,7 +110,47 @@ void Driver::Init(int argc, char **argv) {
RegisterParamGenerator<UniformSqrtFanInOutGen>(kUniformSqrtFanInOut);
}
+template<typename Subclass, typename Type>
+int Driver::RegisterLayer(const Type& type) {
+ auto factory = Singleton<Factory<singa::Layer>>::Instance();
+ factory->Register(type, CreateInstance(Subclass, Layer));
+ return 1;
+}
+
+template<typename Subclass, typename Type>
+int Driver::RegisterParam(const Type& type) {
+ auto factory = Singleton<Factory<singa::Param>>::Instance();
+ factory->Register(type, CreateInstance(Subclass, Param));
+ return 1;
+}
+
+template<typename Subclass, typename Type>
+int Driver::RegisterParamGenerator(const Type& type) {
+ auto factory = Singleton<Factory<singa::ParamGenerator>>::Instance();
+ factory->Register(type, CreateInstance(Subclass, ParamGenerator));
+ return 1;
+}
+
+template<typename Subclass, typename Type>
+int Driver::RegisterUpdater(const Type& type) {
+ auto factory = Singleton<Factory<singa::Updater>>::Instance();
+ factory->Register(type, CreateInstance(Subclass, Updater));
+ return 1;
+}
+template<typename Subclass, typename Type>
+int Driver::RegisterLRGenerator(const Type& type) {
+ auto factory = Singleton<Factory<singa::LRGenerator>>::Instance();
+ factory->Register(type, CreateInstance(Subclass, LRGenerator));
+ return 1;
+}
+
+template<typename Subclass, typename Type>
+int Driver::RegisterWorker(const Type& type) {
+ auto factory = Singleton<Factory<singa::Worker>>::Instance();
+ factory->Register(type, CreateInstance(Subclass, Worker));
+ return 1;
+}
void Driver::Submit(bool resume, const JobProto& jobConf) {
if (singa_conf_.has_log_dir())
@@ -118,9 +161,8 @@ void Driver::Submit(bool resume, const JobProto& jobConf) {
LOG(FATAL) << "workspace does not exist: " << jobConf.cluster().workspace();
if (jobConf.num_openblas_threads() != 1)
LOG(WARNING) << "openblas with "
- << jobConf.num_openblas_threads() << " threads";
+ << jobConf.num_openblas_threads() << " threads";
openblas_set_num_threads(jobConf.num_openblas_threads());
-
JobProto job;
job.CopyFrom(jobConf);
job.set_id(job_id_);
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/src/main.cc
----------------------------------------------------------------------
diff --git a/src/main.cc b/src/main.cc
index 5e94de4..5d2ab2f 100644
--- a/src/main.cc
+++ b/src/main.cc
@@ -45,20 +45,20 @@
*/
int main(int argc, char **argv) {
- // must create driver at the beginning and call its Init method.
+ // must create driver at the beginning and call its Init method.
singa::Driver driver;
driver.Init(argc, argv);
- // if -resume in argument list, set resume to true; otherwise false
+ // if -resume in argument list, set resume to true; otherwise false
int resume_pos = singa::ArgPos(argc, argv, "-resume");
bool resume = (resume_pos != -1);
- // users can register new subclasses of layer, updater, etc.
+ // users can register new subclasses of layer, updater, etc.
- // get the job conf, and custmize it if need
+ // get the job conf, and custmize it if need
singa::JobProto jobConf = driver.job_conf();
- // submit the job
+ // submit the job
driver.Submit(resume, jobConf);
return 0;
}
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/src/neuralnet/neuralnet.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc
index 200824a..775a5a7 100644
--- a/src/neuralnet/neuralnet.cc
+++ b/src/neuralnet/neuralnet.cc
@@ -19,10 +19,10 @@
*
*************************************************************/
+#include "neuralnet/neuralnet.h"
+
#include <algorithm>
#include <queue>
-
-#include "neuralnet/neuralnet.h"
#include "utils/singleton.h"
namespace singa {
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index 8a0589e..ecfc94a 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -19,25 +19,21 @@
*
*************************************************************/
-#include <thread>
-#include <vector>
-#include <map>
-#include <chrono>
+#include "trainer/trainer.h"
+
#include <glog/logging.h>
-#include "utils/tinydir.h"
#include <unistd.h>
+#include <map>
+#include <thread>
+#include "mshadow/tensor.h"
+#include "proto/common.pb.h"
#include "utils/cluster.h"
#include "utils/common.h"
-#include "proto/common.pb.h"
-#include "trainer/trainer.h"
-#include "mshadow/tensor.h"
-
+#include "utils/tinydir.h"
namespace singa {
+
using std::vector;
-using std::map;
-using std::queue;
-using namespace std::chrono;
using std::string;
/***********************Trainer****************************/
@@ -47,12 +43,82 @@ Trainer::~Trainer() {
delete p;
}
+void Trainer::Start(bool resume, const SingaProto& singaConf, JobProto* job) {
+ // register job to zookeeper at the beginning
+ auto cluster = Cluster::Setup(job->id(), singaConf, job->cluster());
+ if (resume) Resume(job);
+ router_ = new Router();
+ router_->Bind(kInprocRouterEndpoint);
+ const string hostip = cluster->hostip();
+ int port = router_->Bind("tcp://" + hostip + ":*");
+ // register endpoint to zookeeper
+ cluster->Register(getpid(), hostip + ":" + std::to_string(port));
+ const vector<Worker*> workers = CreateWorkers(*job);
+ const vector<Server*> servers = CreateServers(*job);
+ SetupWorkerServer(*job, workers, servers);
+#ifdef USE_MPI
+ int nthreads = workers.size() + servers.size();
+ for (int i = 0; i < nthreads; i++)
+ MPIQueues.push_back(make_shared<SafeQueue>());
+#endif
+ vector<std::thread> threads;
+ for (auto server : servers)
+ threads.push_back(std::thread(&Server::Run, server));
+ for (auto worker : workers)
+ threads.push_back(std::thread(&Worker::Run, worker));
+ Run(workers, servers);
+ for (auto& thread : threads)
+ thread.join();
+ for (auto server : servers)
+ delete server;
+ for (auto worker : workers)
+ delete worker;
+}
+
+void Trainer::Resume(JobProto* jobConf) {
+ tinydir_dir dir;
+ string folder = Cluster::Get()->checkpoint_folder();
+ tinydir_open(&dir, folder.c_str());
+ int latest_step = 0;
+ // there would be multi checkpoint files (from diff workers) for one step
+ vector<string> ck_files;
+ // iterate all files to get the files for the last checkpoint
+ while (dir.has_next) {
+ tinydir_file file;
+ tinydir_readfile(&dir, &file);
+ tinydir_next(&dir);
+ char* ch = strstr(file.name, "step");
+ if (ch == nullptr) {
+ if (file.name[0] != '.')
+ LOG(INFO) << "Irregular file in checkpoint folder: " << file.name;
+ continue;
+ }
+ LOG(INFO) << "Add checkpoint file for resume: " << ch;
+ int step = atoi(ch+4);
+ if (step == latest_step) {
+ ck_files.push_back(file.name);
+ } else if (step > latest_step) {
+ latest_step = step;
+ ck_files.clear();
+ ck_files.push_back(string(file.name));
+ }
+ }
+ if (latest_step > 0) {
+ jobConf->set_step(latest_step);
+ if (!jobConf->has_reset_param_version())
+ jobConf->set_reset_param_version(false);
+ jobConf->clear_checkpoint_path();
+ for (auto ck_file : ck_files)
+ jobConf->add_checkpoint_path(folder + "/" + ck_file);
+ }
+ tinydir_close(&dir);
+}
+
const vector<int> SliceParams(const vector<Param*>& params) {
// for load-balance among servers in a group and among server groups
int nserver_grps = Cluster::Get()->nserver_groups();
int nservers_per_grp = Cluster::Get()->nservers_per_group();
int lcm = LeastCommonMultiple(nserver_grps, nservers_per_grp);
-
// collect sizes of unique Params
std::vector<int> paramsize;
for (auto param : params)
@@ -86,10 +152,9 @@ const vector<int> SliceParams(const vector<Param*>& params) {
return slices;
}
-void Trainer::SetupWorkerServer(
- const JobProto& job_conf,
- const vector<Worker*>& workers,
- const vector<Server*>& servers) {
+void Trainer::SetupWorkerServer(const JobProto& job_conf,
+ const vector<Worker*>& workers,
+ const vector<Server*>& servers) {
auto cluster = Cluster::Get();
int grp_size = cluster->nworkers_per_group();
const auto& net_conf = job_conf.neuralnet();
@@ -97,7 +162,6 @@ void Trainer::SetupWorkerServer(
nets_.push_back(net);
// MUST do SliceParam before share param/net with others
auto slices = SliceParams(net->params());
-
std::unordered_map<int, NeuralNet*> grp_net;
int first_grp = workers.size() ? workers.at(0)->grp_id() : -1;
for (auto worker : workers) {
@@ -107,13 +171,17 @@ void Trainer::SetupWorkerServer(
NeuralNet* valid_net = nullptr;
if (grp_net.find(grp_id) == grp_net.end()) {
if (grp_id == first_grp) {
- // test are performed only by the first group now. TODO update.
+ // test are performed only by the first group now.
+ // TODO(wangwei) update.
if (first_grp == 0 && job_conf.test_steps() && worker_id == 0) {
- test_net = NeuralNet::Create(net_conf, kTest, 1); // hard code for exp
+ // hard code for exp
+ // TODO(wangwei) move test unit out as an independent module
+ test_net = NeuralNet::Create(net_conf, kTest, 1);
test_net->ShareParamsFrom(net);
nets_.push_back(test_net);
}
- // validation are performed only by the first group. TODO update.
+ // validation are performed only by the first group.
+ // TODO(wangwei) update.
if (first_grp == 0 && job_conf.valid_steps() && worker_id == 0) {
valid_net = NeuralNet::Create(net_conf, kValidation, 1);
valid_net->ShareParamsFrom(net);
@@ -123,7 +191,7 @@ void Trainer::SetupWorkerServer(
} else {
grp_net[grp_id] = NeuralNet::Create(net_conf, kTrain, grp_size);
nets_.push_back(grp_net[grp_id]);
- if(cluster->share_memory())
+ if (cluster->share_memory())
grp_net[grp_id]->ShareParamsFrom(net);
}
for (auto layer : grp_net[grp_id]->layers()) {
@@ -141,12 +209,10 @@ void Trainer::SetupWorkerServer(
<< worker->id() << " net " << grp_net[grp_id];
worker->Setup(job_conf, grp_net[grp_id], valid_net, test_net);
}
-
// partition among server groups, each group maintains one sub-set for sync
auto slice2group = PartitionSlices(cluster->nserver_groups(), slices);
// partition within one server group, each server updates for one sub-set
slice2server_ = PartitionSlices(cluster->nservers_per_group(), slices);
-
for (auto server : servers)
server->Setup(job_conf.updater(), slice2group, slice2server_);
}
@@ -156,14 +222,13 @@ vector<Server*> Trainer::CreateServers(const JobProto& job) {
vector<Server*> servers;
if (!cluster->has_server())
return servers;
-
int server_procs = cluster->procs_id();
// if true, server procs (logical) id starts after worker procs
if (cluster->server_worker_separate())
server_procs -= cluster->nworker_procs();
const vector<int> rng = cluster->ExecutorRng(server_procs,
- cluster->nservers_per_group(),
- cluster->nservers_per_procs());
+ cluster->nservers_per_group(),
+ cluster->nservers_per_procs());
int gstart = rng[0], gend = rng[1], start = rng[2], end = rng[3];
for (int gid = gstart; gid < gend; gid++) {
for (int sid = start; sid < end; sid++) {
@@ -174,15 +239,14 @@ vector<Server*> Trainer::CreateServers(const JobProto& job) {
return servers;
}
-
vector<Worker*> Trainer::CreateWorkers(const JobProto& job) {
- auto cluster=Cluster::Get();
+ auto cluster = Cluster::Get();
vector<Worker*> workers;
- if(!cluster->has_worker())
+ if (!cluster->has_worker())
return workers;
const vector<int> rng = cluster->ExecutorRng(cluster->procs_id(),
- cluster->nworkers_per_group(),
- cluster->nworkers_per_procs());
+ cluster->nworkers_per_group(),
+ cluster->nworkers_per_procs());
int gstart = rng[0], gend = rng[1], wstart = rng[2], wend = rng[3];
for (int gid = gstart; gid < gend; gid++) {
for (int wid = wstart; wid < wend; wid++) {
@@ -194,93 +258,13 @@ vector<Worker*> Trainer::CreateWorkers(const JobProto& job) {
return workers;
}
-void Trainer::Resume(JobProto* jobConf) {
- tinydir_dir dir;
- string folder = Cluster::Get()->checkpoint_folder();
- tinydir_open(&dir, folder.c_str());
- int latest_step = 0;
- // there would be multi checkpoint files (from diff workers) for one step
- vector<string> ck_files;
- // iterate all files to get the files for the last checkpoint
- while (dir.has_next) {
- tinydir_file file;
- tinydir_readfile(&dir, &file);
- tinydir_next(&dir);
- char* ch = strstr(file.name, "step");
- if (ch == nullptr) {
- if (file.name[0] != '.')
- LOG(INFO) << "Irregular file in checkpoint folder: " << file.name;
- continue;
- }
-
- LOG(INFO) << "Add checkpoint file for resume: " << ch;
- int step = atoi(ch+4);
- if (step == latest_step) {
- ck_files.push_back(file.name);
- } else if(step > latest_step) {
- latest_step = step;
- ck_files.clear();
- ck_files.push_back(string(file.name));
- }
- }
-
- if (latest_step > 0) {
- jobConf->set_step(latest_step);
- if (!jobConf->has_reset_param_version())
- jobConf->set_reset_param_version(false);
- jobConf->clear_checkpoint_path();
- for (auto ck_file : ck_files)
- jobConf->add_checkpoint_path(folder + "/" + ck_file);
- }
- tinydir_close(&dir);
-}
-
-void Trainer::Start(bool resume, const SingaProto& singaConf, JobProto* job) {
- // register job to zookeeper at the beginning
- auto cluster = Cluster::Setup(job->id(), singaConf, job->cluster());
- if (resume)
- Resume(job);
-
- router_ = new Router();
- router_->Bind(kInprocRouterEndpoint);
- const string hostip = cluster->hostip();
- int port = router_->Bind("tcp://" + hostip + ":*");
- // register endpoint to zookeeper
- cluster->Register(getpid(), hostip + ":" + std::to_string(port));
-
- const vector<Worker*> workers = CreateWorkers(*job);
- const vector<Server*> servers = CreateServers(*job);
- SetupWorkerServer(*job, workers, servers);
-
-#ifdef USE_MPI
- int nthreads = workers.size() + servers.size();
- for (int i = 0; i < nthreads; i++)
- MPIQueues.push_back(make_shared<SafeQueue>());
-#endif
- vector<std::thread> threads;
- for(auto server : servers)
- threads.push_back(std::thread(&Server::Run, server));
- for(auto worker : workers)
- threads.push_back(std::thread(&Worker::Run, worker));
- Run(workers, servers);
- for(auto& thread : threads)
- thread.join();
- for(auto server : servers)
- delete server;
- for(auto worker : workers)
- delete worker;
-}
-
-void Trainer::Run(
- const vector<Worker*>& workers,
- const vector<Server*>& servers) {
+void Trainer::Run(const vector<Worker*>& workers,
+ const vector<Server*>& servers) {
int nworkers = workers.size(), nservers = servers.size();
auto cluster = Cluster::Get();
procs_id_ = cluster->procs_id();
LOG(INFO) << "Stub in process " << procs_id_ << " starts";
-
- map<int, Dealer*> inter_dealers; // for sending msg to other procs
-
+ std::map<int, Dealer*> inter_dealers; // for sending msg to other procs
std::queue<Msg*> msg_queue;
while (true) {
Msg* msg = nullptr;
@@ -343,26 +327,27 @@ Dealer* Trainer::CreateInterProcsDealer(int dst_procs) {
// forward to other procs
auto cluster = Cluster::Get();
auto dealer = new Dealer();
- while(cluster->endpoint(dst_procs)=="") {
- //kCollectSleepTime));
+ while (cluster->endpoint(dst_procs) == "") {
+ // kCollectSleepTime));
std::this_thread::sleep_for(std::chrono::milliseconds(3000));
- LOG(ERROR)<<"waiting for procs "<< dst_procs<<" to register";
+ LOG(ERROR) << "waiting for procs " << dst_procs << " to register";
}
dealer->Connect("tcp://"+cluster->endpoint(dst_procs));
return dealer;
}
-void Trainer::HandleLocalMsg(queue<Msg*>* msg_queue, Msg** msg) {
+void Trainer::HandleLocalMsg(std::queue<Msg*>* msg_queue, Msg** msg) {
Msg* msgg = *msg;
int paramid = ParamID(msgg->trgt_val());
int type = msgg->type();
int grp;
ParamEntry *entry = nullptr;
- switch (type) { // TODO process other requests, e.g. RESTful
+ // TODO(wangwei) process other requests, e.g. RESTful
+ switch (type) {
case kUpdate:
grp = AddrGrp(msgg->src());
entry = worker_shard_.at(Hash(grp, paramid));
- for(auto update_msg : HandleUpdate(entry, msg))
+ for (auto update_msg : HandleUpdate(entry, msg))
msg_queue->push(update_msg);
break;
case kRUpdate:
@@ -373,7 +358,7 @@ void Trainer::HandleLocalMsg(queue<Msg*>* msg_queue, Msg** msg) {
case kGet:
grp = AddrGrp(msgg->src());
entry = worker_shard_.at(Hash(grp, paramid));
- for(auto get_msg : HandleGet(entry, msg))
+ for (auto get_msg : HandleGet(entry, msg))
msg_queue->push(get_msg);
break;
case kRGet:
@@ -384,22 +369,22 @@ void Trainer::HandleLocalMsg(queue<Msg*>* msg_queue, Msg** msg) {
case kPut:
grp = AddrGrp(msgg->src());
entry = worker_shard_.at(Hash(grp, paramid));
- for(auto put_msg : HandlePut(entry, msg))
+ for (auto put_msg : HandlePut(entry, msg))
msg_queue->push(put_msg);
break;
default:
- LOG(ERROR)<<"Unknow message type:"<<type;
+ LOG(ERROR) << "Unknow message type:" << type;
break;
}
}
-void Trainer::GenMsgs(int type, int version, ParamEntry* entry,
- Msg* msg, vector<Msg*> *ret) {
+void Trainer::GenMsgs(int type, int version, ParamEntry* entry, Msg* msg,
+ vector<Msg*> *ret) {
int src_grp = AddrGrp(msg->src());
int dst_grp = src_grp / Cluster::Get()->nworker_groups_per_server_group();
- auto param=entry->shares.at(0);
+ auto param = entry->shares.at(0);
for (int idx = 0 ; idx < param->num_slices(); idx++) {
- int slice_id =param->slice_start() + idx;
+ int slice_id = param->slice_start() + idx;
int server = slice2server_[slice_id];
int dst_procs = Cluster::Get()->ProcsIDOf(dst_grp, server, kServer);
Msg* new_msg = nullptr;
@@ -440,10 +425,10 @@ const vector<Msg*> Trainer::HandleUpdate(ParamEntry *entry, Msg** msg) {
// average local gradient
if (entry->num_local > 1) {
auto it = entry->shares.begin();
- auto shape=mshadow::Shape1((*it)->size());
- mshadow::Tensor<mshadow::cpu,1> sum((*it)->mutable_cpu_grad(), shape);
+ auto shape = mshadow::Shape1((*it)->size());
+ mshadow::Tensor<mshadow::cpu, 1> sum((*it)->mutable_cpu_grad(), shape);
for (++it; it != entry->shares.end(); it++) {
- mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape);
+ mshadow::Tensor<mshadow::cpu, 1> grad((*it)->mutable_cpu_grad(), shape);
sum += grad;
}
}
@@ -480,4 +465,5 @@ void Trainer::HandleUpdateResponse(ParamEntry* entry, Msg** msg) {
param->set_version(version);
DeleteMsg(msg);
}
-} /* singa */
+
+} // namespace singa