You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2015/10/04 08:21:17 UTC

svn commit: r1706651 - in /labs/yay/trunk: api/src/main/java/org/apache/yay/ core/src/main/java/org/apache/yay/core/ core/src/main/java/org/apache/yay/core/utils/ core/src/test/java/org/apache/yay/core/

Author: tommaso
Date: Sun Oct  4 06:21:16 2015
New Revision: 1706651

URL: http://svn.apache.org/viewvc?rev=1706651&view=rev
Log:
enabling multiple NN output neurons

Added:
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java
      - copied, changed from r1705721, labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
Removed:
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
Modified:
    labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java
    labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
    labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java Sun Oct  4 06:21:16 2015
@@ -18,6 +18,8 @@
  */
 package org.apache.yay;
 
+import java.util.List;
+
 /**
  * In machine learning an hypothesis is the output of applying a learning
  * algorithm to a training set, an hypothesis maps new inputs to possible outputs.
@@ -45,7 +47,7 @@ public interface Hypothesis<T, I, O> {
    * @return the predicted output
    * @throws PredictionException if any error occurs during the prediction phase
    */
-  O predict(Input<I> input) throws PredictionException;
+  O[] predict(Input<I> input) throws PredictionException;
 
   /**
    * Let this <code>Hypothesis</code> learn by experience, in the form of a {@link org.apache.yay.TrainingSet}

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Sun Oct  4 06:21:16 2015
@@ -27,13 +27,4 @@ import org.apache.commons.math3.linear.R
  */
 public interface NeuralNetwork extends Hypothesis<RealMatrix, Double, Double> {
 
-  /**
-   * Predict the output for a given input
-   *
-   * @param input the input to evaluate
-   * @return the predicted output
-   * @throws PredictionException if any error occurs during the prediction phase
-   */
-  Double[] getOutputVector(Input<Double> input) throws PredictionException;
-
 }

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java Sun Oct  4 06:21:16 2015
@@ -28,6 +28,6 @@ public interface TrainingExample<F, O> e
    *
    * @return the output
    */
-  O getOutput();
+  O[] getOutput();
 
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Sun Oct  4 06:21:16 2015
@@ -19,13 +19,10 @@
 package org.apache.yay.core;
 
 import java.util.Arrays;
-import java.util.DoubleSummaryStatistics;
 import java.util.Iterator;
 
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
-import org.apache.commons.math3.linear.ArrayRealVector;
 import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.commons.math3.linear.RealVector;
 import org.apache.yay.CostFunction;
 import org.apache.yay.DerivativeUpdateFunction;
 import org.apache.yay.LearningStrategy;
@@ -34,7 +31,6 @@ import org.apache.yay.PredictionStrategy
 import org.apache.yay.TrainingExample;
 import org.apache.yay.TrainingSet;
 import org.apache.yay.WeightLearningException;
