You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2015/08/15 10:11:19 UTC
[06/12] incubator-singa git commit: SINGA-55 Refactor main.cc and
singa.h
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2498ff13/include/singa.h
----------------------------------------------------------------------
diff --cc include/singa.h
index 6fb9e97,82df64b..52d1f90
--- a/include/singa.h
+++ b/include/singa.h
@@@ -1,92 -1,40 +1,88 @@@
#ifndef SINGA_SINGA_H_
#define SINGA_SINGA_H_
++
++#include <cblas.h>
#include <gflags/gflags.h>
#include <glog/logging.h>
--#include <cblas.h>
-
-#include "utils/common.h"
+#include <string>
-
- #include "utils/common.h"
++#include "communication/socket.h"
++#include "neuralnet/neuralnet.h"
#include "proto/job.pb.h"
#include "proto/singa.pb.h"
--
++#include "trainer/trainer.h"
++#include "utils/common.h"
#include "utils/param.h"
#include "utils/singleton.h"
#include "utils/factory.h"
--#include "neuralnet/neuralnet.h"
--#include "trainer/trainer.h"
--#include "communication/socket.h"
-
+namespace singa {
+
-DEFINE_string(singa_conf, "conf/singa.conf", "Global config file");
+class Driver {
+ public:
+ /**
+ * Init SINGA, including init glog, parse job id and job conf from cmd line,
+ * and register built-in layer, worker, updater, param subclasses.
+ *
+ * May be used for MPI init if it is used for message passing.
+ */
+ void Init(int argc, char** argv);
+ /**
+ * Register a Layer subclass.
+ *
+ * T is the subclass.
+ * @param type layer type ID. If called by users, it should be different to
+ * the types of built-in layers.
+ * @return 0 if success; otherwise -1.
+ */
+ template<typename T>
+ int RegisterLayer(int type);
+ /**
+ * Register Updater subclasses.
+ *
+ * T is the subclass.
+ * @param type updater type ID. If called by users, it should be different to
+ * the types of built-in updaters.
+ * @return 0 if success; otherwise -1.
+ */
+ template<typename T>
+ int RegisterUpdater(int type);
+ /**
+ * Register Worker subclasses.
+ *
+ * T is the subclass.
+ * @param type worker type ID. If called by users, it should be different to
+ * the types of built-in workers
+ * @return 0 if success; otherwise -1.
+ */
+ template<typename T>
+ int RegisterWorker(int type);
+ /**
+ * Register Param subclasses.
+ *
+ * T is the subclass.
+ * @param type param type. If called by users, it should be different to the
+ * types of built-in params. SINGA currently provides only one built-in Param
+ * implementation whose type ID is 0.
+ * @return 0 if success; otherwise -1.
+ */
+ template<typename T>
+ int RegisterParam(int type);
-
+ /**
+ * Submit the job configuration for starting the job.
+ * @param resume resume from last checkpoint if true.
+ * @param job job configuration
+ */
+ void Submit(bool resume, const JobProto& job);
-
+ /**
+ * @return job ID which is generated by zookeeper and passed in by the
+ * launching script.
+ */
- int job_id() const {
- return job_id_;
- }
++ inline int job_id() const { return job_id_; }
+
+ private:
+ int job_id_;
+};
+
-namespace singa {
-void SubmitJob(int job, bool resume, const JobProto& jobConf) {
- SingaProto singaConf;
- ReadProtoFromTextFile(FLAGS_singa_conf.c_str(), &singaConf);
- if (singaConf.has_log_dir())
- SetupLog(singaConf.log_dir(),
- std::to_string(job) + "-" + jobConf.name());
- if (jobConf.num_openblas_threads() != 1)
- LOG(WARNING) << "openblas is set with " << jobConf.num_openblas_threads()
- << " threads";
- openblas_set_num_threads(jobConf.num_openblas_threads());
- JobProto proto;
- proto.CopyFrom(jobConf);
- proto.set_id(job);
- Trainer trainer;
- trainer.Start(resume, singaConf, &proto);
-}
} // namespace singa
--#endif // SINGA_SINGA_H_
++#endif // SINGA_SINGA_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2498ff13/src/driver.cc
----------------------------------------------------------------------
diff --cc src/driver.cc
index 05c1195,0000000..5469583
mode 100644,000000..100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@@ -1,101 -1,0 +1,102 @@@
+#include "singa.h"
++
+namespace singa {
+
+/**
+ * the job and singa_conf arguments are passed by the singa script which is
+ * transparent to users
+ */
+DEFINE_int32(job, -1, "Unique job ID generated from singa-run.sh");
+DEFINE_string(singa_conf, "conf/singa.conf", "Global config file");
+
+void Driver::Init(int argc, char **argv) {
+ google::InitGoogleLogging(argv[0]);
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+ job_id_ = FLAGS_job;
+
+ // register layers
+ RegisterLayer<BridgeDstLayer>(kBridgeDst);
+ RegisterLayer<BridgeSrcLayer>(kBridgeSrc);
+ RegisterLayer<ConvolutionLayer>(kConvolution);
+ RegisterLayer<ConcateLayer>(kConcate);
+ RegisterLayer<DropoutLayer>(kDropout);
+ RegisterLayer<InnerProductLayer>(kInnerProduct);
+ RegisterLayer<LabelLayer>(kLabel);
+ RegisterLayer<LRNLayer>(kLRN);
+ RegisterLayer<MnistLayer>(kMnist);
+ RegisterLayer<PrefetchLayer>(kPrefetch);
+ RegisterLayer<PoolingLayer>(kPooling);
+ RegisterLayer<RGBImageLayer>(kRGBImage);
+ RegisterLayer<ReLULayer>(kReLU);
+ RegisterLayer<ShardDataLayer>(kShardData);
+ RegisterLayer<SliceLayer>(kSlice);
+ RegisterLayer<SoftmaxLossLayer>(kSoftmaxLoss);
+ RegisterLayer<SplitLayer>(kSplit);
+ RegisterLayer<TanhLayer>(kTanh);
+ RegisterLayer<RBMVisLayer>(kRBMVis);
+ RegisterLayer<RBMHidLayer>(kRBMHid);
+#ifdef USE_LMDB
- RegisterLayer(factory, LMDBData);
++ RegisterLayer<LMDBDataLayer>(kLMDBData);
+#endif
+
+ // register updater
+ RegisterUpdater<AdaGradUpdater>(kAdaGrad);
+ RegisterUpdater<NesterovUpdater>(kNesterov);
- // TODO(wangwei) RegisterUpdater<kRMSPropUpdater>(kRMSProp);
++ // TODO(wangwei) RegisterUpdater<kRMSPropUpdater>(kRMSProp);
+ RegisterUpdater<SGDUpdater>(kSGD);
+
+ // register worker
+ RegisterWorker<BPWorker>(kBP);
+ RegisterWorker<CDWorker>(kCD);
+
+ // register param
+ RegisterParam<Param>(0);
+}
+
+template<typename T>
+int Driver::RegisterLayer(int type) {
+ auto factory = Singleton<Factory<singa::Layer>>::Instance();
+ factory->Register(type, CreateInstance(T, Layer));
+ return 1;
+}
+
+template<typename T>
+int Driver::RegisterParam(int type) {
+ auto factory = Singleton<Factory<singa::Param>>::Instance();
+ factory->Register(type, CreateInstance(T, Param));
+ return 1;
+}
+
+template<typename T>
+int Driver::RegisterUpdater(int type) {
+ auto factory = Singleton<Factory<singa::Updater>>::Instance();
+ factory->Register(type, CreateInstance(T, Updater));
+ return 1;
+}
+
+template<typename T>
+int Driver::RegisterWorker(int type) {
+ auto factory = Singleton<Factory<singa::Worker>>::Instance();
+ factory->Register(type, CreateInstance(T, Worker));
+ return 1;
+}
+
+void Driver::Submit(bool resume, const JobProto& jobConf) {
+ SingaProto singaConf;
+ ReadProtoFromTextFile(FLAGS_singa_conf.c_str(), &singaConf);
+ if (singaConf.has_log_dir())
+ SetupLog(singaConf.log_dir(), std::to_string(FLAGS_job)
+ + "-" + jobConf.name());
+ if (jobConf.num_openblas_threads() != 1)
+ LOG(WARNING) << "openblas with "
+ << jobConf.num_openblas_threads() << " threads";
+ openblas_set_num_threads(jobConf.num_openblas_threads());
+
+ JobProto job;
+ job.CopyFrom(jobConf);
+ job.set_id(job_id_);
+ Trainer trainer;
+ trainer.Start(resume, singaConf, &job);
+}
+
+} // namespace singa