You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2013/03/25 15:18:32 UTC

svn commit: r1460671 - in /labs/yay/trunk: api/src/main/java/org/apache/yay/ core/src/main/java/org/apache/yay/ core/src/test/java/org/apache/yay/

Author: tommaso
Date: Mon Mar 25 14:18:31 2013
New Revision: 1460671

URL: http://svn.apache.org/r1460671
Log:
refactoring api, removed unconsistent generics semantics

Added:
    labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetworkCostFunction.java   (with props)
Modified:
    labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java
    labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
    labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/MaxSelectionFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java?rev=1460671&r1=1460670&r2=1460671&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java Mon Mar 25 14:18:31 2013
@@ -21,7 +21,8 @@ package org.apache.yay;
 import org.apache.commons.math3.linear.RealMatrix;
 
 /**
- * A {@link LearningStrategy}<F,O> defines a learning algorithm to learn the weights of the neural network's layer
+ * A {@link LearningStrategy} defines a learning algorithm to learn the weights
+ * of the neural network's layers.
  */
 public interface LearningStrategy<F, O> {
 

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1460671&r1=1460670&r2=1460671&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Mon Mar 25 14:18:31 2013
@@ -23,6 +23,6 @@ import org.apache.commons.math3.linear.R
 /**
  * A neural network is a layered connected graph of elaboration units
  */
-public interface NeuralNetwork<I, O> extends Hypothesis<RealMatrix, I, O>{
+public interface NeuralNetwork extends Hypothesis<RealMatrix, Double, Double>{
 
 }

Added: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetworkCostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetworkCostFunction.java?rev=1460671&view=auto
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetworkCostFunction.java (added)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetworkCostFunction.java Mon Mar 25 14:18:31 2013
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import org.apache.commons.math3.linear.RealMatrix;
+
+/**
+ * A generic {@link CostFunction} for {@link NeuralNetwork}s which is parametrized
+ * by its {@link RealMatrix} weights (one per layer).
+ */
+public abstract class NeuralNetworkCostFunction<I, O> implements CostFunction<RealMatrix,I, O> {
+}

Propchange: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetworkCostFunction.java
------------------------------------------------------------------------------
    svn:eol-style = native

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java?rev=1460671&r1=1460670&r2=1460671&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java Mon Mar 25 14:18:31 2013
@@ -18,18 +18,33 @@
  */
 package org.apache.yay;
 
-import org.apache.commons.math3.linear.RealMatrix;
-
 import java.util.Collection;
 
+import org.apache.commons.math3.linear.RealMatrix;
+
 /**
- * A {@link PredictionStrategy} defines an algorithm for the prediction of an output
- * <code>O</code> given an input <code>I</code>.
+ * A {@link PredictionStrategy} defines an algorithm for the prediction of outputs
+ * of type <code>O</code> given inputs of type <code>I</code>.
  */
 public interface PredictionStrategy<I, O> {
 
-  public O predictOutput(Collection<I> inputs, RealMatrix[] weightsMatrixSet);
-
+  /**
+   * Perform a prediction and returns an array containing the outputs
+   *
+   * @param inputs           a collection of input values
+   * @param weightsMatrixSet the initial set of weights defined by an array of matrix
+   * @return the array containing the last layer's outputs
+   */
+  public O[] predictOutput(Collection<I> inputs, RealMatrix[] weightsMatrixSet);
+
+  /**
+   * Perform a prediction on the given input values and weights settings returning
+   * an debug output.
+   *
+   * @param inputs           a collection of input values
+   * @param weightsMatrixSet the initial set of weights defined by an array of matrix
+   * @return the perturbed neural network state via its weights matrix array
+   */
   public RealMatrix[] debugOutput(Collection<I> inputs, RealMatrix[] weightsMatrixSet);
 
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1460671&r1=1460670&r2=1460671&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java Mon Mar 25 14:18:31 2013
@@ -28,23 +28,23 @@ import org.apache.yay.utils.ConversionUt
  * Back propagation learning algorithm for neural networks implementation (see
  * <code>http://en.wikipedia.org/wiki/Backpropagation</code>).
  */
-public class BackPropagationLearningStrategy implements LearningStrategy<Double, Double[]> {
+public class BackPropagationLearningStrategy implements LearningStrategy<Double, Double> {
 
-  private final PredictionStrategy<Double, Double[]> predictionStrategy;
+  private final PredictionStrategy<Double, Double> predictionStrategy;
   private CostFunction<RealMatrix, Double, Double> costFunction;
 
-  public BackPropagationLearningStrategy(PredictionStrategy<Double, Double[]> predictionStrategy, CostFunction<RealMatrix, Double, Double> costFunction) {
+  public BackPropagationLearningStrategy(PredictionStrategy<Double, Double> predictionStrategy, CostFunction<RealMatrix, Double, Double> costFunction) {
     this.predictionStrategy = predictionStrategy;
     this.costFunction = costFunction;
   }
 
   @Override
-  public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet, TrainingSet<Double, Double[]> trainingExamples) throws WeightLearningException {
+  public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet, TrainingSet<Double, Double> trainingExamples) throws WeightLearningException {
     // set up the accumulator matrix(es)
     RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
 
     int count = 0;
-    for (TrainingExample<Double, Double[]> trainingExample : trainingExamples) {
+    for (TrainingExample<Double, Double> trainingExample : trainingExamples) {
       try {
         // contains activation errors for the current training example
         // TODO : check if this should be RealVector[] < probably yes
@@ -101,10 +101,12 @@ public class BackPropagationLearningStra
     return thetaL.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
   }
 
-  private RealVector calculateOutputError(TrainingExample<Double, Double[]> trainingExample, RealMatrix[] activations) {
+  private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealMatrix[] activations) {
     RealMatrix output = activations[activations.length - 1];
-    Double[] learnedOutput = trainingExample.getOutput(); // training example output
     RealVector predictedOutputVector = new ArrayRealVector(output.getColumn(output.getColumnDimension() - 1)); // turn output to vector
+
+    Double[] learnedOutput = new Double[predictedOutputVector.getDimension()]; // training example output
+    learnedOutput[trainingExample.getOutput().intValue()] = 1d;
     RealVector learnedOutputRealVector = new ArrayRealVector(learnedOutput); // turn example output to a vector
 
     // TODO : improve error calculation > this should be er_a = out_a * (1 - out_a) * (tgt_a - out_a)

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java?rev=1460671&r1=1460670&r2=1460671&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java Mon Mar 25 14:18:31 2013
@@ -29,7 +29,7 @@ import java.util.Collection;
  * A perceptron {@link NeuralNetwork} implementation based on
  * {@link org.apache.yay.neuron.BinaryThresholdNeuron}s
  */
-public class BasicPerceptron implements NeuralNetwork<Double, Double> {
+public class BasicPerceptron implements NeuralNetwork {
 
   private final BinaryThresholdNeuron perceptronNeuron;
 

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java?rev=1460671&r1=1460670&r2=1460671&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java Mon Mar 25 14:18:31 2013
@@ -43,18 +43,18 @@ import java.util.Collections;
  */
 public class FeedForwardStrategy implements PredictionStrategy<Double, Double> {
 
-  private final ActivationFunction<Double> hypothesis;
+  private final ActivationFunction<Double> activationFunction;
 
-  public FeedForwardStrategy(ActivationFunction<Double> hypothesis) {
-    this.hypothesis = hypothesis;
+  public FeedForwardStrategy(ActivationFunction<Double> activationFunction) {
+    this.activationFunction = activationFunction;
   }
 
   @Override
-  public Double predictOutput(Collection<Double> input, RealMatrix[] RealMatrixSet) {
+  public Double[] predictOutput(Collection<Double> input, RealMatrix[] RealMatrixSet) {
     RealMatrix[] realMatrixes = applyFF(input, RealMatrixSet);
     RealMatrix x = realMatrixes[realMatrixes.length - 1];
     double[] lastColumn = x.getColumn(x.getColumnDimension() - 1);
-    return Collections.max(Arrays.asList(ArrayUtils.toObject(lastColumn)));
+    return ConversionUtils.toDoubleArray(lastColumn);
   }
 
   public RealMatrix[] debugOutput(Collection<Double> input, RealMatrix[] RealMatrixSet) {
@@ -96,7 +96,7 @@ public class FeedForwardStrategy impleme
     public Object transform(Object input) {
       assert input instanceof Double;
       final Double d = (Double) input;
-      return hypothesis.apply(d);
+      return activationFunction.apply(d);
     }
   }
 

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java?rev=1460671&r1=1460670&r2=1460671&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java Mon Mar 25 14:18:31 2013
@@ -23,7 +23,7 @@ import org.apache.commons.math3.linear.R
 /**
  * This calculates the logistic regression cost function for neural networks
  */
-public class LogisticRegressionCostFunction implements CostFunction<RealMatrix, Double, Double> {
+public class LogisticRegressionCostFunction extends NeuralNetworkCostFunction<Double, Double> {
 
   private final Double lambda;
 

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/MaxSelectionFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/MaxSelectionFunction.java?rev=1460671&r1=1460670&r2=1460671&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/MaxSelectionFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/MaxSelectionFunction.java Mon Mar 25 14:18:31 2013
@@ -22,12 +22,12 @@ import java.util.Collection;
 import java.util.Collections;
 
 /**
- * Selects the max value from a {@link Collection} of outputs
+ * Selects the max value from a {@link Collection} of {@link Comparable} outputs.
  */
-public class MaxSelectionFunction implements SelectionFunction<Collection<Comparable>, Comparable> {
+public class MaxSelectionFunction<T extends Comparable<T>> implements SelectionFunction<Collection<T>, T> {
 
   @Override
-  public Comparable selectOutput(Collection<Comparable> neuralNetworkOutput) {
+  public T selectOutput(Collection<T> neuralNetworkOutput) {
     return Collections.max(neuralNetworkOutput);
   }
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java?rev=1460671&r1=1460670&r2=1460671&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java Mon Mar 25 14:18:31 2013
@@ -21,6 +21,7 @@ package org.apache.yay;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.yay.utils.ConversionUtils;
 
+import java.util.Arrays;
 import java.util.Collection;
 
 /**
@@ -38,9 +39,10 @@ public class NeuralNetworkFactory {
    * @return a NeuralNetwork instance
    * @throws CreationException
    */
-  public static NeuralNetwork<Double, Double> create(final RealMatrix[] realMatrixSet, final LearningStrategy<Double, Double> learningStrategy,
-                                                     final PredictionStrategy<Double, Double> predictionStrategy) throws CreationException {
-    return new NeuralNetwork<Double, Double>() {
+  public static NeuralNetwork create(final RealMatrix[] realMatrixSet, final LearningStrategy<Double, Double> learningStrategy,
+                                                     final PredictionStrategy<Double, Double> predictionStrategy,
+                                                     final SelectionFunction<Collection<Double>, Double> selectionFunction) throws CreationException {
+    return new NeuralNetwork() {
 
       private RealMatrix[] updatedRealMatrixSet = realMatrixSet;
 
@@ -67,7 +69,8 @@ public class NeuralNetworkFactory {
       public Double predict(Input<Double> input) throws PredictionException {
         try {
           Collection<Double> inputVector = ConversionUtils.toValuesCollection(input.getFeatures());
-          return predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet);
+          Double[] doubles = predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet);
+          return selectionFunction.selectOutput(Arrays.asList(doubles));
         } catch (Exception e) {
           throw new PredictionException(e);
         }

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java?rev=1460671&r1=1460670&r2=1460671&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java Mon Mar 25 14:18:31 2013
@@ -35,7 +35,7 @@ import static org.junit.Assert.assertTru
 public class LogisticRegressionCostFunctionTest {
 
   private CostFunction<RealMatrix,Double,Double> costFunction;
-  private TrainingSet trainingSet;
+  private TrainingSet<Double, Double> trainingSet;
 
   @Before
   public void setUp() throws Exception {
@@ -50,7 +50,7 @@ public class LogisticRegressionCostFunct
     trainingExamples.add(example2);
     trainingExamples.add(example3);
     trainingExamples.add(example4);
-    trainingSet = new TrainingSet(trainingExamples);
+    trainingSet = new TrainingSet<Double, Double>(trainingExamples);
 
 
   }
@@ -61,9 +61,9 @@ public class LogisticRegressionCostFunct
     RealMatrix singleOrLayerWeights = new Array2DRowRealMatrix(weights);
     final RealMatrix[] orWeightsMatrixSet = new RealMatrix[]{singleOrLayerWeights};
 
-    final NeuralNetwork<Double,Double> neuralNetwork = NeuralNetworkFactory.create(orWeightsMatrixSet,
+    final NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(orWeightsMatrixSet,
             new VoidLearningStrategy<Double, Double>(), new FeedForwardStrategy(
-            new SigmoidFunction()));
+            new SigmoidFunction()), new MaxSelectionFunction<Double>());
 
 
     Double cost = costFunction.calculateAggregatedCost(trainingSet, neuralNetwork);

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java?rev=1460671&r1=1460670&r2=1460671&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java Mon Mar 25 14:18:31 2013
@@ -36,7 +36,7 @@ public class NeuralNetworkFactoryTest {
     double[][] weights = {{-30d, 20d, 20d}};
     RealMatrix singleAndLayerWeights = new Array2DRowRealMatrix(weights);
     RealMatrix[] andRealMatrixSet = new RealMatrix[]{singleAndLayerWeights};
-    NeuralNetwork<Double, Double> andNN = createFFNN(andRealMatrixSet);
+    NeuralNetwork andNN = createFFNN(andRealMatrixSet);
     assertEquals(0l, Math.round(andNN.predict(createSample(1d, 0d))));
     assertEquals(0l, Math.round(andNN.predict(createSample(0d, 1d))));
     assertEquals(0l, Math.round(andNN.predict(createSample(0d, 0d))));
@@ -48,7 +48,7 @@ public class NeuralNetworkFactoryTest {
     double[][] weights = {{-10d, 20d, 20d}};
     RealMatrix singleOrLayerWeights = new Array2DRowRealMatrix(weights);
     RealMatrix[] orRealMatrixSet = new RealMatrix[]{singleOrLayerWeights};
-    NeuralNetwork<Double, Double> orNN = createFFNN(orRealMatrixSet);
+    NeuralNetwork orNN = createFFNN(orRealMatrixSet);
     assertEquals(1l, Math.round(orNN.predict(createSample(1d, 0d))));
     assertEquals(1l, Math.round(orNN.predict(createSample(0d, 1d))));
     assertEquals(0l, Math.round(orNN.predict(createSample(0d, 0d))));
@@ -60,7 +60,7 @@ public class NeuralNetworkFactoryTest {
     double[][] weights = {{10d, -20d}};
     RealMatrix singleNotLayerWeights = new Array2DRowRealMatrix(weights);
     RealMatrix[] notRealMatrixSet = new RealMatrix[]{singleNotLayerWeights};
-    NeuralNetwork<Double, Double> orNN = createFFNN(notRealMatrixSet);
+    NeuralNetwork orNN = createFFNN(notRealMatrixSet);
     assertEquals(1l, Math.round(orNN.predict(createSample(0d))));
     assertEquals(0l, Math.round(orNN.predict(createSample(1d))));
   }
@@ -70,7 +70,7 @@ public class NeuralNetworkFactoryTest {
     RealMatrix firstNorLayerWeights = new Array2DRowRealMatrix(new double[][]{{0, 0, 0}, {-30d, 20d, 20d}, {10d, -20d, -20d}});
     RealMatrix secondNorLayerWeights = new Array2DRowRealMatrix(new double[][]{{-10d, 20d, 20d}});
     RealMatrix[] norRealMatrixSet = new RealMatrix[]{firstNorLayerWeights, secondNorLayerWeights};
-    NeuralNetwork<Double, Double> norNN = createFFNN(norRealMatrixSet);
+    NeuralNetwork norNN = createFFNN(norRealMatrixSet);
     assertEquals(0l, Math.round(norNN.predict(createSample(1d, 0d))));
     assertEquals(0l, Math.round(norNN.predict(createSample(0d, 1d))));
     assertEquals(1l, Math.round(norNN.predict(createSample(0d, 0d))));
@@ -82,16 +82,16 @@ public class NeuralNetworkFactoryTest {
     RealMatrix firstLayer = new Array2DRowRealMatrix(new double[][]{{1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}});
     RealMatrix secondLayer = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 3d}});
     RealMatrix[] RealMatrixes = new RealMatrix[]{firstLayer, secondLayer};
-    NeuralNetwork<Double, Double> neuralNetwork = createFFNN(RealMatrixes);
+    NeuralNetwork neuralNetwork = createFFNN(RealMatrixes);
     Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d, 7d));
     assertEquals(1l, Math.round(prdictedValue));
     assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue);
   }
 
-  private NeuralNetwork<Double, Double> createFFNN(RealMatrix[] realMatrixes)
+  private NeuralNetwork createFFNN(RealMatrix[] realMatrixes)
           throws CreationException {
     return NeuralNetworkFactory.create(realMatrixes, new VoidLearningStrategy<Double, Double>(),
-            new FeedForwardStrategy(new SigmoidFunction()));
+            new FeedForwardStrategy(new SigmoidFunction()), new MaxSelectionFunction<Double>());
   }
 
   private Input<Double> createSample(final Double... params) {



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org