-import org.apache.yay.core.utils.ConversionUtils;
 
 /**
  * Back propagation learning algorithm for neural networks implementation (see
@@ -71,7 +67,7 @@ public class BackPropagationLearningStra
 
   public BackPropagationLearningStrategy() {
     // commonly used defaults
-    this.predictionStrategy = new FeedForwardStrategy(new SigmoidFunction());
+    this.predictionStrategy = new FeedForwardStrategy(new TanhFunction());
     this.costFunction = new LogisticRegressionCostFunction();
     this.alpha = DEFAULT_ALPHA;
     this.threshold = DEFAULT_THRESHOLD;
@@ -85,7 +81,7 @@ public class BackPropagationLearningStra
     try {
       int iterations = 0;
 
-      NeuralNetwork hypothesis = NeuralNetworkFactory.create(weightsMatrixSet, new VoidLearningStrategy<Double, Double>(), predictionStrategy, new MaxSelectionFunction<Double>());
+      NeuralNetwork hypothesis = NeuralNetworkFactory.create(weightsMatrixSet, new VoidLearningStrategy<Double, Double>(), predictionStrategy);
       Iterator<TrainingExample<Double, Double>> iterator = trainingExamples.iterator();
 
       double cost = Double.MAX_VALUE;
@@ -142,7 +138,10 @@ public class BackPropagationLearningStra
       double[][] updatedWeights = weightsMatrixSet[l].getData();
       for (int i = 0; i < updatedWeights.length; i++) {
         for (int j = 0; j < updatedWeights[i].length; j++) {
-          updatedWeights[i][j] = updatedWeights[i][j] - alpha * derivatives[l].getData()[i][j];
+          double curVal = updatedWeights[i][j];
+          if (curVal > 0d && curVal < 1d) {
+            updatedWeights[i][j] = updatedWeights[i][j] - alpha * derivatives[l].getData()[i][j];
+          }
         }
       }
       updatedParameters[l] = new Array2DRowRealMatrix(updatedWeights);

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java Sun Oct  4 06:21:16 2015
@@ -62,7 +62,7 @@ public class BasicPerceptron implements
     Collection<Double> doubles = ConversionUtils.toValuesCollection(example.getFeatures());
     Double[] inputs = doubles.toArray(new Double[doubles.size()]);
     Double calculatedOutput = perceptronNeuron.elaborate(inputs);
-    int diff = calculatedOutput.compareTo(example.getOutput());
+    int diff = calculatedOutput.compareTo(example.getOutput()[0]);
     if (diff > 0) {
       for (int i = 0; i < currentWeights.length; i++) {
         currentWeights[i] += inputs[i];
@@ -90,17 +90,10 @@ public class BasicPerceptron implements
   }
 
   @Override
-  public Double predict(Input<Double> input) throws PredictionException {
-    return perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
+  public Double[] predict(Input<Double> input) throws PredictionException {
+    Double output = perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
             new Double[input.getFeatures().size()]));
+    return new Double[]{output};
   }
 
-  @Override
-  public Double[] getOutputVector(Input<Double> input) throws PredictionException {
-    Double elaborate = perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
-            new Double[input.getFeatures().size()]));
-    Double[] ar = new Double[1];
-    ar[0] = elaborate;
-    return ar;
-  }
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java Sun Oct  4 06:21:16 2015
@@ -78,10 +78,6 @@ public class DefaultDerivativeUpdateFunc
       count++;
     }
 
-    return createDerivatives(triangle, count);
-  }
-
-  private RealMatrix[] createDerivatives(RealMatrix[] triangle, double count) {
     RealMatrix[] derivatives = new RealMatrix[triangle.length];
     for (int i = 0; i < triangle.length; i++) {
       // TODO : introduce regularization diversification on bias term (currently not regularized)
@@ -111,16 +107,17 @@ public class DefaultDerivativeUpdateFunc
   private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealVector[] activations) {
     RealVector output = activations[activations.length - 1];
 
-    Double[] sampleOutput = new Double[output.getDimension()];
-    int sampleOutputIntValue = trainingExample.getOutput().intValue();
-    if (sampleOutputIntValue < sampleOutput.length) {
-      sampleOutput[sampleOutputIntValue] = 1d;
-    } else if (sampleOutput.length == 1) {
-      sampleOutput[0] = trainingExample.getOutput();
-    } else {
-      throw new RuntimeException("problem with multiclass output mapping");
-    }
-    RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // turn example output to a vector
+//    Double[] sampleOutput = new Double[output.getDimension()];
+    Double[] actualOutput = trainingExample.getOutput();
+//    int sampleOutputIntValue = actualOutput.intValue();
+//    if (sampleOutputIntValue < sampleOutput.length) {
+//      sampleOutput[sampleOutputIntValue] = 1d;
+//    } else if (sampleOutput.length == 1) {
+//      sampleOutput[0] = actualOutput;
+//    } else {
+//      throw new RuntimeException("problem with multiclass output mapping");
+//    }
+    RealVector learnedOutputRealVector = new ArrayRealVector(actualOutput); // turn example output to a vector
 
     // TODO : improve error calculation -> this could be er_a = out_a * (1 - out_a) * (tgt_a - out_a)
     return output.subtract(learnedOutputRealVector);

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java Sun Oct  4 06:21:16 2015
@@ -74,11 +74,12 @@ public class LogisticRegressionCostFunct
     Double res = 0d;
 
     for (TrainingExample<Double, Double> input : trainingExamples) {
-      // TODO : handle this for multiple outputs (multi class classification)
-      Double predictedOutput = hypothesis.predict(input);
-      Double sampleOutput = input.getOutput();
-      res += sampleOutput * Math.log(predictedOutput) + (1d - sampleOutput)
-              * Math.log(1d - predictedOutput);
+      Double[] predictedOutput = hypothesis.predict(input);
+      Double[] sampleOutput = input.getOutput();
+      for (int i = 0; i < predictedOutput.length; i++) {
+        res += sampleOutput[i] * Math.log(predictedOutput[i]) + (1d - sampleOutput[i])
+                * Math.log(1d - predictedOutput[i]);
+      }
     }
     return (-1d / trainingExamples.length) * res;
   }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java Sun Oct  4 06:21:16 2015
@@ -18,11 +18,8 @@
  */
 package org.apache.yay.core;
 
