You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2015/09/27 16:34:31 UTC
[08/13] incubator-singa git commit: SINGA-70 Refactor API of Layer,
Worker, Server and Driver
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/neuralnet/neuron_layer.h
----------------------------------------------------------------------
diff --git a/include/neuralnet/neuron_layer.h b/include/neuralnet/neuron_layer.h
index 6c4647d..51ba304 100644
--- a/include/neuralnet/neuron_layer.h
+++ b/include/neuralnet/neuron_layer.h
@@ -7,9 +7,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
-*
+*
* http://www.apache.org/licenses/LICENSE-2.0
-*
+*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -38,9 +38,9 @@ class ConvolutionLayer : public NeuronLayer {
public:
~ConvolutionLayer();
- void Setup(const LayerProto& proto, int npartitions) override;
- void ComputeFeature(int flag, Metric* perf) override;
- void ComputeGradient(int flag, Metric* perf) override;
+ void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
const std::vector<Param*> GetParams() const override {
std::vector<Param*> params{weight_, bias_};
return params;
@@ -63,15 +63,15 @@ class ConvolutionLayer : public NeuronLayer {
*/
class CConvolutionLayer : public ConvolutionLayer {
public:
- void ComputeFeature(int flag, Metric* perf) override;
- void ComputeGradient(int flag, Metric* perf) override;
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
};
class DropoutLayer : public NeuronLayer {
public:
- void Setup(const LayerProto& proto, int npartitions) override;
- void ComputeFeature(int flag, Metric* perf) override;
- void ComputeGradient(int flag, Metric* perf) override;
+ void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
protected:
// drop probability
float pdrop_;
@@ -90,9 +90,9 @@ class DropoutLayer : public NeuronLayer {
* b_i, the neuron after normalization, N is the total num of kernels
*/
class LRNLayer : public NeuronLayer {
- void Setup(const LayerProto& proto, int npartitions) override;
- void ComputeFeature(int flag, Metric *perf) override;
- void ComputeGradient(int flag, Metric* perf) override;
+ void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
protected:
//! shape of the bottom layer feature
@@ -106,9 +106,9 @@ class LRNLayer : public NeuronLayer {
class PoolingLayer : public NeuronLayer {
public:
- void Setup(const LayerProto& proto, int npartitions) override;
- void ComputeFeature(int flag, Metric *perf) override;
- void ComputeGradient(int flag, Metric* perf) override;
+ void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
protected:
int kernel_, pad_, stride_;
@@ -121,26 +121,26 @@ class PoolingLayer : public NeuronLayer {
*/
class CPoolingLayer : public PoolingLayer {
public:
- void Setup(const LayerProto& proto, int npartitions);
- void ComputeFeature(int flag, Metric *perf) override;
- void ComputeGradient(int flag, Metric* perf) override;
+ void Setup(const LayerProto& proto, const vector<Layer*>& srclayers);
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
private:
Blob<float> mask_;
};
class ReLULayer : public NeuronLayer {
public:
- void Setup(const LayerProto& proto, int npartitions) override;
- void ComputeFeature(int flag, Metric *perf) override;
- void ComputeGradient(int flag, Metric* perf) override;
+ void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
};
class InnerProductLayer : public NeuronLayer {
public:
~InnerProductLayer();
- void Setup(const LayerProto& proto, int npartitions) override;
- void ComputeFeature(int flag, Metric* perf) override;
- void ComputeGradient(int flag, Metric* perf) override;
+ void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
const std::vector<Param*> GetParams() const override {
std::vector<Param*> params{weight_, bias_};
return params;
@@ -159,9 +159,9 @@ class InnerProductLayer : public NeuronLayer {
*/
class STanhLayer : public NeuronLayer {
public:
- void Setup(const LayerProto& proto, int npartitions) override;
- void ComputeFeature(int flag, Metric *perf) override;
- void ComputeGradient(int flag, Metric* perf) override;
+ void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
};
/**
@@ -174,19 +174,19 @@ class SigmoidLayer: public Layer {
using Layer::ComputeFeature;
using Layer::ComputeGradient;
- void Setup(const LayerProto& proto, int npartitions) override;
- void ComputeFeature(int flag, Metric* perf) override;
- void ComputeGradient(int flag, Metric* perf) override;
+ void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
};
/**
* Base layer for RBM models.
*/
-class RBMLayer: public Layer {
+class RBMLayer: virtual public Layer {
public:
virtual ~RBMLayer() {}
- void Setup(const LayerProto& proto, int npartitions) override;
+ void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
const Blob<float>& neg_data(const Layer* layer) {
return neg_data_;
}
@@ -218,12 +218,12 @@ class RBMLayer: public Layer {
/**
* RBM visible layer
*/
-class RBMVisLayer: public RBMLayer {
+class RBMVisLayer: public RBMLayer, public LossLayer {
public:
~RBMVisLayer();
- void Setup(const LayerProto& proto, int npartitions) override;
- void ComputeFeature(int flag, Metric* perf) override;
- void ComputeGradient(int flag, Metric* perf) override;
+ void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
private:
RBMLayer* hid_layer_;
@@ -235,9 +235,9 @@ class RBMVisLayer: public RBMLayer {
class RBMHidLayer: public RBMLayer {
public:
~RBMHidLayer();
- void Setup(const LayerProto& proto, int npartitions) override;
- void ComputeFeature(int flag, Metric* perf) override;
- void ComputeGradient(int flag, Metric* perf) override;
+ void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+ void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+ void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
private:
RBMLayer *vis_layer_;
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/server.h
----------------------------------------------------------------------
diff --git a/include/server.h b/include/server.h
new file mode 100644
index 0000000..4b75430
--- /dev/null
+++ b/include/server.h
@@ -0,0 +1,133 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#ifndef SINGA_SERVER_H_
+#define SINGA_SERVER_H_
+
+#include <unordered_map>
+#include <vector>
+#include "comm/socket.h"
+#include "proto/job.pb.h"
+#include "utils/param.h"
+#include "utils/updater.h"
+
+namespace singa {
+
+ /* Repsond to worker's get/put/udpate request, and periodically syncing with
+ * other servers.
+ *
+ * Normally, the Server creates a response message for each request which
+ * will be sent back to the one who issued the request. However, if the request
+ * are not processed successfully, the original message will be returned. The
+ * sever does not know the returned message is a response or the original
+ * message. It just sends it to the router. The router will decided to
+ * re-send the request to the server or send it to the worker.
+ */
+class Server {
+ public:
+ ~Server();
+ Server(int group_id, int server_id,
+ const JobProto& job_conf,
+ const std::vector<int>& slice2group,
+ const std::vector<int>& slice2server);
+ void Run();
+ inline int grp_id() const { return grp_id_; }
+ inline int id() const { return id_; }
+
+ protected:
+ /**
+ * Process GET request.
+ *
+ * @return the orignal message or a response message which contains the values
+ * of the Param with the request version.
+ */
+ Msg* HandleGet(Msg** msg);
+ /**
+ * Process Update request.
+ *
+ * It waits until received the gradients from all workers from the same worker
+ * group. After updating, it responses to each sender with the new Param
+ * values. It may generate a sync message to the server group that maintains
+ * the global version of the updated Param (slice).
+ *
+ * Note: there is no counter for each worker group on the number of received
+ * update requests. Hence it is possible that the server would conduct the
+ * update when it receives x requests from group a and y requests from group
+ * b where x + y = group size. To avoid this problem, we can
+ * -# maintain request list for each group for each Param at the server side
+ * -# do not span a worker group among multiple nodes. then the updates from
+ * the same group would be locally aggregated on the worker node. And the
+ * server would conduct the update immediately after receiving the aggregated
+ * request.
+ * -# launch only one worker group.
+ *
+ * @return the orignal message or response message
+ */
+ const std::vector<Msg*> HandleUpdate(Msg **msg);
+ /**
+ * Process PUT request.
+ *
+ * @return the original message or response message. If we don't want to
+ * acknowledge the put request, then return nullptr.
+ */
+ Msg* HandlePut(Msg **msg);
+ /**
+ * Handle sync request from other server groups.
+ *
+ * It adds updates of Param (slice) from other server groups directly to
+ * local Param (slice). Currently, each Param (slice) has a master group,
+ * i.e., slice2group_[sliceid], which would receive such requests from all
+ * other server groups for the Param object.
+ *
+ * @param msg request msg containing the parameter updates
+ * @return response msg that contains the fresh parameter values.
+ */
+ Msg* HandleSyncRequest(Msg** msg);
+ /**
+ * Handle sync response.
+ *
+ * The response msg includes the latest values of a Param object from the
+ * server group that maintainers this Param object.
+ * The local Param values are replaced with the addition result of local
+ * udpates since the sync request was sent and the received Param values.
+ *
+ * @param response message
+ */
+ void HandleSyncResponse(Msg** msg);
+
+ protected:
+ int grp_id_ = -1;
+ int id_ = -1;
+ Updater* updater_ = nullptr;
+ //!< map from slice ID to slice and deleted in the destructor
+ std::unordered_map<int, ParamEntry*> shard_;
+ std::vector<int> slice2group_, slice2server_;
+ //!< num of updates from last sync with master server group for a param/slice
+ std::vector<int> n_updates_;
+ //!< num of sync requests that have not been responded
+ std::vector<int> n_pending_sync_;
+ std::vector<Blob<float>> last_sync_;
+ std::unordered_map<int, std::vector<Msg*>> buffer_requests_;
+};
+
+} // namespace singa
+
+#endif // SINGA_SERVER_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/singa.h
----------------------------------------------------------------------
diff --git a/include/singa.h b/include/singa.h
index d4ee557..6c801ab 100644
--- a/include/singa.h
+++ b/include/singa.h
@@ -7,9 +7,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
-*
+*
* http://www.apache.org/licenses/LICENSE-2.0
-*
+*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -22,16 +22,15 @@
#ifndef SINGA_SINGA_H_
#define SINGA_SINGA_H_
-#include "communication/socket.h"
+#include "comm/socket.h"
#include "neuralnet/neuralnet.h"
#include "neuralnet/layer.h"
#include "proto/job.pb.h"
#include "proto/singa.pb.h"
-#include "trainer/trainer.h"
#include "utils/common.h"
#include "utils/param.h"
#include "utils/singleton.h"
#include "utils/factory.h"
-#include "driver.h"
+#include "./driver.h"
#endif // SINGA_SINGA_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/stub.h
----------------------------------------------------------------------
diff --git a/include/stub.h b/include/stub.h
new file mode 100644
index 0000000..719f033
--- /dev/null
+++ b/include/stub.h
@@ -0,0 +1,109 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#ifndef SINGA_STUB_H_
+#define SINGA_STUB_H_
+
+#include <queue>
+#include <unordered_map>
+#include <vector>
+#include <string>
+#include "comm/socket.h"
+#include "neuralnet/neuralnet.h"
+#include "proto/job.pb.h"
+#include "proto/singa.pb.h"
+#include "utils/factory.h"
+#include "utils/param.h"
+#include "utils/singleton.h"
+#include "./server.h"
+#include "./worker.h"
+
+namespace singa {
+
+class Stub {
+ public:
+ ~Stub();
+ /**
+ * Find an endpoint to bind.
+ */
+ void Setup();
+ /**
+ * The Stub instance runs this function in the main thread to handle (e.g.,
+ * forward) messages from workers and servers.
+ *
+ * @param[in] slice2server the k-th value is the ID of the server that is in
+ * charge of updating the Param slice with ID k. Large Param objects are
+ * sliced into subsets for load-balance. Different subsets are updated by
+ * different servers.
+ */
+ void Run(const vector<int>& slice2server,
+ const std::vector<Worker*>& workers,
+ const std::vector<Server*>& servers);
+
+ const std::string& endpoint() const {
+ return endpoint_;
+ }
+
+ protected:
+ /**
+ * Create a socket to send msg to the specified process
+ * @param dst_procs the dst process (logical) ID
+ * @return the newly created socket
+ */
+ Dealer* CreateInterProcsDealer(int dst_procs);
+ /**
+ * Generate a request message to Get the parameter object.
+ */
+ const std::vector<Msg*> HandleGetRequest(ParamEntry* entry, Msg** msg);
+ void HandleGetResponse(ParamEntry* entry, Msg** msg);
+ /**
+ * Generate a request message to Update the parameter object.
+ */
+ const std::vector<Msg*> HandleUpdateRequest(ParamEntry* entry, Msg** msg);
+ /**
+ * Handle response msg from servers for the update requests.
+ */
+ void HandleUpdateResponse(ParamEntry* entry, Msg** msg);
+ /**
+ * Generate a request message to Put the parameter object.
+ */
+ const std::vector<Msg*> HandlePutRequest(ParamEntry* entry, Msg** msg);
+ /**
+ * Called by HandlePut, HandleUpdate and HandleGet functions
+ * @param type message type
+ * @param version param version
+ * @param entry
+ * @param msg
+ * @param ret generated messages
+ */
+ void GenMsgs(int type, int version, ParamEntry* entry,
+ Msg* msg, std::vector<Msg*> *ret);
+
+
+ protected:
+ Router *router_ = nullptr;
+ std::string endpoint_;
+ std::vector<int> slice2server_;
+};
+
+} // namespace singa
+
+#endif // SINGA_STUB_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/trainer/server.h
----------------------------------------------------------------------
diff --git a/include/trainer/server.h b/include/trainer/server.h
deleted file mode 100644
index 84b3a41..0000000
--- a/include/trainer/server.h
+++ /dev/null
@@ -1,132 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements. See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership. The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied. See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_TRAINER_SERVER_H_
-#define SINGA_TRAINER_SERVER_H_
-
-#include <unordered_map>
-#include <vector>
-#include "communication/socket.h"
-#include "proto/job.pb.h"
-#include "utils/param.h"
-#include "utils/updater.h"
-
-namespace singa {
-
- /* Repsond to worker's get/put/udpate request, and periodically syncing with
- * other servers.
- *
- * Normally, the Server creates a response message for each request which
- * will be sent back to the one who issued the request. However, if the request
- * are not processed successfully, the original message will be returned. The
- * sever does not know the returned message (response or the original message),
- * it just sends it to the router. The router will decide to re-send the
- * request to the server or send it to the worker.
- */
-class Server {
- public:
- Server(int group_id, int server_id);
- ~Server();
- void Setup(const UpdaterProto& proto, const std::vector<int>& slice2group,
- const std::vector<int>& slice2server);
- void Run();
- inline int grp_id() const { return grp_id_; }
- inline int id() const { return id_; }
-
- protected:
- /**
- * Process GET request.
- *
- * @return the orignal message or a response message which contains the values
- * of the Param with the request version.
- */
- Msg* HandleGet(Msg** msg);
- /**
- * Process Update request.
- *
- * It waits until received the gradients from all workers from the same worker
- * group. After updating, it responses to each sender with the new Param
- * values. It may generate a sync message to the server group that maintains
- * the global version of the updated Param (slice).
- *
- * Note: there is no counter for each worker group on the number of received
- * update requests. Hence it is possible that the server would conduct the
- * update when it receives x requests from group a and y requests from group
- * b where x + y = group size. To avoid this problem, we can
- * 1. maintain request list for each group for each Param at the server side
- * 2. do not span a worker group among multiple nodes. then the updates from
- * the same group would be locally aggregated on the worker node. And the
- * server would conduct the update immediately after receiving the aggregated
- * request.
- * 3. launch only one worker group.
- *
- * @return the orignal message or response message
- */
- const std::vector<Msg*> HandleUpdate(Msg **msg);
- /**
- * Process PUT request.
- *
- * @return the original message or response message. If we don't want to
- * acknowledge the put request, then return nullptr.
- */
- Msg* HandlePut(Msg **msg);
- /**
- * Handle sync request from other server groups.
- *
- * It adds updates of Param (slice) from other server groups directly to
- * local Param (slice). Currently, each Param (slice) has a master group,
- * i.e., slice2group_[sliceid], which would receive such requests from all
- * other server groups for the Param object.
- *
- * @param msg request msg containing the parameter updates
- * @return response msg that contains the fresh parameter values.
- */
- Msg* HandleSyncRequest(Msg** msg);
- /**
- * Handle sync response.
- *
- * The response msg includes the latest values of a Param object, for which
- * this server sent the sync request to the master/maintainer group.
- * The local Param values are replaced with the addition result of local
- * udpates since the sync request was sent and the received Param values.
- *
- * @param response message
- */
- void HandleSyncResponse(Msg** msg);
-
- protected:
- int grp_id_ = -1;
- int id_ = -1;
- Updater* updater_ = nullptr;
- //!< map from slice ID to slice and deleted in the destructor
- std::unordered_map<int, ParamEntry*> shard_;
- std::vector<int> slice2group_, slice2server_;
- //!< num of updates from last sync with master server group for a param/slice
- std::vector<int> n_updates_;
- //!< num of sync requests that have not been responded
- std::vector<int> n_pending_sync_;
- std::vector<Blob<float>> last_sync_;
- std::unordered_map<int, std::vector<Msg*>> buffer_requests_;
-};
-
-} // namespace singa
-
-#endif // SINGA_TRAINER_SERVER_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/trainer/trainer.h
----------------------------------------------------------------------
diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h
deleted file mode 100644
index 1c0e039..0000000
--- a/include/trainer/trainer.h
+++ /dev/null
@@ -1,163 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements. See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership. The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied. See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_TRAINER_TRAINER_H_
-#define SINGA_TRAINER_TRAINER_H_
-
-#include <queue>
-#include <unordered_map>
-#include <vector>
-#include "communication/socket.h"
-#include "neuralnet/neuralnet.h"
-#include "proto/job.pb.h"
-#include "proto/singa.pb.h"
-#include "trainer/server.h"
-#include "trainer/worker.h"
-#include "utils/factory.h"
-#include "utils/param.h"
-#include "utils/singleton.h"
-
-namespace singa {
-
-/**
- * Every running process has a training object which launches one or more
- * worker (and server) threads.
- *
- * The main thread runs a loop to forward messages between workers and servers.
- */
-class Trainer{
- public:
- ~Trainer();
- /**
- * Entrance function which construct the workers and servers, and luanch
- * one thread per worker/server.
- *
- * @param resume if true resume the training from the latest checkpoint files
- * @param singaConf global singa configuration including zookeeper and
- * @param jobConf job configuration, including cluster and model configuration
- */
- void Start(bool resume, const SingaProto& singaConf, JobProto* jobConf);
-
- protected:
- /**
- * Setting the checkpoint field of model configuration to resume training.
- *
- * The checkpoint folder will be searched to get the files for the latest
- * checkpoint, which will be added into the checkpoint field. The workers
- * would then load the values of params from the checkpoint files.
- *
- * @param jobConf job configuration
- */
- void Resume(JobProto* jobConf);
- /**
- * Create server instances.
- * @param nthread total num of threads in current procs which is used to
- * assign each thread a local thread ID. The number of workers is extracted
- * from Cluster
- * @param jobConf
- * @return server instances
- */
- std::vector<Server*> CreateServers(const JobProto& jobConf);
- /**
- * Create workers instances.
- * @param nthread total num of threads in current procs which is used to
- * assign each thread a local thread ID. The number of workers is extracted
- * from Cluster
- * @param jobConf
- * @return worker instances
- */
- std::vector<Worker*> CreateWorkers(const JobProto& jobConf);
- /**
- * Setup workers and servers.
- *
- * For each worker, create and assign a neuralnet to it.
- * For each server, create and assign the param shard to it.
- * Create the partition map from slice ID to server
- * @param modelConf
- * @param workers
- * @param servers
- */
- void SetupWorkerServer(const JobProto& jobConf,
- const std::vector<Worker*>& workers,
- const std::vector<Server*>& servers);
- void Run(const std::vector<Worker*>& workers,
- const std::vector<Server*>& servers);
- /**
- * Display metrics to log (standard output)
- */
- void DisplayMetric(Msg** msg);
- /**
- * Create a socket to send msg to the specified process
- * @param dst_procs the dst process (logical) ID
- * @return the newly created socket
- */
- Dealer* CreateInterProcsDealer(int dst_procs);
- /**
- * Handle messages to local servers and local stub
- */
- void HandleLocalMsg(std::queue<Msg*>* msg_queue, Msg** msg);
- /**
- * Generate a request message to Get the parameter object.
- */
- const std::vector<Msg*> HandleGet(ParamEntry* entry, Msg** msg);
- void HandleGetResponse(ParamEntry* entry, Msg** msg);
- /**
- * Generate a request message to Update the parameter object.
- */
- const std::vector<Msg*> HandleUpdate(ParamEntry* entry, Msg** msg);
- void HandleUpdateResponse(ParamEntry* entry, Msg** msg);
- /**
- * Generate a request message to Put the parameter object.
- */
- const std::vector<Msg*> HandlePut(ParamEntry* entry, Msg** msg);
- /**
- * Called by HandlePut, HandleUpdate and HandleGet functions
- * @param type message type
- * @param version param version
- * @param entry
- * @param msg
- * @param ret generated messages
- */
- void GenMsgs(int type, int version, ParamEntry* entry,
- Msg* msg, std::vector<Msg*> *ret);
- /**
- * Get a hash id for a Param object from a group.
- *
- * Simple multiple group_id with a large prime number 997 (assuming there are
- * no more than 997 worker groups) and plus owner param id.
- */
- inline int Hash(int grp_id, int param_id) {
- return grp_id * 997 + param_id;
- }
-
- protected:
- int procs_id_ = -1;
- Router *router_ = nullptr;
- std::unordered_map<int, ParamEntry*> worker_shard_;
- //!< map from slice to the server that updates it
- std::vector<int> slice2server_;
- // a buffer of created nets, will destroy them all in destructor
- std::vector<NeuralNet*> nets_;
-};
-
-} // namespace singa
-
-#endif // SINGA_TRAINER_TRAINER_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/trainer/worker.h
----------------------------------------------------------------------
diff --git a/include/trainer/worker.h b/include/trainer/worker.h
deleted file mode 100644
index 66439ec..0000000
--- a/include/trainer/worker.h
+++ /dev/null
@@ -1,258 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements. See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership. The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied. See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_TRAINER_WORKER_H_
-#define SINGA_TRAINER_WORKER_H_
-
-#include <string>
-#include "communication/socket.h"
-#include "neuralnet/neuralnet.h"
-#include "proto/job.pb.h"
-
-namespace singa {
-
-//!< sleep 5 milliseconds if the Param is not updated to the expected version
-const int kCollectSleepTime = 5;
-/**
- * The Worker class which runs the training algorithm.
- * The first worker group will initialize parameters of the Net,
- * and put them into the distributed memory/table.
- * The virtual function TrainOneBatch and TestOneBatch implement the
- * training and test algorithm for one mini-batch data.
- *
- * Child workers override the two functions to implement their training
- * algorithms, e.g., the BPWorker/CDWorker/BPTTWorker implements the BP/CD/BPTT
- * algorithm respectively.
- */
-class Worker {
- public:
- static Worker* Create(const JobProto& proto);
- /**
- * @param thread_id local thread index within the procs
- * @param grp_id global worker group ID
- * @param id worker ID within the group
- */
- virtual void Init(int grp_id, int id);
- virtual ~Worker();
- /**
- * Setup members
- */
- void Setup(const JobProto& job, NeuralNet* train_net, NeuralNet* valid_net,
- NeuralNet* test_net);
- /**
- * Init all local params (i.e., params from layers resident in this worker).
- *
- * If the param is owned by the worker, then init it and put it to servers.
- * Otherwise call Get() to get the param. The Get may not send get request.
- * Because the param's own is in the same procs. Once the owner initializes
- * the param, its version is visiable to all shares.
- * If the training starts from scrath, the params are initialzed using random
- * distributions, e.g., Gaussian distribution. After that, the worker may
- * train for a couple of steps to warmup the params before put
- * them to servers (warmup of JobProto controls this).
- *
- * If the owner param is available from checkpoint file, then its
- * values are parsed from the checkpoint file instead of randomly initialized.
- * For params who do not have checkpoints, randomly init them.
- */
- void InitLocalParams();
- /**
- * Main function of Worker.
- *
- * Train the neuralnet step by step, test/validation is done periodically.
- */
- void Run();
- /**
- * Checkpoint all params owned by the worker from the first group onto disk.
- * The serialization is done using BlobProtos which includes the name, version
- * and values of each Param.
- * Different worker would generate different checkpoint files. The file path
- * is <workspace>/checkpoint-<jobname>-step<step>-worker<worker_id>.bin
- * @param step training step of this worker
- * @param net the training net whose params will be dumped.
- */
- void Checkpoint(int step, NeuralNet* net);
- /**
- * Test the perforance of the learned model on validation or test dataset.
- * Test is done by the first group.
- * @param net, neural network
- */
- void Test(int nsteps, Phase phase, NeuralNet* net);
- /**
- * Train one mini-batch.
- * Test/Validation is done before training.
- */
- virtual void TrainOneBatch(int step, Metric* perf) = 0;
- /**
- * Test/validate one mini-batch.
- */
- virtual void TestOneBatch(int step, Phase phase, NeuralNet* net,
- Metric* perf) = 0;
- /**
- * Report performance to the stub.
- *
- * @param prefix display prefix, e.g., 'Train', 'Test'
- * @param perf
- */
- void Report(const std::string& prefix, const Metric & perf);
- /**
- * Put Param to server.
- * @param param
- * @param step used as current param version for the put request
- */
- int Put(Param* param, int step);
- /**
- * Get Param with specific version from server
- * If the current version >= the requested version, then return.
- * Otherwise send a get request to stub who would forwards it to servers.
- * @param param
- * @param step requested param version
- */
- int Get(Param* param, int step);
- /**
- * Update Param
- * @param param
- * @param step training step used for updating (e.g., deciding learning rate)
- */
- int Update(Param* param, int step);
- /**
- * Block until the param is updated since sending the update request
- *
- * @param param
- * @param step not used
- */
- int Collect(Param* param, int step);
- /**
- * Call Collect for every param of net
- */
- int CollectAll(NeuralNet* net, int step);
- /**
- * Receive blobs from other workers due to model partitions.
- */
- void ReceiveBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net);
- /**
- * Send blobs to other workers due to model partitions.
- */
- void SendBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net);
- /**
- * Check is it time to display training info, e.g., loss and precison.
- */
- inline bool DisplayNow(int step) const {
- return job_conf_.disp_freq() > 0
- && step >= job_conf_.disp_after()
- && ((step - job_conf_.disp_after()) % job_conf_.disp_freq() == 0);
- }
- /**
- * Check is it time to display training info, e.g., loss and precison.
- */
- inline bool DisplayDebugInfo(int step) const {
- return DisplayNow(step) && job_conf_.debug() && grp_id_ == 0;
- }
- /**
- * Check is it time to stop
- */
- inline bool StopNow(int step) const {
- return step >= job_conf_.train_steps();
- }
- /**
- * Check is it time to do checkpoint.
- */
- inline bool CheckpointNow(int step) const {
- return grp_id_ == 0
- && job_conf_.checkpoint_freq() > 0
- && step >= job_conf_.checkpoint_after()
- && ((step - job_conf_.checkpoint_after())
- % job_conf_.checkpoint_freq() == 0);
- }
- /**
- * Check is it time to do test.
- * @param step the ::Train() has been called this num times.
- */
- inline bool TestNow(int step) const {
- return grp_id_ == 0
- && job_conf_.test_freq() > 0
- && job_conf_.test_steps() > 0
- && step >= job_conf_.test_after()
- && ((step - job_conf_.test_after()) % job_conf_.test_freq() == 0);
- }
- /**
- * Check is it time to do validation.
- * @param step the ::Train() has been called step times.
- */
- inline bool ValidateNow(int step) const {
- return grp_id_ == 0
- && job_conf_.valid_freq() > 0
- && job_conf_.valid_steps() > 0
- && step >= job_conf_.valid_after()
- && ((step - job_conf_.valid_after()) % job_conf_.valid_freq() == 0);
- }
- /**
- * @return group ID
- */
- int grp_id() const { return grp_id_; }
- /**
- * @reutrn worker ID within the worker group.
- */
- int id() const { return id_; }
-
- protected:
- int grp_id_ = -1, id_ = -1;
- int step_ = 0;
- JobProto job_conf_;
- NeuralNet* train_net_ = nullptr;
- NeuralNet* test_net_ = nullptr;
- NeuralNet* validation_net_ = nullptr;
- Dealer* layer_dealer_ = nullptr;
- Dealer* dealer_ = nullptr;
-};
-
-class BPWorker: public Worker {
- public:
- void TrainOneBatch(int step, Metric* perf) override;
- void TestOneBatch(int step, Phase phase, NeuralNet* net, Metric* perf)
- override;
- void Forward(int step, Phase phase, NeuralNet* net, Metric* perf);
- void Backward(int step, NeuralNet* net);
-};
-
-class CDWorker: public Worker {
- public:
- void TrainOneBatch(int step, Metric* perf) override;
- void TestOneBatch(int step, Phase phase, NeuralNet* net, Metric* perf)
- override;
-};
-
-inline int BlobTrgt(int grp, int layer) {
- return (grp << 16) | layer;
-}
-
-inline int BlobGrp(int blob_trgt) {
- return blob_trgt >> 16;
-}
-
-inline int BlobLayer(int blob_trgt) {
- static int mask = (1 << 16) -1;
- return blob_trgt & mask;
-}
-
-} // namespace singa
-
-#endif // SINGA_TRAINER_WORKER_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/utils/param.h
----------------------------------------------------------------------
diff --git a/include/utils/param.h b/include/utils/param.h
index e6c8c7c..f690438 100644
--- a/include/utils/param.h
+++ b/include/utils/param.h
@@ -7,9 +7,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
-*
+*
* http://www.apache.org/licenses/LICENSE-2.0
-*
+*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -25,12 +25,13 @@
#include <memory>
#include <string>
#include <vector>
-#include "communication/msg.h"
+
+#include "comm/msg.h"
#include "proto/job.pb.h"
#include "utils/blob.h"
namespace singa {
-
+using std::vector;
/**
* Base parameter generator which intializes parameter values.
*/
@@ -92,7 +93,34 @@ class UniformSqrtFanInOutGen : public UniformGen {
*/
class Param {
public:
- static Param* Create(const ParamProto& proto);
+ /**
+ * Create an instance of (sub) Param class based on the type from the
+ * configuration.
+ *
+ * @param[in] conf configuration
+ * @param a pointer to an instance
+ */
+ static Param* Create(const ParamProto& conf);
+
+ /**
+ * Try to slice the Param objects (from a neural net) into a given number of
+ * servers (groups) evenly. This is to achieve load-balance among servers.
+ *
+ * It does not change the Param objects, but just computes the length of each
+ * slice.
+ *
+ * @param num number of servers (groups) for maintaining the Param objects.
+ * @param params all Param objects from a neural net.
+ * @return the length of each slice.
+ */
+ static const vector<int> ComputeSlices(int num, const vector<Param*>& params);
+ /**
+ * It computes the length of each slice and slices the Param objects by adding
+ * the slicing information into every Param object.
+ *
+ * @copydetails ComputeSlices()
+ */
+ static void SliceParams(int num, const vector<Param*>& params);
Param() {}
virtual ~Param() {}
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/worker.h
----------------------------------------------------------------------
diff --git a/include/worker.h b/include/worker.h
new file mode 100644
index 0000000..58f02c4
--- /dev/null
+++ b/include/worker.h
@@ -0,0 +1,311 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#ifndef SINGA_WORKER_H_
+#define SINGA_WORKER_H_
+
+#include <string>
+#include <vector>
+#include "comm/socket.h"
+#include "neuralnet/neuralnet.h"
+#include "proto/job.pb.h"
+
+namespace singa {
+
+//!< sleep 5 milliseconds if the Param is not updated to the expected version
+const int kCollectSleepTime = 5;
+/**
+ * The Worker class which runs the training algorithm.
+ * The first worker group will initialize parameters of the Net,
+ * and put them into the distributed memory/table.
+ * The virtual function TrainOneBatch and TestOneBatch implement the
+ * training and test algorithm for one mini-batch data.
+ *
+ * Child workers override the two functions to implement their training
+ * algorithms, e.g., the BPWorker/CDWorker/BPTTWorker implements the BP/CD/BPTT
+ * algorithm respectively.
+ */
+class Worker {
+ public:
+ /**
+ * Create an instance of the subclass of Worker.
+ *
+ * @param[in] conf configuration of the TrainOneBatch algorithm. Different
+ * Worker subclasses implement different algorithms. Hence the creation is
+ * based on the TrainOneBatch algorithm type. Currently SINGA
+ * provides two algorithms:
+ * -# Back-propagation for the feed-forward models, e.g., CNN and MLP, and the
+ * recurrent neural networks.
+ * -# Contrastive divergence for the energy models, e.g., RBM.
+ *
+ * @return a pointer to the instance of the Worker subclass.
+ */
+ static Worker* Create(const AlgProto& conf);
+ virtual ~Worker();
+ /**
+ * @param[in] grp_id global worker group ID
+ * @param[in] id worker ID within the group
+ * @param[in] conf job configuration
+ * @param[in] train_net pointer to the training neural net, which could be
+ * shared with other workers from the same group. Different workers run over
+ * differnt subset of layers.
+ * @param[in] val_net pointer to the validation neural net. Currently only the
+ * first worker from the first group would have validation neural net. All
+ * other workers receive nullptr for this argument.
+ * @param[in] test_net pointer to the test neural net. Currently only the
+ * first worker from the first group would have test neural net. All other
+ * workers receive nullptr for this argument.
+ */
+ virtual void Setup(int grp_id, int id, const JobProto& conf,
+ NeuralNet* train_net, NeuralNet* val_net, NeuralNet* test_net);
+
+ /**
+ * Main function of Worker.
+ *
+ * Train the neuralnet step by step, test/validation is done periodically.
+ */
+ void Run();
+
+ /**
+ * Init values of Param instances assocaited with local layers (i.e., layers
+ * dispatched to this worker).
+ *
+ * If one Param is owned by the worker, then it should be initialized and put
+ * to servers. Otherwise Get() should be called to get the Param. The Get()
+ * may not send get requests if the Param owner is in the same procs, for
+ * which case the memory space of the Param objects are shared. But if this
+ * worker and the Param owner worker run on different devices (e.g., GPUs),
+ * then the get request would be sent.
+ *
+ * If the training starts from scrath, every Param object is initialzed using
+ * ParamGenerator. After that, the worker may
+ * train for a couple of steps to warmup the params before put
+ * them to servers (warmup of JobProto controls this).
+ *
+ * If one Param object's name matches that of one Param object from the
+ * checkpoint files, its Param values would be loaded from checkpoint files.
+ *
+ * @param[in] job_conf job configuration which provides settings for
+ * checkpoint file paths, warmup steps and Param versions.
+ * @param[out] net pointer to a neural net whose Param values will be
+ * initialized.
+ */
+ void InitNetParams(const JobProto& job_conf, NeuralNet* net);
+
+ /**
+ * Checkpoint all Param objects owned by the worker onto disk.
+ * The serialization is done using BlobProtos which includes the name, version
+ * and values of each Param object.
+ * Different workers would generate different checkpoint files. The file path
+ * is <workspace>/checkpoint-<jobname>-step<step>-worker<worker_id>.bin
+ * @param[in] step training step
+ * @param[in] folder directory to put the checkpoint file
+ * @param net the training net whose Param objects will be dumped.
+ */
+ void Checkpoint(int step, const std::string& folder, NeuralNet* net);
+
+ /**
+ * Train one mini-batch.
+ * Test/Validation is done before training.
+ *
+ * @param[in] step training step.
+ * @param[in] net neural net to be trained.
+ */
+ virtual void TrainOneBatch(int step, NeuralNet* net) = 0;
+
+ /**
+ * Test/validate one mini-batch data.
+ *
+ * @param[in] step test step.
+ * @param[in] phase test could be done for validation or test phase.
+ * @param[in] net neural net for test
+ */
+ virtual void TestOneBatch(int step, Phase phase, NeuralNet* net) = 0;
+
+ /**
+ * Display infomation from layers.
+ *
+ * @param flag could be a combination of multiple phases, e.g, kTest|kForward,
+ * it is passed to the Layer::ToString() function for each layer to decide
+ * what to display .
+ * @param prefix display prefix, e.g., 'Train step 100', 'Test step 90'.
+ * @param net display layers from this neural net.
+ */
+ void Display(int flag, const std::string& prefix, NeuralNet* net);
+
+ /**
+ * Put Param values to server.
+ *
+ * @param param
+ * @param step used as current param version for the put request
+ */
+ int Put(int step, Param* param);
+
+ /**
+ * Get Param with specific version from server
+ * If the current version >= the requested version, then return.
+ * Otherwise send a get request to stub who would forwards it to servers.
+ * @param param
+ * @param step requested param version
+ */
+ int Get(int step, Param* param);
+
+ /**
+ * Update Param.
+ *
+ * @param param
+ * @param step training step used for updating (e.g., deciding learning rate).
+ */
+ int Update(int step, Param* param);
+
+ /**
+ * Wait for the response of the update/get requests.
+ *
+ * @param param
+ * @param step not used now.
+ */
+ int Collect(int step, Param* param);
+
+ /**
+ * Call Collect() for every param of net
+ */
+ int CollectAll(int step, NeuralNet* net);
+
+ /**
+ * Receive blobs from other workers due to model partitions.
+ */
+ void ReceiveBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net);
+
+ /**
+ * Send blobs to other workers due to model partitions.
+ */
+ void SendBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net);
+
+
+ /**
+ * @param[in] step
+ * @return true if it is time to display training info, e.g., loss; otherwise
+ * false.
+ */
+ inline bool DisplayNow(int step) const {
+ return job_conf_.disp_freq() > 0
+ && step >= job_conf_.disp_after()
+ && ((step - job_conf_.disp_after()) % job_conf_.disp_freq() == 0);
+ }
+ /**
+ * @param[in] step
+ * @return true if it is time to finish the training; otherwise false.
+ */
+ inline bool StopNow(int step) const {
+ return step >= job_conf_.train_steps();
+ }
+ /**
+ * @param[in] step
+ * @return true if it is time to do checkpoint Param objects; otherwise false.
+ */
+ inline bool CheckpointNow(int step) const {
+ return job_conf_.checkpoint_freq() > 0
+ && step >= job_conf_.checkpoint_after()
+ && ((step - job_conf_.checkpoint_after())
+ % job_conf_.checkpoint_freq() == 0);
+ }
+ /**
+ * @param[in] step
+ * @return true if it is time to do test over the test dataset.
+ */
+ inline bool TestNow(int step) const {
+ return job_conf_.test_freq() > 0
+ && job_conf_.test_steps() > 0
+ && step >= job_conf_.test_after()
+ && ((step - job_conf_.test_after()) % job_conf_.test_freq() == 0);
+ }
+ /**
+ * @param[in] step
+ * @return true if it is time to do test over the validation dataset.
+ */
+ inline bool ValidateNow(int step) const {
+ return job_conf_.validate_freq() > 0
+ && job_conf_.validate_steps() > 0
+ && step >= job_conf_.validate_after()
+ && ((step - job_conf_.validate_after()) % job_conf_.validate_freq() == 0);
+ }
+ /**
+ * @return a vector with pointers to all neural nets.
+ */
+ const std::vector<NeuralNet*> GetNets() const {
+ return std::vector<NeuralNet*> {train_net_, val_net_, test_net_};
+ }
+ /**
+ * @return training net.
+ */
+ inline NeuralNet* train_net() const {
+ return train_net_;
+ }
+ /**
+ * @return group ID
+ */
+ inline int grp_id() const { return grp_id_; }
+ /**
+ * @reutrn worker ID within the worker group.
+ */
+ inline int id() const { return id_; }
+
+ protected:
+ int grp_id_ = -1, id_ = -1;
+ int step_ = 0;
+ JobProto job_conf_;
+ NeuralNet* train_net_ = nullptr;
+ NeuralNet* test_net_ = nullptr;
+ NeuralNet* val_net_ = nullptr;
+ Dealer* layer_dealer_ = nullptr;
+ Dealer* dealer_ = nullptr;
+};
+
+class BPWorker: public Worker {
+ public:
+ void TrainOneBatch(int step, NeuralNet* net) override;
+ void TestOneBatch(int step, Phase phase, NeuralNet* net) override;
+ void Forward(int step, Phase phase, NeuralNet* net);
+ void Backward(int step, NeuralNet* net);
+};
+
+class CDWorker: public Worker {
+ public:
+ void TrainOneBatch(int step, NeuralNet* net) override;
+ void TestOneBatch(int step, Phase phase, NeuralNet* net) override;
+};
+
+inline int BlobTrgt(int grp, int layer) {
+ return (grp << 16) | layer;
+}
+
+inline int BlobGrp(int blob_trgt) {
+ return blob_trgt >> 16;
+}
+
+inline int BlobLayer(int blob_trgt) {
+ static int mask = (1 << 16) -1;
+ return blob_trgt & mask;
+}
+
+} // namespace singa
+
+#endif // SINGA_WORKER_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/comm/msg.cc
----------------------------------------------------------------------
diff --git a/src/comm/msg.cc b/src/comm/msg.cc
new file mode 100644
index 0000000..2521c28
--- /dev/null
+++ b/src/comm/msg.cc
@@ -0,0 +1,215 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "comm/msg.h"
+
+#include <glog/logging.h>
+
+namespace singa {
+
+#ifdef USE_ZMQ
+Msg::~Msg() {
+ if (msg_ != nullptr)
+ zmsg_destroy(&msg_);
+ frame_ = nullptr;
+}
+
+Msg::Msg() {
+ msg_ = zmsg_new();
+}
+
+Msg::Msg(const Msg& msg) {
+ src_ = msg.src_;
+ dst_ = msg.dst_;
+ type_ = msg.type_;
+ trgt_val_ = msg.trgt_val_;
+ trgt_version_ = msg.trgt_version_;
+ msg_ = zmsg_dup(msg.msg_);
+}
+
+Msg::Msg(int src, int dst) {
+ src_ = src;
+ dst_ = dst;
+ msg_ = zmsg_new();
+}
+
+void Msg::SwapAddr() {
+ std::swap(src_, dst_);
+}
+
+int Msg::size() const {
+ return zmsg_content_size(msg_);
+}
+
+void Msg::AddFrame(const void* addr, int nBytes) {
+ zmsg_addmem(msg_, addr, nBytes);
+}
+
+int Msg::FrameSize() {
+ return zframe_size(frame_);
+}
+
+void* Msg::FrameData() {
+ return zframe_data(frame_);
+}
+
+char* Msg::FrameStr() {
+ return zframe_strdup(frame_);
+}
+bool Msg::NextFrame() {
+ frame_ = zmsg_next(msg_);
+ return frame_ != nullptr;
+}
+
+void Msg::FirstFrame() {
+ frame_ = zmsg_first(msg_);
+}
+
+void Msg::LastFrame() {
+ frame_ = zmsg_last(msg_);
+}
+
+void Msg::ParseFromZmsg(zmsg_t* msg) {
+ char* tmp = zmsg_popstr(msg);
+ sscanf(tmp, "%d %d %d %d %d",
+ &src_, &dst_, &type_, &trgt_val_, &trgt_version_);
+ frame_ = zmsg_first(msg);
+ msg_ = msg;
+}
+
+zmsg_t* Msg::DumpToZmsg() {
+ zmsg_pushstrf(msg_, "%d %d %d %d %d",
+ src_, dst_, type_, trgt_val_, trgt_version_);
+ zmsg_t *tmp = msg_;
+ msg_ = nullptr;
+ return tmp;
+}
+
+// frame marker indicating this frame is serialize like printf
+#define FMARKER "*singa*"
+
+#define kMaxFrameLen 2048
+
+int Msg::AddFormatFrame(const char *format, ...) {
+ va_list argptr;
+ va_start(argptr, format);
+ int size = strlen(FMARKER);
+ char dst[kMaxFrameLen];
+ memcpy(dst, FMARKER, size);
+ dst[size++] = 0;
+ while (*format) {
+ if (*format == 'i') {
+ int x = va_arg(argptr, int);
+ dst[size++] = 'i';
+ memcpy(dst + size, &x, sizeof(x));
+ size += sizeof(x);
+ } else if (*format == 'f') {
+ float x = static_cast<float> (va_arg(argptr, double));
+ dst[size++] = 'f';
+ memcpy(dst + size, &x, sizeof(x));
+ size += sizeof(x);
+ } else if (*format == '1') {
+ uint8_t x = va_arg(argptr, int);
+ memcpy(dst + size, &x, sizeof(x));
+ size += sizeof(x);
+ } else if (*format == '2') {
+ uint16_t x = va_arg(argptr, int);
+ memcpy(dst + size, &x, sizeof(x));
+ size += sizeof(x);
+ } else if (*format == '4') {
+ uint32_t x = va_arg(argptr, uint32_t);
+ memcpy(dst + size, &x, sizeof(x));
+ size += sizeof(x);
+ } else if (*format == 's') {
+ char* x = va_arg(argptr, char *);
+ dst[size++] = 's';
+ memcpy(dst + size, x, strlen(x));
+ size += strlen(x);
+ dst[size++] = 0;
+ } else if (*format == 'p') {
+ void* x = va_arg(argptr, void *);
+ dst[size++] = 'p';
+ memcpy(dst + size, &x, sizeof(x));
+ size += sizeof(x);
+ } else {
+ LOG(ERROR) << "Unknown format " << *format;
+ }
+ format++;
+ CHECK_LE(size, kMaxFrameLen);
+ }
+ va_end(argptr);
+ zmsg_addmem(msg_, dst, size);
+ return size;
+}
+
+int Msg::ParseFormatFrame(const char *format, ...) {
+ va_list argptr;
+ va_start(argptr, format);
+ char* src = zframe_strdup(frame_);
+ CHECK_STREQ(FMARKER, src);
+ int size = strlen(FMARKER) + 1;
+ while (*format) {
+ if (*format == 'i') {
+ int *x = va_arg(argptr, int *);
+ CHECK_EQ(src[size++], 'i');
+ memcpy(x, src + size, sizeof(*x));
+ size += sizeof(*x);
+ } else if (*format == 'f') {
+ float *x = va_arg(argptr, float *);
+ CHECK_EQ(src[size++], 'f');
+ memcpy(x, src + size, sizeof(*x));
+ size += sizeof(*x);
+ } else if (*format == '1') {
+ uint8_t *x = va_arg(argptr, uint8_t *);
+ memcpy(x, src + size, sizeof(*x));
+ size += sizeof(*x);
+ } else if (*format == '2') {
+ uint16_t *x = va_arg(argptr, uint16_t *);
+ memcpy(x, src + size, sizeof(*x));
+ size += sizeof(*x);
+ } else if (*format == '4') {
+ uint32_t *x = va_arg(argptr, uint32_t *);
+ memcpy(x, src + size, sizeof(*x));
+ size += sizeof(*x);
+ } else if (*format == 's') {
+ char* x = va_arg(argptr, char *);
+ CHECK_EQ(src[size++], 's');
+ int len = strlen(src + size);
+ memcpy(x, src + size, len);
+ x[len] = 0;
+ size += len + 1;
+ } else if (*format == 'p') {
+ void** x = va_arg(argptr, void **);
+ CHECK_EQ(src[size++], 'p');
+ memcpy(x, src + size, sizeof(*x));
+ size += sizeof(*x);
+ } else {
+ LOG(ERROR) << "Unknown format type " << *format;
+ }
+ format++;
+ }
+ va_end(argptr);
+ delete src;
+ return size;
+}
+#endif
+
+} // namespace singa
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/comm/socket.cc
----------------------------------------------------------------------
diff --git a/src/comm/socket.cc b/src/comm/socket.cc
new file mode 100644
index 0000000..b9c7810
--- /dev/null
+++ b/src/comm/socket.cc
@@ -0,0 +1,180 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+#include "comm/socket.h"
+
+#include <glog/logging.h>
+
+namespace singa {
+
+#ifdef USE_ZMQ
+Poller::Poller() {
+ poller_ = zpoller_new(nullptr);
+}
+
+Poller::Poller(SocketInterface* socket) {
+ poller_ = zpoller_new(nullptr);
+ Add(socket);
+}
+
+void Poller::Add(SocketInterface* socket) {
+ zsock_t* zsock = static_cast<zsock_t*>(socket->InternalID());
+ zpoller_add(poller_, zsock);
+ zsock2Socket_[zsock] = socket;
+}
+
+SocketInterface* Poller::Wait(int timeout) {
+ zsock_t* sock = static_cast<zsock_t*>(zpoller_wait(poller_, timeout));
+ if (sock != nullptr)
+ return zsock2Socket_[sock];
+ else
+ return nullptr;
+}
+
+bool Poller::Terminated() {
+ return zpoller_terminated(poller_);
+}
+
+
+Dealer::Dealer() : Dealer(-1) {}
+
+Dealer::Dealer(int id) : id_(id) {
+ dealer_ = zsock_new(ZMQ_DEALER);
+ CHECK_NOTNULL(dealer_);
+}
+
+Dealer::~Dealer() {
+ zsock_destroy(&dealer_);
+}
+
+int Dealer::Connect(const std::string& endpoint) {
+ CHECK_GT(endpoint.length(), 0);
+ if (endpoint.length()) {
+ CHECK_EQ(zsock_connect(dealer_, "%s", endpoint.c_str()), 0);
+ return 1;
+ }
+ return 0;
+}
+
+int Dealer::Send(Msg** msg) {
+ zmsg_t* zmsg = (*msg)->DumpToZmsg();
+ zmsg_send(&zmsg, dealer_);
+ delete *msg;
+ *msg = nullptr;
+ return 1;
+}
+
+Msg* Dealer::Receive() {
+ zmsg_t* zmsg = zmsg_recv(dealer_);
+ if (zmsg == nullptr)
+ return nullptr;
+ Msg* msg = new Msg();
+ msg->ParseFromZmsg(zmsg);
+ return msg;
+}
+
+void* Dealer::InternalID() const {
+ return dealer_;
+}
+
+Router::Router() : Router(100) {}
+
+Router::Router(int bufsize) {
+ nBufmsg_ = 0;
+ bufsize_ = bufsize;
+ router_ = zsock_new(ZMQ_ROUTER);
+ CHECK_NOTNULL(router_);
+ poller_ = zpoller_new(router_);
+ CHECK_NOTNULL(poller_);
+}
+
+Router::~Router() {
+ zsock_destroy(&router_);
+ for (auto it : id2addr_)
+ zframe_destroy(&it.second);
+ for (auto it : bufmsg_) {
+ for (auto *msg : it.second)
+ zmsg_destroy(&msg);
+ }
+}
+int Router::Bind(const std::string& endpoint) {
+ int port = -1;
+ if (endpoint.length()) {
+ port = zsock_bind(router_, "%s", endpoint.c_str());
+ }
+ CHECK_NE(port, -1) << endpoint;
+ LOG(INFO) << "bind successfully to " << endpoint + ":" + std::to_string(port);
+ return port;
+}
+
+int Router::Send(Msg **msg) {
+ zmsg_t* zmsg = (*msg)->DumpToZmsg();
+ int dstid = (*msg)->dst();
+ if (id2addr_.find(dstid) != id2addr_.end()) {
+ // the connection has already been set up
+ zframe_t* addr = zframe_dup(id2addr_[dstid]);
+ zmsg_prepend(zmsg, &addr);
+ zmsg_send(&zmsg, router_);
+ } else {
+ // the connection is not ready, buffer the message
+ if (bufmsg_.size() == 0)
+ nBufmsg_ = 0;
+ bufmsg_[dstid].push_back(zmsg);
+ ++nBufmsg_;
+ CHECK_LE(nBufmsg_, bufsize_);
+ }
+ delete *msg;
+ *msg = nullptr;
+ return 1;
+}
+
+Msg* Router::Receive() {
+ zmsg_t* zmsg = zmsg_recv(router_);
+ if (zmsg == nullptr) {
+ LOG(ERROR) << "Connection broken!";
+ exit(0);
+ }
+ zframe_t* dealer = zmsg_pop(zmsg);
+ Msg* msg = new Msg();
+ msg->ParseFromZmsg(zmsg);
+ if (id2addr_.find(msg->src()) == id2addr_.end()) {
+ // new connection, store the sender's identfier and send buffered messages
+ // for it
+ id2addr_[msg->src()] = dealer;
+ if (bufmsg_.find(msg->src()) != bufmsg_.end()) {
+ for (auto& it : bufmsg_.at(msg->src())) {
+ zframe_t* addr = zframe_dup(dealer);
+ zmsg_prepend(it, &addr);
+ zmsg_send(&it, router_);
+ }
+ bufmsg_.erase(msg->src());
+ }
+ } else {
+ zframe_destroy(&dealer);
+ }
+ return msg;
+}
+
+void* Router::InternalID() const {
+ return router_;
+}
+#endif
+
+} // namespace singa
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/communication/msg.cc
----------------------------------------------------------------------
diff --git a/src/communication/msg.cc b/src/communication/msg.cc
deleted file mode 100644
index 6042057..0000000
--- a/src/communication/msg.cc
+++ /dev/null
@@ -1,215 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements. See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership. The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied. See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#include "communication/msg.h"
-
-#include <glog/logging.h>
-
-namespace singa {
-
-#ifdef USE_ZMQ
-Msg::~Msg() {
- if (msg_ != nullptr)
- zmsg_destroy(&msg_);
- frame_ = nullptr;
-}
-
-Msg::Msg() {
- msg_ = zmsg_new();
-}
-
-Msg::Msg(const Msg& msg) {
- src_ = msg.src_;
- dst_ = msg.dst_;
- type_ = msg.type_;
- trgt_val_ = msg.trgt_val_;
- trgt_version_ = msg.trgt_version_;
- msg_ = zmsg_dup(msg.msg_);
-}
-
-Msg::Msg(int src, int dst) {
- src_ = src;
- dst_ = dst;
- msg_ = zmsg_new();
-}
-
-void Msg::SwapAddr() {
- std::swap(src_, dst_);
-}
-
-int Msg::size() const {
- return zmsg_content_size(msg_);
-}
-
-void Msg::AddFrame(const void* addr, int nBytes) {
- zmsg_addmem(msg_, addr, nBytes);
-}
-
-int Msg::FrameSize() {
- return zframe_size(frame_);
-}
-
-void* Msg::FrameData() {
- return zframe_data(frame_);
-}
-
-char* Msg::FrameStr() {
- return zframe_strdup(frame_);
-}
-bool Msg::NextFrame() {
- frame_ = zmsg_next(msg_);
- return frame_ != nullptr;
-}
-
-void Msg::FirstFrame() {
- frame_ = zmsg_first(msg_);
-}
-
-void Msg::LastFrame() {
- frame_ = zmsg_last(msg_);
-}
-
-void Msg::ParseFromZmsg(zmsg_t* msg) {
- char* tmp = zmsg_popstr(msg);
- sscanf(tmp, "%d %d %d %d %d",
- &src_, &dst_, &type_, &trgt_val_, &trgt_version_);
- frame_ = zmsg_first(msg);
- msg_ = msg;
-}
-
-zmsg_t* Msg::DumpToZmsg() {
- zmsg_pushstrf(msg_, "%d %d %d %d %d",
- src_, dst_, type_, trgt_val_, trgt_version_);
- zmsg_t *tmp = msg_;
- msg_ = nullptr;
- return tmp;
-}
-
-// frame marker indicating this frame is serialize like printf
-#define FMARKER "*singa*"
-
-#define kMaxFrameLen 2048
-
-int Msg::AddFormatFrame(const char *format, ...) {
- va_list argptr;
- va_start(argptr, format);
- int size = strlen(FMARKER);
- char dst[kMaxFrameLen];
- memcpy(dst, FMARKER, size);
- dst[size++] = 0;
- while (*format) {
- if (*format == 'i') {
- int x = va_arg(argptr, int);
- dst[size++] = 'i';
- memcpy(dst + size, &x, sizeof(x));
- size += sizeof(x);
- } else if (*format == 'f') {
- float x = static_cast<float> (va_arg(argptr, double));
- dst[size++] = 'f';
- memcpy(dst + size, &x, sizeof(x));
- size += sizeof(x);
- } else if (*format == '1') {
- uint8_t x = va_arg(argptr, int);
- memcpy(dst + size, &x, sizeof(x));
- size += sizeof(x);
- } else if (*format == '2') {
- uint16_t x = va_arg(argptr, int);
- memcpy(dst + size, &x, sizeof(x));
- size += sizeof(x);
- } else if (*format == '4') {
- uint32_t x = va_arg(argptr, uint32_t);
- memcpy(dst + size, &x, sizeof(x));
- size += sizeof(x);
- } else if (*format == 's') {
- char* x = va_arg(argptr, char *);
- dst[size++] = 's';
- memcpy(dst + size, x, strlen(x));
- size += strlen(x);
- dst[size++] = 0;
- } else if (*format == 'p') {
- void* x = va_arg(argptr, void *);
- dst[size++] = 'p';
- memcpy(dst + size, &x, sizeof(x));
- size += sizeof(x);
- } else {
- LOG(ERROR) << "Unknown format " << *format;
- }
- format++;
- CHECK_LE(size, kMaxFrameLen);
- }
- va_end(argptr);
- zmsg_addmem(msg_, dst, size);
- return size;
-}
-
-int Msg::ParseFormatFrame(const char *format, ...) {
- va_list argptr;
- va_start(argptr, format);
- char* src = zframe_strdup(frame_);
- CHECK_STREQ(FMARKER, src);
- int size = strlen(FMARKER) + 1;
- while (*format) {
- if (*format == 'i') {
- int *x = va_arg(argptr, int *);
- CHECK_EQ(src[size++], 'i');
- memcpy(x, src + size, sizeof(*x));
- size += sizeof(*x);
- } else if (*format == 'f') {
- float *x = va_arg(argptr, float *);
- CHECK_EQ(src[size++], 'f');
- memcpy(x, src + size, sizeof(*x));
- size += sizeof(*x);
- } else if (*format == '1') {
- uint8_t *x = va_arg(argptr, uint8_t *);
- memcpy(x, src + size, sizeof(*x));
- size += sizeof(*x);
- } else if (*format == '2') {
- uint16_t *x = va_arg(argptr, uint16_t *);
- memcpy(x, src + size, sizeof(*x));
- size += sizeof(*x);
- } else if (*format == '4') {
- uint32_t *x = va_arg(argptr, uint32_t *);
- memcpy(x, src + size, sizeof(*x));
- size += sizeof(*x);
- } else if (*format == 's') {
- char* x = va_arg(argptr, char *);
- CHECK_EQ(src[size++], 's');
- int len = strlen(src + size);
- memcpy(x, src + size, len);
- x[len] = 0;
- size += len + 1;
- } else if (*format == 'p') {
- void** x = va_arg(argptr, void **);
- CHECK_EQ(src[size++], 'p');
- memcpy(x, src + size, sizeof(*x));
- size += sizeof(*x);
- } else {
- LOG(ERROR) << "Unknown format type " << *format;
- }
- format++;
- }
- va_end(argptr);
- delete src;
- return size;
-}
-#endif
-
-} // namespace singa
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/communication/socket.cc
----------------------------------------------------------------------
diff --git a/src/communication/socket.cc b/src/communication/socket.cc
deleted file mode 100644
index 60e1cc1..0000000
--- a/src/communication/socket.cc
+++ /dev/null
@@ -1,180 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements. See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership. The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied. See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-#include "communication/socket.h"
-
-#include <glog/logging.h>
-
-namespace singa {
-
-#ifdef USE_ZMQ
-Poller::Poller() {
- poller_ = zpoller_new(nullptr);
-}
-
-Poller::Poller(SocketInterface* socket) {
- poller_ = zpoller_new(nullptr);
- Add(socket);
-}
-
-void Poller::Add(SocketInterface* socket) {
- zsock_t* zsock = static_cast<zsock_t*>(socket->InternalID());
- zpoller_add(poller_, zsock);
- zsock2Socket_[zsock] = socket;
-}
-
-SocketInterface* Poller::Wait(int timeout) {
- zsock_t* sock = static_cast<zsock_t*>(zpoller_wait(poller_, timeout));
- if (sock != nullptr)
- return zsock2Socket_[sock];
- else
- return nullptr;
-}
-
-bool Poller::Terminated() {
- return zpoller_terminated(poller_);
-}
-
-
-Dealer::Dealer() : Dealer(-1) {}
-
-Dealer::Dealer(int id) : id_(id) {
- dealer_ = zsock_new(ZMQ_DEALER);
- CHECK_NOTNULL(dealer_);
-}
-
-Dealer::~Dealer() {
- zsock_destroy(&dealer_);
-}
-
-int Dealer::Connect(const std::string& endpoint) {
- CHECK_GT(endpoint.length(), 0);
- if (endpoint.length()) {
- CHECK_EQ(zsock_connect(dealer_, "%s", endpoint.c_str()), 0);
- return 1;
- }
- return 0;
-}
-
-int Dealer::Send(Msg** msg) {
- zmsg_t* zmsg = (*msg)->DumpToZmsg();
- zmsg_send(&zmsg, dealer_);
- delete *msg;
- *msg = nullptr;
- return 1;
-}
-
-Msg* Dealer::Receive() {
- zmsg_t* zmsg = zmsg_recv(dealer_);
- if (zmsg == nullptr)
- return nullptr;
- Msg* msg = new Msg();
- msg->ParseFromZmsg(zmsg);
- return msg;
-}
-
-void* Dealer::InternalID() const {
- return dealer_;
-}
-
-Router::Router() : Router(100) {}
-
-Router::Router(int bufsize) {
- nBufmsg_ = 0;
- bufsize_ = bufsize;
- router_ = zsock_new(ZMQ_ROUTER);
- CHECK_NOTNULL(router_);
- poller_ = zpoller_new(router_);
- CHECK_NOTNULL(poller_);
-}
-
-Router::~Router() {
- zsock_destroy(&router_);
- for (auto it : id2addr_)
- zframe_destroy(&it.second);
- for (auto it : bufmsg_) {
- for (auto *msg : it.second)
- zmsg_destroy(&msg);
- }
-}
-int Router::Bind(const std::string& endpoint) {
- int port = -1;
- if (endpoint.length()) {
- port = zsock_bind(router_, "%s", endpoint.c_str());
- }
- CHECK_NE(port, -1) << endpoint;
- LOG(INFO) << "bind successfully to " << endpoint + ":" + std::to_string(port);
- return port;
-}
-
-int Router::Send(Msg **msg) {
- zmsg_t* zmsg = (*msg)->DumpToZmsg();
- int dstid = (*msg)->dst();
- if (id2addr_.find(dstid) != id2addr_.end()) {
- // the connection has already been set up
- zframe_t* addr = zframe_dup(id2addr_[dstid]);
- zmsg_prepend(zmsg, &addr);
- zmsg_send(&zmsg, router_);
- } else {
- // the connection is not ready, buffer the message
- if (bufmsg_.size() == 0)
- nBufmsg_ = 0;
- bufmsg_[dstid].push_back(zmsg);
- ++nBufmsg_;
- CHECK_LE(nBufmsg_, bufsize_);
- }
- delete *msg;
- *msg = nullptr;
- return 1;
-}
-
-Msg* Router::Receive() {
- zmsg_t* zmsg = zmsg_recv(router_);
- if (zmsg == nullptr) {
- LOG(ERROR) << "Connection broken!";
- exit(0);
- }
- zframe_t* dealer = zmsg_pop(zmsg);
- Msg* msg = new Msg();
- msg->ParseFromZmsg(zmsg);
- if (id2addr_.find(msg->src()) == id2addr_.end()) {
- // new connection, store the sender's identfier and send buffered messages
- // for it
- id2addr_[msg->src()] = dealer;
- if (bufmsg_.find(msg->src()) != bufmsg_.end()) {
- for (auto& it : bufmsg_.at(msg->src())) {
- zframe_t* addr = zframe_dup(dealer);
- zmsg_prepend(it, &addr);
- zmsg_send(&it, router_);
- }
- bufmsg_.erase(msg->src());
- }
- } else {
- zframe_destroy(&dealer);
- }
- return msg;
-}
-
-void* Router::InternalID() const {
- return router_;
-}
-#endif
-
-} // namespace singa
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
index 6fa70ee..d3f0f3e 100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@ -19,16 +19,17 @@
*
*************************************************************/
-#include "driver.h"
-
#include <glog/logging.h>
+#include <set>
#include <string>
#include "neuralnet/layer.h"
-#include "trainer/trainer.h"
#include "utils/common.h"
#include "utils/tinydir.h"
+#include "utils/cluster.h"
+#include "./stub.h"
+#include "./driver.h"
-extern "C" void openblas_set_num_threads(int);
+extern "C" void openblas_set_num_threads(int num);
namespace singa {
@@ -109,22 +110,192 @@ void Driver::Init(int argc, char **argv) {
}
-void Driver::Submit(bool resume, const JobProto& jobConf) {
+void Driver::Train(bool resume, const JobProto& job_conf) {
+ Cluster::Setup(job_id_, singa_conf_, job_conf.cluster());
if (singa_conf_.has_log_dir())
- SetupLog(singa_conf_.log_dir(), std::to_string(job_id_)
- + "-" + jobConf.name());
+ SetupLog(singa_conf_.log_dir(),
+ std::to_string(job_id_) + "-" + job_conf.name());
tinydir_dir workspace;
- if (tinydir_open(&workspace, jobConf.cluster().workspace().c_str()) == -1)
- LOG(FATAL) << "workspace does not exist: " << jobConf.cluster().workspace();
- if (jobConf.num_openblas_threads() != 1)
- LOG(WARNING) << "openblas with "
- << jobConf.num_openblas_threads() << " threads";
- openblas_set_num_threads(jobConf.num_openblas_threads());
+ if (tinydir_open(&workspace, job_conf.cluster().workspace().c_str()) == -1)
+ LOG(FATAL) << "workspace not exist: " << job_conf.cluster().workspace();
+ if (job_conf.num_openblas_threads() != 1)
+ LOG(WARNING) << "openblas luanches "
+ << job_conf.num_openblas_threads() << " threads";
+ openblas_set_num_threads(job_conf.num_openblas_threads());
+
JobProto job;
- job.CopyFrom(jobConf);
+ job.CopyFrom(job_conf);
+ if (resume)
+ SetupForResume(&job);
job.set_id(job_id_);
- Trainer trainer;
- trainer.Start(resume, singa_conf_, &job);
+ Train(job);
}
+void Driver::Train(const JobProto& job_conf) {
+ auto cluster = Cluster::Get();
+ int nserver_grps = cluster->nserver_groups();
+ int grp_size = cluster->nworkers_per_group();
+ Stub stub;
+ // no need to create Stub if there is only a single worker without servers,
+ // i.e., the training will be conducted by the single worker.
+ if (grp_size > 1 || nserver_grps > 0) {
+ stub.Setup();
+ // TODO(wangwei) register endpoint to zookeeper if > 1 procs;
+ cluster->Register(getpid(), stub.endpoint()); // getpid() is from unistd.h
+ }
+
+ NeuralNet* net = NeuralNet::Create(job_conf.neuralnet(), kTrain, grp_size);
+ const vector<Worker*> workers = CreateWorkers(job_conf, net);
+ const vector<Server*> servers = CreateServers(job_conf, net);
+
+#ifdef USE_MPI
+ int nthreads = workers.size() + servers.size() + 1;
+ for (int i = 0; i < nthreads; i++)
+ MPIQueues.push_back(make_shared<SafeQueue>());
+#endif
+
+ vector<std::thread> threads;
+ for (auto server : servers)
+ threads.push_back(std::thread(&Server::Run, server));
+ for (auto worker : workers)
+ threads.push_back(std::thread(&Worker::Run, worker));
+ if (grp_size > 1 || nserver_grps > 0) {
+ int nservers_per_grp = cluster->nservers_per_group();
+ int lcm = LeastCommonMultiple(nservers_per_grp, nserver_grps);
+ auto slices = Param::ComputeSlices(lcm, net->params());
+ auto slice2server = PartitionSlices(nservers_per_grp, slices);
+ stub.Run(slice2server, workers, servers);
+ }
+
+ for (auto& thread : threads)
+ thread.join();
+ for (auto server : servers)
+ delete server;
+ delete net;
+ std::set<NeuralNet*> deleted{net, nullptr};
+ for (auto worker : workers) {
+ for (auto ptr : worker->GetNets())
+ if (deleted.find(ptr) == deleted.end()) {
+ delete ptr;
+ deleted.insert(ptr);
+ }
+ delete worker;
+ }
+}
+
+void Driver::SetupForResume(JobProto* job_conf) {
+ tinydir_dir dir;
+ std::string folder = Cluster::Get()->checkpoint_folder();
+ tinydir_open(&dir, folder.c_str());
+ int latest_step = 0;
+ // there would be multi checkpoint files (from diff workers) for one step
+ vector<std::string> ck_files;
+ // iterate all files to get the files for the last checkpoint
+ while (dir.has_next) {
+ tinydir_file file;
+ tinydir_readfile(&dir, &file);
+ tinydir_next(&dir);
+ char* ch = strstr(file.name, "step");
+ if (ch == nullptr) {
+ if (file.name[0] != '.')
+ LOG(INFO) << "Irregular file in checkpoint folder: " << file.name;
+ continue;
+ }
+ LOG(INFO) << "Add checkpoint file for resume: " << ch;
+ int step = atoi(ch+4);
+ if (step == latest_step) {
+ ck_files.push_back(file.name);
+ } else if (step > latest_step) {
+ latest_step = step;
+ ck_files.clear();
+ ck_files.push_back(std::string(file.name));
+ }
+ }
+ if (latest_step > 0) {
+ job_conf->set_step(latest_step);
+ if (!job_conf->has_reset_param_version())
+ job_conf->set_reset_param_version(false);
+ job_conf->clear_checkpoint_path();
+ for (auto ck_file : ck_files)
+ job_conf->add_checkpoint_path(folder + "/" + ck_file);
+ }
+ tinydir_close(&dir);
+}
+
+const vector<Worker*> Driver::CreateWorkers(const JobProto& job_conf,
+ NeuralNet* net) {
+ auto cluster = Cluster::Get();
+ vector<Worker*> workers;
+ if (!cluster->has_worker()) return workers;
+ int wgrp_size = cluster->nworkers_per_group();
+ int nservers_per_grp = cluster->nservers_per_group();
+ int nserver_grps = cluster->nserver_groups();
+ int lcm = LeastCommonMultiple(nserver_grps, nservers_per_grp);
+ const vector<int> rng = cluster->ExecutorRng(cluster->procs_id(),
+ cluster->nworkers_per_group(), cluster->nworkers_per_procs());
+ int gstart = rng[0], gend = rng[1], wstart = rng[2], wend = rng[3];
+ for (int gid = gstart; gid < gend; gid++) {
+ NeuralNet* train_net = nullptr, *test_net = nullptr, *val_net = nullptr;
+ if (gid == gstart) {
+ train_net = net;
+ Param::SliceParams(lcm, train_net->params());
+ // test and validation are performed by the 1st group.
+ if (gid == 0 && job_conf.test_steps() > 0) {
+ test_net = NeuralNet::Create(job_conf.neuralnet(), kTest, 1);
+ test_net->ShareParamsFrom(train_net);
+ }
+ if (gid == 0 && job_conf.validate_steps() > 0) {
+ val_net = NeuralNet::Create(job_conf.neuralnet(), kVal, 1);
+ val_net->ShareParamsFrom(train_net);
+ }
+ } else {
+ train_net = NeuralNet::Create(job_conf.neuralnet(), kTrain, wgrp_size);
+ if (cluster->share_memory()) {
+ train_net->ShareParamsFrom(net);
+ } else {
+ Param::SliceParams(lcm, train_net->params());
+ }
+ }
+ for (int wid = wstart; wid < wend; wid++) {
+ auto *worker = Worker::Create(job_conf.train_one_batch());
+ // TODO(wangwei) extend to test among workers in a grp
+ if (wid == 0)
+ worker->Setup(gid, wid, job_conf, train_net, val_net, test_net);
+ else
+ worker->Setup(gid, wid, job_conf, train_net, nullptr, nullptr);
+ workers.push_back(worker);
+ }
+ }
+ return workers;
+}
+
+const vector<Server*> Driver::CreateServers(const JobProto& job_conf,
+ NeuralNet* net) {
+ auto cluster = Cluster::Get();
+ vector<Server*> servers;
+ if (!cluster->has_server()) return servers;
+ int nservers_per_grp = cluster->nservers_per_group();
+ int nserver_grps = cluster->nserver_groups();
+ int lcm = LeastCommonMultiple(nserver_grps, nservers_per_grp);
+ auto slices = Param::ComputeSlices(lcm, net->params());
+ // partition among server groups, each group maintains one sub-set for sync
+ auto slice2group = PartitionSlices(nserver_grps, slices);
+ // partition within one server group, each server updates for one sub-set
+ auto slice2server = PartitionSlices(nservers_per_grp, slices);
+
+ int server_procs = cluster->procs_id();
+ // if true, server procs (logical) id starts after worker procs
+ if (cluster->server_worker_separate())
+ server_procs -= cluster->nworker_procs();
+ const vector<int> rng = cluster->ExecutorRng(server_procs,
+ cluster->nservers_per_group(), cluster->nservers_per_procs());
+ int gstart = rng[0], gend = rng[1], start = rng[2], end = rng[3];
+ for (int gid = gstart; gid < gend; gid++) {
+ for (int sid = start; sid < end; sid++) {
+ auto server = new Server(gid, sid, job_conf, slice2group, slice2server);
+ servers.push_back(server);
+ }
+ }
+ return servers;
+}
} // namespace singa
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/main.cc
----------------------------------------------------------------------
diff --git a/src/main.cc b/src/main.cc
index 5d2ab2f..99c91b8 100644
--- a/src/main.cc
+++ b/src/main.cc
@@ -7,9 +7,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
-*
+*
* http://www.apache.org/licenses/LICENSE-2.0
-*
+*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,9 +19,9 @@
*
*************************************************************/
-#include "singa.h"
+#include "./singa.h"
/**
- * \file main.cc provides an example main func.
+ * \file main.cc provides an example main function.
*
* Like the main func of Hadoop, it prepares the job configuration and submit it
* to the Driver which starts the training.
@@ -31,19 +31,17 @@
* func must call Driver::Init at the beginning, and pass the job configuration
* and resume option to the Driver for job submission.
*
- * Optionally, users can register their own implemented classes, e.g., layer,
- * updater, through the registration func provided by the Driver.
+ * Optionally, users can register their own implemented subclasses of Layer,
+ * Updater, etc. through the registration function provided by the Driver.
*
* Users must pass at least one argument to the singa-run.sh, i.e., the job
* configuration file which includes the cluster topology setting. Other fields
* e.g, neuralnet, updater can be configured in main.cc.
*
* TODO
- * Add helper functions for users to generate their configurations easily.
- * e.g., AddLayer(layer_type, source_layers, meta_data),
- * or, MLP(layer1_size, layer2_size, tanh, loss);
+ * Add helper functions for users to generate configurations for popular models
+ * easily, e.g., MLP(layer1_size, layer2_size, tanh, loss);
*/
-
int main(int argc, char **argv) {
// must create driver at the beginning and call its Init method.
singa::Driver driver;
@@ -58,7 +56,7 @@ int main(int argc, char **argv) {
// get the job conf, and custmize it if need
singa::JobProto jobConf = driver.job_conf();
- // submit the job
- driver.Submit(resume, jobConf);
+ // submit the job for training
+ driver.Train(resume, jobConf);
return 0;
}