You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2012/05/29 17:01:08 UTC

svn commit: r1343764 - in /labs/yay/trunk/core/src: main/java/org/apache/yay/ test/java/org/apache/yay/

Author: tommaso
Date: Tue May 29 15:01:07 2012
New Revision: 1343764

URL: http://svn.apache.org/viewvc?rev=1343764&view=rev
Log:
changing NNF signature: drop Set in favor of Array to preserve order, added NOR test

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/LearningStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/VoidLearningStrategy.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java Tue May 29 15:01:07 2012
@@ -33,7 +33,7 @@ public class BackPropagationLearningStra
   }
 
   @Override
-  public Set<WeightsMatrix> learnWeights(Set<WeightsMatrix> weightsMatrixSet, Collection<TrainingExample<Long, Long>> trainingExamples) throws WeightLearningException {
+  public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<Long, Long>> trainingExamples) throws WeightLearningException {
     for (TrainingExample<Long, Long> trainingExample : trainingExamples) {
       try {
         Long output = neuralNetwork.predict(trainingExample);

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java Tue May 29 15:01:07 2012
@@ -51,7 +51,7 @@ public class FeedForwardStrategy impleme
 
 
   @Override
-  public Double predictOutput(Vector<Double> input, Set<WeightsMatrix> weightsMatrixSet) {
+  public Double predictOutput(Vector<Double> input, WeightsMatrix[] weightsMatrixSet) {
     // TODO : fix this impl as it's very slow and commons-math Java1.4 constraint is so ugly to see...
     RealVector v = matrixConverter.toRealVector(input);
     RealMatrix x = v.outerProduct(new ArrayRealVector(new Double[]{1d})).transpose(); // a 1xN matrix

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/LearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/LearningStrategy.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/LearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/LearningStrategy.java Tue May 29 15:01:07 2012
@@ -26,7 +26,7 @@ import java.util.Set;
  */
 public interface LearningStrategy<F, O> {
 
-  public Set<WeightsMatrix> learnWeights(Set<WeightsMatrix> weightsMatrixSet, Collection<TrainingExample<F, O>>
+  public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<F, O>>
           trainingExamples) throws WeightLearningException;
 
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java Tue May 29 15:01:07 2012
@@ -32,18 +32,18 @@ public class NeuralNetworkFactory {
    * by a set of matrices, the learning and prediction strategies to be used.
    *
    * @param trainingExamples   the training set
-   * @param weightsMatrixSet   the initial settings for weights matrixes
+   * @param weightsMatrixSet   the initial settings for weights matrices
    * @param learningStrategy   a learning strategy
    * @param predictionStrategy a prediction strategy
    * @return a NeuralNetwork instance
    * @throws CreationException
    */
   public static NeuralNetwork<Double, Double> create(final Collection<TrainingExample<Double, Double>> trainingExamples,
-                                                     final Set<WeightsMatrix> weightsMatrixSet, final LearningStrategy learningStrategy,
+                                                     final WeightsMatrix[] weightsMatrixSet, final LearningStrategy learningStrategy,
                                                      final PredictionStrategy<Double, Double> predictionStrategy) throws CreationException {
     NeuralNetwork<Double, Double> neuralNetwork = new NeuralNetwork<Double, Double>() {
 
-      private Set<WeightsMatrix> updatedWeightsMatrixSet = weightsMatrixSet;
+      private WeightsMatrix[] updatedWeightsMatrixSet = weightsMatrixSet;
 
       @Override
       public void learn(TrainingExample<Double, Double>... samples) throws LearningException {

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java Tue May 29 15:01:07 2012
@@ -18,13 +18,12 @@
  */
 package org.apache.yay;
 
-import java.util.Set;
 import java.util.Vector;
 
 /**
  */
 public interface PredictionStrategy<I, O> {
 
-  public O predictOutput(Vector<I> input, Set<WeightsMatrix> weightsMatrixSet);
+  public O predictOutput(Vector<I> input, WeightsMatrix[] weightsMatrixSet);
 
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/VoidLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/VoidLearningStrategy.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/VoidLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/VoidLearningStrategy.java Tue May 29 15:01:07 2012
@@ -27,7 +27,7 @@ import java.util.Set;
 public class VoidLearningStrategy<F, O> implements LearningStrategy<F, O> {
 
   @Override
-  public Set<WeightsMatrix> learnWeights(Set<WeightsMatrix> weightsMatrixSet, Collection<TrainingExample<F, O>> trainingExamples) throws WeightLearningException {
+  public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<F, O>> trainingExamples) throws WeightLearningException {
     return weightsMatrixSet;
   }
 }

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java?rev=1343764&r1=1343763&r2=1343764&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java Tue May 29 15:01:07 2012
@@ -20,9 +20,7 @@ package org.apache.yay;
 
 import org.junit.Test;
 
-import java.util.HashSet;
 import java.util.LinkedList;
-import java.util.Set;
 import java.util.Vector;
 
 import static org.junit.Assert.assertEquals;
@@ -34,10 +32,9 @@ public class NeuralNetworkFactoryTest {
 
   @Test
   public void andNNCreationTest() throws Exception {
-    Set<WeightsMatrix> andWeightsMatrixSet = new HashSet<WeightsMatrix>();
     double[][] weights = {{-30d, 20d, 20d}};
     WeightsMatrix singleAndLayerWeights = new WeightsMatrix(weights);
-    andWeightsMatrixSet.add(singleAndLayerWeights);
+    WeightsMatrix[] andWeightsMatrixSet = new WeightsMatrix[]{singleAndLayerWeights};
     NeuralNetwork<Double,Double> andNN = NeuralNetworkFactory.create(new LinkedList<TrainingExample<Double, Double>>(), andWeightsMatrixSet, new VoidLearningStrategy(), new FeedForwardStrategy(new SigmoidFunction()));
     assertEquals(0l, Math.round(andNN.predict(createSample(1d, 0d))));
     assertEquals(0l, Math.round(andNN.predict(createSample(0d, 1d))));
@@ -47,10 +44,9 @@ public class NeuralNetworkFactoryTest {
 
   @Test
   public void orNNCreationTest() throws Exception {
-    Set<WeightsMatrix> orWeightsMatrixSet = new HashSet<WeightsMatrix>();
     double[][] weights = {{-10d, 20d, 20d}};
     WeightsMatrix singleOrLayerWeights = new WeightsMatrix(weights);
-    orWeightsMatrixSet.add(singleOrLayerWeights);
+    WeightsMatrix[] orWeightsMatrixSet = new WeightsMatrix[]{singleOrLayerWeights};
     NeuralNetwork<Double,Double> orNN = NeuralNetworkFactory.create(new LinkedList<TrainingExample<Double, Double>>(), orWeightsMatrixSet, new VoidLearningStrategy(), new FeedForwardStrategy(new SigmoidFunction()));
     assertEquals(1l, Math.round(orNN.predict(createSample(1d, 0d))));
     assertEquals(1l, Math.round(orNN.predict(createSample(0d, 1d))));
@@ -60,15 +56,26 @@ public class NeuralNetworkFactoryTest {
 
   @Test
   public void notNNCreationTest() throws Exception {
-    Set<WeightsMatrix> notWeightsMatrixSet = new HashSet<WeightsMatrix>();
     double[][] weights = {{10d, -20d}};
     WeightsMatrix singleNotLayerWeights = new WeightsMatrix(weights);
-    notWeightsMatrixSet.add(singleNotLayerWeights);
+    WeightsMatrix[] notWeightsMatrixSet = new WeightsMatrix[]{singleNotLayerWeights};
     NeuralNetwork<Double,Double> orNN = NeuralNetworkFactory.create(new LinkedList<TrainingExample<Double, Double>>(), notWeightsMatrixSet, new VoidLearningStrategy(), new FeedForwardStrategy(new SigmoidFunction()));
     assertEquals(1l, Math.round(orNN.predict(createSample(0d))));
     assertEquals(0l, Math.round(orNN.predict(createSample(1d))));
   }
 
+  @Test
+  public void norNNCreationTest() throws Exception {
+    WeightsMatrix firstNorLayerWeights = new WeightsMatrix(new double[][]{{0, 0, 0},{-30d, 20d, 20d}, {10d, -20d, -20d}});
+    WeightsMatrix secondNorLayerWeights = new WeightsMatrix(new double[][]{{-10d, 20d, 20d}});
+    WeightsMatrix[] norWeightsMatrixSet = new WeightsMatrix[]{firstNorLayerWeights,secondNorLayerWeights};
+    NeuralNetwork<Double,Double> norNN = NeuralNetworkFactory.create(new LinkedList<TrainingExample<Double, Double>>(), norWeightsMatrixSet, new VoidLearningStrategy(), new FeedForwardStrategy(new SigmoidFunction()));
+    assertEquals(0l, Math.round(norNN.predict(createSample(1d, 0d))));
+    assertEquals(0l, Math.round(norNN.predict(createSample(0d, 1d))));
+    assertEquals(1l, Math.round(norNN.predict(createSample(0d, 0d))));
+    assertEquals(1l, Math.round(norNN.predict(createSample(1d, 1d))));
+  }
+
   private Example<Double> createSample(final Double... params) {
     return new Example<Double>() {
       @Override



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org