-import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collection;
 import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.commons.math3.linear.RealVector;
 import org.apache.yay.CreationException;
 import org.apache.yay.Input;
 import org.apache.yay.LearningException;
@@ -30,7 +27,6 @@ import org.apache.yay.LearningStrategy;
 import org.apache.yay.NeuralNetwork;
 import org.apache.yay.PredictionException;
 import org.apache.yay.PredictionStrategy;
-import org.apache.yay.SelectionFunction;
 import org.apache.yay.TrainingSet;
 import org.apache.yay.WeightLearningException;
 import org.apache.yay.core.utils.ConversionUtils;
@@ -51,12 +47,10 @@ public class NeuralNetworkFactory {
    * @throws org.apache.yay.CreationException
    */
   public static NeuralNetwork create(final RealMatrix[] realMatrixSet, final LearningStrategy<Double, Double> learningStrategy,
-                                     final PredictionStrategy<Double, Double> predictionStrategy,
-                                     final SelectionFunction<Collection<Double>, Double> selectionFunction) throws CreationException {
+                                     final PredictionStrategy<Double, Double> predictionStrategy) throws CreationException {
     return new NeuralNetwork() {
 
-      @Override
-      public Double[] getOutputVector(Input<Double> input) throws PredictionException {
+      private Double[] getOutputVector(Input<Double> input) throws PredictionException {
         Collection<Double> inputVector = ConversionUtils.toValuesCollection(input.getFeatures());
         return predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet);
       }
@@ -83,10 +77,9 @@ public class NeuralNetworkFactory {
       }
 
       @Override
-      public Double predict(Input<Double> input) throws PredictionException {
+      public Double[] predict(Input<Double> input) throws PredictionException {
         try {
-          Double[] doubles = getOutputVector(input);
-          return selectionFunction.selectOutput(Arrays.asList(doubles));
+          return getOutputVector(input);
         } catch (Exception e) {
           throw new PredictionException(e);
         }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java Sun Oct  4 06:21:16 2015
@@ -39,22 +39,22 @@ public class ExamplesFactory {
       }
 
       @Override
-      public Double getOutput() {
-        return output;
+      public Double[] getOutput() {
+        return new Double[]{output};
       }
     };
   }
 
-  public static TrainingExample<Double, Collection<Double[]>> createSGMExample(final Collection<Double[]> output,
+  public static TrainingExample<Double, Double> createDoubleArrayTrainingExample(final Double[] output,
                                                                             final Double... featuresValues) {
-    return new TrainingExample<Double, Collection<Double[]>>() {
+    return new TrainingExample<Double, Double>() {
       @Override
       public ArrayList<Feature<Double>> getFeatures() {
         return doublesToFeatureVector(featuresValues);
       }
 
       @Override
-      public Collection<Double[]> getOutput() {
+      public Double[] getOutput() {
         return output;
       }
     };

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java Sun Oct  4 06:21:16 2015
@@ -41,9 +41,9 @@ public class BackPropagationLearningStra
   public void testLearningWithRandomNetwork() throws Exception {
     BackPropagationLearningStrategy backPropagationLearningStrategy = new BackPropagationLearningStrategy();
 
-    RealMatrix[] initialWeights = createRandomWeights();
+    RealMatrix[] initialWeights = createRandomWeights(2);
 
-    Collection<TrainingExample<Double, Double>> samples = createSamples(1000, initialWeights[0].getColumnDimension() - 1);
+    Collection<TrainingExample<Double, Double>> samples = createSamples(1000, initialWeights[0].getColumnDimension() - 1, 2);
     TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples);
     RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
     assertNotNull(learntWeights);
@@ -62,9 +62,9 @@ public class BackPropagationLearningStra
     BackPropagationLearningStrategy backPropagationLearningStrategy = new BackPropagationLearningStrategy(alpha, threshold,
             predictionStrategy, costFunction);
 
-    RealMatrix[] initialWeights = createRandomWeights();
+    RealMatrix[] initialWeights = createRandomWeights(10);
 
-    Collection<TrainingExample<Double, Double>> samples = createSamples(1000, initialWeights[0].getColumnDimension() - 1);
+    Collection<TrainingExample<Double, Double>> samples = createSamples(1000, initialWeights[0].getColumnDimension() - 1, 10);
     TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples);
     RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
     assertNotNull(learntWeights);
@@ -82,7 +82,7 @@ public class BackPropagationLearningStra
     }
   }
 
-  private RealMatrix[] createRandomWeights() {
+  private RealMatrix[] createRandomWeights(int outputSize) {
     Random r = new Random();
     int weightsCount = (Math.abs(r.nextInt()) % 5) + 2;
 
@@ -95,7 +95,7 @@ public class BackPropagationLearningStra
       } else {
         cols = initialWeights[i - 1].getRowDimension();
         if (i == weightsCount - 1) {
-          rows = 1;
+          rows = outputSize;
         }
       }
       double[][] d = new double[rows][cols];
@@ -137,7 +137,7 @@ public class BackPropagationLearningStra
     initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 0d}, {1d, 0.5d, 1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}}); // 4 x 4
     initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d, 0.5d}}); // 1 x 4
 
