You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2015/10/04 08:21:17 UTC
svn commit: r1706651 - in /labs/yay/trunk: api/src/main/java/org/apache/yay/
core/src/main/java/org/apache/yay/core/
core/src/main/java/org/apache/yay/core/utils/
core/src/test/java/org/apache/yay/core/
Author: tommaso
Date: Sun Oct 4 06:21:16 2015
New Revision: 1706651
URL: http://svn.apache.org/viewvc?rev=1706651&view=rev
Log:
enabling multiple NN output neurons
Added:
labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java
- copied, changed from r1705721, labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
Removed:
labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
Modified:
labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java
labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java
labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java Sun Oct 4 06:21:16 2015
@@ -18,6 +18,8 @@
*/
package org.apache.yay;
+import java.util.List;
+
/**
* In machine learning an hypothesis is the output of applying a learning
* algorithm to a training set, an hypothesis maps new inputs to possible outputs.
@@ -45,7 +47,7 @@ public interface Hypothesis<T, I, O> {
* @return the predicted output
* @throws PredictionException if any error occurs during the prediction phase
*/
- O predict(Input<I> input) throws PredictionException;
+ O[] predict(Input<I> input) throws PredictionException;
/**
* Let this <code>Hypothesis</code> learn by experience, in the form of a {@link org.apache.yay.TrainingSet}
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Sun Oct 4 06:21:16 2015
@@ -27,13 +27,4 @@ import org.apache.commons.math3.linear.R
*/
public interface NeuralNetwork extends Hypothesis<RealMatrix, Double, Double> {
- /**
- * Predict the output for a given input
- *
- * @param input the input to evaluate
- * @return the predicted output
- * @throws PredictionException if any error occurs during the prediction phase
- */
- Double[] getOutputVector(Input<Double> input) throws PredictionException;
-
}
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java Sun Oct 4 06:21:16 2015
@@ -28,6 +28,6 @@ public interface TrainingExample<F, O> e
*
* @return the output
*/
- O getOutput();
+ O[] getOutput();
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Sun Oct 4 06:21:16 2015
@@ -19,13 +19,10 @@
package org.apache.yay.core;
import java.util.Arrays;
-import java.util.DoubleSummaryStatistics;
import java.util.Iterator;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
-import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.commons.math3.linear.RealVector;
import org.apache.yay.CostFunction;
import org.apache.yay.DerivativeUpdateFunction;
import org.apache.yay.LearningStrategy;
@@ -34,7 +31,6 @@ import org.apache.yay.PredictionStrategy
import org.apache.yay.TrainingExample;
import org.apache.yay.TrainingSet;
import org.apache.yay.WeightLearningException;
-import org.apache.yay.core.utils.ConversionUtils;
/**
* Back propagation learning algorithm for neural networks implementation (see
@@ -71,7 +67,7 @@ public class BackPropagationLearningStra
public BackPropagationLearningStrategy() {
// commonly used defaults
- this.predictionStrategy = new FeedForwardStrategy(new SigmoidFunction());
+ this.predictionStrategy = new FeedForwardStrategy(new TanhFunction());
this.costFunction = new LogisticRegressionCostFunction();
this.alpha = DEFAULT_ALPHA;
this.threshold = DEFAULT_THRESHOLD;
@@ -85,7 +81,7 @@ public class BackPropagationLearningStra
try {
int iterations = 0;
- NeuralNetwork hypothesis = NeuralNetworkFactory.create(weightsMatrixSet, new VoidLearningStrategy<Double, Double>(), predictionStrategy, new MaxSelectionFunction<Double>());
+ NeuralNetwork hypothesis = NeuralNetworkFactory.create(weightsMatrixSet, new VoidLearningStrategy<Double, Double>(), predictionStrategy);
Iterator<TrainingExample<Double, Double>> iterator = trainingExamples.iterator();
double cost = Double.MAX_VALUE;
@@ -142,7 +138,10 @@ public class BackPropagationLearningStra
double[][] updatedWeights = weightsMatrixSet[l].getData();
for (int i = 0; i < updatedWeights.length; i++) {
for (int j = 0; j < updatedWeights[i].length; j++) {
- updatedWeights[i][j] = updatedWeights[i][j] - alpha * derivatives[l].getData()[i][j];
+ double curVal = updatedWeights[i][j];
+ if (curVal > 0d && curVal < 1d) {
+ updatedWeights[i][j] = updatedWeights[i][j] - alpha * derivatives[l].getData()[i][j];
+ }
}
}
updatedParameters[l] = new Array2DRowRealMatrix(updatedWeights);
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java Sun Oct 4 06:21:16 2015
@@ -62,7 +62,7 @@ public class BasicPerceptron implements
Collection<Double> doubles = ConversionUtils.toValuesCollection(example.getFeatures());
Double[] inputs = doubles.toArray(new Double[doubles.size()]);
Double calculatedOutput = perceptronNeuron.elaborate(inputs);
- int diff = calculatedOutput.compareTo(example.getOutput());
+ int diff = calculatedOutput.compareTo(example.getOutput()[0]);
if (diff > 0) {
for (int i = 0; i < currentWeights.length; i++) {
currentWeights[i] += inputs[i];
@@ -90,17 +90,10 @@ public class BasicPerceptron implements
}
@Override
- public Double predict(Input<Double> input) throws PredictionException {
- return perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
+ public Double[] predict(Input<Double> input) throws PredictionException {
+ Double output = perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
new Double[input.getFeatures().size()]));
+ return new Double[]{output};
}
- @Override
- public Double[] getOutputVector(Input<Double> input) throws PredictionException {
- Double elaborate = perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
- new Double[input.getFeatures().size()]));
- Double[] ar = new Double[1];
- ar[0] = elaborate;
- return ar;
- }
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java Sun Oct 4 06:21:16 2015
@@ -78,10 +78,6 @@ public class DefaultDerivativeUpdateFunc
count++;
}
- return createDerivatives(triangle, count);
- }
-
- private RealMatrix[] createDerivatives(RealMatrix[] triangle, double count) {
RealMatrix[] derivatives = new RealMatrix[triangle.length];
for (int i = 0; i < triangle.length; i++) {
// TODO : introduce regularization diversification on bias term (currently not regularized)
@@ -111,16 +107,17 @@ public class DefaultDerivativeUpdateFunc
private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealVector[] activations) {
RealVector output = activations[activations.length - 1];
- Double[] sampleOutput = new Double[output.getDimension()];
- int sampleOutputIntValue = trainingExample.getOutput().intValue();
- if (sampleOutputIntValue < sampleOutput.length) {
- sampleOutput[sampleOutputIntValue] = 1d;
- } else if (sampleOutput.length == 1) {
- sampleOutput[0] = trainingExample.getOutput();
- } else {
- throw new RuntimeException("problem with multiclass output mapping");
- }
- RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // turn example output to a vector
+// Double[] sampleOutput = new Double[output.getDimension()];
+ Double[] actualOutput = trainingExample.getOutput();
+// int sampleOutputIntValue = actualOutput.intValue();
+// if (sampleOutputIntValue < sampleOutput.length) {
+// sampleOutput[sampleOutputIntValue] = 1d;
+// } else if (sampleOutput.length == 1) {
+// sampleOutput[0] = actualOutput;
+// } else {
+// throw new RuntimeException("problem with multiclass output mapping");
+// }
+ RealVector learnedOutputRealVector = new ArrayRealVector(actualOutput); // turn example output to a vector
// TODO : improve error calculation -> this could be er_a = out_a * (1 - out_a) * (tgt_a - out_a)
return output.subtract(learnedOutputRealVector);
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java Sun Oct 4 06:21:16 2015
@@ -74,11 +74,12 @@ public class LogisticRegressionCostFunct
Double res = 0d;
for (TrainingExample<Double, Double> input : trainingExamples) {
- // TODO : handle this for multiple outputs (multi class classification)
- Double predictedOutput = hypothesis.predict(input);
- Double sampleOutput = input.getOutput();
- res += sampleOutput * Math.log(predictedOutput) + (1d - sampleOutput)
- * Math.log(1d - predictedOutput);
+ Double[] predictedOutput = hypothesis.predict(input);
+ Double[] sampleOutput = input.getOutput();
+ for (int i = 0; i < predictedOutput.length; i++) {
+ res += sampleOutput[i] * Math.log(predictedOutput[i]) + (1d - sampleOutput[i])
+ * Math.log(1d - predictedOutput[i]);
+ }
}
return (-1d / trainingExamples.length) * res;
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java Sun Oct 4 06:21:16 2015
@@ -18,11 +18,8 @@
*/
package org.apache.yay.core;
-import java.util.ArrayList;
-import java.util.Arrays;
import java.util.Collection;
import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.commons.math3.linear.RealVector;
import org.apache.yay.CreationException;
import org.apache.yay.Input;
import org.apache.yay.LearningException;
@@ -30,7 +27,6 @@ import org.apache.yay.LearningStrategy;
import org.apache.yay.NeuralNetwork;
import org.apache.yay.PredictionException;
import org.apache.yay.PredictionStrategy;
-import org.apache.yay.SelectionFunction;
import org.apache.yay.TrainingSet;
import org.apache.yay.WeightLearningException;
import org.apache.yay.core.utils.ConversionUtils;
@@ -51,12 +47,10 @@ public class NeuralNetworkFactory {
* @throws org.apache.yay.CreationException
*/
public static NeuralNetwork create(final RealMatrix[] realMatrixSet, final LearningStrategy<Double, Double> learningStrategy,
- final PredictionStrategy<Double, Double> predictionStrategy,
- final SelectionFunction<Collection<Double>, Double> selectionFunction) throws CreationException {
+ final PredictionStrategy<Double, Double> predictionStrategy) throws CreationException {
return new NeuralNetwork() {
- @Override
- public Double[] getOutputVector(Input<Double> input) throws PredictionException {
+ private Double[] getOutputVector(Input<Double> input) throws PredictionException {
Collection<Double> inputVector = ConversionUtils.toValuesCollection(input.getFeatures());
return predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet);
}
@@ -83,10 +77,9 @@ public class NeuralNetworkFactory {
}
@Override
- public Double predict(Input<Double> input) throws PredictionException {
+ public Double[] predict(Input<Double> input) throws PredictionException {
try {
- Double[] doubles = getOutputVector(input);
- return selectionFunction.selectOutput(Arrays.asList(doubles));
+ return getOutputVector(input);
} catch (Exception e) {
throw new PredictionException(e);
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java Sun Oct 4 06:21:16 2015
@@ -39,22 +39,22 @@ public class ExamplesFactory {
}
@Override
- public Double getOutput() {
- return output;
+ public Double[] getOutput() {
+ return new Double[]{output};
}
};
}
- public static TrainingExample<Double, Collection<Double[]>> createSGMExample(final Collection<Double[]> output,
+ public static TrainingExample<Double, Double> createDoubleArrayTrainingExample(final Double[] output,
final Double... featuresValues) {
- return new TrainingExample<Double, Collection<Double[]>>() {
+ return new TrainingExample<Double, Double>() {
@Override
public ArrayList<Feature<Double>> getFeatures() {
return doublesToFeatureVector(featuresValues);
}
@Override
- public Collection<Double[]> getOutput() {
+ public Double[] getOutput() {
return output;
}
};
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java Sun Oct 4 06:21:16 2015
@@ -41,9 +41,9 @@ public class BackPropagationLearningStra
public void testLearningWithRandomNetwork() throws Exception {
BackPropagationLearningStrategy backPropagationLearningStrategy = new BackPropagationLearningStrategy();
- RealMatrix[] initialWeights = createRandomWeights();
+ RealMatrix[] initialWeights = createRandomWeights(2);
- Collection<TrainingExample<Double, Double>> samples = createSamples(1000, initialWeights[0].getColumnDimension() - 1);
+ Collection<TrainingExample<Double, Double>> samples = createSamples(1000, initialWeights[0].getColumnDimension() - 1, 2);
TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples);
RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
assertNotNull(learntWeights);
@@ -62,9 +62,9 @@ public class BackPropagationLearningStra
BackPropagationLearningStrategy backPropagationLearningStrategy = new BackPropagationLearningStrategy(alpha, threshold,
predictionStrategy, costFunction);
- RealMatrix[] initialWeights = createRandomWeights();
+ RealMatrix[] initialWeights = createRandomWeights(10);
- Collection<TrainingExample<Double, Double>> samples = createSamples(1000, initialWeights[0].getColumnDimension() - 1);
+ Collection<TrainingExample<Double, Double>> samples = createSamples(1000, initialWeights[0].getColumnDimension() - 1, 10);
TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples);
RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
assertNotNull(learntWeights);
@@ -82,7 +82,7 @@ public class BackPropagationLearningStra
}
}
- private RealMatrix[] createRandomWeights() {
+ private RealMatrix[] createRandomWeights(int outputSize) {
Random r = new Random();
int weightsCount = (Math.abs(r.nextInt()) % 5) + 2;
@@ -95,7 +95,7 @@ public class BackPropagationLearningStra
} else {
cols = initialWeights[i - 1].getRowDimension();
if (i == weightsCount - 1) {
- rows = 1;
+ rows = outputSize;
}
}
double[][] d = new double[rows][cols];
@@ -137,7 +137,7 @@ public class BackPropagationLearningStra
initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 0d}, {1d, 0.5d, 1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}}); // 4 x 4
initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d, 0.5d}}); // 1 x 4
- Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 2);
+ Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 2, 1);
TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples);
RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
assertNotNull(learntWeights);
@@ -169,7 +169,7 @@ public class BackPropagationLearningStra
initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 0d}, {1d, 0.5d, 1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}});
initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d, 0.5d}});
- Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 2);
+ Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 2, 1);
TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples);
RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
assertNotNull(learntWeights);
@@ -201,7 +201,7 @@ public class BackPropagationLearningStra
{1d, Math.random(), Math.random(), Math.random()}
});
- Collection<TrainingExample<Double, Double>> samples = createSamples(1000000, 2);
+ Collection<TrainingExample<Double, Double>> samples = createSamples(1000000, 2, 1);
TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples);
RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
assertNotNull(learntWeights);
@@ -211,14 +211,18 @@ public class BackPropagationLearningStra
assertFalse(learntWeights[2].equals(initialWeights[2]));
}
- private Collection<TrainingExample<Double, Double>> createSamples(int size, int noOfFeatures) {
+ private Collection<TrainingExample<Double, Double>> createSamples(int size, int noOfFeatures, int noOfOutputs) {
Collection<TrainingExample<Double, Double>> trainingExamples = new ArrayList<TrainingExample<Double, Double>>(size);
for (int i = 0; i < size; i++) {
Double[] featureValues = new Double[noOfFeatures];
for (int j = 0; j < noOfFeatures; j++) {
featureValues[j] = Math.random();
}
- trainingExamples.add(ExamplesFactory.createDoubleTrainingExample(1d, featureValues));
+ Double[] outputs = new Double[noOfOutputs];
+ for (int j = 0; j < outputs.length; j++) {
+ outputs[j] = Math.random();
+ }
+ trainingExamples.add(ExamplesFactory.createDoubleArrayTrainingExample(outputs, featureValues));
}
return trainingExamples;
}
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java Sun Oct 4 06:21:16 2015
@@ -86,7 +86,7 @@ public class BasicPerceptronTest {
r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(),
r.nextDouble(), r.nextDouble());
basicPerceptron.learn(bigDataset);
- Double output = basicPerceptron.predict(createInput(r));
+ Double output = basicPerceptron.predict(createInput(r))[0];
assertTrue(output == 0d || output == 1d);
}
@@ -102,7 +102,7 @@ public class BasicPerceptronTest {
r.nextDouble(), r.nextDouble());
basicPerceptron.learn(bigDataset);
TrainingExample<Double, Double> input = createInput(r);
- Double output = basicPerceptron.predict(input);
+ Double output = basicPerceptron.predict(input)[0];
assertTrue(output == 0d || output == 1d);
basicPerceptron.learn(createTrainingExample(1d, r.nextDouble(),
r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(),
@@ -110,7 +110,7 @@ public class BasicPerceptronTest {
r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(),
r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(),
r.nextDouble(), r.nextDouble()));
- Double secondOutput = basicPerceptron.predict(input);
+ Double secondOutput = basicPerceptron.predict(input)[0];
assertTrue(secondOutput == 0d || secondOutput == 1d);
}
@@ -135,7 +135,7 @@ public class BasicPerceptronTest {
public void testLearnAndPredictWithSmallDataset() throws Exception {
BasicPerceptron basicPerceptron = new BasicPerceptron(1d, 2d, 3d, 4d);
basicPerceptron.learn(smallDataset);
- Double output = basicPerceptron.predict(createTrainingExample(null, 1d, 6d, 0.4d));
+ Double output = basicPerceptron.predict(createTrainingExample(null, 1d, 6d, 0.4d))[0];
assertEquals(Double.valueOf(1d), output);
}
@@ -157,8 +157,8 @@ public class BasicPerceptronTest {
}
@Override
- public Double getOutput() {
- return output;
+ public Double[] getOutput() {
+ return new Double[]{output};
}
};
}
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java Sun Oct 4 06:21:16 2015
@@ -64,8 +64,7 @@ public class LogisticRegressionCostFunct
final RealMatrix[] orWeightsMatrixSet = new RealMatrix[]{singleOrLayerWeights};
final NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(orWeightsMatrixSet,
- new VoidLearningStrategy<Double, Double>(), new FeedForwardStrategy(
- new SigmoidFunction()), new MaxSelectionFunction<Double>());
+ new VoidLearningStrategy<Double, Double>(), new FeedForwardStrategy(new SigmoidFunction()));
Double cost = costFunction.calculateAggregatedCost(trainingSet, neuralNetwork);
assertTrue("cost should not be negative", cost > 0d);
Copied: labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java (from r1705721, labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java)
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java?p2=labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java&p1=labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java&r1=1705721&r2=1706651&rev=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java Sun Oct 4 06:21:16 2015
@@ -19,31 +19,39 @@
package org.apache.yay.core;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Random;
+
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.yay.CreationException;
import org.apache.yay.Feature;
import org.apache.yay.Input;
+import org.apache.yay.LearningStrategy;
import org.apache.yay.NeuralNetwork;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+import org.apache.yay.core.utils.ExamplesFactory;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
/**
- * Testcase for {@link org.apache.yay.core.NeuralNetworkFactory}
+ * Integration test for NN
*/
-public class NeuralNetworkFactoryTest {
+public class NeuralNetworkIntegrationTest {
@Test
public void andNNCreationTest() throws Exception {
double[][] weights = {{-30d, 20d, 20d}};
RealMatrix singleAndLayerWeights = new Array2DRowRealMatrix(weights);
RealMatrix[] andRealMatrixSet = new RealMatrix[]{singleAndLayerWeights};
- NeuralNetwork andNN = createFFNN(andRealMatrixSet);
- assertEquals(0l, Math.round(andNN.predict(createSample(1d, 0d))));
- assertEquals(0l, Math.round(andNN.predict(createSample(0d, 1d))));
- assertEquals(0l, Math.round(andNN.predict(createSample(0d, 0d))));
- assertEquals(1l, Math.round(andNN.predict(createSample(1d, 1d))));
+ NeuralNetwork andNN = createNN(andRealMatrixSet, new VoidLearningStrategy<Double, Double>());
+ assertEquals(0l, Math.round(andNN.predict(createSample(1d, 0d))[0]));
+ assertEquals(0l, Math.round(andNN.predict(createSample(0d, 1d))[0]));
+ assertEquals(0l, Math.round(andNN.predict(createSample(0d, 0d))[0]));
+ assertEquals(1l, Math.round(andNN.predict(createSample(1d, 1d))[0]));
}
@Test
@@ -51,11 +59,11 @@ public class NeuralNetworkFactoryTest {
double[][] weights = {{-10d, 20d, 20d}};
RealMatrix singleOrLayerWeights = new Array2DRowRealMatrix(weights);
RealMatrix[] orRealMatrixSet = new RealMatrix[]{singleOrLayerWeights};
- NeuralNetwork orNN = createFFNN(orRealMatrixSet);
- assertEquals(1l, Math.round(orNN.predict(createSample(1d, 0d))));
- assertEquals(1l, Math.round(orNN.predict(createSample(0d, 1d))));
- assertEquals(0l, Math.round(orNN.predict(createSample(0d, 0d))));
- assertEquals(1l, Math.round(orNN.predict(createSample(1d, 1d))));
+ NeuralNetwork orNN = createNN(orRealMatrixSet, new VoidLearningStrategy<Double, Double>());
+ assertEquals(1l, Math.round(orNN.predict(createSample(1d, 0d))[0]));
+ assertEquals(1l, Math.round(orNN.predict(createSample(0d, 1d))[0]));
+ assertEquals(0l, Math.round(orNN.predict(createSample(0d, 0d))[0]));
+ assertEquals(1l, Math.round(orNN.predict(createSample(1d, 1d))[0]));
}
@Test
@@ -63,9 +71,9 @@ public class NeuralNetworkFactoryTest {
double[][] weights = {{10d, -20d}};
RealMatrix singleNotLayerWeights = new Array2DRowRealMatrix(weights);
RealMatrix[] notRealMatrixSet = new RealMatrix[]{singleNotLayerWeights};
- NeuralNetwork orNN = createFFNN(notRealMatrixSet);
- assertEquals(1l, Math.round(orNN.predict(createSample(0d))));
- assertEquals(0l, Math.round(orNN.predict(createSample(1d))));
+ NeuralNetwork orNN = createNN(notRealMatrixSet, new VoidLearningStrategy<Double, Double>());
+ assertEquals(1l, Math.round(orNN.predict(createSample(0d))[0]));
+ assertEquals(0l, Math.round(orNN.predict(createSample(1d))[0]));
}
@Test
@@ -73,11 +81,11 @@ public class NeuralNetworkFactoryTest {
RealMatrix firstNorLayerWeights = new Array2DRowRealMatrix(new double[][]{{0, 0, 0}, {-30d, 20d, 20d}, {10d, -20d, -20d}});
RealMatrix secondNorLayerWeights = new Array2DRowRealMatrix(new double[][]{{-10d, 20d, 20d}});
RealMatrix[] norRealMatrixSet = new RealMatrix[]{firstNorLayerWeights, secondNorLayerWeights};
- NeuralNetwork norNN = createFFNN(norRealMatrixSet);
- assertEquals(0l, Math.round(norNN.predict(createSample(1d, 0d))));
- assertEquals(0l, Math.round(norNN.predict(createSample(0d, 1d))));
- assertEquals(1l, Math.round(norNN.predict(createSample(0d, 0d))));
- assertEquals(1l, Math.round(norNN.predict(createSample(1d, 1d))));
+ NeuralNetwork norNN = createNN(norRealMatrixSet, new VoidLearningStrategy<Double, Double>());
+ assertEquals(0l, Math.round(norNN.predict(createSample(1d, 0d))[0]));
+ assertEquals(0l, Math.round(norNN.predict(createSample(0d, 1d))[0]));
+ assertEquals(1l, Math.round(norNN.predict(createSample(0d, 0d))[0]));
+ assertEquals(1l, Math.round(norNN.predict(createSample(1d, 1d))[0]));
}
@Test
@@ -87,17 +95,17 @@ public class NeuralNetworkFactoryTest {
RealMatrix[] RealMatrixes = new RealMatrix[]{firstLayer, secondLayer};
- NeuralNetwork neuralNetwork = createFFNN(RealMatrixes);
+ NeuralNetwork neuralNetwork = createNN(RealMatrixes, new VoidLearningStrategy<Double, Double>());
- Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d, 7d));
+ Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d, 7d))[0];
assertEquals(1l, Math.round(prdictedValue));
assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue);
}
- private NeuralNetwork createFFNN(RealMatrix[] realMatrixes)
+ private NeuralNetwork createNN(RealMatrix[] realMatrixes, LearningStrategy<Double, Double> learningStrategy)
throws CreationException {
- return NeuralNetworkFactory.create(realMatrixes, new VoidLearningStrategy<Double, Double>(),
- new FeedForwardStrategy(new SigmoidFunction()), new MaxSelectionFunction<Double>());
+ return NeuralNetworkFactory.create(realMatrixes, learningStrategy,
+ new FeedForwardStrategy(new SigmoidFunction()));
}
private Input<Double> createSample(final Double... params) {
@@ -117,4 +125,87 @@ public class NeuralNetworkFactoryTest {
}
};
}
+
+ @Test
+ public void testRandomNNEvaluation() throws Exception {
+ int noOfOutputs = 10;
+ RealMatrix[] randomWeights = createRandomWeights(noOfOutputs);
+ NeuralNetwork nn = createNN(randomWeights, new BackPropagationLearningStrategy());
+ int noOfFeatures = randomWeights[0].getColumnDimension() - 1;
+ Collection<TrainingExample<Double, Double>> samples = createSamples(1000000, noOfFeatures, noOfOutputs);
+ nn.learn(new TrainingSet<Double, Double>(samples));
+ for (TrainingExample<Double, Double> sample : samples) {
+ Double[] predictedOutput = nn.predict(sample);
+ Double[] expectedOutput = sample.getOutput();
+ boolean equals = Arrays.equals(expectedOutput, predictedOutput);
+// if (!equals) {
+// System.err.println(Arrays.toString(expectedOutput) + " vs " + Arrays.toString(predictedOutput));
+// } else {
+// System.err.println("equals!");
+// }
+ }
+
+ }
+
+ private Collection<TrainingExample<Double, Double>> createSamples(int size, int noOfFeatures, int noOfOutputs) {
+ Random r = new Random();
+ Collection<TrainingExample<Double, Double>> trainingExamples = new ArrayList<TrainingExample<Double, Double>>(size);
+ for (int i = 0; i < size; i++) {
+ Double[] featureValues = new Double[noOfFeatures];
+ for (int j = 0; j < noOfFeatures; j++) {
+ featureValues[j] = r.nextInt(100) / 101d;
+ }
+ Double[] outputs = new Double[noOfOutputs];
+ for (int j = 0; j < outputs.length; j++) {
+ outputs[j] = r.nextInt(100) / 101d;
+ }
+ trainingExamples.add(ExamplesFactory.createDoubleArrayTrainingExample(outputs, featureValues));
+ }
+ return trainingExamples;
+ }
+
+ private RealMatrix[] createRandomWeights(int outputSize) {
+ Random r = new Random();
+ int weightsCount = (Math.abs(r.nextInt()) % 5) + 2;
+
+ RealMatrix[] initialWeights = new RealMatrix[weightsCount];
+ for (int i = 0; i < weightsCount; i++) {
+ int rows = (Math.abs(r.nextInt()) % 4) + 2;
+ int cols;
+ if (i == 0) {
+ cols = (Math.abs(r.nextInt()) % 4) + 2;
+ } else {
+ cols = initialWeights[i - 1].getRowDimension();
+ if (i == weightsCount - 1) {
+ rows = outputSize;
+ }
+ }
+ double[][] d = new double[rows][cols];
+ for (int c = 0; c < cols; c++) {
+ if (i == weightsCount - 1) {
+ if (c == 0) {
+ d[0][c] = 1d;
+ } else {
+ d[0][c] = r.nextDouble();
+ }
+ } else {
+ d[0][c] = 0;
+ }
+ }
+
+ for (int k = 1; k < rows; k++) {
+ for (int j = 0; j < cols; j++) {
+ double val;
+ if (j == 0) {
+ val = 1d;
+ } else {
+ val = r.nextDouble();
+ }
+ d[k][j] = val;
+ }
+ }
+ initialWeights[i] = new Array2DRowRealMatrix(d);
+ }
+ return initialWeights;
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org