You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@horn.apache.org by ed...@apache.org on 2016/05/26 23:09:33 UTC
[2/2] incubator-horn git commit: HORN-26: Double to float as a
default type
HORN-26: Double to float as a default type
Project: http://git-wip-us.apache.org/repos/asf/incubator-horn/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-horn/commit/af88df41
Tree: http://git-wip-us.apache.org/repos/asf/incubator-horn/tree/af88df41
Diff: http://git-wip-us.apache.org/repos/asf/incubator-horn/diff/af88df41
Branch: refs/heads/master
Commit: af88df41bd6f70cef77b32e1a82f014255bab66f
Parents: b538634
Author: Edward J. Yoon <ed...@apache.org>
Authored: Wed May 25 13:51:01 2016 +0900
Committer: Edward J. Yoon <ed...@apache.org>
Committed: Fri May 27 08:06:49 2016 +0900
----------------------------------------------------------------------
README.md | 18 +-
bin/horn | 39 +-
conf/horn-env.sh | 3 +-
pom.xml | 4 +-
.../horn/core/AbstractLayeredNeuralNetwork.java | 54 +-
.../apache/horn/core/AbstractNeuralNetwork.java | 43 +-
.../horn/core/AbstractNeuralNetworkTrainer.java | 6 +-
.../java/org/apache/horn/core/AutoEncoder.java | 41 +-
.../horn/core/FloatFeatureTransformer.java | 17 +
src/main/java/org/apache/horn/core/HornJob.java | 12 +-
.../org/apache/horn/core/LayerInterface.java | 4 +-
.../apache/horn/core/LayeredNeuralNetwork.java | 172 +++--
.../horn/core/LayeredNeuralNetworkTrainer.java | 83 +--
src/main/java/org/apache/horn/core/Neuron.java | 68 +-
.../apache/horn/core/ParameterMergerServer.java | 10 +-
.../org/apache/horn/core/ParameterMessage.java | 48 +-
src/main/java/org/apache/horn/core/Synapse.java | 22 +-
.../horn/examples/MultiLayerPerceptron.java | 28 +-
.../horn/funcs/CategoricalCrossEntropy.java | 12 +-
.../org/apache/horn/funcs/CrossEntropy.java | 34 +-
.../org/apache/horn/funcs/FunctionFactory.java | 8 +-
.../org/apache/horn/funcs/IdentityFunction.java | 8 +-
src/main/java/org/apache/horn/funcs/ReLU.java | 14 +-
.../java/org/apache/horn/funcs/Sigmoid.java | 19 +-
.../java/org/apache/horn/funcs/SoftMax.java | 23 +-
.../org/apache/horn/funcs/SquaredError.java | 12 +-
src/main/java/org/apache/horn/funcs/Tanh.java | 12 +-
.../org/apache/horn/utils/MNISTConverter.java | 18 +-
.../org/apache/horn/utils/MNISTEvaluator.java | 14 +-
.../java/org/apache/horn/core/MLTestBase.java | 16 +-
.../org/apache/horn/core/TestAutoEncoder.java | 57 +-
.../java/org/apache/horn/core/TestNeuron.java | 44 +-
.../core/TestSmallLayeredNeuralNetwork.java | 658 -------------------
.../TestSmallLayeredNeuralNetworkMessage.java | 173 -----
.../horn/examples/MultiLayerPerceptronTest.java | 32 +-
35 files changed, 486 insertions(+), 1340 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/README.md
----------------------------------------------------------------------
diff --git a/README.md b/README.md
index 5fd125e..4c9ec6d 100644
--- a/README.md
+++ b/README.md
@@ -8,10 +8,10 @@ Apache Horn provides a neuron-centric programming model for implementing the neu
```Java
@Override
public void forward(
- Iterable<Synapse<DoubleWritable, DoubleWritable>> messages)
+ Iterable<Synapse<FloatWritable, FloatWritable>> messages)
throws IOException {
- double sum = 0;
- for (Synapse<DoubleWritable, DoubleWritable> m : messages) {
+ float sum = 0;
+ for (Synapse<FloatWritable, FloatWritable> m : messages) {
sum += m.getInput() * m.getWeight();
}
this.feedforward(this.squashingFunction.apply(sum));
@@ -21,15 +21,15 @@ Then, we measure the margin of error of the output and adjust the weights accord
```Java
@Override
public void backward(
- Iterable<Synapse<DoubleWritable, DoubleWritable>> messages)
+ Iterable<Synapse<FloatWritable, FloatWritable>> messages)
throws IOException {
- double gradient = 0;
- for (Synapse<DoubleWritable, DoubleWritable> m : messages) {
+ float gradient = 0;
+ for (Synapse<FloatWritable, FloatWritable> m : messages) {
// Calculates error gradient for each neuron
- double gradient += (m.getDelta() * m.getWeight());
+ gradient += (m.getDelta() * m.getWeight());
// Weight corrections
- double weight = -this.getLearningRate() * this.getOutput()
+ float weight = -this.getLearningRate() * this.getOutput()
* m.getDelta() + this.getMomentumWeight() * m.getPrevWeight();
this.push(weight);
}
@@ -68,7 +68,7 @@ Then, train it with following command (in this example, we used \u03b7 0.01, \u03b1 0.9,
0.01 0.9 0.0005 784 100 10 10 12000
```
-With this default example, you'll reach over the 95% accuracy. The local-mode parallel synchronous SGD based on multithreading will took around 30 mins ~ 1 hour to train.
+With this default example, you'll reach over the 95% accuracy. In local mode, 6 tasks will train the model in synchronous parallel fashion and will took around 30 mins.
## High Scalability
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/bin/horn
----------------------------------------------------------------------
diff --git a/bin/horn b/bin/horn
index 8cbd106..d2cc6c5 100755
--- a/bin/horn
+++ b/bin/horn
@@ -72,43 +72,8 @@ CLASSPATH="${HORN_CONF_DIR}"
CLASSPATH=${CLASSPATH}:$JAVA_HOME/lib/tools.jar
# for developers, add Horn classes to CLASSPATH
-if [ -d "$HORN_HOME/core/target/classes" ]; then
- CLASSPATH=${CLASSPATH}:$HORN_HOME/core/target/classes
-fi
-if [ -d "$HORN_HOME/core/target/test-classes/classes" ]; then
- CLASSPATH=${CLASSPATH}:$HORN_HOME/core/target/test-classes
-fi
-
-# for developers, add Commons classes to CLASSPATH
-if [ -d "$HORN_HOME/commons/target/classes" ]; then
- CLASSPATH=${CLASSPATH}:$HORN_HOME/commons/target/classes
-fi
-if [ -d "$HORN_HOME/commons/target/test-classes/classes" ]; then
- CLASSPATH=${CLASSPATH}:$HORN_HOME/commons/target/test-classes
-fi
-
-# for developers, add Graph classes to CLASSPATH
-if [ -d "$HORN_HOME/graph/target/classes" ]; then
- CLASSPATH=${CLASSPATH}:$HORN_HOME/graph/target/classes
-fi
-if [ -d "$HORN_HOME/graph/target/test-classes/classes" ]; then
- CLASSPATH=${CLASSPATH}:$HORN_HOME/graph/target/test-classes
-fi
-
-# for developers, add ML classes to CLASSPATH
-if [ -d "$HORN_HOME/ml/target/classes" ]; then
- CLASSPATH=${CLASSPATH}:$HORN_HOME/ml/target/classes
-fi
-if [ -d "$HORN_HOME/ml/target/test-classes/classes" ]; then
- CLASSPATH=${CLASSPATH}:$HORN_HOME/ml/target/test-classes
-fi
-
-# add mesos classes to CLASSPATH
-if [ -d "$HORN_HOME/mesos/target/classes" ]; then
- CLASSPATH=${CLASSPATH}:$HORN_HOME/mesos/target/classes
-fi
-if [ -d "$HORN_HOME/mesos/target/test-classes/classes" ]; then
- CLASSPATH=${CLASSPATH}:$HORN_HOME/mesos/target/test-classes
+if [ -d "$HORN_HOME/target/classes" ]; then
+ CLASSPATH=${CLASSPATH}:$HORN_HOME/target/classes
fi
# so that filenames w/ spaces are handled correctly in loops below
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/conf/horn-env.sh
----------------------------------------------------------------------
diff --git a/conf/horn-env.sh b/conf/horn-env.sh
index 26d190f..c60c2aa 100644
--- a/conf/horn-env.sh
+++ b/conf/horn-env.sh
@@ -22,5 +22,4 @@
# Set environment variables here.
# The java implementation to use. Required.
-export JAVA_HOME=/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/
-
+export JAVA_HOME=/usr/lib/jvm/java-8-oracle
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index e7da3aa..cd00794 100644
--- a/pom.xml
+++ b/pom.xml
@@ -211,6 +211,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
+ <version>2.1</version>
<executions>
<execution>
<id>copy-dependencies</id>
@@ -221,7 +222,8 @@
<configuration>
<outputDirectory>${project.basedir}/lib</outputDirectory>
<overWriteReleases>false</overWriteReleases>
- <overWriteSnapshots>true</overWriteSnapshots>
+ <overWriteSnapshots>false</overWriteSnapshots>
+ <overWriteIfNewer>true</overWriteIfNewer>
<excludeGroupIds>org.apache.horn</excludeGroupIds>
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
index 5ec57a2..4d1ea52 100644
--- a/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
@@ -24,10 +24,10 @@ import java.util.List;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hama.HamaConfiguration;
-import org.apache.hama.commons.math.DoubleDoubleFunction;
-import org.apache.hama.commons.math.DoubleFunction;
-import org.apache.hama.commons.math.DoubleMatrix;
-import org.apache.hama.commons.math.DoubleVector;
+import org.apache.hama.commons.math.FloatFloatFunction;
+import org.apache.hama.commons.math.FloatFunction;
+import org.apache.hama.commons.math.FloatMatrix;
+import org.apache.hama.commons.math.FloatVector;
import org.apache.horn.core.Constants.LearningStyle;
import org.apache.horn.core.Constants.TrainingMethod;
import org.apache.horn.funcs.CategoricalCrossEntropy;
@@ -49,19 +49,19 @@ import com.google.common.collect.Lists;
*/
abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
- private static final double DEFAULT_REGULARIZATION_WEIGHT = 0;
- private static final double DEFAULT_MOMENTUM_WEIGHT = 0.1;
+ private static final float DEFAULT_REGULARIZATION_WEIGHT = 0;
+ private static final float DEFAULT_MOMENTUM_WEIGHT = 0.1f;
- double trainingError;
+ float trainingError;
/* The weight of regularization */
- protected double regularizationWeight;
+ protected float regularizationWeight;
/* The momentumWeight */
- protected double momentumWeight;
+ protected float momentumWeight;
/* The cost function of the model */
- protected DoubleDoubleFunction costFunction;
+ protected FloatFloatFunction costFunction;
/* Record the size of each layer */
protected List<Integer> layerSizeList;
@@ -92,14 +92,14 @@ abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
*
* @param regularizationWeight
*/
- public void setRegularizationWeight(double regularizationWeight) {
+ public void setRegularizationWeight(float regularizationWeight) {
Preconditions.checkArgument(regularizationWeight >= 0
&& regularizationWeight < 1.0,
"Regularization weight must be in range [0, 1.0)");
this.regularizationWeight = regularizationWeight;
}
- public double getRegularizationWeight() {
+ public float getRegularizationWeight() {
return this.regularizationWeight;
}
@@ -108,13 +108,13 @@ abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
*
* @param momentumWeight
*/
- public void setMomemtumWeight(double momentumWeight) {
+ public void setMomemtumWeight(float momentumWeight) {
Preconditions.checkArgument(momentumWeight >= 0 && momentumWeight <= 1.0,
"Momentum weight must be in range [0, 1.0]");
this.momentumWeight = momentumWeight;
}
- public double getMomemtumWeight() {
+ public float getMomemtumWeight() {
return this.momentumWeight;
}
@@ -139,7 +139,7 @@ abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
*
* @param costFunction
*/
- public void setCostFunction(DoubleDoubleFunction costFunction) {
+ public void setCostFunction(FloatFloatFunction costFunction) {
this.costFunction = costFunction;
}
@@ -155,7 +155,7 @@ abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
* @return The layer index, starts with 0.
*/
public abstract int addLayer(int size, boolean isFinalLayer,
- DoubleFunction squashingFunction, Class<? extends Neuron> neuronClass);
+ FloatFunction squashingFunction, Class<? extends Neuron> neuronClass);
/**
* Get the size of a particular layer.
@@ -184,9 +184,9 @@ abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
* Get the weights between layer layerIdx and layerIdx + 1
*
* @param layerIdx The index of the layer
- * @return The weights in form of {@link DoubleMatrix}
+ * @return The weights in form of {@link floatMatrix}
*/
- public abstract DoubleMatrix getWeightsByLayer(int layerIdx);
+ public abstract FloatMatrix getWeightsByLayer(int layerIdx);
/**
* Get the updated weights using one training instance.
@@ -196,7 +196,7 @@ abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
* @return The update of each weight, in form of matrix list.
* @throws Exception
*/
- public abstract DoubleMatrix[] trainByInstance(DoubleVector trainingInstance);
+ public abstract FloatMatrix[] trainByInstance(FloatVector trainingInstance);
/**
* Get the output calculated by the model.
@@ -204,7 +204,7 @@ abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
* @param instance The feature instance.
* @return a new vector with the result of the operation.
*/
- public abstract DoubleVector getOutput(DoubleVector instance);
+ public abstract FloatVector getOutput(FloatVector instance);
/**
* Calculate the training error based on the labels and outputs.
@@ -212,20 +212,20 @@ abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
* @param labels
* @param output
*/
- protected abstract void calculateTrainingError(DoubleVector labels,
- DoubleVector output);
+ protected abstract void calculateTrainingError(FloatVector labels,
+ FloatVector output);
@Override
public void readFields(DataInput input) throws IOException {
super.readFields(input);
// read regularization weight
- this.regularizationWeight = input.readDouble();
+ this.regularizationWeight = input.readFloat();
// read momentum weight
- this.momentumWeight = input.readDouble();
+ this.momentumWeight = input.readFloat();
// read cost function
this.costFunction = FunctionFactory
- .createDoubleDoubleFunction(WritableUtils.readString(input));
+ .createFloatFloatFunction(WritableUtils.readString(input));
// read layer size list
int numLayers = input.readInt();
@@ -242,9 +242,9 @@ abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
public void write(DataOutput output) throws IOException {
super.write(output);
// write regularization weight
- output.writeDouble(this.regularizationWeight);
+ output.writeFloat(this.regularizationWeight);
// write momentum weight
- output.writeDouble(this.momentumWeight);
+ output.writeFloat(this.momentumWeight);
// write cost function
WritableUtils.writeString(output, costFunction.getFunctionName());
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java b/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
index 77d6af0..64d5945 100644
--- a/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
@@ -35,8 +35,6 @@ import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSPJob;
-import org.apache.hama.ml.util.DefaultFeatureTransformer;
-import org.apache.hama.ml.util.FeatureTransformer;
import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;
@@ -52,10 +50,10 @@ public abstract class AbstractNeuralNetwork implements Writable {
protected HamaConfiguration conf;
protected FileSystem fs;
-
- private static final double DEFAULT_LEARNING_RATE = 0.5;
- protected double learningRate;
+ private static final float DEFAULT_LEARNING_RATE = 0.5f;
+
+ protected float learningRate;
protected boolean learningRateDecay = false;
// the name of the model
@@ -63,12 +61,12 @@ public abstract class AbstractNeuralNetwork implements Writable {
// the path to store the model
protected String modelPath;
- protected FeatureTransformer featureTransformer;
+ protected FloatFeatureTransformer featureTransformer;
public AbstractNeuralNetwork() {
this.learningRate = DEFAULT_LEARNING_RATE;
this.modelType = this.getClass().getSimpleName();
- this.featureTransformer = new DefaultFeatureTransformer();
+ this.featureTransformer = new FloatFeatureTransformer();
}
public AbstractNeuralNetwork(HamaConfiguration conf, String modelPath) {
@@ -88,13 +86,13 @@ public abstract class AbstractNeuralNetwork implements Writable {
*
* @param learningRate
*/
- public void setLearningRate(double learningRate) {
+ public void setLearningRate(float learningRate) {
Preconditions.checkArgument(learningRate > 0,
"Learning rate must be larger than 0.");
this.learningRate = learningRate;
}
- public double getLearningRate() {
+ public float getLearningRate() {
return this.learningRate;
}
@@ -111,15 +109,16 @@ public abstract class AbstractNeuralNetwork implements Writable {
*
* @param dataInputPath The path of the training data.
* @param trainingParams The parameters for training.
- * @throws InterruptedException
- * @throws ClassNotFoundException
+ * @throws InterruptedException
+ * @throws ClassNotFoundException
* @throws IOException
*/
- public BSPJob train(HamaConfiguration conf) throws ClassNotFoundException, IOException, InterruptedException {
+ public BSPJob train(HamaConfiguration conf) throws ClassNotFoundException,
+ IOException, InterruptedException {
Preconditions.checkArgument(this.modelPath != null,
"Please set the model path before training.");
// train with BSP job
- return trainInternal(conf);
+ return trainInternal(conf);
}
/**
@@ -128,8 +127,8 @@ public abstract class AbstractNeuralNetwork implements Writable {
* @param dataInputPath
* @param trainingParams
*/
- protected abstract BSPJob trainInternal(HamaConfiguration conf) throws IOException,
- InterruptedException, ClassNotFoundException;
+ protected abstract BSPJob trainInternal(HamaConfiguration conf)
+ throws IOException, InterruptedException, ClassNotFoundException;
/**
* Read the model meta-data from the specified location.
@@ -199,7 +198,7 @@ public abstract class AbstractNeuralNetwork implements Writable {
// read model type
this.modelType = WritableUtils.readString(input);
// read learning rate
- this.learningRate = input.readDouble();
+ this.learningRate = input.readFloat();
// read model path
this.modelPath = WritableUtils.readString(input);
@@ -214,7 +213,7 @@ public abstract class AbstractNeuralNetwork implements Writable {
featureTransformerBytes[i] = input.readByte();
}
- Class<? extends FeatureTransformer> featureTransformerCls = (Class<? extends FeatureTransformer>) SerializationUtils
+ Class<? extends FloatFeatureTransformer> featureTransformerCls = (Class<? extends FloatFeatureTransformer>) SerializationUtils
.deserialize(featureTransformerBytes);
Constructor[] constructors = featureTransformerCls
@@ -222,7 +221,7 @@ public abstract class AbstractNeuralNetwork implements Writable {
Constructor constructor = constructors[0];
try {
- this.featureTransformer = (FeatureTransformer) constructor
+ this.featureTransformer = (FloatFeatureTransformer) constructor
.newInstance(new Object[] {});
} catch (InstantiationException e) {
e.printStackTrace();
@@ -240,7 +239,7 @@ public abstract class AbstractNeuralNetwork implements Writable {
// write model type
WritableUtils.writeString(output, modelType);
// write learning rate
- output.writeDouble(learningRate);
+ output.writeFloat(learningRate);
// write model path
if (this.modelPath != null) {
WritableUtils.writeString(output, modelPath);
@@ -249,7 +248,7 @@ public abstract class AbstractNeuralNetwork implements Writable {
}
// serialize the class
- Class<? extends FeatureTransformer> featureTransformerCls = this.featureTransformer
+ Class<? extends FloatFeatureTransformer> featureTransformerCls = this.featureTransformer
.getClass();
byte[] featureTransformerBytes = SerializationUtils
.serialize(featureTransformerCls);
@@ -257,11 +256,11 @@ public abstract class AbstractNeuralNetwork implements Writable {
output.write(featureTransformerBytes);
}
- public void setFeatureTransformer(FeatureTransformer featureTransformer) {
+ public void setFeatureTransformer(FloatFeatureTransformer featureTransformer) {
this.featureTransformer = featureTransformer;
}
- public FeatureTransformer getFeatureTransformer() {
+ public FloatFeatureTransformer getFeatureTransformer() {
return this.featureTransformer;
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/core/AbstractNeuralNetworkTrainer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/AbstractNeuralNetworkTrainer.java b/src/main/java/org/apache/horn/core/AbstractNeuralNetworkTrainer.java
index 3547a1a..d3cfa45 100644
--- a/src/main/java/org/apache/horn/core/AbstractNeuralNetworkTrainer.java
+++ b/src/main/java/org/apache/horn/core/AbstractNeuralNetworkTrainer.java
@@ -29,8 +29,6 @@ import org.apache.hama.bsp.BSP;
import org.apache.hama.bsp.BSPPeer;
import org.apache.hama.bsp.sync.SyncException;
import org.apache.hama.commons.io.VectorWritable;
-import org.apache.hama.ml.util.DefaultFeatureTransformer;
-import org.apache.hama.ml.util.FeatureTransformer;
/**
* The trainer that is used to train the {@link LayeredNeuralNetwork} with
@@ -50,14 +48,14 @@ public abstract class AbstractNeuralNetworkTrainer
protected int batchSize;
protected String trainingMode;
- protected FeatureTransformer featureTransformer;
+ protected FloatFeatureTransformer featureTransformer;
@Override
final public void setup(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, Synapse<DoubleWritable, DoubleWritable>> peer)
throws IOException, SyncException, InterruptedException {
conf = peer.getConfiguration();
- featureTransformer = new DefaultFeatureTransformer();
+ featureTransformer = new FloatFeatureTransformer();
this.extraSetup(peer);
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/core/AutoEncoder.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/AutoEncoder.java b/src/main/java/org/apache/horn/core/AutoEncoder.java
index 1b7a406..e7b3233 100644
--- a/src/main/java/org/apache/horn/core/AutoEncoder.java
+++ b/src/main/java/org/apache/horn/core/AutoEncoder.java
@@ -23,11 +23,10 @@ import java.util.Map;
import org.apache.hadoop.fs.Path;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSPJob;
-import org.apache.hama.commons.math.DenseDoubleVector;
-import org.apache.hama.commons.math.DoubleFunction;
-import org.apache.hama.commons.math.DoubleMatrix;
-import org.apache.hama.commons.math.DoubleVector;
-import org.apache.hama.ml.util.FeatureTransformer;
+import org.apache.hama.commons.math.DenseFloatVector;
+import org.apache.hama.commons.math.FloatFunction;
+import org.apache.hama.commons.math.FloatMatrix;
+import org.apache.hama.commons.math.FloatVector;
import org.apache.horn.core.Constants.LearningStyle;
import org.apache.horn.funcs.FunctionFactory;
@@ -54,15 +53,15 @@ public class AutoEncoder {
public AutoEncoder(int inputDimensions, int compressedDimensions) {
model = new LayeredNeuralNetwork();
model.addLayer(inputDimensions, false,
- FunctionFactory.createDoubleFunction("Sigmoid"), null);
+ FunctionFactory.createFloatFunction("Sigmoid"), null);
model.addLayer(compressedDimensions, false,
- FunctionFactory.createDoubleFunction("Sigmoid"), null);
+ FunctionFactory.createFloatFunction("Sigmoid"), null);
model.addLayer(inputDimensions, true,
- FunctionFactory.createDoubleFunction("Sigmoid"), null);
+ FunctionFactory.createFloatFunction("Sigmoid"), null);
model
.setLearningStyle(LearningStyle.UNSUPERVISED);
model.setCostFunction(FunctionFactory
- .createDoubleDoubleFunction("SquaredError"));
+ .createFloatFloatFunction("SquaredError"));
}
public AutoEncoder(HamaConfiguration conf, String modelPath) {
@@ -94,7 +93,7 @@ public class AutoEncoder {
*
* @param trainingInstance
*/
- public void trainOnline(DoubleVector trainingInstance) {
+ public void trainOnline(FloatVector trainingInstance) {
model.trainOnline(trainingInstance);
}
@@ -103,7 +102,7 @@ public class AutoEncoder {
*
* @return this matrix with encode the input.
*/
- public DoubleMatrix getEncodeWeightMatrix() {
+ public FloatMatrix getEncodeWeightMatrix() {
return model.getWeightsByLayer(0);
}
@@ -112,7 +111,7 @@ public class AutoEncoder {
*
* @return this matrix with decode the compressed information.
*/
- public DoubleMatrix getDecodeWeightMatrix() {
+ public FloatMatrix getDecodeWeightMatrix() {
return model.getWeightsByLayer(1);
}
@@ -122,21 +121,21 @@ public class AutoEncoder {
* @param inputInstance
* @return The compressed information.
*/
- private DoubleVector transform(DoubleVector inputInstance, int inputLayer) {
- DoubleVector internalInstance = new DenseDoubleVector(
+ private FloatVector transform(FloatVector inputInstance, int inputLayer) {
+ FloatVector internalInstance = new DenseFloatVector(
inputInstance.getDimension() + 1);
internalInstance.set(0, 1);
for (int i = 0; i < inputInstance.getDimension(); ++i) {
internalInstance.set(i + 1, inputInstance.get(i));
}
- DoubleFunction squashingFunction = model.getSquashingFunction(inputLayer);
- DoubleMatrix weightMatrix = null;
+ FloatFunction squashingFunction = model.getSquashingFunction(inputLayer);
+ FloatMatrix weightMatrix = null;
if (inputLayer == 0) {
weightMatrix = this.getEncodeWeightMatrix();
} else {
weightMatrix = this.getDecodeWeightMatrix();
}
- DoubleVector vec = weightMatrix.multiplyVectorUnsafe(internalInstance);
+ FloatVector vec = weightMatrix.multiplyVectorUnsafe(internalInstance);
vec = vec.applyToElements(squashingFunction);
return vec;
}
@@ -147,7 +146,7 @@ public class AutoEncoder {
* @param inputInstance
* @return a new vector with the encode input instance.
*/
- public DoubleVector encode(DoubleVector inputInstance) {
+ public FloatVector encode(FloatVector inputInstance) {
Preconditions
.checkArgument(
inputInstance.getDimension() == model.getLayerSize(0) - 1,
@@ -164,7 +163,7 @@ public class AutoEncoder {
* @param inputInstance
* @return a new vector with the decode input instance.
*/
- public DoubleVector decode(DoubleVector inputInstance) {
+ public FloatVector decode(FloatVector inputInstance) {
Preconditions
.checkArgument(
inputInstance.getDimension() == model.getLayerSize(1) - 1,
@@ -182,7 +181,7 @@ public class AutoEncoder {
* @return a new vector with output of the model according to given feature
* instance.
*/
- public DoubleVector getOutput(DoubleVector inputInstance) {
+ public FloatVector getOutput(FloatVector inputInstance) {
return model.getOutput(inputInstance);
}
@@ -191,7 +190,7 @@ public class AutoEncoder {
*
* @param featureTransformer
*/
- public void setFeatureTransformer(FeatureTransformer featureTransformer) {
+ public void setFeatureTransformer(FloatFeatureTransformer featureTransformer) {
this.model.setFeatureTransformer(featureTransformer);
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/core/FloatFeatureTransformer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/FloatFeatureTransformer.java b/src/main/java/org/apache/horn/core/FloatFeatureTransformer.java
new file mode 100644
index 0000000..8fc7860
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/FloatFeatureTransformer.java
@@ -0,0 +1,17 @@
+package org.apache.horn.core;
+
+import org.apache.hama.commons.math.FloatVector;
+
+public class FloatFeatureTransformer {
+
+ public FloatFeatureTransformer() {
+ }
+
+ /**
+ * Directly return the original features.
+ */
+ public FloatVector transform(FloatVector originalFeatures) {
+ return originalFeatures;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/core/HornJob.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/HornJob.java b/src/main/java/org/apache/horn/core/HornJob.java
index 30e9e88..d178166 100644
--- a/src/main/java/org/apache/horn/core/HornJob.java
+++ b/src/main/java/org/apache/horn/core/HornJob.java
@@ -49,7 +49,7 @@ public class HornJob extends BSPJob {
Class<? extends Neuron> neuronClass) {
neuralNetwork
.addLayer(featureDimension, false,
- FunctionFactory.createDoubleFunction(func.getSimpleName()),
+ FunctionFactory.createFloatFunction(func.getSimpleName()),
neuronClass);
}
@@ -58,13 +58,13 @@ public class HornJob extends BSPJob {
Class<? extends Neuron> neuronClass) {
neuralNetwork
.addLayer(labels, true,
- FunctionFactory.createDoubleFunction(func.getSimpleName()),
+ FunctionFactory.createFloatFunction(func.getSimpleName()),
neuronClass);
}
public void setCostFunction(Class<? extends Function> func) {
neuralNetwork.setCostFunction(FunctionFactory
- .createDoubleDoubleFunction(func.getSimpleName()));
+ .createFloatFloatFunction(func.getSimpleName()));
}
public void setDouble(String name, double value) {
@@ -87,7 +87,7 @@ public class HornJob extends BSPJob {
this.neuralNetwork.setLearningStyle(style);
}
- public void setLearningRate(double learningRate) {
+ public void setLearningRate(float learningRate) {
this.neuralNetwork.setLearningRate(learningRate);
}
@@ -95,11 +95,11 @@ public class HornJob extends BSPJob {
this.conf.setInt("convergence.check.interval", n);
}
- public void setMomentumWeight(double momentumWeight) {
+ public void setMomentumWeight(float momentumWeight) {
this.neuralNetwork.setMomemtumWeight(momentumWeight);
}
- public void setRegularizationWeight(double regularizationWeight) {
+ public void setRegularizationWeight(float regularizationWeight) {
this.neuralNetwork.setRegularizationWeight(regularizationWeight);
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/core/LayerInterface.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/LayerInterface.java b/src/main/java/org/apache/horn/core/LayerInterface.java
index c010cc9..3e537a6 100644
--- a/src/main/java/org/apache/horn/core/LayerInterface.java
+++ b/src/main/java/org/apache/horn/core/LayerInterface.java
@@ -19,10 +19,10 @@ package org.apache.horn.core;
import java.io.IOException;
-import org.apache.hama.commons.math.DoubleVector;
+import org.apache.hama.commons.math.FloatVector;
public interface LayerInterface {
- public DoubleVector interlayer(DoubleVector intermediateOutput) throws IOException;
+ public FloatVector interlayer(FloatVector intermediateOutput) throws IOException;
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
index d33726e..aa8e68d 100644
--- a/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
@@ -29,20 +29,20 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.WritableUtils;
+import org.apache.hama.Constants;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSPJob;
-import org.apache.hama.Constants;
-import org.apache.hama.commons.io.MatrixWritable;
+import org.apache.hama.commons.io.FloatMatrixWritable;
import org.apache.hama.commons.io.VectorWritable;
-import org.apache.hama.commons.math.DenseDoubleMatrix;
-import org.apache.hama.commons.math.DenseDoubleVector;
-import org.apache.hama.commons.math.DoubleFunction;
-import org.apache.hama.commons.math.DoubleMatrix;
-import org.apache.hama.commons.math.DoubleVector;
+import org.apache.hama.commons.math.DenseFloatMatrix;
+import org.apache.hama.commons.math.DenseFloatVector;
+import org.apache.hama.commons.math.FloatFunction;
+import org.apache.hama.commons.math.FloatMatrix;
+import org.apache.hama.commons.math.FloatVector;
import org.apache.hama.util.ReflectionUtils;
import org.apache.horn.core.Constants.LearningStyle;
import org.apache.horn.core.Constants.TrainingMethod;
@@ -71,13 +71,13 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
private static final Log LOG = LogFactory.getLog(LayeredNeuralNetwork.class);
/* Weights between neurons at adjacent layers */
- protected List<DoubleMatrix> weightMatrixList;
+ protected List<FloatMatrix> weightMatrixList;
/* Previous weight updates between neurons at adjacent layers */
- protected List<DoubleMatrix> prevWeightUpdatesList;
+ protected List<FloatMatrix> prevWeightUpdatesList;
/* Different layers can have different squashing function */
- protected List<DoubleFunction> squashingFunctionList;
+ protected List<FloatFunction> squashingFunctionList;
protected List<Class<? extends Neuron>> neuronClassList;
@@ -129,12 +129,12 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
* {@inheritDoc}
*/
public int addLayer(int size, boolean isFinalLayer,
- DoubleFunction squashingFunction, Class<? extends Neuron> neuronClass) {
+ FloatFunction squashingFunction, Class<? extends Neuron> neuronClass) {
return addLayer(size, isFinalLayer, squashingFunction, neuronClass, null);
}
public int addLayer(int size, boolean isFinalLayer,
- DoubleFunction squashingFunction, Class<? extends Neuron> neuronClass,
+ FloatFunction squashingFunction, Class<? extends Neuron> neuronClass,
Class<? extends IntermediateOutput> interlayer) {
Preconditions.checkArgument(size > 0,
"Size of layer must be larger than 0.");
@@ -162,21 +162,21 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
// size of previous layer
int row = isFinalLayer ? size : size - 1;
int col = sizePrevLayer;
- DoubleMatrix weightMatrix = new DenseDoubleMatrix(row, col);
+ FloatMatrix weightMatrix = new DenseFloatMatrix(row, col);
// initialize weights
- weightMatrix.applyToElements(new DoubleFunction() {
+ weightMatrix.applyToElements(new FloatFunction() {
@Override
- public double apply(double value) {
- return RandomUtils.nextDouble() - 0.5;
+ public float apply(float value) {
+ return RandomUtils.nextFloat() - 0.5f;
}
@Override
- public double applyDerivative(double value) {
+ public float applyDerivative(float value) {
throw new UnsupportedOperationException("");
}
});
this.weightMatrixList.add(weightMatrix);
- this.prevWeightUpdatesList.add(new DenseDoubleMatrix(row, col));
+ this.prevWeightUpdatesList.add(new DenseFloatMatrix(row, col));
this.squashingFunctionList.add(squashingFunction);
this.neuronClassList.add(neuronClass);
@@ -189,9 +189,9 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
*
* @param matrices
*/
- public void updateWeightMatrices(DoubleMatrix[] matrices) {
+ public void updateWeightMatrices(FloatMatrix[] matrices) {
for (int i = 0; i < matrices.length; ++i) {
- DoubleMatrix matrix = this.weightMatrixList.get(i);
+ FloatMatrix matrix = this.weightMatrixList.get(i);
this.weightMatrixList.set(i, matrix.add(matrices[i]));
}
}
@@ -201,7 +201,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
*
* @param prevUpdates
*/
- void setPrevWeightMatrices(DoubleMatrix[] prevUpdates) {
+ void setPrevWeightMatrices(FloatMatrix[] prevUpdates) {
this.prevWeightUpdatesList.clear();
Collections.addAll(this.prevWeightUpdatesList, prevUpdates);
}
@@ -212,8 +212,8 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
* @param destMatrices
* @param sourceMatrices
*/
- static void matricesAdd(DoubleMatrix[] destMatrices,
- DoubleMatrix[] sourceMatrices) {
+ static void matricesAdd(FloatMatrix[] destMatrices,
+ FloatMatrix[] sourceMatrices) {
for (int i = 0; i < destMatrices.length; ++i) {
destMatrices[i] = destMatrices[i].add(sourceMatrices[i]);
}
@@ -224,8 +224,8 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
*
* @return The matrices in form of matrix array.
*/
- DoubleMatrix[] getWeightMatrices() {
- DoubleMatrix[] matrices = new DoubleMatrix[this.weightMatrixList.size()];
+ FloatMatrix[] getWeightMatrices() {
+ FloatMatrix[] matrices = new FloatMatrix[this.weightMatrixList.size()];
this.weightMatrixList.toArray(matrices);
return matrices;
}
@@ -235,8 +235,8 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
*
* @param matrices
*/
- public void setWeightMatrices(DoubleMatrix[] matrices) {
- this.weightMatrixList = new ArrayList<DoubleMatrix>();
+ public void setWeightMatrices(FloatMatrix[] matrices) {
+ this.weightMatrixList = new ArrayList<FloatMatrix>();
Collections.addAll(this.weightMatrixList, matrices);
}
@@ -245,8 +245,8 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
*
* @return The matrices in form of matrix array.
*/
- public DoubleMatrix[] getPrevMatricesUpdates() {
- DoubleMatrix[] prevMatricesUpdates = new DoubleMatrix[this.prevWeightUpdatesList
+ public FloatMatrix[] getPrevMatricesUpdates() {
+ FloatMatrix[] prevMatricesUpdates = new FloatMatrix[this.prevWeightUpdatesList
.size()];
for (int i = 0; i < this.prevWeightUpdatesList.size(); ++i) {
prevMatricesUpdates[i] = this.prevWeightUpdatesList.get(i);
@@ -254,7 +254,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
return prevMatricesUpdates;
}
- public void setWeightMatrix(int index, DoubleMatrix matrix) {
+ public void setWeightMatrix(int index, FloatMatrix matrix) {
Preconditions.checkArgument(
0 <= index && index < this.weightMatrixList.size(), String.format(
"index [%d] should be in range[%d, %d].", index, 0,
@@ -287,7 +287,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
this.squashingFunctionList = Lists.newArrayList();
for (int i = 0; i < squashingFunctionSize; ++i) {
this.squashingFunctionList.add(FunctionFactory
- .createDoubleFunction(WritableUtils.readString(input)));
+ .createFloatFunction(WritableUtils.readString(input)));
}
// read weights and construct matrices of previous updates
@@ -295,10 +295,10 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
this.weightMatrixList = Lists.newArrayList();
this.prevWeightUpdatesList = Lists.newArrayList();
for (int i = 0; i < numOfMatrices; ++i) {
- DoubleMatrix matrix = MatrixWritable.read(input);
+ FloatMatrix matrix = FloatMatrixWritable.read(input);
this.weightMatrixList.add(matrix);
- this.prevWeightUpdatesList.add(new DenseDoubleMatrix(
- matrix.getRowCount(), matrix.getColumnCount()));
+ this.prevWeightUpdatesList.add(new DenseFloatMatrix(matrix.getRowCount(),
+ matrix.getColumnCount()));
}
}
@@ -317,22 +317,22 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
// write squashing functions
output.writeInt(this.squashingFunctionList.size());
- for (DoubleFunction aSquashingFunctionList : this.squashingFunctionList) {
+ for (FloatFunction aSquashingFunctionList : this.squashingFunctionList) {
WritableUtils.writeString(output,
aSquashingFunctionList.getFunctionName());
}
// write weight matrices
output.writeInt(this.weightMatrixList.size());
- for (DoubleMatrix aWeightMatrixList : this.weightMatrixList) {
- MatrixWritable.write(aWeightMatrixList, output);
+ for (FloatMatrix aWeightMatrixList : this.weightMatrixList) {
+ FloatMatrixWritable.write(aWeightMatrixList, output);
}
// DO NOT WRITE WEIGHT UPDATE
}
@Override
- public DoubleMatrix getWeightsByLayer(int layerIdx) {
+ public FloatMatrix getWeightsByLayer(int layerIdx) {
return this.weightMatrixList.get(layerIdx);
}
@@ -340,19 +340,19 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
* Get the output of the model according to given feature instance.
*/
@Override
- public DoubleVector getOutput(DoubleVector instance) {
+ public FloatVector getOutput(FloatVector instance) {
Preconditions.checkArgument(this.layerSizeList.get(0) - 1 == instance
.getDimension(), String.format(
"The dimension of input instance should be %d.",
this.layerSizeList.get(0) - 1));
// transform the features to another space
- DoubleVector transformedInstance = this.featureTransformer
+ FloatVector transformedInstance = this.featureTransformer
.transform(instance);
// add bias feature
- DoubleVector instanceWithBias = new DenseDoubleVector(
+ FloatVector instanceWithBias = new DenseFloatVector(
transformedInstance.getDimension() + 1);
- instanceWithBias.set(0, 0.99999); // set bias to be a little bit less than
- // 1.0
+ instanceWithBias.set(0, 0.99999f); // set bias to be a little bit less than
+ // 1.0
for (int i = 1; i < instanceWithBias.getDimension(); ++i) {
instanceWithBias.set(i, transformedInstance.get(i - 1));
}
@@ -368,7 +368,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
* @param instanceWithBias The instance contains the features.
* @return Cached output of each layer.
*/
- public DoubleVector getOutputInternal(DoubleVector instanceWithBias) {
+ public FloatVector getOutputInternal(FloatVector instanceWithBias) {
// sets the output of input layer
Neuron[] inputLayer = neurons.get(0);
for (int i = 0; i < inputLayer.length; i++) {
@@ -379,7 +379,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
forward(i);
}
- DoubleVector output = new DenseDoubleVector(
+ FloatVector output = new DenseFloatVector(
neurons.get(this.finalLayerIdx).length);
for (int i = 0; i < output.getDimension(); i++) {
output.set(i, neurons.get(this.finalLayerIdx)[i].getOutput());
@@ -404,17 +404,17 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
*/
protected void forward(int fromLayer) {
int curLayerIdx = fromLayer + 1;
- DoubleMatrix weightMatrix = this.weightMatrixList.get(fromLayer);
+ FloatMatrix weightMatrix = this.weightMatrixList.get(fromLayer);
- DoubleFunction squashingFunction = getSquashingFunction(fromLayer);
- DoubleVector vec = new DenseDoubleVector(weightMatrix.getRowCount());
+ FloatFunction squashingFunction = getSquashingFunction(fromLayer);
+ FloatVector vec = new DenseFloatVector(weightMatrix.getRowCount());
for (int row = 0; row < weightMatrix.getRowCount(); row++) {
- List<Synapse<DoubleWritable, DoubleWritable>> msgs = new ArrayList<Synapse<DoubleWritable, DoubleWritable>>();
+ List<Synapse<FloatWritable, FloatWritable>> msgs = new ArrayList<Synapse<FloatWritable, FloatWritable>>();
for (int col = 0; col < weightMatrix.getColumnCount(); col++) {
- msgs.add(new Synapse<DoubleWritable, DoubleWritable>(
- new DoubleWritable(neurons.get(fromLayer)[col].getOutput()),
- new DoubleWritable(weightMatrix.get(row, col))));
+ msgs.add(new Synapse<FloatWritable, FloatWritable>(new FloatWritable(
+ neurons.get(fromLayer)[col].getOutput()), new FloatWritable(
+ weightMatrix.get(row, col))));
}
Neuron n;
@@ -459,20 +459,20 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
*
* @param trainingInstance
*/
- public void trainOnline(DoubleVector trainingInstance) {
- DoubleMatrix[] updateMatrices = this.trainByInstance(trainingInstance);
+ public void trainOnline(FloatVector trainingInstance) {
+ FloatMatrix[] updateMatrices = this.trainByInstance(trainingInstance);
this.updateWeightMatrices(updateMatrices);
}
@Override
- public DoubleMatrix[] trainByInstance(DoubleVector trainingInstance) {
- DoubleVector transformedVector = this.featureTransformer
+ public FloatMatrix[] trainByInstance(FloatVector trainingInstance) {
+ FloatVector transformedVector = this.featureTransformer
.transform(trainingInstance.sliceUnsafe(this.layerSizeList.get(0) - 1));
int inputDimension = this.layerSizeList.get(0) - 1;
int outputDimension;
- DoubleVector inputInstance = null;
- DoubleVector labels = null;
+ FloatVector inputInstance = null;
+ FloatVector labels = null;
if (this.learningStyle == LearningStyle.SUPERVISED) {
outputDimension = this.layerSizeList.get(this.layerSizeList.size() - 1);
// validate training instance
@@ -484,7 +484,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
trainingInstance.getDimension(), inputDimension
+ outputDimension));
- inputInstance = new DenseDoubleVector(this.layerSizeList.get(0));
+ inputInstance = new DenseFloatVector(this.layerSizeList.get(0));
inputInstance.set(0, 1); // add bias
// get the features from the transformed vector
for (int i = 0; i < inputDimension; ++i) {
@@ -502,7 +502,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
"The dimension of training instance is %d, but requires %d.",
trainingInstance.getDimension(), inputDimension));
- inputInstance = new DenseDoubleVector(this.layerSizeList.get(0));
+ inputInstance = new DenseFloatVector(this.layerSizeList.get(0));
inputInstance.set(0, 1); // add bias
// get the features from the transformed vector
for (int i = 0; i < inputDimension; ++i) {
@@ -512,7 +512,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
labels = transformedVector.deepCopy();
}
- DoubleVector output = this.getOutputInternal(inputInstance);
+ FloatVector output = this.getOutputInternal(inputInstance);
// get the training error
calculateTrainingError(labels, output);
@@ -532,27 +532,27 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
* @param trainingInstance
* @return The weight update matrices.
*/
- private DoubleMatrix[] trainByInstanceGradientDescent(DoubleVector labels) {
+ private FloatMatrix[] trainByInstanceGradientDescent(FloatVector labels) {
// initialize weight update matrices
- DenseDoubleMatrix[] weightUpdateMatrices = new DenseDoubleMatrix[this.weightMatrixList
+ DenseFloatMatrix[] weightUpdateMatrices = new DenseFloatMatrix[this.weightMatrixList
.size()];
for (int m = 0; m < weightUpdateMatrices.length; ++m) {
- weightUpdateMatrices[m] = new DenseDoubleMatrix(this.weightMatrixList
- .get(m).getRowCount(), this.weightMatrixList.get(m).getColumnCount());
+ weightUpdateMatrices[m] = new DenseFloatMatrix(this.weightMatrixList.get(
+ m).getRowCount(), this.weightMatrixList.get(m).getColumnCount());
}
- DoubleVector deltaVec = new DenseDoubleVector(
+ FloatVector deltaVec = new DenseFloatVector(
this.layerSizeList.get(this.layerSizeList.size() - 1));
- DoubleFunction squashingFunction = this.squashingFunctionList
+ FloatFunction squashingFunction = this.squashingFunctionList
.get(this.squashingFunctionList.size() - 1);
- DoubleMatrix lastWeightMatrix = this.weightMatrixList
+ FloatMatrix lastWeightMatrix = this.weightMatrixList
.get(this.weightMatrixList.size() - 1);
for (int i = 0; i < deltaVec.getDimension(); ++i) {
- double finalOut = neurons.get(finalLayerIdx)[i].getOutput();
- double costFuncDerivative = this.costFunction.applyDerivative(
+ float finalOut = neurons.get(finalLayerIdx)[i].getOutput();
+ float costFuncDerivative = this.costFunction.applyDerivative(
labels.get(i), finalOut);
// add regularization
costFuncDerivative += this.regularizationWeight
@@ -584,36 +584,34 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
* @param layer Index of current layer.
*/
private void backpropagate(int curLayerIdx,
- // DoubleVector nextLayerDelta, DoubleVector curLayerOutput,
- DenseDoubleMatrix weightUpdateMatrix) {
+ // FloatVector nextLayerDelta, FloatVector curLayerOutput,
+ DenseFloatMatrix weightUpdateMatrix) {
// get layer related information
- DoubleMatrix weightMatrix = this.weightMatrixList.get(curLayerIdx);
- DoubleMatrix prevWeightMatrix = this.prevWeightUpdatesList.get(curLayerIdx);
+ FloatMatrix weightMatrix = this.weightMatrixList.get(curLayerIdx);
+ FloatMatrix prevWeightMatrix = this.prevWeightUpdatesList.get(curLayerIdx);
- DoubleVector deltaVector = new DenseDoubleVector(
+ FloatVector deltaVector = new DenseFloatVector(
weightMatrix.getColumnCount());
for (int row = 0; row < weightMatrix.getColumnCount(); ++row) {
Neuron n = neurons.get(curLayerIdx)[row];
n.setWeightVector(weightMatrix.getRowCount());
- List<Synapse<DoubleWritable, DoubleWritable>> msgs = new ArrayList<Synapse<DoubleWritable, DoubleWritable>>();
+ List<Synapse<FloatWritable, FloatWritable>> msgs = new ArrayList<Synapse<FloatWritable, FloatWritable>>();
for (int col = 0; col < weightMatrix.getRowCount(); ++col) {
- double deltaOfNextLayer;
+ float deltaOfNextLayer;
if (curLayerIdx + 1 == this.finalLayerIdx)
deltaOfNextLayer = neurons.get(curLayerIdx + 1)[col].getDelta();
else
deltaOfNextLayer = neurons.get(curLayerIdx + 1)[col + 1].getDelta();
- msgs.add(new Synapse<DoubleWritable, DoubleWritable>(
- new DoubleWritable(deltaOfNextLayer), new DoubleWritable(
- weightMatrix.get(col, row)), new DoubleWritable(
- prevWeightMatrix.get(col, row))));
+ msgs.add(new Synapse<FloatWritable, FloatWritable>(new FloatWritable(
+ deltaOfNextLayer), new FloatWritable(weightMatrix.get(col, row)),
+ new FloatWritable(prevWeightMatrix.get(col, row))));
}
- Iterable<Synapse<DoubleWritable, DoubleWritable>> iterable = msgs;
try {
n.backward(msgs);
} catch (IOException e) {
@@ -653,7 +651,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
job.setBspClass(LayeredNeuralNetworkTrainer.class);
job.getConfiguration().setInt(Constants.ADDITIONAL_BSP_TASKS, 1);
-
+
job.setInputPath(new Path(conf.get("training.input.path")));
job.setInputFormat(org.apache.hama.bsp.SequenceFileInputFormat.class);
job.setInputKeyClass(LongWritable.class);
@@ -666,8 +664,8 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
}
@Override
- protected void calculateTrainingError(DoubleVector labels, DoubleVector output) {
- DoubleVector errors = labels.deepCopy().applyToElements(output,
+ protected void calculateTrainingError(FloatVector labels, FloatVector output) {
+ FloatVector errors = labels.deepCopy().applyToElements(output,
this.costFunction);
this.trainingError = errors.sum();
}
@@ -678,7 +676,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
* @param idx
* @return a new vector with the result of the operation.
*/
- public DoubleFunction getSquashingFunction(int idx) {
+ public FloatFunction getSquashingFunction(int idx) {
return this.squashingFunctionList.get(idx);
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java b/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
index 275dd75..e0810e2 100644
--- a/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
+++ b/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
@@ -30,10 +30,10 @@ import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSP;
import org.apache.hama.bsp.BSPPeer;
import org.apache.hama.bsp.sync.SyncException;
-import org.apache.hama.commons.io.VectorWritable;
-import org.apache.hama.commons.math.DenseDoubleMatrix;
-import org.apache.hama.commons.math.DoubleMatrix;
-import org.apache.hama.commons.math.DoubleVector;
+import org.apache.hama.commons.io.FloatVectorWritable;
+import org.apache.hama.commons.math.DenseFloatMatrix;
+import org.apache.hama.commons.math.FloatMatrix;
+import org.apache.hama.commons.math.FloatVector;
/**
* The trainer that train the {@link LayeredNeuralNetwork} based on BSP
@@ -42,7 +42,7 @@ import org.apache.hama.commons.math.DoubleVector;
*/
public final class LayeredNeuralNetworkTrainer
extends
- BSP<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> {
+ BSP<LongWritable, FloatVectorWritable, NullWritable, NullWritable, ParameterMessage> {
private static final Log LOG = LogFactory
.getLog(LayeredNeuralNetworkTrainer.class);
@@ -58,7 +58,6 @@ public final class LayeredNeuralNetworkTrainer
private long convergenceCheckInterval;
private long iterations;
private long maxIterations;
- private long epoch;
private boolean isConverge;
private String modelPath;
@@ -68,16 +67,15 @@ public final class LayeredNeuralNetworkTrainer
* If the model path is specified, load the existing from storage location.
*/
public void setup(
- BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> peer) {
+ BSPPeer<LongWritable, FloatVectorWritable, NullWritable, NullWritable, ParameterMessage> peer) {
if (peer.getPeerIndex() == 0) {
LOG.info("Begin to train");
}
this.isConverge = false;
this.conf = peer.getConfiguration();
this.iterations = 0;
- this.epoch = 0;
this.modelPath = conf.get("model.path");
- this.maxIterations = conf.getLong("training.max.iterations", 100000);
+ this.maxIterations = conf.getLong("training.max.iterations", Long.MAX_VALUE);
this.convergenceCheckInterval = conf.getLong("convergence.check.interval",
100);
this.inMemoryModel = new LayeredNeuralNetwork(conf, modelPath);
@@ -90,9 +88,9 @@ public final class LayeredNeuralNetworkTrainer
* Write the trained model back to stored location.
*/
public void cleanup(
- BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> peer) {
+ BSPPeer<LongWritable, FloatVectorWritable, NullWritable, NullWritable, ParameterMessage> peer) {
// write model to modelPath
- if (peer.getPeerIndex() == 0) {
+ if (peer.getPeerIndex() == peer.getNumPeers() - 1) {
try {
LOG.info(String.format("End of training, number of iterations: %d.",
this.iterations));
@@ -105,18 +103,18 @@ public final class LayeredNeuralNetworkTrainer
}
}
- private List<DoubleVector> trainingSet = new ArrayList<DoubleVector>();
+ private List<FloatVector> trainingSet = new ArrayList<FloatVector>();
private Random r = new Random();
@Override
public void bsp(
- BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> peer)
+ BSPPeer<LongWritable, FloatVectorWritable, NullWritable, NullWritable, ParameterMessage> peer)
throws IOException, SyncException, InterruptedException {
// load local data into memory
LongWritable key = new LongWritable();
- VectorWritable value = new VectorWritable();
+ FloatVectorWritable value = new FloatVectorWritable();
while (peer.readNext(key, value)) {
- DoubleVector v = value.getVector();
+ FloatVector v = value.getVector();
trainingSet.add(v);
}
@@ -131,18 +129,22 @@ public final class LayeredNeuralNetworkTrainer
mergeUpdates(peer);
}
}
-
+
peer.sync();
-
- if(isConverge) {
- if(peer.getPeerIndex() == peer.getNumPeers() - 1)
+
+ if (maxIterations == Long.MAX_VALUE && isConverge) {
+ if (peer.getPeerIndex() == peer.getNumPeers() - 1)
peer.sync();
break;
}
}
+
+ peer.sync();
+ if (peer.getPeerIndex() == peer.getNumPeers() - 1)
+ mergeUpdates(peer); // merge last updates
}
- private DoubleVector getRandomInstance() {
+ private FloatVector getRandomInstance() {
return trainingSet.get(r.nextInt(trainingSet.size()));
}
@@ -153,13 +155,13 @@ public final class LayeredNeuralNetworkTrainer
* @throws IOException
*/
private void calculateUpdates(
- BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> peer)
+ BSPPeer<LongWritable, FloatVectorWritable, NullWritable, NullWritable, ParameterMessage> peer)
throws IOException {
// receive update information from master
if (peer.getNumCurrentMessages() != 0) {
ParameterMessage inMessage = peer.getCurrentMessage();
- DoubleMatrix[] newWeights = inMessage.getCurMatrices();
- DoubleMatrix[] preWeightUpdates = inMessage.getPrevMatrices();
+ FloatMatrix[] newWeights = inMessage.getCurMatrices();
+ FloatMatrix[] preWeightUpdates = inMessage.getPrevMatrices();
this.inMemoryModel.setWeightMatrices(newWeights);
this.inMemoryModel.setPrevWeightMatrices(preWeightUpdates);
this.isConverge = inMessage.isConverge();
@@ -169,18 +171,19 @@ public final class LayeredNeuralNetworkTrainer
}
}
- DoubleMatrix[] weightUpdates = new DoubleMatrix[this.inMemoryModel.weightMatrixList
+ FloatMatrix[] weightUpdates = new FloatMatrix[this.inMemoryModel.weightMatrixList
.size()];
for (int i = 0; i < weightUpdates.length; ++i) {
int row = this.inMemoryModel.weightMatrixList.get(i).getRowCount();
int col = this.inMemoryModel.weightMatrixList.get(i).getColumnCount();
- weightUpdates[i] = new DenseDoubleMatrix(row, col);
+ weightUpdates[i] = new DenseFloatMatrix(row, col);
}
// continue to train
- double avgTrainingError = 0.0;
+ float avgTrainingError = 0.0f;
for (int recordsRead = 0; recordsRead < batchSize; ++recordsRead) {
- DoubleVector trainingInstance = getRandomInstance();
+ FloatVector trainingInstance = getRandomInstance();
+
LayeredNeuralNetwork.matricesAdd(weightUpdates,
this.inMemoryModel.trainByInstance(trainingInstance));
avgTrainingError += this.inMemoryModel.trainingError;
@@ -192,7 +195,7 @@ public final class LayeredNeuralNetworkTrainer
weightUpdates[i] = weightUpdates[i].divide(batchSize);
}
- DoubleMatrix[] prevWeightUpdates = this.inMemoryModel
+ FloatMatrix[] prevWeightUpdates = this.inMemoryModel
.getPrevMatricesUpdates();
ParameterMessage outMessage = new ParameterMessage(avgTrainingError, false,
weightUpdates, prevWeightUpdates);
@@ -206,7 +209,7 @@ public final class LayeredNeuralNetworkTrainer
* @throws IOException
*/
private void mergeUpdates(
- BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> peer)
+ BSPPeer<LongWritable, FloatVectorWritable, NullWritable, NullWritable, ParameterMessage> peer)
throws IOException {
int numMessages = peer.getNumCurrentMessages();
boolean converge = false;
@@ -216,8 +219,8 @@ public final class LayeredNeuralNetworkTrainer
}
double avgTrainingError = 0;
- DoubleMatrix[] matricesUpdates = null;
- DoubleMatrix[] prevMatricesUpdates = null;
+ FloatMatrix[] matricesUpdates = null;
+ FloatMatrix[] prevMatricesUpdates = null;
while (peer.getNumCurrentMessages() > 0) {
ParameterMessage message = peer.getCurrentMessage();
@@ -260,14 +263,16 @@ public final class LayeredNeuralNetworkTrainer
}
curAvgTrainingError += avgTrainingError / convergenceCheckInterval;
this.isConverge = converge;
-
- // broadcast updated weight matrices
- for (String peerName : peer.getAllPeerNames()) {
- ParameterMessage msg = new ParameterMessage(0, converge,
- this.inMemoryModel.getWeightMatrices(),
- this.inMemoryModel.getPrevMatricesUpdates());
- if (!peer.getPeerName().equals(peerName))
- peer.send(peerName, msg);
+
+ if (iterations < maxIterations) {
+ // broadcast updated weight matrices
+ for (String peerName : peer.getAllPeerNames()) {
+ ParameterMessage msg = new ParameterMessage(0, converge,
+ this.inMemoryModel.getWeightMatrices(),
+ this.inMemoryModel.getPrevMatricesUpdates());
+ if (!peer.getPeerName().equals(peerName))
+ peer.send(peerName, msg);
+ }
}
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/core/Neuron.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/Neuron.java b/src/main/java/org/apache/horn/core/Neuron.java
index 1c0f475..908abf4 100644
--- a/src/main/java/org/apache/horn/core/Neuron.java
+++ b/src/main/java/org/apache/horn/core/Neuron.java
@@ -22,21 +22,21 @@ import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.Writable;
-import org.apache.hama.commons.math.DoubleFunction;
+import org.apache.hama.commons.math.FloatFunction;
public abstract class Neuron<M extends Writable> implements Writable, NeuronInterface<M> {
int id;
- double output;
- double weight;
- double delta;
+ float output;
+ float weight;
+ float delta;
- double momentumWeight;
- double learningRate;
+ float momentumWeight;
+ float learningRate;
int layerIndex;
boolean isOutputLayer;
- protected DoubleFunction squashingFunction;
+ protected FloatFunction squashingFunction;
public void setNeuronID(int id) {
this.id = id;
@@ -54,47 +54,47 @@ public abstract class Neuron<M extends Writable> implements Writable, NeuronInte
this.layerIndex = index;
}
- public void feedforward(double sum) {
+ public void feedforward(float sum) {
this.output = sum;
}
- public void backpropagate(double gradient) {
+ public void backpropagate(float gradient) {
this.delta = gradient;
}
- public void setDelta(double delta) {
+ public void setDelta(float delta) {
this.delta = delta;
}
- public double getDelta() {
+ public float getDelta() {
return delta;
}
- public void setWeight(double weight) {
+ public void setWeight(float weight) {
this.weight = weight;
}
- public void setOutput(double output) {
+ public void setOutput(float output) {
this.output = output;
}
- public double getOutput() {
+ public float getOutput() {
return output;
}
- public void setMomentumWeight(double momentumWeight) {
+ public void setMomentumWeight(float momentumWeight) {
this.momentumWeight = momentumWeight;
}
- public double getMomentumWeight() {
+ public float getMomentumWeight() {
return momentumWeight;
}
- public void setLearningRate(double learningRate) {
+ public void setLearningRate(float learningRate) {
this.learningRate = learningRate;
}
- public double getLearningRate() {
+ public float getLearningRate() {
return learningRate;
}
@@ -102,49 +102,49 @@ public abstract class Neuron<M extends Writable> implements Writable, NeuronInte
private int i;
- public void push(double weight) {
+ public void push(float weight) {
weights[i++] = weight;
}
- public double getUpdate() {
+ public float getUpdate() {
return weight;
}
- double[] weights;
+ float[] weights;
public void setWeightVector(int rowCount) {
i = 0;
- weights = new double[rowCount];
+ weights = new float[rowCount];
}
- public double[] getWeights() {
+ public float[] getWeights() {
return weights;
}
- public void setSquashingFunction(DoubleFunction squashingFunction) {
+ public void setSquashingFunction(FloatFunction squashingFunction) {
this.squashingFunction = squashingFunction;
}
@Override
public void readFields(DataInput in) throws IOException {
id = in.readInt();
- output = in.readDouble();
- weight = in.readDouble();
- delta = in.readDouble();
+ output = in.readFloat();
+ weight = in.readFloat();
+ delta = in.readFloat();
- momentumWeight = in.readDouble();
- learningRate = in.readDouble();
+ momentumWeight = in.readFloat();
+ learningRate = in.readFloat();
}
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(id);
- out.writeDouble(output);
- out.writeDouble(weight);
- out.writeDouble(delta);
+ out.writeFloat(output);
+ out.writeFloat(weight);
+ out.writeFloat(delta);
- out.writeDouble(momentumWeight);
- out.writeDouble(learningRate);
+ out.writeFloat(momentumWeight);
+ out.writeFloat(learningRate);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/core/ParameterMergerServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/ParameterMergerServer.java b/src/main/java/org/apache/horn/core/ParameterMergerServer.java
index 7bd5543..70fb04f 100644
--- a/src/main/java/org/apache/horn/core/ParameterMergerServer.java
+++ b/src/main/java/org/apache/horn/core/ParameterMergerServer.java
@@ -22,9 +22,6 @@ import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-import org.apache.hama.commons.math.DoubleMatrix;
-
-import com.google.common.base.Preconditions;
public class ParameterMergerServer implements ParameterMerger {
@@ -76,9 +73,8 @@ public class ParameterMergerServer implements ParameterMerger {
}
@Override
- public ParameterMessage merge(
- ParameterMessage msg) {
-
+ public ParameterMessage merge(ParameterMessage msg) {
+/*
double trainingError = msg.getTrainingError();
DoubleMatrix[] weightUpdates = msg.getCurMatrices();
DoubleMatrix[] prevWeightUpdates = msg.getPrevMatrices();
@@ -127,6 +123,8 @@ public class ParameterMergerServer implements ParameterMerger {
return new ParameterMessage(0, this.isConverge.get(),
this.inMemoryModel.getWeightMatrices(),
this.inMemoryModel.getPrevMatricesUpdates());
+ */
+ return null;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/core/ParameterMessage.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/ParameterMessage.java b/src/main/java/org/apache/horn/core/ParameterMessage.java
index 524c443..44697f2 100644
--- a/src/main/java/org/apache/horn/core/ParameterMessage.java
+++ b/src/main/java/org/apache/horn/core/ParameterMessage.java
@@ -22,9 +22,9 @@ import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.Writable;
-import org.apache.hama.commons.io.MatrixWritable;
-import org.apache.hama.commons.math.DenseDoubleMatrix;
-import org.apache.hama.commons.math.DoubleMatrix;
+import org.apache.hama.commons.io.FloatMatrixWritable;
+import org.apache.hama.commons.math.DenseFloatMatrix;
+import org.apache.hama.commons.math.FloatMatrix;
/**
* ParameterMessage transmits the messages between workers and parameter
@@ -33,18 +33,18 @@ import org.apache.hama.commons.math.DoubleMatrix;
*/
public class ParameterMessage implements Writable {
- protected double trainingError;
- protected DoubleMatrix[] curMatrices;
- protected DoubleMatrix[] prevMatrices;
+ protected float trainingError;
+ protected FloatMatrix[] curMatrices;
+ protected FloatMatrix[] prevMatrices;
protected boolean converge;
public ParameterMessage() {
this.converge = false;
- this.trainingError = 0.0d;
+ this.trainingError = 0.0f;
}
- public ParameterMessage(double trainingError, boolean converge,
- DoubleMatrix[] weightMatrices, DoubleMatrix[] prevMatrices) {
+ public ParameterMessage(float trainingError, boolean converge,
+ FloatMatrix[] weightMatrices, FloatMatrix[] prevMatrices) {
this.trainingError = trainingError;
this.converge = converge;
this.curMatrices = weightMatrices;
@@ -53,40 +53,40 @@ public class ParameterMessage implements Writable {
@Override
public void readFields(DataInput input) throws IOException {
- trainingError = input.readDouble();
+ trainingError = input.readFloat();
converge = input.readBoolean();
boolean hasCurMatrices = input.readBoolean();
if(hasCurMatrices) {
int numMatrices = input.readInt();
- curMatrices = new DenseDoubleMatrix[numMatrices];
+ curMatrices = new DenseFloatMatrix[numMatrices];
// read matrice updates
for (int i = 0; i < curMatrices.length; ++i) {
- curMatrices[i] = (DenseDoubleMatrix) MatrixWritable.read(input);
+ curMatrices[i] = (DenseFloatMatrix) FloatMatrixWritable.read(input);
}
}
boolean hasPrevMatrices = input.readBoolean();
if (hasPrevMatrices) {
int numMatrices = input.readInt();
- prevMatrices = new DenseDoubleMatrix[numMatrices];
+ prevMatrices = new DenseFloatMatrix[numMatrices];
// read previous matrices updates
for (int i = 0; i < prevMatrices.length; ++i) {
- prevMatrices[i] = (DenseDoubleMatrix) MatrixWritable.read(input);
+ prevMatrices[i] = (DenseFloatMatrix) FloatMatrixWritable.read(input);
}
}
}
@Override
public void write(DataOutput output) throws IOException {
- output.writeDouble(trainingError);
+ output.writeFloat(trainingError);
output.writeBoolean(converge);
if (curMatrices == null) {
output.writeBoolean(false);
} else {
output.writeBoolean(true);
output.writeInt(curMatrices.length);
- for (DoubleMatrix matrix : curMatrices) {
- MatrixWritable.write(matrix, output);
+ for (FloatMatrix matrix : curMatrices) {
+ FloatMatrixWritable.write(matrix, output);
}
}
@@ -95,8 +95,8 @@ public class ParameterMessage implements Writable {
} else {
output.writeBoolean(true);
output.writeInt(prevMatrices.length);
- for (DoubleMatrix matrix : prevMatrices) {
- MatrixWritable.write(matrix, output);
+ for (FloatMatrix matrix : prevMatrices) {
+ FloatMatrixWritable.write(matrix, output);
}
}
}
@@ -105,7 +105,7 @@ public class ParameterMessage implements Writable {
return trainingError;
}
- public void setTrainingError(double trainingError) {
+ public void setTrainingError(float trainingError) {
this.trainingError = trainingError;
}
@@ -117,19 +117,19 @@ public class ParameterMessage implements Writable {
this.converge = converge;
}
- public DoubleMatrix[] getCurMatrices() {
+ public FloatMatrix[] getCurMatrices() {
return curMatrices;
}
- public void setMatrices(DoubleMatrix[] curMatrices) {
+ public void setMatrices(FloatMatrix[] curMatrices) {
this.curMatrices = curMatrices;
}
- public DoubleMatrix[] getPrevMatrices() {
+ public FloatMatrix[] getPrevMatrices() {
return prevMatrices;
}
- public void setPrevMatrices(DoubleMatrix[] prevMatrices) {
+ public void setPrevMatrices(FloatMatrix[] prevMatrices) {
this.prevMatrices = prevMatrices;
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/core/Synapse.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/Synapse.java b/src/main/java/org/apache/horn/core/Synapse.java
index 6dbada8..7e9db2a 100644
--- a/src/main/java/org/apache/horn/core/Synapse.java
+++ b/src/main/java/org/apache/horn/core/Synapse.java
@@ -21,7 +21,7 @@ import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
-import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.Writable;
/**
@@ -30,16 +30,16 @@ import org.apache.hadoop.io.Writable;
public class Synapse<M extends Writable, W extends Writable> implements
Writable {
- DoubleWritable message;
- DoubleWritable weight;
- DoubleWritable prevWeight;
+ FloatWritable message;
+ FloatWritable weight;
+ FloatWritable prevWeight;
- public Synapse(DoubleWritable message, DoubleWritable weight) {
+ public Synapse(FloatWritable message, FloatWritable weight) {
this.message = message;
this.weight = weight;
}
- public Synapse(DoubleWritable message, DoubleWritable weight, DoubleWritable prevWeight) {
+ public Synapse(FloatWritable message, FloatWritable weight, FloatWritable prevWeight) {
this.message = message;
this.weight = weight;
this.prevWeight = prevWeight;
@@ -48,25 +48,25 @@ public class Synapse<M extends Writable, W extends Writable> implements
/**
* @return the activation or error message
*/
- public double getMessage() {
+ public float getMessage() {
return message.get();
}
- public double getInput() {
+ public float getInput() {
// returns the input
return message.get();
}
- public double getDelta() {
+ public float getDelta() {
// returns the delta
return message.get();
}
- public double getWeight() {
+ public float getWeight() {
return weight.get();
}
- public double getPrevWeight() {
+ public float getPrevWeight() {
return prevWeight.get();
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
index ac17cc4..a787dda 100644
--- a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
+++ b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
@@ -19,7 +19,7 @@ package org.apache.horn.examples;
import java.io.IOException;
-import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.FloatWritable;
import org.apache.hama.HamaConfiguration;
import org.apache.horn.core.Constants.TrainingMethod;
import org.apache.horn.core.HornJob;
@@ -32,14 +32,14 @@ import org.apache.horn.funcs.SoftMax;
public class MultiLayerPerceptron {
public static class StandardNeuron extends
- Neuron<Synapse<DoubleWritable, DoubleWritable>> {
+ Neuron<Synapse<FloatWritable, FloatWritable>> {
@Override
public void forward(
- Iterable<Synapse<DoubleWritable, DoubleWritable>> messages)
+ Iterable<Synapse<FloatWritable, FloatWritable>> messages)
throws IOException {
- double sum = 0;
- for (Synapse<DoubleWritable, DoubleWritable> m : messages) {
+ float sum = 0;
+ for (Synapse<FloatWritable, FloatWritable> m : messages) {
sum += m.getInput() * m.getWeight();
}
this.feedforward(squashingFunction.apply(sum));
@@ -47,15 +47,15 @@ public class MultiLayerPerceptron {
@Override
public void backward(
- Iterable<Synapse<DoubleWritable, DoubleWritable>> messages)
+ Iterable<Synapse<FloatWritable, FloatWritable>> messages)
throws IOException {
- double gradient = 0;
- for (Synapse<DoubleWritable, DoubleWritable> m : messages) {
+ float gradient = 0;
+ for (Synapse<FloatWritable, FloatWritable> m : messages) {
// Calculates error gradient for each neuron
gradient += (m.getDelta() * m.getWeight());
// Weight corrections
- double weight = -this.getLearningRate() * this.getOutput()
+ float weight = -this.getLearningRate() * this.getOutput()
* m.getDelta() + this.getMomentumWeight() * m.getPrevWeight();
this.push(weight);
}
@@ -66,8 +66,8 @@ public class MultiLayerPerceptron {
}
public static HornJob createJob(HamaConfiguration conf, String modelPath,
- String inputPath, double learningRate, double momemtumWeight,
- double regularizationWeight, int features, int hu, int labels,
+ String inputPath, float learningRate, float momemtumWeight,
+ float regularizationWeight, int features, int hu, int labels,
int miniBatch, int maxIteration) throws IOException {
HornJob job = new HornJob(conf, MultiLayerPerceptron.class);
@@ -79,7 +79,7 @@ public class MultiLayerPerceptron {
job.setMomentumWeight(momemtumWeight);
job.setRegularizationWeight(regularizationWeight);
- job.setConvergenceCheckInterval(600);
+ job.setConvergenceCheckInterval(1000);
job.setBatchSize(miniBatch);
job.setTrainingMethod(TrainingMethod.GRADIENT_DESCENT);
@@ -104,8 +104,8 @@ public class MultiLayerPerceptron {
}
HornJob ann = createJob(new HamaConfiguration(), args[0], args[1],
- Double.parseDouble(args[2]), Double.parseDouble(args[3]),
- Double.parseDouble(args[4]), Integer.parseInt(args[5]),
+ Float.parseFloat(args[2]), Float.parseFloat(args[3]),
+ Float.parseFloat(args[4]), Integer.parseInt(args[5]),
Integer.parseInt(args[6]), Integer.parseInt(args[7]),
Integer.parseInt(args[8]), Integer.parseInt(args[9]));
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/funcs/CategoricalCrossEntropy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/funcs/CategoricalCrossEntropy.java b/src/main/java/org/apache/horn/funcs/CategoricalCrossEntropy.java
index 96c228a..887f24d 100644
--- a/src/main/java/org/apache/horn/funcs/CategoricalCrossEntropy.java
+++ b/src/main/java/org/apache/horn/funcs/CategoricalCrossEntropy.java
@@ -17,22 +17,22 @@
*/
package org.apache.horn.funcs;
-import org.apache.hama.commons.math.DoubleDoubleFunction;
+import org.apache.hama.commons.math.FloatFloatFunction;
/**
* for softmaxed output
*/
-public class CategoricalCrossEntropy extends DoubleDoubleFunction {
+public class CategoricalCrossEntropy extends FloatFloatFunction {
- private static final double epsilon = 1e-8;
+ private static final float epsilon = (float) 1e-8;
@Override
- public double apply(double target, double actual) {
- return -target * Math.log(Math.max(actual, epsilon));
+ public float apply(float target, float actual) {
+ return -target * (float) Math.log(Math.max(actual, epsilon));
}
@Override
- public double applyDerivative(double target, double actual) {
+ public float applyDerivative(float target, float actual) {
// o - y
return -(target - actual);
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/funcs/CrossEntropy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/funcs/CrossEntropy.java b/src/main/java/org/apache/horn/funcs/CrossEntropy.java
index a096be0..822b0ba 100644
--- a/src/main/java/org/apache/horn/funcs/CrossEntropy.java
+++ b/src/main/java/org/apache/horn/funcs/CrossEntropy.java
@@ -17,7 +17,7 @@
*/
package org.apache.horn.funcs;
-import org.apache.hama.commons.math.DoubleDoubleFunction;
+import org.apache.hama.commons.math.FloatFloatFunction;
/**
* The cross entropy cost function.
@@ -27,29 +27,23 @@ import org.apache.hama.commons.math.DoubleDoubleFunction;
* where t denotes the target value, y denotes the estimated value.
* </pre>
*/
-public class CrossEntropy extends DoubleDoubleFunction {
+public class CrossEntropy extends FloatFloatFunction {
+
+ private static final float epsilon = 1e-8f;
- private static final double epsilon = 1e-8;
-
@Override
- public double apply(double target, double actual) {
- double adjustedTarget = (target == 0 ? 0.000001 : target);
- adjustedTarget = (target == 1.0 ? 0.999999 : adjustedTarget);
- double adjustedActual = (actual == 0 ? 0.000001 : actual);
- adjustedActual = (actual == 1 ? 0.999999 : adjustedActual);
-
- return -target * Math.log(Math.max(actual, epsilon)) - (1 - target)
- * Math.log(Math.max(1 - actual, epsilon));
- // return -adjustedTarget * Math.log(adjustedActual) - (1 - adjustedTarget) * Math.log(adjustedActual);
+ public float apply(float target, float actual) {
+ return -target * (float) Math.log(Math.max(actual, epsilon)) - (1 - target)
+ * (float) Math.log(Math.max(1 - actual, epsilon));
}
-
+
@Override
- public double applyDerivative(double target, double actual) {
- double adjustedTarget = (target == 0 ? 0.000001 : target);
- adjustedTarget = (target == 1.0 ? 0.999999 : adjustedTarget);
- double adjustedActual = (actual == 0 ? 0.000001 : actual);
- adjustedActual = (actual == 1 ? 0.999999 : adjustedActual);
-
+ public float applyDerivative(float target, float actual) {
+ float adjustedTarget = (target == 0 ? 0.000001f : target);
+ adjustedTarget = (target == 1.0 ? 0.999999f : adjustedTarget);
+ float adjustedActual = (actual == 0 ? 0.000001f : actual);
+ adjustedActual = (actual == 1 ? 0.999999f : adjustedActual);
+
return -adjustedTarget / adjustedActual + (1 - adjustedTarget)
/ (1 - adjustedActual);
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/funcs/FunctionFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/funcs/FunctionFactory.java b/src/main/java/org/apache/horn/funcs/FunctionFactory.java
index 4310a95..41861e9 100644
--- a/src/main/java/org/apache/horn/funcs/FunctionFactory.java
+++ b/src/main/java/org/apache/horn/funcs/FunctionFactory.java
@@ -17,8 +17,8 @@
*/
package org.apache.horn.funcs;
-import org.apache.hama.commons.math.DoubleDoubleFunction;
-import org.apache.hama.commons.math.DoubleFunction;
+import org.apache.hama.commons.math.FloatFloatFunction;
+import org.apache.hama.commons.math.FloatFunction;
/**
* Factory to create the functions.
@@ -32,7 +32,7 @@ public class FunctionFactory {
* @param functionName
* @return an appropriate double function.
*/
- public static DoubleFunction createDoubleFunction(String functionName) {
+ public static FloatFunction createFloatFunction(String functionName) {
if (functionName.equalsIgnoreCase(Sigmoid.class.getSimpleName())) {
return new Sigmoid();
} else if (functionName.equalsIgnoreCase(Tanh.class.getSimpleName())) {
@@ -56,7 +56,7 @@ public class FunctionFactory {
* @param functionName
* @return an appropriate double double function.
*/
- public static DoubleDoubleFunction createDoubleDoubleFunction(
+ public static FloatFloatFunction createFloatFloatFunction(
String functionName) {
if (functionName.equalsIgnoreCase(SquaredError.class.getSimpleName())) {
return new SquaredError();
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/funcs/IdentityFunction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/funcs/IdentityFunction.java b/src/main/java/org/apache/horn/funcs/IdentityFunction.java
index 01e2e67..7ad4771 100644
--- a/src/main/java/org/apache/horn/funcs/IdentityFunction.java
+++ b/src/main/java/org/apache/horn/funcs/IdentityFunction.java
@@ -17,21 +17,21 @@
*/
package org.apache.horn.funcs;
-import org.apache.hama.commons.math.DoubleFunction;
+import org.apache.hama.commons.math.FloatFunction;
/**
* The identity function f(x) = x.
*
*/
-public class IdentityFunction extends DoubleFunction {
+public class IdentityFunction extends FloatFunction {
@Override
- public double apply(double value) {
+ public float apply(float value) {
return value;
}
@Override
- public double applyDerivative(double value) {
+ public float applyDerivative(float value) {
return 1;
}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/af88df41/src/main/java/org/apache/horn/funcs/ReLU.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/funcs/ReLU.java b/src/main/java/org/apache/horn/funcs/ReLU.java
index 85af867..2f14f54 100644
--- a/src/main/java/org/apache/horn/funcs/ReLU.java
+++ b/src/main/java/org/apache/horn/funcs/ReLU.java
@@ -17,7 +17,7 @@
*/
package org.apache.horn.funcs;
-import org.apache.hama.commons.math.DoubleFunction;
+import org.apache.hama.commons.math.FloatFunction;
/**
* The rectifier function
@@ -26,19 +26,19 @@ import org.apache.hama.commons.math.DoubleFunction;
* f(x) = max(0, x)
* </pre>
*/
-public class ReLU extends DoubleFunction {
+public class ReLU extends FloatFunction {
@Override
- public double apply(double value) {
- return Math.max(0.001, value);
+ public float apply(float value) {
+ return Math.max(0.001f, value);
}
@Override
- public double applyDerivative(double value) {
+ public float applyDerivative(float value) {
if (value > 0)
- return 0.999;
+ return 0.999f;
else
- return 0.001;
+ return 0.001f;
}
}