-    Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 2);
+    Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 2, 1);
     TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples);
     RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
     assertNotNull(learntWeights);
@@ -169,7 +169,7 @@ public class BackPropagationLearningStra
     initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 0d}, {1d, 0.5d, 1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}});
     initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d, 0.5d}});
 
-    Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 2);
+    Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 2, 1);
     TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples);
     RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
     assertNotNull(learntWeights);
@@ -201,7 +201,7 @@ public class BackPropagationLearningStra
             {1d, Math.random(), Math.random(), Math.random()}
     });
 
-    Collection<TrainingExample<Double, Double>> samples = createSamples(1000000, 2);
+    Collection<TrainingExample<Double, Double>> samples = createSamples(1000000, 2, 1);
     TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples);
     RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
     assertNotNull(learntWeights);
@@ -211,14 +211,18 @@ public class BackPropagationLearningStra
     assertFalse(learntWeights[2].equals(initialWeights[2]));
   }
 
-  private Collection<TrainingExample<Double, Double>> createSamples(int size, int noOfFeatures) {
+  private Collection<TrainingExample<Double, Double>> createSamples(int size, int noOfFeatures, int noOfOutputs) {
     Collection<TrainingExample<Double, Double>> trainingExamples = new ArrayList<TrainingExample<Double, Double>>(size);
     for (int i = 0; i < size; i++) {
       Double[] featureValues = new Double[noOfFeatures];
       for (int j = 0; j < noOfFeatures; j++) {
         featureValues[j] = Math.random();
       }
-      trainingExamples.add(ExamplesFactory.createDoubleTrainingExample(1d, featureValues));
+      Double[] outputs = new Double[noOfOutputs];
+      for (int j = 0; j < outputs.length; j++) {
+        outputs[j] = Math.random();
+      }
+      trainingExamples.add(ExamplesFactory.createDoubleArrayTrainingExample(outputs, featureValues));
     }
     return trainingExamples;
   }

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java Sun Oct  4 06:21:16 2015
@@ -86,7 +86,7 @@ public class BasicPerceptronTest {
             r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(),
             r.nextDouble(), r.nextDouble());
     basicPerceptron.learn(bigDataset);
-    Double output = basicPerceptron.predict(createInput(r));
+    Double output = basicPerceptron.predict(createInput(r))[0];
     assertTrue(output == 0d || output == 1d);
   }
 
@@ -102,7 +102,7 @@ public class BasicPerceptronTest {
             r.nextDouble(), r.nextDouble());
     basicPerceptron.learn(bigDataset);
     TrainingExample<Double, Double> input = createInput(r);
