You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2012/05/29 17:01:08 UTC
svn commit: r1343764 - in /labs/yay/trunk/core/src:
main/java/org/apache/yay/ test/java/org/apache/yay/
Author: tommaso
Date: Tue May 29 15:01:07 2012
New Revision: 1343764
URL: http://svn.apache.org/viewvc?rev=1343764&view=rev
Log:
changing NNF signature: drop Set in favor of Array to preserve order, added NOR test
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/LearningStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/VoidLearningStrategy.java
labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java Tue May 29 15:01:07 2012
@@ -33,7 +33,7 @@ public class BackPropagationLearningStra
}
@Override
- public Set<WeightsMatrix> learnWeights(Set<WeightsMatrix> weightsMatrixSet, Collection<TrainingExample<Long, Long>> trainingExamples) throws WeightLearningException {
+ public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<Long, Long>> trainingExamples) throws WeightLearningException {
for (TrainingExample<Long, Long> trainingExample : trainingExamples) {
try {
Long output = neuralNetwork.predict(trainingExample);
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java Tue May 29 15:01:07 2012
@@ -51,7 +51,7 @@ public class FeedForwardStrategy impleme
@Override
- public Double predictOutput(Vector<Double> input, Set<WeightsMatrix> weightsMatrixSet) {
+ public Double predictOutput(Vector<Double> input, WeightsMatrix[] weightsMatrixSet) {
// TODO : fix this impl as it's very slow and commons-math Java1.4 constraint is so ugly to see...
RealVector v = matrixConverter.toRealVector(input);
RealMatrix x = v.outerProduct(new ArrayRealVector(new Double[]{1d})).transpose(); // a 1xN matrix
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/LearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/LearningStrategy.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/LearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/LearningStrategy.java Tue May 29 15:01:07 2012
@@ -26,7 +26,7 @@ import java.util.Set;
*/
public interface LearningStrategy<F, O> {
- public Set<WeightsMatrix> learnWeights(Set<WeightsMatrix> weightsMatrixSet, Collection<TrainingExample<F, O>>
+ public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<F, O>>
trainingExamples) throws WeightLearningException;
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java Tue May 29 15:01:07 2012
@@ -32,18 +32,18 @@ public class NeuralNetworkFactory {
* by a set of matrices, the learning and prediction strategies to be used.
*
* @param trainingExamples the training set
- * @param weightsMatrixSet the initial settings for weights matrixes
+ * @param weightsMatrixSet the initial settings for weights matrices
* @param learningStrategy a learning strategy
* @param predictionStrategy a prediction strategy
* @return a NeuralNetwork instance
* @throws CreationException
*/
public static NeuralNetwork<Double, Double> create(final Collection<TrainingExample<Double, Double>> trainingExamples,
- final Set<WeightsMatrix> weightsMatrixSet, final LearningStrategy learningStrategy,
+ final WeightsMatrix[] weightsMatrixSet, final LearningStrategy learningStrategy,
final PredictionStrategy<Double, Double> predictionStrategy) throws CreationException {
NeuralNetwork<Double, Double> neuralNetwork = new NeuralNetwork<Double, Double>() {
- private Set<WeightsMatrix> updatedWeightsMatrixSet = weightsMatrixSet;
+ private WeightsMatrix[] updatedWeightsMatrixSet = weightsMatrixSet;
@Override
public void learn(TrainingExample<Double, Double>... samples) throws LearningException {
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java Tue May 29 15:01:07 2012
@@ -18,13 +18,12 @@
*/
package org.apache.yay;
-import java.util.Set;
import java.util.Vector;
/**
*/
public interface PredictionStrategy<I, O> {
- public O predictOutput(Vector<I> input, Set<WeightsMatrix> weightsMatrixSet);
+ public O predictOutput(Vector<I> input, WeightsMatrix[] weightsMatrixSet);
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/VoidLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/VoidLearningStrategy.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/VoidLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/VoidLearningStrategy.java Tue May 29 15:01:07 2012
@@ -27,7 +27,7 @@ import java.util.Set;
public class VoidLearningStrategy<F, O> implements LearningStrategy<F, O> {
@Override
- public Set<WeightsMatrix> learnWeights(Set<WeightsMatrix> weightsMatrixSet, Collection<TrainingExample<F, O>> trainingExamples) throws WeightLearningException {
+ public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<F, O>> trainingExamples) throws WeightLearningException {
return weightsMatrixSet;
}
}
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java Tue May 29 15:01:07 2012
@@ -20,9 +20,7 @@ package org.apache.yay;
import org.junit.Test;
-import java.util.HashSet;
import java.util.LinkedList;
-import java.util.Set;
import java.util.Vector;
import static org.junit.Assert.assertEquals;
@@ -34,10 +32,9 @@ public class NeuralNetworkFactoryTest {
@Test
public void andNNCreationTest() throws Exception {
- Set<WeightsMatrix> andWeightsMatrixSet = new HashSet<WeightsMatrix>();
double[][] weights = {{-30d, 20d, 20d}};
WeightsMatrix singleAndLayerWeights = new WeightsMatrix(weights);
- andWeightsMatrixSet.add(singleAndLayerWeights);
+ WeightsMatrix[] andWeightsMatrixSet = new WeightsMatrix[]{singleAndLayerWeights};
NeuralNetwork<Double,Double> andNN = NeuralNetworkFactory.create(new LinkedList<TrainingExample<Double, Double>>(), andWeightsMatrixSet, new VoidLearningStrategy(), new FeedForwardStrategy(new SigmoidFunction()));
assertEquals(0l, Math.round(andNN.predict(createSample(1d, 0d))));
assertEquals(0l, Math.round(andNN.predict(createSample(0d, 1d))));
@@ -47,10 +44,9 @@ public class NeuralNetworkFactoryTest {
@Test
public void orNNCreationTest() throws Exception {
- Set<WeightsMatrix> orWeightsMatrixSet = new HashSet<WeightsMatrix>();
double[][] weights = {{-10d, 20d, 20d}};
WeightsMatrix singleOrLayerWeights = new WeightsMatrix(weights);
- orWeightsMatrixSet.add(singleOrLayerWeights);
+ WeightsMatrix[] orWeightsMatrixSet = new WeightsMatrix[]{singleOrLayerWeights};
NeuralNetwork<Double,Double> orNN = NeuralNetworkFactory.create(new LinkedList<TrainingExample<Double, Double>>(), orWeightsMatrixSet, new VoidLearningStrategy(), new FeedForwardStrategy(new SigmoidFunction()));
assertEquals(1l, Math.round(orNN.predict(createSample(1d, 0d))));
assertEquals(1l, Math.round(orNN.predict(createSample(0d, 1d))));
@@ -60,15 +56,26 @@ public class NeuralNetworkFactoryTest {
@Test
public void notNNCreationTest() throws Exception {
- Set<WeightsMatrix> notWeightsMatrixSet = new HashSet<WeightsMatrix>();
double[][] weights = {{10d, -20d}};
WeightsMatrix singleNotLayerWeights = new WeightsMatrix(weights);
- notWeightsMatrixSet.add(singleNotLayerWeights);
+ WeightsMatrix[] notWeightsMatrixSet = new WeightsMatrix[]{singleNotLayerWeights};
NeuralNetwork<Double,Double> orNN = NeuralNetworkFactory.create(new LinkedList<TrainingExample<Double, Double>>(), notWeightsMatrixSet, new VoidLearningStrategy(), new FeedForwardStrategy(new SigmoidFunction()));
assertEquals(1l, Math.round(orNN.predict(createSample(0d))));
assertEquals(0l, Math.round(orNN.predict(createSample(1d))));
}
+ @Test
+ public void norNNCreationTest() throws Exception {
+ WeightsMatrix firstNorLayerWeights = new WeightsMatrix(new double[][]{{0, 0, 0},{-30d, 20d, 20d}, {10d, -20d, -20d}});
+ WeightsMatrix secondNorLayerWeights = new WeightsMatrix(new double[][]{{-10d, 20d, 20d}});
+ WeightsMatrix[] norWeightsMatrixSet = new WeightsMatrix[]{firstNorLayerWeights,secondNorLayerWeights};
+ NeuralNetwork<Double,Double> norNN = NeuralNetworkFactory.create(new LinkedList<TrainingExample<Double, Double>>(), norWeightsMatrixSet, new VoidLearningStrategy(), new FeedForwardStrategy(new SigmoidFunction()));
+ assertEquals(0l, Math.round(norNN.predict(createSample(1d, 0d))));
+ assertEquals(0l, Math.round(norNN.predict(createSample(0d, 1d))));
+ assertEquals(1l, Math.round(norNN.predict(createSample(0d, 0d))));
+ assertEquals(1l, Math.round(norNN.predict(createSample(1d, 1d))));
+ }
+
private Example<Double> createSample(final Double... params) {
return new Example<Double>() {
@Override
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org