You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hama.apache.org by yx...@apache.org on 2013/12/10 14:41:44 UTC
svn commit: r1549842 - in /hama/trunk: ./
ml/src/main/java/org/apache/hama/ml/ann/
ml/src/test/java/org/apache/hama/ml/ann/
Author: yxjiang
Date: Tue Dec 10 13:41:43 2013
New Revision: 1549842
URL: http://svn.apache.org/r1549842
Log:
HAMA-828: Improve code, fix typo and modify unclear comment in org.apache.hama.ml.ann package
Modified:
hama/trunk/CHANGES.txt
hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/NeuralNetwork.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetwork.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkMessage.java
hama/trunk/ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetwork.java
Modified: hama/trunk/CHANGES.txt
URL: http://svn.apache.org/viewvc/hama/trunk/CHANGES.txt?rev=1549842&r1=1549841&r2=1549842&view=diff
==============================================================================
--- hama/trunk/CHANGES.txt (original)
+++ hama/trunk/CHANGES.txt Tue Dec 10 13:41:43 2013
@@ -16,6 +16,7 @@ Release 0.7.0 (unreleased changes)
IMPROVEMENTS
+ HAMA-828: Improve code, fix typo and modify unclear comment in org.apache.hama.ml.ann package (Yexi Jiang)
HAMA-699: Add commons module (Martin Illecker)
HAMA-818: Remove useless comments in GroomServer (edwardyoon)
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork.java?rev=1549842&r1=1549841&r2=1549842&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork.java Tue Dec 10 13:41:43 2013
@@ -31,6 +31,7 @@ import org.apache.hama.commons.math.Doub
import org.apache.hama.commons.math.FunctionFactory;
import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
/**
* AbstractLayeredNeuralNetwork defines the general operations for derivative
@@ -66,7 +67,7 @@ abstract class AbstractLayeredNeuralNetw
protected LearningStyle learningStyle;
public static enum TrainingMethod {
- GRADIATE_DESCENT
+ GRADIENT_DESCENT
}
public static enum LearningStyle {
@@ -77,7 +78,7 @@ abstract class AbstractLayeredNeuralNetw
public AbstractLayeredNeuralNetwork() {
this.regularizationWeight = DEFAULT_REGULARIZATION_WEIGHT;
this.momentumWeight = DEFAULT_MOMENTUM_WEIGHT;
- this.trainingMethod = TrainingMethod.GRADIATE_DESCENT;
+ this.trainingMethod = TrainingMethod.GRADIENT_DESCENT;
this.learningStyle = LearningStyle.SUPERVISED;
}
@@ -229,7 +230,7 @@ abstract class AbstractLayeredNeuralNetw
// read layer size list
int numLayers = input.readInt();
- this.layerSizeList = new ArrayList<Integer>();
+ this.layerSizeList = Lists.newArrayList();
for (int i = 0; i < numLayers; ++i) {
this.layerSizeList.add(input.readInt());
}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/NeuralNetwork.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/NeuralNetwork.java?rev=1549842&r1=1549841&r2=1549842&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/NeuralNetwork.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/NeuralNetwork.java Tue Dec 10 13:41:43 2013
@@ -39,6 +39,7 @@ import org.apache.hama.ml.util.DefaultFe
import org.apache.hama.ml.util.FeatureTransformer;
import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
/**
* NeuralNetwork defines the general operations for all the derivative models.
@@ -85,7 +86,7 @@ abstract class NeuralNetwork implements
*/
public void setLearningRate(double learningRate) {
Preconditions.checkArgument(learningRate > 0,
- "Learning rate must larger than 0.");
+ "Learning rate must be larger than 0.");
this.learningRate = learningRate;
}
@@ -144,13 +145,16 @@ abstract class NeuralNetwork implements
Preconditions.checkArgument(this.modelPath != null,
"Model path has not been set.");
Configuration conf = new Configuration();
+ FSDataInputStream is = null;
try {
URI uri = new URI(this.modelPath);
FileSystem fs = FileSystem.get(uri, conf);
- FSDataInputStream is = new FSDataInputStream(fs.open(new Path(modelPath)));
+ is = new FSDataInputStream(fs.open(new Path(modelPath)));
this.readFields(is);
} catch (URISyntaxException e) {
e.printStackTrace();
+ } finally {
+ Closeables.close(is, false);
}
}
@@ -164,10 +168,17 @@ abstract class NeuralNetwork implements
Preconditions.checkArgument(this.modelPath != null,
"Model path has not been set.");
Configuration conf = new Configuration();
- FileSystem fs = FileSystem.get(conf);
- FSDataOutputStream stream = fs.create(new Path(this.modelPath), true);
- this.write(stream);
- stream.close();
+ FSDataOutputStream is = null;
+ try {
+ URI uri = new URI(this.modelPath);
+ FileSystem fs = FileSystem.get(uri, conf);
+ is = fs.create(new Path(this.modelPath), true);
+ this.write(is);
+ } catch (URISyntaxException e) {
+ e.printStackTrace();
+ }
+
+ Closeables.close(is, false);
}
/**
@@ -215,7 +226,7 @@ abstract class NeuralNetwork implements
Constructor[] constructors = featureTransformerCls
.getDeclaredConstructors();
Constructor constructor = constructors[0];
-
+
try {
this.featureTransformer = (FeatureTransformer) constructor
.newInstance(new Object[] {});
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetwork.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetwork.java?rev=1549842&r1=1549841&r2=1549842&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetwork.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetwork.java Tue Dec 10 13:41:43 2013
@@ -23,8 +23,8 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
-import java.util.Random;
+import org.apache.commons.lang.math.RandomUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
@@ -43,6 +43,7 @@ import org.apache.hama.commons.math.Func
import org.mortbay.log.Log;
import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
/**
* SmallLayeredNeuralNetwork defines the general operations for derivative
@@ -70,10 +71,10 @@ public class SmallLayeredNeuralNetwork e
protected int finalLayerIdx;
public SmallLayeredNeuralNetwork() {
- this.layerSizeList = new ArrayList<Integer>();
- this.weightMatrixList = new ArrayList<DoubleMatrix>();
- this.prevWeightUpdatesList = new ArrayList<DoubleMatrix>();
- this.squashingFunctionList = new ArrayList<DoubleFunction>();
+ this.layerSizeList = Lists.newArrayList();
+ this.weightMatrixList = Lists.newArrayList();
+ this.prevWeightUpdatesList = Lists.newArrayList();
+ this.squashingFunctionList = Lists.newArrayList();
}
public SmallLayeredNeuralNetwork(String modelPath) {
@@ -86,7 +87,8 @@ public class SmallLayeredNeuralNetwork e
*/
public int addLayer(int size, boolean isFinalLayer,
DoubleFunction squashingFunction) {
- Preconditions.checkArgument(size > 0, "Size of layer must larger than 0.");
+ Preconditions.checkArgument(size > 0,
+ "Size of layer must be larger than 0.");
if (!isFinalLayer) {
size += 1;
}
@@ -107,11 +109,10 @@ public class SmallLayeredNeuralNetwork e
int col = sizePrevLayer;
DoubleMatrix weightMatrix = new DenseDoubleMatrix(row, col);
// initialize weights
- final Random rnd = new Random();
weightMatrix.applyToElements(new DoubleFunction() {
@Override
public double apply(double value) {
- return rnd.nextDouble() - 0.5;
+ return RandomUtils.nextDouble() - 0.5;
}
@Override
@@ -138,6 +139,10 @@ public class SmallLayeredNeuralNetwork e
}
}
+ /**
+ * Set the previous weight matrices.
+ * @param prevUpdates
+ */
void setPrevWeightMatrices(DoubleMatrix[] prevUpdates) {
this.prevWeightUpdatesList.clear();
for (DoubleMatrix prevUpdate : prevUpdates) {
@@ -176,8 +181,8 @@ public class SmallLayeredNeuralNetwork e
*/
public void setWeightMatrices(DoubleMatrix[] matrices) {
this.weightMatrixList = new ArrayList<DoubleMatrix>();
- for (int i = 0; i < matrices.length; ++i) {
- this.weightMatrixList.add(matrices[i]);
+ for (DoubleMatrix matrix : matrices) {
+ this.weightMatrixList.add(matrix);
}
}
@@ -197,8 +202,9 @@ public class SmallLayeredNeuralNetwork e
public void setWeightMatrix(int index, DoubleMatrix matrix) {
Preconditions.checkArgument(
- 0 <= index && index < this.weightMatrixList.size(),
- String.format("index [%d] out of range.", index));
+ 0 <= index && index < this.weightMatrixList.size(), String.format(
+ "index [%d] should be in range[%d, %d].", index, 0,
+ this.weightMatrixList.size()));
this.weightMatrixList.set(index, matrix);
}
@@ -208,7 +214,7 @@ public class SmallLayeredNeuralNetwork e
// read squash functions
int squashingFunctionSize = input.readInt();
- this.squashingFunctionList = new ArrayList<DoubleFunction>();
+ this.squashingFunctionList = Lists.newArrayList();
for (int i = 0; i < squashingFunctionSize; ++i) {
this.squashingFunctionList.add(FunctionFactory
.createDoubleFunction(WritableUtils.readString(input)));
@@ -216,8 +222,8 @@ public class SmallLayeredNeuralNetwork e
// read weights and construct matrices of previous updates
int numOfMatrices = input.readInt();
- this.weightMatrixList = new ArrayList<DoubleMatrix>();
- this.prevWeightUpdatesList = new ArrayList<DoubleMatrix>();
+ this.weightMatrixList = Lists.newArrayList();
+ this.prevWeightUpdatesList = Lists.newArrayList();
for (int i = 0; i < numOfMatrices; ++i) {
DoubleMatrix matrix = MatrixWritable.read(input);
this.weightMatrixList.add(matrix);
@@ -257,8 +263,8 @@ public class SmallLayeredNeuralNetwork e
*/
@Override
public DoubleVector getOutput(DoubleVector instance) {
- Preconditions.checkArgument(this.layerSizeList.get(0) == instance
- .getDimension() + 1, String.format(
+ Preconditions.checkArgument(this.layerSizeList.get(0) - 1 == instance
+ .getDimension(), String.format(
"The dimension of input instance should be %d.",
this.layerSizeList.get(0) - 1));
// transform the features to another space
@@ -336,8 +342,6 @@ public class SmallLayeredNeuralNetwork e
public DoubleMatrix[] trainByInstance(DoubleVector trainingInstance) {
DoubleVector transformedVector = this.featureTransformer
.transform(trainingInstance.sliceUnsafe(this.layerSizeList.get(0) - 1));
-
-
int inputDimension = this.layerSizeList.get(0) - 1;
int outputDimension;
@@ -389,11 +393,12 @@ public class SmallLayeredNeuralNetwork e
calculateTrainingError(labels,
output.deepCopy().sliceUnsafe(1, output.getDimension() - 1));
- if (this.trainingMethod.equals(TrainingMethod.GRADIATE_DESCENT)) {
+ if (this.trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) {
return this.trainByInstanceGradientDescent(labels, internalResults);
+ } else {
+ throw new IllegalArgumentException(
+ String.format("Training method is not supported."));
}
- throw new IllegalArgumentException(
- String.format("Training method is not supported."));
}
/**
@@ -483,9 +488,6 @@ public class SmallLayeredNeuralNetwork e
* squashingFunction.applyDerivative(curLayerOutput.get(i)));
}
- // System.out.printf("Delta layer: %d, %s\n", curLayerIdx,
- // delta.toString());
-
// update weights
for (int i = 0; i < weightUpdateMatrix.getRowCount(); ++i) {
for (int j = 0; j < weightUpdateMatrix.getColumnCount(); ++j) {
@@ -495,9 +497,6 @@ public class SmallLayeredNeuralNetwork e
}
}
- // System.out.printf("Weight Layer %d, %s\n", curLayerIdx,
- // weightUpdateMatrix.toString());
-
return delta;
}
@@ -556,9 +555,7 @@ public class SmallLayeredNeuralNetwork e
protected void calculateTrainingError(DoubleVector labels, DoubleVector output) {
DoubleVector errors = labels.deepCopy().applyToElements(output,
this.costFunction);
- // System.out.printf("Labels: %s\tOutput: %s\n", labels, output);
this.trainingError = errors.sum();
- // System.out.printf("Training error: %s\n", errors);
}
/**
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkMessage.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkMessage.java?rev=1549842&r1=1549841&r2=1549842&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkMessage.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkMessage.java Tue Dec 10 13:41:43 2013
@@ -78,12 +78,12 @@ public class SmallLayeredNeuralNetworkMe
} else {
output.writeBoolean(true);
}
- for (int i = 0; i < curMatrices.length; ++i) {
- MatrixWritable.write(curMatrices[i], output);
+ for (DoubleMatrix matrix : curMatrices) {
+ MatrixWritable.write(matrix, output);
}
if (prevMatrices != null) {
- for (int i = 0; i < prevMatrices.length; ++i) {
- MatrixWritable.write(prevMatrices[i], output);
+ for (DoubleMatrix matrix : prevMatrices) {
+ MatrixWritable.write(matrix, output);
}
}
}
Modified: hama/trunk/ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetwork.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetwork.java?rev=1549842&r1=1549841&r2=1549842&view=diff
==============================================================================
--- hama/trunk/ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetwork.java (original)
+++ hama/trunk/ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetwork.java Tue Dec 10 13:41:43 2013
@@ -103,7 +103,7 @@ public class TestSmallLayeredNeuralNetwo
assertEquals(momentumWeight, annCopy.getMomemtumWeight(), 0.000001);
assertEquals(regularizationWeight, annCopy.getRegularizationWeight(),
0.000001);
- assertEquals(TrainingMethod.GRADIATE_DESCENT, annCopy.getTrainingMethod());
+ assertEquals(TrainingMethod.GRADIENT_DESCENT, annCopy.getTrainingMethod());
assertEquals(LearningStyle.UNSUPERVISED, annCopy.getLearningStyle());
// compare weights