-    Double output = basicPerceptron.predict(input);
+    Double output = basicPerceptron.predict(input)[0];
     assertTrue(output == 0d || output == 1d);
     basicPerceptron.learn(createTrainingExample(1d, r.nextDouble(),
             r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(),
@@ -110,7 +110,7 @@ public class BasicPerceptronTest {
             r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(),
             r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(),
             r.nextDouble(), r.nextDouble()));
-    Double secondOutput = basicPerceptron.predict(input);
+    Double secondOutput = basicPerceptron.predict(input)[0];
     assertTrue(secondOutput == 0d || secondOutput == 1d);
   }
 
@@ -135,7 +135,7 @@ public class BasicPerceptronTest {
   public void testLearnAndPredictWithSmallDataset() throws Exception {
     BasicPerceptron basicPerceptron = new BasicPerceptron(1d, 2d, 3d, 4d);
     basicPerceptron.learn(smallDataset);
-    Double output = basicPerceptron.predict(createTrainingExample(null, 1d, 6d, 0.4d));
+    Double output = basicPerceptron.predict(createTrainingExample(null, 1d, 6d, 0.4d))[0];
     assertEquals(Double.valueOf(1d), output);
   }
 
@@ -157,8 +157,8 @@ public class BasicPerceptronTest {
       }
 
       @Override
-      public Double getOutput() {
-        return output;
+      public Double[] getOutput() {
+        return new Double[]{output};
       }
     };
   }

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java?rev=1706651&r1=1706650&r2=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java Sun Oct  4 06:21:16 2015
@@ -64,8 +64,7 @@ public class LogisticRegressionCostFunct
     final RealMatrix[] orWeightsMatrixSet = new RealMatrix[]{singleOrLayerWeights};
 
     final NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(orWeightsMatrixSet,
-            new VoidLearningStrategy<Double, Double>(), new FeedForwardStrategy(
-                    new SigmoidFunction()), new MaxSelectionFunction<Double>());
+            new VoidLearningStrategy<Double, Double>(), new FeedForwardStrategy(new SigmoidFunction()));
 
     Double cost = costFunction.calculateAggregatedCost(trainingSet, neuralNetwork);
     assertTrue("cost should not be negative", cost > 0d);

Copied: labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java (from r1705721, labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java)
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java?p2=labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java&p1=labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java&r1=1705721&r2=1706651&rev=1706651&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java Sun Oct  4 06:21:16 2015
@@ -19,31 +19,39 @@
 package org.apache.yay.core;
 
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Random;
+
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.yay.CreationException;
 import org.apache.yay.Feature;
 import org.apache.yay.Input;
+import org.apache.yay.LearningStrategy;
 import org.apache.yay.NeuralNetwork;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+import org.apache.yay.core.utils.ExamplesFactory;
 import org.junit.Test;
 
 import static org.junit.Assert.assertEquals;
 
 /**
- * Testcase for {@link org.apache.yay.core.NeuralNetworkFactory}
+ * Integration test for NN
  */
