You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hama.apache.org by ed...@apache.org on 2013/06/11 00:11:07 UTC
svn commit: r1491624 - in /hama/trunk: ./
ml/src/main/java/org/apache/hama/ml/perception/
ml/src/test/java/org/apache/hama/ml/perception/
Author: edwardyoon
Date: Mon Jun 10 22:11:06 2013
New Revision: 1491624
URL: http://svn.apache.org/r1491624
Log:
HAMA-760: Add new features to existing Multi Layer Perceptron (Yexi Jiang via edwardyoon)
Added:
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java
Modified:
hama/trunk/CHANGES.txt
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java
hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java
hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java
Modified: hama/trunk/CHANGES.txt
URL: http://svn.apache.org/viewvc/hama/trunk/CHANGES.txt?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/CHANGES.txt (original)
+++ hama/trunk/CHANGES.txt Mon Jun 10 22:11:06 2013
@@ -15,6 +15,7 @@ Release 0.7 (unreleased changes)
IMPROVEMENTS
+ HAMA-760: Add new features to existing Multi Layer Perceptron (Yexi Jiang via edwardyoon)
HAMA-758: Send message to non-exist vertex makes the job fail (MaoYuan Xian via edwardyoon)
HAMA-757: The partitioning job output should be un-splitable (MaoYuan Xian via edwardyoon)
HAMA-754: PartitioningRunner should write raw records to partition files (edwardyoon)
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java Mon Jun 10 22:11:06 2013
@@ -19,7 +19,6 @@ package org.apache.hama.ml.perception;
/**
* The common interface for cost functions.
- *
*/
public abstract class CostFunction {
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java Mon Jun 10 22:11:06 2013
@@ -32,9 +32,10 @@ public class CostFunctionFactory {
public static CostFunction getCostFunction(String name) {
if (name.equalsIgnoreCase("SquaredError")) {
return new SquaredError();
- } else if (name.equalsIgnoreCase("LogisticError")) {
- return new LogisticCostFunction();
+ } else if (name.equalsIgnoreCase("CrossEntropy")) {
+ return new CrossEntropy();
}
- return new SquaredError();
+ throw new IllegalStateException(String.format(
+ "No cost function with name '%s' found.", name));
}
}
Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java?rev=1491624&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java (added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java Mon Jun 10 22:11:06 2013
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.perception;
+
+/**
+ * The cross entropy cost function.
+ *
+ * <pre>
+ * cost(t, y) = - t * log(y) - (1 - t) * log(1 - y),
+ * where t denotes the target value, y denotes the estimated value.
+ * </pre>
+ */
+public class CrossEntropy extends CostFunction {
+
+ @Override
+ public double calculate(double target, double actual) {
+ return -target * Math.log(actual) - (1 - target) * Math.log(1 - actual);
+ }
+
+ @Override
+ public double calculateDerivative(double target, double actual) {
+ double adjustedTarget = target;
+ double adjustedActual = actual;
+ if (adjustedActual == 1) {
+ adjustedActual = 0.999;
+ } else if (actual == 0) {
+ adjustedActual = 0.001;
+ }
+ if (adjustedTarget == 1) {
+ adjustedTarget = 0.999;
+ } else if (adjustedTarget == 0) {
+ adjustedTarget = 0.001;
+ }
+ return -adjustedTarget / adjustedActual + (1 - adjustedTarget)
+ / (1 - adjustedActual);
+ }
+
+}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java Mon Jun 10 22:11:06 2013
@@ -36,7 +36,7 @@ public abstract class MultiLayerPerceptr
/* Model meta-data */
protected String MLPType;
protected double learningRate;
- protected boolean regularization;
+ protected double regularization;
protected double momentum;
protected int numberOfLayers;
protected String squashingFunctionName;
@@ -50,20 +50,33 @@ public abstract class MultiLayerPerceptr
* Initialize the MLP.
*
* @param learningRate Larger learningRate makes MLP learn more aggressive.
- * @param regularization Turn on regularization make MLP less likely to
- * overfit.
+ * Learning rate cannot be negative.
+ * @param regularization Regularization makes MLP less likely to overfit. The
+ * value of regularization cannot be negative or too large,
+ * otherwise it will affect the precision.
* @param momentum The momentum makes the historical adjust have affect to
- * current adjust.
+ * current adjust. The weight of momentum cannot be negative.
* @param squashingFunctionName The name of squashing function.
* @param costFunctionName The name of the cost function.
* @param layerSizeArray The number of neurons for each layer. Note that the
* actual size of each layer is one more than the input size.
*/
- public MultiLayerPerceptron(double learningRate, boolean regularization,
+ public MultiLayerPerceptron(double learningRate, double regularization,
double momentum, String squashingFunctionName, String costFunctionName,
int[] layerSizeArray) {
+ this.MLPType = getTypeName();
+ if (learningRate <= 0) {
+ throw new IllegalStateException("learning rate cannot be negative.");
+ }
this.learningRate = learningRate;
+ if (regularization < 0 || regularization >= 0.5) {
+ throw new IllegalStateException(
+ "regularization weight must be in range (0, 0.5).");
+ }
this.regularization = regularization; // no regularization
+ if (momentum < 0) {
+ throw new IllegalStateException("momentum weight cannot be negative.");
+ }
this.momentum = momentum; // no momentum
this.squashingFunctionName = squashingFunctionName;
this.costFunctionName = costFunctionName;
@@ -101,8 +114,12 @@ public abstract class MultiLayerPerceptr
* @param featureVector The feature of an instance to feed the perceptron.
* @return The results.
*/
- public abstract DoubleVector output(DoubleVector featureVector)
- throws Exception;
+ public abstract DoubleVector output(DoubleVector featureVector);
+
+ /**
+ * Use the class name as the type name.
+ */
+ protected abstract String getTypeName();
/**
* Read the model meta-data from the specified location.
@@ -131,7 +148,7 @@ public abstract class MultiLayerPerceptr
return learningRate;
}
- public boolean isRegularization() {
+ public double isRegularization() {
return regularization;
}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java Mon Jun 10 22:11:06 2013
@@ -17,7 +17,6 @@
*/
package org.apache.hama.ml.perception;
-
/**
* The Sigmoid function
*
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java Mon Jun 10 22:11:06 2013
@@ -29,18 +29,47 @@ import org.apache.hama.ml.writable.Matri
* {@link SmallMultiLayerPerceptron}. It send the whole parameter matrix from
* one task to another.
*/
-public class SmallMLPMessage extends MLPMessage {
+class SmallMLPMessage extends MLPMessage {
private int owner; // the ID of the task who creates the message
+ private int numOfUpdatedMatrices;
private DenseDoubleMatrix[] weightUpdatedMatrices;
- private int numOfMatrices;
+ private int numOfPrevUpdatedMatrices;
+ private DenseDoubleMatrix[] prevWeightUpdatedMatrices;
- public SmallMLPMessage(int owner, boolean terminated, DenseDoubleMatrix[] mat) {
+ /**
+ * When slave send message to master, use this constructor.
+ *
+ * @param owner The owner that create the message
+ * @param terminated Whether the training is terminated for the owner task
+ * @param weightUpdatedMatrics The weight updates
+ */
+ public SmallMLPMessage(int owner, boolean terminated,
+ DenseDoubleMatrix[] weightUpdatedMatrics) {
super(terminated);
this.owner = owner;
- this.weightUpdatedMatrices = mat;
- this.numOfMatrices = this.weightUpdatedMatrices == null ? 0
+ this.weightUpdatedMatrices = weightUpdatedMatrics;
+ this.numOfUpdatedMatrices = this.weightUpdatedMatrices == null ? 0
: this.weightUpdatedMatrices.length;
+ this.numOfPrevUpdatedMatrices = 0;
+ this.prevWeightUpdatedMatrices = null;
+ }
+
+ /**
+ * When master send message to slave, use this constructor.
+ *
+ * @param owner The owner that create the message
+ * @param terminated Whether the training is terminated for the owner task
+ * @param weightUpdatedMatrics The weight updates
+ * @param prevWeightUpdatedMatrices
+ */
+ public SmallMLPMessage(int owner, boolean terminated,
+ DenseDoubleMatrix[] weightUpdatedMatrices,
+ DenseDoubleMatrix[] prevWeightUpdatedMatrices) {
+ this(owner, terminated, weightUpdatedMatrices);
+ this.prevWeightUpdatedMatrices = prevWeightUpdatedMatrices;
+ this.numOfPrevUpdatedMatrices = this.prevWeightUpdatedMatrices == null ? 0
+ : this.prevWeightUpdatedMatrices.length;
}
/**
@@ -57,30 +86,44 @@ public class SmallMLPMessage extends MLP
*
* @return
*/
- public DenseDoubleMatrix[] getWeightsUpdatedMatrices() {
+ public DenseDoubleMatrix[] getWeightUpdatedMatrices() {
return this.weightUpdatedMatrices;
}
+ public DenseDoubleMatrix[] getPrevWeightsUpdatedMatrices() {
+ return this.prevWeightUpdatedMatrices;
+ }
+
@Override
public void readFields(DataInput input) throws IOException {
this.owner = input.readInt();
this.terminated = input.readBoolean();
- this.numOfMatrices = input.readInt();
- this.weightUpdatedMatrices = new DenseDoubleMatrix[this.numOfMatrices];
- for (int i = 0; i < this.numOfMatrices; ++i) {
+ this.numOfUpdatedMatrices = input.readInt();
+ this.weightUpdatedMatrices = new DenseDoubleMatrix[this.numOfUpdatedMatrices];
+ for (int i = 0; i < this.numOfUpdatedMatrices; ++i) {
this.weightUpdatedMatrices[i] = (DenseDoubleMatrix) MatrixWritable
.read(input);
}
+ this.numOfPrevUpdatedMatrices = input.readInt();
+ this.prevWeightUpdatedMatrices = new DenseDoubleMatrix[this.numOfPrevUpdatedMatrices];
+ for (int i = 0; i < this.numOfPrevUpdatedMatrices; ++i) {
+ this.prevWeightUpdatedMatrices[i] = (DenseDoubleMatrix) MatrixWritable
+ .read(input);
+ }
}
@Override
public void write(DataOutput output) throws IOException {
output.writeInt(this.owner);
output.writeBoolean(this.terminated);
- output.writeInt(this.numOfMatrices);
- for (int i = 0; i < this.numOfMatrices; ++i) {
+ output.writeInt(this.numOfUpdatedMatrices);
+ for (int i = 0; i < this.numOfUpdatedMatrices; ++i) {
MatrixWritable.write(this.weightUpdatedMatrices[i], output);
}
+ output.writeInt(this.numOfPrevUpdatedMatrices);
+ for (int i = 0; i < this.numOfPrevUpdatedMatrices; ++i) {
+ MatrixWritable.write(this.prevWeightUpdatedMatrices[i], output);
+ }
}
}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java Mon Jun 10 22:11:06 2013
@@ -33,7 +33,7 @@ import org.apache.hama.ml.writable.Vecto
/**
* The perceptron trainer for small scale MLP.
*/
-public class SmallMLPTrainer extends PerceptronTrainer {
+class SmallMLPTrainer extends PerceptronTrainer {
private static final Log LOG = LogFactory.getLog(SmallMLPTrainer.class);
/* used by master only, check whether all slaves finishes reading */
@@ -66,7 +66,7 @@ public class SmallMLPTrainer extends Per
// build model from scratch
if (modelPath == null || modelPath.trim().length() == 0) {
double learningRate = Double.parseDouble(conf.get("learningRate"));
- boolean regularization = Boolean.parseBoolean(conf.get("regularization"));
+ double regularization = Double.parseDouble(conf.get("regularization"));
double momentum = Double.parseDouble(conf.get("momentum"));
String squashingFunctionName = conf.get("squashingFunctionName");
String costFunctionName = conf.get("costFunctionName");
@@ -184,7 +184,7 @@ public class SmallMLPTrainer extends Per
this.statusSet.set(message.getOwner());
}
- DenseDoubleMatrix[] weightUpdates = message.getWeightsUpdatedMatrices();
+ DenseDoubleMatrix[] weightUpdates = message.getWeightUpdatedMatrices();
for (int m = 0; m < mergedUpdates.length; ++m) {
mergedUpdates[m] = (DenseDoubleMatrix) mergedUpdates[m]
.add(weightUpdates[m]);
@@ -206,12 +206,14 @@ public class SmallMLPTrainer extends Per
// update the weight matrices
this.inMemoryPerceptron.updateWeightMatrices(mergedUpdates);
+ this.inMemoryPerceptron.setPrevWeightUpdateMatrices(mergedUpdates);
}
// broadcast updated weight matrices
for (String peerName : peer.getAllPeerNames()) {
SmallMLPMessage msg = new SmallMLPMessage(peer.getPeerIndex(),
- this.terminateTraining, this.inMemoryPerceptron.getWeightMatrices());
+ this.terminateTraining, this.inMemoryPerceptron.getWeightMatrices(),
+ this.inMemoryPerceptron.getPrevWeightUpdateMatrices());
peer.send(peerName, msg);
}
@@ -233,7 +235,9 @@ public class SmallMLPTrainer extends Per
this.terminateTraining = message.isTerminated();
// each slave renew its weight matrices
this.inMemoryPerceptron.setWeightMatrices(message
- .getWeightsUpdatedMatrices());
+ .getWeightUpdatedMatrices());
+ this.inMemoryPerceptron.setPrevWeightUpdateMatrices(message
+ .getPrevWeightsUpdatedMatrices());
if (this.terminateTraining) {
return true;
}
@@ -272,8 +276,8 @@ public class SmallMLPTrainer extends Per
weightUpdates[m] = (DenseDoubleMatrix) weightUpdates[m].divide(count);
}
- LOG.info(String.format("Task %d has read %d records.",
- peer.getPeerIndex(), this.numTrainingInstanceRead));
+ LOG.info(String.format("Task %d has read %d records.", peer.getPeerIndex(),
+ this.numTrainingInstanceRead));
// send the weight updates to master task
SmallMLPMessage message = new SmallMLPMessage(peer.getPeerIndex(),
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java Mon Jun 10 22:11:06 2013
@@ -65,16 +65,19 @@ public final class SmallMultiLayerPercep
/* The in-memory weight matrix */
private DenseDoubleMatrix[] weightMatrice;
+ /* Previous weight updates, used for momentum */
+ private DenseDoubleMatrix[] prevWeightUpdateMatrices;
+
/**
* {@inheritDoc}
*/
- public SmallMultiLayerPerceptron(double learningRate, boolean regularization,
+ public SmallMultiLayerPerceptron(double learningRate, double regularization,
double momentum, String squashingFunctionName, String costFunctionName,
int[] layerSizeArray) {
super(learningRate, regularization, momentum, squashingFunctionName,
costFunctionName, layerSizeArray);
- this.MLPType = "SmallMLP";
initializeWeightMatrix();
+ this.initializePrevWeightUpdateMatrix();
}
/**
@@ -85,6 +88,7 @@ public final class SmallMultiLayerPercep
if (modelPath != null) {
try {
this.readFromModel();
+ this.initializePrevWeightUpdateMatrix();
} catch (IOException e) {
e.printStackTrace();
}
@@ -113,20 +117,30 @@ public final class SmallMultiLayerPercep
}
}
+ /**
+ * Initial the momentum weight matrices.
+ */
+ private void initializePrevWeightUpdateMatrix() {
+ this.prevWeightUpdateMatrices = new DenseDoubleMatrix[this.numberOfLayers - 1];
+ for (int i = 0; i < this.prevWeightUpdateMatrices.length; ++i) {
+ int row = this.layerSizeArray[i] + 1;
+ int col = this.layerSizeArray[i + 1];
+ this.prevWeightUpdateMatrices[i] = new DenseDoubleMatrix(row, col);
+ }
+ }
+
@Override
/**
* {@inheritDoc}
* The model meta-data is stored in memory.
*/
- public DoubleVector output(DoubleVector featureVector) throws Exception {
+ public DoubleVector output(DoubleVector featureVector) {
List<double[]> outputCache = this.outputInternal(featureVector);
// the output of the last layer is the output of the MLP
return new DenseDoubleVector(outputCache.get(outputCache.size() - 1));
}
- private List<double[]> outputInternal(DoubleVector featureVector)
- throws Exception {
-
+ private List<double[]> outputInternal(DoubleVector featureVector) {
// store the output of the hidden layers and output layer, each array store
// one layer
List<double[]> outputCache = new ArrayList<double[]>();
@@ -134,7 +148,7 @@ public final class SmallMultiLayerPercep
// start from the first hidden layer
double[] intermediateResults = new double[this.layerSizeArray[0] + 1];
if (intermediateResults.length - 1 != featureVector.getDimension()) {
- throw new Exception(
+ throw new IllegalStateException(
"Input feature dimension incorrect! The dimension of input layer is "
+ (this.layerSizeArray[0] - 1)
+ ", but the dimension of input feature is "
@@ -227,17 +241,31 @@ public final class SmallMultiLayerPercep
double[] outputLayerOutput = outputCache.get(outputCache.size() - 1);
double[] lastHiddenLayerOutput = outputCache.get(outputCache.size() - 2);
+ DenseDoubleMatrix prevWeightUpdateMatrix = this.prevWeightUpdateMatrices[this.prevWeightUpdateMatrices.length - 1];
for (int j = 0; j < delta.length; ++j) {
- delta[j] = this.squashingFunction
- .calculateDerivative(outputLayerOutput[j])
- * this.costFunction.calculateDerivative(trainingLabels[j],
- outputLayerOutput[j]);
+ delta[j] = this.costFunction.calculateDerivative(trainingLabels[j],
+ outputLayerOutput[j]);
+ // add regularization term
+ if (this.regularization != 0.0) {
+ double derivativeRegularization = 0.0;
+ DenseDoubleMatrix weightMatrix = this.weightMatrice[this.weightMatrice.length - 1];
+ for (int k = 0; k < this.layerSizeArray[this.layerSizeArray.length - 1]; ++k) {
+ derivativeRegularization += weightMatrix.get(k, j);
+ }
+ derivativeRegularization /= this.layerSizeArray[this.layerSizeArray.length - 1];
+ delta[j] += this.regularization * derivativeRegularization;
+ }
+
+ delta[j] *= this.squashingFunction
+ .calculateDerivative(outputLayerOutput[j]);
// calculate the weight update matrix between the last hidden layer and
// the output layer
for (int i = 0; i < this.layerSizeArray[this.layerSizeArray.length - 2] + 1; ++i) {
- double updatedValue = this.learningRate * delta[j]
+ double updatedValue = -this.learningRate * delta[j]
* lastHiddenLayerOutput[i];
+ // add momentum
+ updatedValue += this.momentum * prevWeightUpdateMatrix.get(i, j);
weightUpdateMatrices[weightUpdateMatrices.length - 1].set(i, j,
updatedValue);
}
@@ -270,6 +298,7 @@ public final class SmallMultiLayerPercep
double[] curLayerOutput = outputCache.get(curLayerIdx);
double[] prevLayerOutput = outputCache.get(prevLayerIdx);
+ DenseDoubleMatrix prevWeightUpdateMatrix = this.prevWeightUpdateMatrices[curLayerIdx - 1];
// for each neuron j in nextLayer, calculate the delta
for (int j = 0; j < delta.length; ++j) {
// aggregate delta from next layer
@@ -283,7 +312,10 @@ public final class SmallMultiLayerPercep
// calculate the weight update matrix between the previous layer and the
// current layer
for (int i = 0; i < weightUpdateMatrices[prevLayerIdx].getRowCount(); ++i) {
- double updatedValue = this.learningRate * delta[j] * prevLayerOutput[i];
+ double updatedValue = -this.learningRate * delta[j]
+ * prevLayerOutput[i];
+ // add momemtum
+ updatedValue += this.momentum * prevWeightUpdateMatrix.get(i, j);
weightUpdateMatrices[prevLayerIdx].set(i, j, updatedValue);
}
}
@@ -349,7 +381,7 @@ public final class SmallMultiLayerPercep
public void readFields(DataInput input) throws IOException {
this.MLPType = WritableUtils.readString(input);
this.learningRate = input.readDouble();
- this.regularization = input.readBoolean();
+ this.regularization = input.readDouble();
this.momentum = input.readDouble();
this.numberOfLayers = input.readInt();
this.squashingFunctionName = WritableUtils.readString(input);
@@ -373,7 +405,7 @@ public final class SmallMultiLayerPercep
public void write(DataOutput output) throws IOException {
WritableUtils.writeString(output, MLPType);
output.writeDouble(learningRate);
- output.writeBoolean(regularization);
+ output.writeDouble(regularization);
output.writeDouble(momentum);
output.writeInt(numberOfLayers);
WritableUtils.writeString(output, squashingFunctionName);
@@ -402,6 +434,11 @@ public final class SmallMultiLayerPercep
FileSystem fs = FileSystem.get(uri, conf);
FSDataInputStream is = new FSDataInputStream(fs.open(new Path(modelPath)));
this.readFields(is);
+ if (!this.MLPType.equals(this.getClass().getName())) {
+ throw new IllegalStateException(String.format(
+ "Model type incorrect, cannot load model '%s' for '%s'.",
+ this.MLPType, this.getClass().getName()));
+ }
} catch (URISyntaxException e) {
e.printStackTrace();
}
@@ -425,10 +462,19 @@ public final class SmallMultiLayerPercep
return this.weightMatrice;
}
+ DenseDoubleMatrix[] getPrevWeightUpdateMatrices() {
+ return this.prevWeightUpdateMatrices;
+ }
+
void setWeightMatrices(DenseDoubleMatrix[] newMatrices) {
this.weightMatrice = newMatrices;
}
+ void setPrevWeightUpdateMatrices(
+ DenseDoubleMatrix[] newPrevWeightUpdateMatrices) {
+ this.prevWeightUpdateMatrices = newPrevWeightUpdateMatrices;
+ }
+
/**
* Update the weight matrices with given updates.
*
@@ -462,4 +508,9 @@ public final class SmallMultiLayerPercep
return sb.toString();
}
+ @Override
+ protected String getTypeName() {
+ return this.getClass().getName();
+ }
+
}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java Mon Jun 10 22:11:06 2013
@@ -40,7 +40,8 @@ public class SquaredError extends CostFu
* {@inheritDoc}
*/
public double calculateDerivative(double target, double actual) {
- return target - actual;
+ // return target - actual;
+ return actual - target;
}
}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java Mon Jun 10 22:11:06 2013
@@ -36,7 +36,8 @@ public class SquashingFunctionFactory {
} else if (name.equalsIgnoreCase("Tanh")) {
return new Tanh();
}
- return new Sigmoid();
+ throw new IllegalStateException(String.format(
+ "No squashing function with name '%s' found.", name));
}
}
Modified: hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java (original)
+++ hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java Mon Jun 10 22:11:06 2013
@@ -1,4 +1,3 @@
-
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
@@ -33,7 +32,6 @@ import org.apache.hadoop.fs.Path;
import org.apache.hama.ml.math.DenseDoubleMatrix;
import org.junit.Test;
-
/**
* Test the functionalities of SmallMLPMessage
*
@@ -41,12 +39,10 @@ import org.junit.Test;
public class TestSmallMLPMessage {
@Test
- public void testReadWrite() {
+ public void testReadWriteWithoutPrevUpdate() {
int owner = 101;
double[][] mat = { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } };
-
double[][] mat2 = { { 10, 20 }, { 30, 40 }, { 50, 60 } };
-
double[][][] mats = { mat, mat2 };
DenseDoubleMatrix[] matrices = new DenseDoubleMatrix[] {
@@ -68,11 +64,10 @@ public class TestSmallMLPMessage {
outMessage.readFields(in);
assertEquals(owner, outMessage.getOwner());
- DenseDoubleMatrix[] outMatrices = outMessage.getWeightsUpdatedMatrices();
+ DenseDoubleMatrix[] outMatrices = outMessage.getWeightUpdatedMatrices();
// check each matrix
for (int i = 0; i < outMatrices.length; ++i) {
- double[][] outMat = outMessage.getWeightsUpdatedMatrices()[i]
- .getValues();
+ double[][] outMat = outMatrices[i].getValues();
for (int j = 0; j < outMat.length; ++j) {
assertArrayEquals(mats[i][j], outMat[j], 0.0001);
}
@@ -84,6 +79,69 @@ public class TestSmallMLPMessage {
} catch (URISyntaxException e) {
e.printStackTrace();
}
+ }
+
+ @Test
+ public void testReadWriteWithPrevUpdate() {
+ int owner = 101;
+ double[][] mat = { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } };
+ double[][] mat2 = { { 10, 20 }, { 30, 40 }, { 50, 60 } };
+ double[][][] mats = { mat, mat2 };
+
+ double[][] prevMat = { { 0.1, 0.2, 0.3 }, { 0.4, 0.5, 0.6 },
+ { 0.7, 0.8, 0.9 } };
+ double[][] prevMat2 = { { 1, 2 }, { 3, 4 }, { 5, 6 } };
+ double[][][] prevMats = { prevMat, prevMat2 };
+
+ DenseDoubleMatrix[] matrices = new DenseDoubleMatrix[] {
+ new DenseDoubleMatrix(mat), new DenseDoubleMatrix(mat2) };
+
+ DenseDoubleMatrix[] prevMatrices = new DenseDoubleMatrix[] {
+ new DenseDoubleMatrix(prevMat), new DenseDoubleMatrix(prevMat2) };
+
+ boolean terminated = false;
+ SmallMLPMessage message = new SmallMLPMessage(owner, terminated, matrices,
+ prevMatrices);
+
+ Configuration conf = new Configuration();
+ String strPath = "/tmp/testSmallMLPMessageWithPrevMatrices";
+ Path path = new Path(strPath);
+ try {
+ FileSystem fs = FileSystem.get(new URI(strPath), conf);
+ FSDataOutputStream out = fs.create(path, true);
+ message.write(out);
+ out.close();
+
+ FSDataInputStream in = fs.open(path);
+ SmallMLPMessage outMessage = new SmallMLPMessage(0, false, null);
+ outMessage.readFields(in);
+
+ assertEquals(owner, outMessage.getOwner());
+ assertEquals(terminated, outMessage.isTerminated());
+ DenseDoubleMatrix[] outMatrices = outMessage.getWeightUpdatedMatrices();
+ // check each matrix
+ for (int i = 0; i < outMatrices.length; ++i) {
+ double[][] outMat = outMatrices[i].getValues();
+ for (int j = 0; j < outMat.length; ++j) {
+ assertArrayEquals(mats[i][j], outMat[j], 0.0001);
+ }
+ }
+
+ DenseDoubleMatrix[] outPrevMatrices = outMessage
+ .getPrevWeightsUpdatedMatrices();
+ // check each matrix
+ for (int i = 0; i < outPrevMatrices.length; ++i) {
+ double[][] outMat = outPrevMatrices[i].getValues();
+ for (int j = 0; j < outMat.length; ++j) {
+ assertArrayEquals(prevMats[i][j], outMat[j], 0.0001);
+ }
+ }
+ fs.delete(path, true);
+ } catch (IOException e) {
+ e.printStackTrace();
+ } catch (URISyntaxException e) {
+ e.printStackTrace();
+ }
}
}
Modified: hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java (original)
+++ hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java Mon Jun 10 22:11:06 2013
@@ -50,8 +50,8 @@ public class TestSmallMultiLayerPerceptr
@Test
public void testWriteReadMLP() {
String modelPath = "/tmp/sampleModel-testWriteReadMLP.data";
- double learningRate = 0.5;
- boolean regularization = false; // no regularization
+ double learningRate = 0.3;
+ double regularization = 0.0; // no regularization
double momentum = 0; // no momentum
String squashingFunctionName = "Sigmoid";
String costFunctionName = "SquaredError";
@@ -70,9 +70,9 @@ public class TestSmallMultiLayerPerceptr
Configuration conf = new Configuration();
FileSystem fs = FileSystem.get(conf);
mlp = new SmallMultiLayerPerceptron(modelPath);
- assertEquals("SmallMLP", mlp.getMLPType());
+ assertEquals(mlp.getClass().getName(), mlp.getMLPType());
assertEquals(learningRate, mlp.getLearningRate(), 0.001);
- assertEquals(regularization, mlp.isRegularization());
+ assertEquals(regularization, mlp.isRegularization(), 0.001);
assertEquals(layerSizeArray.length, mlp.getNumberOfLayers());
assertEquals(momentum, mlp.getMomentum(), 0.001);
assertEquals(squashingFunctionName, mlp.getSquashingFunctionName());
@@ -97,10 +97,10 @@ public class TestSmallMultiLayerPerceptr
FileSystem fs = FileSystem.get(conf);
FSDataOutputStream output = fs.create(new Path(modelPath), true);
- String MLPType = "SmallMLP";
+ String MLPType = SmallMultiLayerPerceptron.class.getName();
double learningRate = 0.5;
- boolean regularization = false;
- double momentum = 0;
+ double regularization = 0.0;
+ double momentum = 0.1;
String squashingFunctionName = "Sigmoid";
String costFunctionName = "SquaredError";
int[] layerSizeArray = new int[] { 3, 2, 3, 3 };
@@ -108,7 +108,7 @@ public class TestSmallMultiLayerPerceptr
WritableUtils.writeString(output, MLPType);
output.writeDouble(learningRate);
- output.writeBoolean(regularization);
+ output.writeDouble(regularization);
output.writeDouble(momentum);
output.writeInt(numberOfLayers);
WritableUtils.writeString(output, squashingFunctionName);
@@ -162,10 +162,10 @@ public class TestSmallMultiLayerPerceptr
}
/**
- * Test the MLP on XOR problem.
+ * Test training with squared error on the XOR problem.
*/
@Test
- public void testSingleInstanceTraining() {
+ public void testTrainWithSquaredError() {
// generate training data
DoubleVector[] trainingData = new DenseDoubleVector[] {
new DenseDoubleVector(new double[] { 0, 0, 0 }),
@@ -174,8 +174,8 @@ public class TestSmallMultiLayerPerceptr
new DenseDoubleVector(new double[] { 1, 1, 0 }) };
// set parameters
- double learningRate = 0.6;
- boolean regularization = false; // no regularization
+ double learningRate = 0.5;
+ double regularization = 0.02; // no regularization
double momentum = 0; // no momentum
String squashingFunctionName = "Sigmoid";
String costFunctionName = "SquaredError";
@@ -207,6 +207,142 @@ public class TestSmallMultiLayerPerceptr
}
/**
+ * Test training with cross entropy on the XOR problem.
+ */
+ @Test
+ public void testTrainWithCrossEntropy() {
+ // generate training data
+ DoubleVector[] trainingData = new DenseDoubleVector[] {
+ new DenseDoubleVector(new double[] { 0, 0, 0 }),
+ new DenseDoubleVector(new double[] { 0, 1, 1 }),
+ new DenseDoubleVector(new double[] { 1, 0, 1 }),
+ new DenseDoubleVector(new double[] { 1, 1, 0 }) };
+
+ // set parameters
+ double learningRate = 0.5;
+ double regularization = 0.0; // no regularization
+ double momentum = 0; // no momentum
+ String squashingFunctionName = "Sigmoid";
+ String costFunctionName = "CrossEntropy";
+ int[] layerSizeArray = new int[] { 2, 7, 1 };
+ SmallMultiLayerPerceptron mlp = new SmallMultiLayerPerceptron(learningRate,
+ regularization, momentum, squashingFunctionName, costFunctionName,
+ layerSizeArray);
+
+ try {
+ // train by multiple instances
+ Random rnd = new Random();
+ for (int i = 0; i < 20000; ++i) {
+ DenseDoubleMatrix[] weightUpdates = mlp
+ .trainByInstance(trainingData[rnd.nextInt(4)]);
+ mlp.updateWeightMatrices(weightUpdates);
+ }
+
+ // System.out.printf("Weight matrices: %s\n",
+ // mlp.weightsToString(mlp.getWeightMatrices()));
+ for (int i = 0; i < trainingData.length; ++i) {
+ DenseDoubleVector testVec = (DenseDoubleVector) trainingData[i]
+ .slice(2);
+ assertEquals(trainingData[i].toArray()[2], mlp.output(testVec)
+ .toArray()[0], 0.2);
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ /**
+ * Test training with regularizatiion.
+ */
+ @Test
+ public void testWithRegularization() {
+ // generate training data
+ DoubleVector[] trainingData = new DenseDoubleVector[] {
+ new DenseDoubleVector(new double[] { 0, 0, 0 }),
+ new DenseDoubleVector(new double[] { 0, 1, 1 }),
+ new DenseDoubleVector(new double[] { 1, 0, 1 }),
+ new DenseDoubleVector(new double[] { 1, 1, 0 }) };
+
+ // set parameters
+ double learningRate = 0.5;
+ double regularization = 0.02; // regularization should be a tiny number
+ double momentum = 0; // no momentum
+ String squashingFunctionName = "Sigmoid";
+ String costFunctionName = "CrossEntropy";
+ int[] layerSizeArray = new int[] { 2, 7, 1 };
+ SmallMultiLayerPerceptron mlp = new SmallMultiLayerPerceptron(learningRate,
+ regularization, momentum, squashingFunctionName, costFunctionName,
+ layerSizeArray);
+
+ try {
+ // train by multiple instances
+ Random rnd = new Random();
+ for (int i = 0; i < 10000; ++i) {
+ DenseDoubleMatrix[] weightUpdates = mlp
+ .trainByInstance(trainingData[rnd.nextInt(4)]);
+ mlp.updateWeightMatrices(weightUpdates);
+ }
+
+ // System.out.printf("Weight matrices: %s\n",
+ // mlp.weightsToString(mlp.getWeightMatrices()));
+ for (int i = 0; i < trainingData.length; ++i) {
+ DenseDoubleVector testVec = (DenseDoubleVector) trainingData[i]
+ .slice(2);
+ assertEquals(trainingData[i].toArray()[2], mlp.output(testVec)
+ .toArray()[0], 0.2);
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ /**
+ * Test training with momentum.
+ * The MLP can converge faster.
+ */
+ @Test
+ public void testWithMomentum() {
+ // generate training data
+ DoubleVector[] trainingData = new DenseDoubleVector[] {
+ new DenseDoubleVector(new double[] { 0, 0, 0 }),
+ new DenseDoubleVector(new double[] { 0, 1, 1 }),
+ new DenseDoubleVector(new double[] { 1, 0, 1 }),
+ new DenseDoubleVector(new double[] { 1, 1, 0 }) };
+
+ // set parameters
+ double learningRate = 0.5;
+ double regularization = 0.02; // regularization should be a tiny number
+ double momentum = 0.5; // no momentum
+ String squashingFunctionName = "Sigmoid";
+ String costFunctionName = "CrossEntropy";
+ int[] layerSizeArray = new int[] { 2, 7, 1 };
+ SmallMultiLayerPerceptron mlp = new SmallMultiLayerPerceptron(learningRate,
+ regularization, momentum, squashingFunctionName, costFunctionName,
+ layerSizeArray);
+
+ try {
+ // train by multiple instances
+ Random rnd = new Random();
+ for (int i = 0; i < 3000; ++i) {
+ DenseDoubleMatrix[] weightUpdates = mlp
+ .trainByInstance(trainingData[rnd.nextInt(4)]);
+ mlp.updateWeightMatrices(weightUpdates);
+ }
+
+ // System.out.printf("Weight matrices: %s\n",
+ // mlp.weightsToString(mlp.getWeightMatrices()));
+ for (int i = 0; i < trainingData.length; ++i) {
+ DenseDoubleVector testVec = (DenseDoubleVector) trainingData[i]
+ .slice(2);
+ assertEquals(trainingData[i].toArray()[2], mlp.output(testVec)
+ .toArray()[0], 0.2);
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ /**
* Test the XOR problem.
*/
@Test
@@ -246,8 +382,8 @@ public class TestSmallMultiLayerPerceptr
// begin training
String modelPath = "/tmp/xorModel-training-by-xor.data";
double learningRate = 0.6;
- boolean regularization = false; // no regularization
- double momentum = 0; // no momentum
+ double regularization = 0.02; // no regularization
+ double momentum = 0.3; // no momentum
String squashingFunctionName = "Tanh";
String costFunctionName = "SquaredError";
int[] layerSizeArray = new int[] { 2, 5, 1 };
@@ -256,7 +392,7 @@ public class TestSmallMultiLayerPerceptr
layerSizeArray);
Map<String, String> trainingParams = new HashMap<String, String>();
- trainingParams.put("training.iteration", "10000");
+ trainingParams.put("training.iteration", "1000");
trainingParams.put("training.mode", "minibatch.gradient.descent");
trainingParams.put("training.batch.size", "100");
trainingParams.put("tasks", "3");