You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@horn.apache.org by zj...@apache.org on 2016/02/03 00:59:05 UTC
[3/3] incubator-horn git commit: HORN-8: Implement asynchronous
parameter server
HORN-8: Implement asynchronous parameter server
Project: http://git-wip-us.apache.org/repos/asf/incubator-horn/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-horn/commit/91c0c796
Tree: http://git-wip-us.apache.org/repos/asf/incubator-horn/tree/91c0c796
Diff: http://git-wip-us.apache.org/repos/asf/incubator-horn/diff/91c0c796
Branch: refs/heads/master
Commit: 91c0c796e76303a0e3cf27606fbc10a03d05ed0e
Parents: 8f412c6
Author: Lee Dongjin <do...@gmail.com>
Authored: Tue Feb 2 00:15:22 2016 +0900
Committer: Lee Dongjin <do...@gmail.com>
Committed: Tue Feb 2 00:16:29 2016 +0900
----------------------------------------------------------------------
.../org/apache/horn/bsp/ParameterMerger.java | 10 ++
.../apache/horn/bsp/ParameterMergerServer.java | 97 +++++++++++
.../bsp/SmallLayeredNeuralNetworkTrainer.java | 173 ++++++-------------
3 files changed, 162 insertions(+), 118 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/91c0c796/src/main/java/org/apache/horn/bsp/ParameterMerger.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/bsp/ParameterMerger.java b/src/main/java/org/apache/horn/bsp/ParameterMerger.java
new file mode 100644
index 0000000..709331b
--- /dev/null
+++ b/src/main/java/org/apache/horn/bsp/ParameterMerger.java
@@ -0,0 +1,10 @@
+package org.apache.horn.bsp;
+
+import org.apache.hama.commons.math.DoubleMatrix;
+import org.apache.hama.ipc.VersionedProtocol;
+
+public interface ParameterMerger extends VersionedProtocol {
+ long versionID = 1L;
+
+ SmallLayeredNeuralNetworkMessage merge(double trainingError, DoubleMatrix[] weightUpdates, DoubleMatrix[] prevWeightUpdates);
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/91c0c796/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java b/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java
new file mode 100644
index 0000000..54caf2b
--- /dev/null
+++ b/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java
@@ -0,0 +1,97 @@
+package org.apache.horn.bsp;
+
+import com.google.common.base.Preconditions;
+
+import org.apache.hama.commons.math.DoubleMatrix;
+import org.mortbay.log.Log;
+
+import java.io.IOException;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+public class ParameterMergerServer implements ParameterMerger {
+ /* The parameter merge base. */
+ protected SmallLayeredNeuralNetwork inMemoryModel;
+
+ /* To terminate or not to terminate. */
+ protected AtomicBoolean isConverge;
+
+ /* The number of slave works that request commits. */
+ protected int SlaveCount;
+
+ /* After mergeLimit, terminate whether the result is converging or not. */
+ protected int mergeLimit;
+
+ /* last n training errors. converging is decided based on the average value of these errors. */
+ protected double[] trainingErrors;
+
+ /* If the average of last n training errors is smaller than this value, it is converging. */
+ protected double prevAvgTrainingError = Double.MAX_VALUE;
+
+ /* current index for trainingErrors. */
+ protected int curTrainingError = 0;
+
+ /* how many merges have been conducted? */
+ protected int mergeCount = 0;
+
+ public ParameterMergerServer(SmallLayeredNeuralNetwork inMemoryModel, AtomicBoolean isConverge,
+ int slaveCount, int mergeLimit, int convergenceCheckInterval) {
+ this.inMemoryModel = inMemoryModel;
+ this.isConverge = isConverge;
+ this.SlaveCount = slaveCount;
+ this.mergeLimit = mergeLimit;
+ this.trainingErrors = new double[convergenceCheckInterval];
+ }
+
+ @Override
+ public long getProtocolVersion(String s, long l) throws IOException {
+ return ParameterMerger.versionID;
+ }
+
+ @Override
+ public SmallLayeredNeuralNetworkMessage merge(double trainingError, DoubleMatrix[] weightUpdates,
+ DoubleMatrix[] prevWeightUpdates) {
+ Preconditions.checkArgument(weightUpdates.length == prevWeightUpdates.length);
+
+ Log.info(String.format("Start merging: %d.\n", this.mergeCount));
+
+ if (!this.isConverge.get()) {
+ for (int i = 0; i < weightUpdates.length; ++i) {
+ weightUpdates[i] = weightUpdates[i].divide(this.SlaveCount);
+ prevWeightUpdates[i] = prevWeightUpdates[i].divide(this.SlaveCount);
+ }
+
+ synchronized (inMemoryModel) {
+ this.inMemoryModel.updateWeightMatrices(weightUpdates);
+ this.inMemoryModel.setPrevWeightMatrices(prevWeightUpdates);
+
+ // add trainingError to trainingErrors
+ this.trainingErrors[this.curTrainingError++] = trainingError;
+
+ // check convergence
+ if (this.trainingErrors.length == this.curTrainingError) {
+ double curAvgTrainingError = 0.0;
+ for (int i = 0; i < this.curTrainingError; ++i) {
+ curAvgTrainingError += this.trainingErrors[i];
+ }
+ curAvgTrainingError /= this.trainingErrors.length;
+
+ if (prevAvgTrainingError < curAvgTrainingError) {
+ this.isConverge.set(true);
+ } else {
+ // update
+ prevAvgTrainingError = curAvgTrainingError;
+ this.curTrainingError = 0;
+ }
+ }
+
+ if (++this.mergeCount == this.mergeLimit) {
+ this.isConverge.set(true);
+ }
+ }
+ }
+
+ return new SmallLayeredNeuralNetworkMessage(
+ 0, this.isConverge.get(), this.inMemoryModel.getWeightMatrices(),
+ this.inMemoryModel.getPrevMatricesUpdates());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/91c0c796/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java b/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
index b4657f0..9e3d02f 100644
--- a/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
+++ b/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
@@ -29,9 +29,11 @@ import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleMatrix;
import org.apache.hama.commons.math.DoubleMatrix;
import org.apache.hama.commons.math.DoubleVector;
+import org.apache.hama.ipc.RPC;
import org.mortbay.log.Log;
import java.io.IOException;
+import java.net.InetSocketAddress;
import java.util.concurrent.atomic.AtomicBoolean;
/**
@@ -42,21 +44,26 @@ import java.util.concurrent.atomic.AtomicBoolean;
public final class SmallLayeredNeuralNetworkTrainer
extends
BSP<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> {
-
+ /* When given peer is master worker: base of parameter merge */
+ /* When given peer is slave worker: neural network for training */
private SmallLayeredNeuralNetwork inMemoryModel;
+
+ /* Job configuration */
private Configuration conf;
+
/* Default batch size */
private int batchSize;
- /* check the interval between intervals */
- private double prevAvgTrainingError;
- private double curAvgTrainingError;
- private long convergenceCheckInterval;
- private long iterations;
- private long maxIterations;
+ /* whether it is converging or not */
private AtomicBoolean isConverge;
- private String modelPath;
+ /* When given peer is master worker: Asynchronous parameter merger */
+ /* When given peer is slave worker: null */
+ private RPC.Server merger;
+
+ /* When given peer is master worker: null */
+ /* When given peer is slave worker: proxy to Asynchronous parameter merger */
+ private ParameterMerger proxy;
/**
* Returns true if this worker is master worker.
@@ -77,20 +84,37 @@ public final class SmallLayeredNeuralNetworkTrainer
// At least one master & slave worker exist.
Preconditions.checkArgument(peer.getNumPeers() >= 2);
+ String modelPath = conf.get("modelPath");
+ this.inMemoryModel = new SmallLayeredNeuralNetwork(modelPath);
+ this.conf = peer.getConfiguration();
+ this.batchSize = conf.getInt("training.batch.size", 50);
+ this.isConverge = new AtomicBoolean(false);
+
+ int slaveCount = peer.getNumPeers() - 1;
+ int mergeLimit = conf.getInt("training.max.iterations", 100000);
+ int convergenceCheckInterval = peer.getNumPeers() * conf.getInt("convergence.check.interval",
+ 2000);
+ String master = peer.getPeerName();
+ String masterAddr = master.substring(0, master.indexOf(':'));
+ int port = conf.getInt("sync.server.port", 40042);
+
if (isMaster(peer)) {
+ try {
+ this.merger = RPC.getServer(new ParameterMergerServer(inMemoryModel, isConverge, slaveCount,
+ mergeLimit, convergenceCheckInterval), masterAddr, port, conf);
+ merger.start();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
Log.info("Begin to train");
+ } else {
+ InetSocketAddress addr = new InetSocketAddress(masterAddr, port);
+ try {
+ this.proxy = (ParameterMerger) RPC.getProxy(ParameterMerger.class, ParameterMerger.versionID, addr, conf);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
}
- this.isConverge = new AtomicBoolean(false);
- this.conf = peer.getConfiguration();
- this.iterations = 0;
- this.modelPath = conf.get("modelPath");
- this.maxIterations = conf.getLong("training.max.iterations", 100000);
- this.convergenceCheckInterval = conf.getLong("convergence.check.interval",
- 2000);
- this.modelPath = conf.get("modelPath");
- this.inMemoryModel = new SmallLayeredNeuralNetwork(modelPath);
- this.prevAvgTrainingError = Integer.MAX_VALUE;
- this.batchSize = conf.getInt("training.batch.size", 50);
}
@Override
@@ -102,8 +126,6 @@ public final class SmallLayeredNeuralNetworkTrainer
// write model to modelPath
if (isMaster(peer)) {
try {
- Log.info(String.format("End of training, number of iterations: %d.\n",
- this.iterations));
Log.info(String.format("Write model back to %s\n",
inMemoryModel.getModelPath()));
this.inMemoryModel.writeModelToFile();
@@ -117,21 +139,12 @@ public final class SmallLayeredNeuralNetworkTrainer
public void bsp(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer)
throws IOException, SyncException, InterruptedException {
- while (this.iterations++ < maxIterations) {
- // each slave-worker calculate the matrices updates according to local data
- if (!isMaster(peer)) {
+ if (!isMaster(peer)) {
+ while (!this.isConverge.get()) {
+ // each slave-worker calculate the matrices updates according to local data
+ // and merge them with master
calculateUpdates(peer);
}
- peer.sync();
-
- // master merge the updates model
- if (isMaster(peer)) {
- mergeUpdates(peer);
- }
- peer.sync();
- if (this.isConverge.get()) {
- break;
- }
}
}
@@ -144,20 +157,6 @@ public final class SmallLayeredNeuralNetworkTrainer
private void calculateUpdates(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer)
throws IOException {
- // receive update information from master
- if (peer.getNumCurrentMessages() != 0) {
- SmallLayeredNeuralNetworkMessage inMessage = peer.getCurrentMessage();
- DoubleMatrix[] newWeights = inMessage.getCurMatrices();
- DoubleMatrix[] preWeightUpdates = inMessage.getPrevMatrices();
- this.inMemoryModel.setWeightMatrices(newWeights);
- this.inMemoryModel.setPrevWeightMatrices(preWeightUpdates);
- this.isConverge.set(inMessage.isConverge());
- // check converge
- if (isConverge.get()) {
- return;
- }
- }
-
DoubleMatrix[] weightUpdates = new DoubleMatrix[this.inMemoryModel.weightMatrixList
.size()];
for (int i = 0; i < weightUpdates.length; ++i) {
@@ -187,76 +186,14 @@ public final class SmallLayeredNeuralNetworkTrainer
weightUpdates[i] = weightUpdates[i].divide(batchSize);
}
- DoubleMatrix[] prevWeightUpdates = this.inMemoryModel
- .getPrevMatricesUpdates();
- SmallLayeredNeuralNetworkMessage outMessage = new SmallLayeredNeuralNetworkMessage(
- avgTrainingError, false, weightUpdates, prevWeightUpdates);
- peer.send(peer.getPeerName(0), outMessage);
- }
-
- /**
- * Merge the updates according to the updates of the grooms.
- *
- * @param peer
- * @throws IOException
- */
- private void mergeUpdates(
- BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer)
- throws IOException {
- int numMessages = peer.getNumCurrentMessages();
- boolean isConverge = false;
- if (numMessages == 0) { // converges
- this.isConverge.set(true);
- return;
- }
-
- double avgTrainingError = 0;
- DoubleMatrix[] matricesUpdates = null;
- DoubleMatrix[] prevMatricesUpdates = null;
-
- while (peer.getNumCurrentMessages() > 0) {
- SmallLayeredNeuralNetworkMessage message = peer.getCurrentMessage();
- if (matricesUpdates == null) {
- matricesUpdates = message.getCurMatrices();
- prevMatricesUpdates = message.getPrevMatrices();
- } else {
- SmallLayeredNeuralNetwork.matricesAdd(matricesUpdates,
- message.getCurMatrices());
- SmallLayeredNeuralNetwork.matricesAdd(prevMatricesUpdates,
- message.getPrevMatrices());
- }
- avgTrainingError += message.getTrainingError();
- }
-
- if (numMessages != 1) {
- avgTrainingError /= numMessages;
- for (int i = 0; i < matricesUpdates.length; ++i) {
- matricesUpdates[i] = matricesUpdates[i].divide(numMessages);
- prevMatricesUpdates[i] = prevMatricesUpdates[i].divide(numMessages);
- }
- }
- this.inMemoryModel.updateWeightMatrices(matricesUpdates);
- this.inMemoryModel.setPrevWeightMatrices(prevMatricesUpdates);
-
- // check convergence
- if (iterations % convergenceCheckInterval == 0) {
- if (prevAvgTrainingError < curAvgTrainingError) {
- // error cannot decrease any more
- isConverge = true;
- }
- // update
- prevAvgTrainingError = curAvgTrainingError;
- curAvgTrainingError = 0;
- }
- curAvgTrainingError += avgTrainingError / convergenceCheckInterval;
-
- // broadcast updated weight matrices
- for (String peerName : peer.getAllPeerNames()) {
- SmallLayeredNeuralNetworkMessage msg = new SmallLayeredNeuralNetworkMessage(
- 0, isConverge, this.inMemoryModel.getWeightMatrices(),
- this.inMemoryModel.getPrevMatricesUpdates());
- peer.send(peerName, msg);
- }
+ // exchange parameter update with master
+ SmallLayeredNeuralNetworkMessage inMessage = proxy.merge(avgTrainingError, weightUpdates,
+ this.inMemoryModel.getWeightMatrices());
+ DoubleMatrix[] newWeights = inMessage.getCurMatrices();
+ DoubleMatrix[] preWeightUpdates = inMessage.getPrevMatrices();
+ this.inMemoryModel.setWeightMatrices(newWeights);
+ this.inMemoryModel.setPrevWeightMatrices(preWeightUpdates);
+ this.isConverge.set(inMessage.isConverge());
}
}