-public class NeuralNetworkFactoryTest {
+public class NeuralNetworkIntegrationTest {
 
   @Test
   public void andNNCreationTest() throws Exception {
     double[][] weights = {{-30d, 20d, 20d}};
     RealMatrix singleAndLayerWeights = new Array2DRowRealMatrix(weights);
     RealMatrix[] andRealMatrixSet = new RealMatrix[]{singleAndLayerWeights};
-    NeuralNetwork andNN = createFFNN(andRealMatrixSet);
-    assertEquals(0l, Math.round(andNN.predict(createSample(1d, 0d))));
-    assertEquals(0l, Math.round(andNN.predict(createSample(0d, 1d))));
-    assertEquals(0l, Math.round(andNN.predict(createSample(0d, 0d))));
-    assertEquals(1l, Math.round(andNN.predict(createSample(1d, 1d))));
+    NeuralNetwork andNN = createNN(andRealMatrixSet, new VoidLearningStrategy<Double, Double>());
+    assertEquals(0l, Math.round(andNN.predict(createSample(1d, 0d))[0]));
+    assertEquals(0l, Math.round(andNN.predict(createSample(0d, 1d))[0]));
+    assertEquals(0l, Math.round(andNN.predict(createSample(0d, 0d))[0]));
+    assertEquals(1l, Math.round(andNN.predict(createSample(1d, 1d))[0]));
   }
 
   @Test
@@ -51,11 +59,11 @@ public class NeuralNetworkFactoryTest {
     double[][] weights = {{-10d, 20d, 20d}};
     RealMatrix singleOrLayerWeights = new Array2DRowRealMatrix(weights);
     RealMatrix[] orRealMatrixSet = new RealMatrix[]{singleOrLayerWeights};
-    NeuralNetwork orNN = createFFNN(orRealMatrixSet);
-    assertEquals(1l, Math.round(orNN.predict(createSample(1d, 0d))));
-    assertEquals(1l, Math.round(orNN.predict(createSample(0d, 1d))));
-    assertEquals(0l, Math.round(orNN.predict(createSample(0d, 0d))));
-    assertEquals(1l, Math.round(orNN.predict(createSample(1d, 1d))));
+    NeuralNetwork orNN = createNN(orRealMatrixSet, new VoidLearningStrategy<Double, Double>());
+    assertEquals(1l, Math.round(orNN.predict(createSample(1d, 0d))[0]));
+    assertEquals(1l, Math.round(orNN.predict(createSample(0d, 1d))[0]));
+    assertEquals(0l, Math.round(orNN.predict(createSample(0d, 0d))[0]));
+    assertEquals(1l, Math.round(orNN.predict(createSample(1d, 1d))[0]));
   }
 
   @Test
@@ -63,9 +71,9 @@ public class NeuralNetworkFactoryTest {
     double[][] weights = {{10d, -20d}};
     RealMatrix singleNotLayerWeights = new Array2DRowRealMatrix(weights);
     RealMatrix[] notRealMatrixSet = new RealMatrix[]{singleNotLayerWeights};
-    NeuralNetwork orNN = createFFNN(notRealMatrixSet);
-    assertEquals(1l, Math.round(orNN.predict(createSample(0d))));
-    assertEquals(0l, Math.round(orNN.predict(createSample(1d))));
+    NeuralNetwork orNN = createNN(notRealMatrixSet, new VoidLearningStrategy<Double, Double>());
+    assertEquals(1l, Math.round(orNN.predict(createSample(0d))[0]));
+    assertEquals(0l, Math.round(orNN.predict(createSample(1d))[0]));
   }
 
   @Test
@@ -73,11 +81,11 @@ public class NeuralNetworkFactoryTest {
     RealMatrix firstNorLayerWeights = new Array2DRowRealMatrix(new double[][]{{0, 0, 0}, {-30d, 20d, 20d}, {10d, -20d, -20d}});
     RealMatrix secondNorLayerWeights = new Array2DRowRealMatrix(new double[][]{{-10d, 20d, 20d}});
     RealMatrix[] norRealMatrixSet = new RealMatrix[]{firstNorLayerWeights, secondNorLayerWeights};
-    NeuralNetwork norNN = createFFNN(norRealMatrixSet);
-    assertEquals(0l, Math.round(norNN.predict(createSample(1d, 0d))));
-    assertEquals(0l, Math.round(norNN.predict(createSample(0d, 1d))));
-    assertEquals(1l, Math.round(norNN.predict(createSample(0d, 0d))));
-    assertEquals(1l, Math.round(norNN.predict(createSample(1d, 1d))));
+    NeuralNetwork norNN = createNN(norRealMatrixSet, new VoidLearningStrategy<Double, Double>());
+    assertEquals(0l, Math.round(norNN.predict(createSample(1d, 0d))[0]));
+    assertEquals(0l, Math.round(norNN.predict(createSample(0d, 1d))[0]));
+    assertEquals(1l, Math.round(norNN.predict(createSample(0d, 0d))[0]));
+    assertEquals(1l, Math.round(norNN.predict(createSample(1d, 1d))[0]));
   }
 
   @Test
@@ -87,17 +95,17 @@ public class NeuralNetworkFactoryTest {
 
     RealMatrix[] RealMatrixes = new RealMatrix[]{firstLayer, secondLayer};
 
-    NeuralNetwork neuralNetwork = createFFNN(RealMatrixes);
+    NeuralNetwork neuralNetwork = createNN(RealMatrixes, new VoidLearningStrategy<Double, Double>());
 
-    Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d, 7d));
+    Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d, 7d))[0];
     assertEquals(1l, Math.round(prdictedValue));
     assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue);
   }
 
-  private NeuralNetwork createFFNN(RealMatrix[] realMatrixes)
+  private NeuralNetwork createNN(RealMatrix[] realMatrixes, LearningStrategy<Double, Double> learningStrategy)
           throws CreationException {
-    return NeuralNetworkFactory.create(realMatrixes, new VoidLearningStrategy<Double, Double>(),
-            new FeedForwardStrategy(new SigmoidFunction()), new MaxSelectionFunction<Double>());
+    return NeuralNetworkFactory.create(realMatrixes, learningStrategy,
+            new FeedForwardStrategy(new SigmoidFunction()));
   }
 
   private Input<Double> createSample(final Double... params) {
@@ -117,4 +125,87 @@ public class NeuralNetworkFactoryTest {
       }
     };
   }
+
+  @Test
+  public void testRandomNNEvaluation() throws Exception {
+    int noOfOutputs = 10;
+    RealMatrix[] randomWeights = createRandomWeights(noOfOutputs);
+    NeuralNetwork nn = createNN(randomWeights, new BackPropagationLearningStrategy());
+    int noOfFeatures = randomWeights[0].getColumnDimension() - 1;
+    Collection<TrainingExample<Double, Double>> samples = createSamples(1000000, noOfFeatures, noOfOutputs);
+    nn.learn(new TrainingSet<Double, Double>(samples));
+    for (TrainingExample<Double, Double> sample : samples) {
+      Double[] predictedOutput = nn.predict(sample);
+      Double[] expectedOutput = sample.getOutput();
+      boolean equals = Arrays.equals(expectedOutput, predictedOutput);
+//      if (!equals) {
+//        System.err.println(Arrays.toString(expectedOutput) + " vs " + Arrays.toString(predictedOutput));
+//      } else {
+//        System.err.println("equals!");
+//      }
+    }
+
+  }
+
+  private Collection<TrainingExample<Double, Double>> createSamples(int size, int noOfFeatures, int noOfOutputs) {
+    Random r = new Random();
+    Collection<TrainingExample<Double, Double>> trainingExamples = new ArrayList<TrainingExample<Double, Double>>(size);
+    for (int i = 0; i < size; i++) {
+      Double[] featureValues = new Double[noOfFeatures];
+      for (int j = 0; j < noOfFeatures; j++) {
+        featureValues[j] = r.nextInt(100) / 101d;
+      }
+      Double[] outputs = new Double[noOfOutputs];
+      for (int j = 0; j < outputs.length; j++) {
+        outputs[j] = r.nextInt(100) / 101d;
+      }
+      trainingExamples.add(ExamplesFactory.createDoubleArrayTrainingExample(outputs, featureValues));
+    }
+    return trainingExamples;
+  }
+
+  private RealMatrix[] createRandomWeights(int outputSize) {
+    Random r = new Random();
+    int weightsCount = (Math.abs(r.nextInt()) % 5) + 2;
+
+    RealMatrix[] initialWeights = new RealMatrix[weightsCount];
+    for (int i = 0; i < weightsCount; i++) {
+      int rows = (Math.abs(r.nextInt()) % 4) + 2;
+      int cols;
+      if (i == 0) {
+        cols = (Math.abs(r.nextInt()) % 4) + 2;
+      } else {
+        cols = initialWeights[i - 1].getRowDimension();
+        if (i == weightsCount - 1) {
+          rows = outputSize;
+        }
+      }
+      double[][] d = new double[rows][cols];
+      for (int c = 0; c < cols; c++) {
+        if (i == weightsCount - 1) {
+          if (c == 0) {
+            d[0][c] = 1d;
+          } else {
+            d[0][c] = r.nextDouble();
+          }
+        } else {
+          d[0][c] = 0;
+        }
+      }
+
+      for (int k = 1; k < rows; k++) {
+        for (int j = 0; j < cols; j++) {
+          double val;
+          if (j == 0) {
+            val = 1d;
+          } else {
+            val = r.nextDouble();
+          }
+          d[k][j] = val;
+        }
+      }
+      initialWeights[i] = new Array2DRowRealMatrix(d);
+    }
+    return initialWeights;
+  }
 }



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org