You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2015/09/28 18:49:57 UTC

svn commit: r1705721 - in /labs/yay/trunk: api/src/main/java/org/apache/yay/ core/src/main/java/org/apache/yay/core/ core/src/main/java/org/apache/yay/core/utils/ core/src/test/java/org/apache/yay/core/

Author: tommaso
Date: Mon Sep 28 16:49:57 2015
New Revision: 1705721

URL: http://svn.apache.org/viewvc?rev=1705721&view=rev
Log:
switch from batch to stochastic GD in backprop, abstracted derivative update function

Added:
    labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
Removed:
    labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java
Modified:
    labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java

Added: labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java?rev=1705721&view=auto
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java (added)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java Mon Sep 28 16:49:57 2015
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.yay.TrainingSet;
+
+/**
+ * Derivatives update function
+ */
+public interface DerivativeUpdateFunction<F,O> {
+
+  RealMatrix[] updateParameters(RealMatrix[] weights, TrainingSet<F,O> trainingSet);
+}

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Mon Sep 28 16:49:57 2015
@@ -21,8 +21,19 @@ package org.apache.yay;
 import org.apache.commons.math3.linear.RealMatrix;
 
 /**
- * A neural network is a layered connected graph of elaboration units
+ * A Neural Network is a layered connected graph of elaboration units.
+ *
+ * It takes a double vector as input and produces a double vector as output.
  */
 public interface NeuralNetwork extends Hypothesis<RealMatrix, Double, Double> {
 
+  /**
+   * Predict the output for a given input
+   *
+   * @param input the input to evaluate
+   * @return the predicted output
+   * @throws PredictionException if any error occurs during the prediction phase
+   */
+  Double[] getOutputVector(Input<Double> input) throws PredictionException;
+
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Mon Sep 28 16:49:57 2015
@@ -19,6 +19,7 @@
 package org.apache.yay.core;
 
 import java.util.Arrays;
+import java.util.DoubleSummaryStatistics;
 import java.util.Iterator;
 
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
@@ -26,6 +27,7 @@ import org.apache.commons.math3.linear.A
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.RealVector;
 import org.apache.yay.CostFunction;
+import org.apache.yay.DerivativeUpdateFunction;
 import org.apache.yay.LearningStrategy;
 import org.apache.yay.NeuralNetwork;
 import org.apache.yay.PredictionStrategy;
@@ -46,6 +48,7 @@ public class BackPropagationLearningStra
 
   private final PredictionStrategy<Double, Double> predictionStrategy;
   private final CostFunction<RealMatrix, Double, Double> costFunction;
+  private final DerivativeUpdateFunction<Double, Double> derivativeUpdateFunction;
   private final double alpha;
   private final double threshold;
   private final int batch;
@@ -63,6 +66,7 @@ public class BackPropagationLearningStra
     this.alpha = alpha;
     this.threshold = threshold;
     this.batch = batch;
+    this.derivativeUpdateFunction = new DefaultDerivativeUpdateFunction(predictionStrategy);
   }
 
   public BackPropagationLearningStrategy() {
@@ -72,6 +76,7 @@ public class BackPropagationLearningStra
     this.alpha = DEFAULT_ALPHA;
     this.threshold = DEFAULT_THRESHOLD;
     this.batch = 1;
+    this.derivativeUpdateFunction = new DefaultDerivativeUpdateFunction(predictionStrategy);
   }
 
   @Override
@@ -114,7 +119,7 @@ public class BackPropagationLearningStra
         cost = newCost;
 
         // calculate the derivatives to update the parameters
-        RealMatrix[] derivatives = calculateDerivatives(weightsMatrixSet, samples);
+        RealMatrix[] derivatives = derivativeUpdateFunction.updateParameters(weightsMatrixSet, samples);
 
         // calculate the updated parameters
         updatedWeights = updateWeights(weightsMatrixSet, derivatives, alpha);
@@ -131,48 +136,6 @@ public class BackPropagationLearningStra
     return updatedWeights;
   }
 
-  private RealMatrix[] calculateDerivatives(RealMatrix[] weightsMatrixSet, TrainingSet<Double, Double> trainingExamples) throws WeightLearningException {
-    // set up the accumulator matrix(es)
-    RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
-    RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length];
-
-    int noOfMatrixes = weightsMatrixSet.length - 1;
-    double count = 0;
-    for (TrainingExample<Double, Double> trainingExample : trainingExamples) {
-      try {
-        // get activations from feed forward propagation
-        RealVector[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet);
-
-        // calculate output error (corresponding to the last delta^l)
-        RealVector nextLayerDelta = calculateOutputError(trainingExample, activations);
-
-        deltaVectors[noOfMatrixes] = nextLayerDelta;
-
-        // back prop the error and update the deltas accordingly
-        for (int l = noOfMatrixes; l > 0; l--) {
-          RealVector currentActivationsVector = activations[l - 1];
-          nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l], currentActivationsVector, nextLayerDelta);
-
-          // collect delta vectors for this example
-          deltaVectors[l - 1] = nextLayerDelta;
-        }
-
-        RealVector[] newActivations = new RealVector[activations.length];
-        newActivations[0] = ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures()));
-        System.arraycopy(activations, 0, newActivations, 1, activations.length - 1);
-
-        // update triangle (big delta matrix)
-        updateTriangle(triangle, newActivations, deltaVectors, weightsMatrixSet);
-
-      } catch (Exception e) {
-        throw new WeightLearningException("error during derivatives calculation", e);
-      }
-      count++;
-    }
-
-    return createDerivatives(triangle, count);
-  }
-
   private RealMatrix[] updateWeights(RealMatrix[] weightsMatrixSet, RealMatrix[] derivatives, double alpha) {
     RealMatrix[] updatedParameters = new RealMatrix[weightsMatrixSet.length];
     for (int l = 0; l < weightsMatrixSet.length; l++) {
@@ -187,48 +150,4 @@ public class BackPropagationLearningStra
     return updatedParameters;
   }
 
-  private RealMatrix[] createDerivatives(RealMatrix[] triangle, double count) {
-    RealMatrix[] derivatives = new RealMatrix[triangle.length];
-    for (int i = 0; i < triangle.length; i++) {
-      // TODO : introduce regularization diversification on bias term (currently not regularized)
-      derivatives[i] = triangle[i].scalarMultiply(1d / count);
-    }
-    return derivatives;
-  }
-
-  private void updateTriangle(RealMatrix[] triangle, RealVector[] activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) {
-    for (int l = weightsMatrixSet.length - 1; l >= 0; l--) {
-      RealMatrix realMatrix = deltaVectors[l].outerProduct(activations[l]);
-      if (triangle[l] == null) {
-        triangle[l] = realMatrix;
-      } else {
-        triangle[l] = triangle[l].add(realMatrix);
-      }
-    }
-  }
-
-  private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector activationsVector, RealVector nextLayerDelta) {
-    // TODO : remove the bias term from the error calculations
-    ArrayRealVector identity = new ArrayRealVector(activationsVector.getDimension(), 1d);
-    RealVector gz = activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l .* (1-a^l)
-    return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz);
-  }
-
-  private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealVector[] activations) {
-    RealVector output = activations[activations.length - 1];
-
-    Double[] sampleOutput = new Double[output.getDimension()];
-    int sampleOutputIntValue = trainingExample.getOutput().intValue();
-    if (sampleOutputIntValue < sampleOutput.length) {
-      sampleOutput[sampleOutputIntValue] = 1d;
-    } else if (sampleOutput.length == 1) {
-      sampleOutput[0] = trainingExample.getOutput();
-    } else {
-      throw new RuntimeException("problem with multiclass output mapping");
-    }
-    RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // turn example output to a vector
-
-    // TODO : improve error calculation -> this could be er_a = out_a * (1 - out_a) * (tgt_a - out_a)
-    return output.subtract(learnedOutputRealVector);
-  }
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java Mon Sep 28 16:49:57 2015
@@ -94,4 +94,13 @@ public class BasicPerceptron implements
     return perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
             new Double[input.getFeatures().size()]));
   }
+
+  @Override
+  public Double[] getOutputVector(Input<Double> input) throws PredictionException {
+    Double elaborate = perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
+            new Double[input.getFeatures().size()]));
+    Double[] ar = new Double[1];
+    ar[0] = elaborate;
+    return ar;
+  }
 }

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1705721&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java Mon Sep 28 16:49:57 2015
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay.core;
+
+import org.apache.commons.math3.linear.ArrayRealVector;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
+import org.apache.yay.DerivativeUpdateFunction;
+import org.apache.yay.PredictionStrategy;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+import org.apache.yay.core.utils.ConversionUtils;
+
+/**
+ * Default derivatives update function
+ */
+public class DefaultDerivativeUpdateFunction implements DerivativeUpdateFunction<Double, Double> {
+
+  private final PredictionStrategy<Double, Double> predictionStrategy;
+
+  public DefaultDerivativeUpdateFunction(PredictionStrategy<Double, Double> predictionStrategy) {
+    this.predictionStrategy = predictionStrategy;
+  }
+
+  @Override
+  public RealMatrix[] updateParameters(RealMatrix[] weightsMatrixSet, TrainingSet<Double, Double> trainingExamples) {
+    // set up the accumulator matrix(es)
+    RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
+    RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length];
+
+    int noOfMatrixes = weightsMatrixSet.length - 1;
+    double count = 0;
+    for (TrainingExample<Double, Double> trainingExample : trainingExamples) {
+      try {
+        // get activations from feed forward propagation
+        RealVector[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet);
+
+        // calculate output error (corresponding to the last delta^l)
+        RealVector nextLayerDelta = calculateOutputError(trainingExample, activations);
+
+        deltaVectors[noOfMatrixes] = nextLayerDelta;
+
+        // back prop the error and update the deltas accordingly
+        for (int l = noOfMatrixes; l > 0; l--) {
+          RealVector currentActivationsVector = activations[l - 1];
+          nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l], currentActivationsVector, nextLayerDelta);
+
+          // collect delta vectors for this example
+          deltaVectors[l - 1] = nextLayerDelta;
+        }
+
+        RealVector[] newActivations = new RealVector[activations.length];
+        newActivations[0] = ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures()));
+        System.arraycopy(activations, 0, newActivations, 1, activations.length - 1);
+
+        // update triangle (big delta matrix)
+        updateTriangle(triangle, newActivations, deltaVectors, weightsMatrixSet);
+
+      } catch (Exception e) {
+        throw new RuntimeException("error during derivatives calculation", e);
+      }
+      count++;
+    }
+
+    return createDerivatives(triangle, count);
+  }
+
+  private RealMatrix[] createDerivatives(RealMatrix[] triangle, double count) {
+    RealMatrix[] derivatives = new RealMatrix[triangle.length];
+    for (int i = 0; i < triangle.length; i++) {
+      // TODO : introduce regularization diversification on bias term (currently not regularized)
+      derivatives[i] = triangle[i].scalarMultiply(1d / count);
+    }
+    return derivatives;
+  }
+
+  private void updateTriangle(RealMatrix[] triangle, RealVector[] activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) {
+    for (int l = weightsMatrixSet.length - 1; l >= 0; l--) {
+      RealMatrix realMatrix = deltaVectors[l].outerProduct(activations[l]);
+      if (triangle[l] == null) {
+        triangle[l] = realMatrix;
+      } else {
+        triangle[l] = triangle[l].add(realMatrix);
+      }
+    }
+  }
+
+  private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector activationsVector, RealVector nextLayerDelta) {
+    // TODO : remove the bias term from the error calculations
+    ArrayRealVector identity = new ArrayRealVector(activationsVector.getDimension(), 1d);
+    RealVector gz = activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l .* (1-a^l)
+    return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz);
+  }
+
+  private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealVector[] activations) {
+    RealVector output = activations[activations.length - 1];
+
+    Double[] sampleOutput = new Double[output.getDimension()];
+    int sampleOutputIntValue = trainingExample.getOutput().intValue();
+    if (sampleOutputIntValue < sampleOutput.length) {
+      sampleOutput[sampleOutputIntValue] = 1d;
+    } else if (sampleOutput.length == 1) {
+      sampleOutput[0] = trainingExample.getOutput();
+    } else {
+      throw new RuntimeException("problem with multiclass output mapping");
+    }
+    RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // turn example output to a vector
+
+    // TODO : improve error calculation -> this could be er_a = out_a * (1 - out_a) * (tgt_a - out_a)
+    return output.subtract(learnedOutputRealVector);
+  }
+}

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Mon Sep 28 16:49:57 2015
@@ -16,7 +16,6 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-
 package org.apache.yay.core;
 
 import java.util.ArrayList;

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java Mon Sep 28 16:49:57 2015
@@ -18,9 +18,11 @@
  */
 package org.apache.yay.core;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
 import org.apache.yay.CreationException;
 import org.apache.yay.Input;
 import org.apache.yay.LearningException;
@@ -53,6 +55,12 @@ public class NeuralNetworkFactory {
                                      final SelectionFunction<Collection<Double>, Double> selectionFunction) throws CreationException {
     return new NeuralNetwork() {
 
+      @Override
+      public Double[] getOutputVector(Input<Double> input) throws PredictionException {
+        Collection<Double> inputVector = ConversionUtils.toValuesCollection(input.getFeatures());
+        return predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet);
+      }
+
       private RealMatrix[] updatedRealMatrixSet = realMatrixSet;
 
       @Override
@@ -77,8 +85,7 @@ public class NeuralNetworkFactory {
       @Override
       public Double predict(Input<Double> input) throws PredictionException {
         try {
-          Collection<Double> inputVector = ConversionUtils.toValuesCollection(input.getFeatures());
-          Double[] doubles = predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet);
+          Double[] doubles = getOutputVector(input);
           return selectionFunction.selectOutput(Arrays.asList(doubles));
         } catch (Exception e) {
           throw new PredictionException(e);

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java Mon Sep 28 16:49:57 2015
@@ -19,6 +19,8 @@
 package org.apache.yay.core.utils;
 
 import java.util.ArrayList;
+import java.util.Collection;
+
 import org.apache.yay.Feature;
 import org.apache.yay.Input;
 import org.apache.yay.TrainingExample;
@@ -41,6 +43,21 @@ public class ExamplesFactory {
         return output;
       }
     };
+  }
+
+  public static TrainingExample<Double, Collection<Double[]>> createSGMExample(final Collection<Double[]> output,
+                                                                            final Double... featuresValues) {
+    return new TrainingExample<Double, Collection<Double[]>>() {
+      @Override
+      public ArrayList<Feature<Double>> getFeatures() {
+        return doublesToFeatureVector(featuresValues);
+      }
+
+      @Override
+      public Collection<Double[]> getOutput() {
+        return output;
+      }
+    };
   }
 
   public static Input<Double> createDoubleInput(final Double... featuresValues) {

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java Mon Sep 28 16:49:57 2015
@@ -84,8 +84,11 @@ public class NeuralNetworkFactoryTest {
   public void sampleCreationTest() throws Exception {
     RealMatrix firstLayer = new Array2DRowRealMatrix(new double[][]{{1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}});
     RealMatrix secondLayer = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 3d}});
+
     RealMatrix[] RealMatrixes = new RealMatrix[]{firstLayer, secondLayer};
+
     NeuralNetwork neuralNetwork = createFFNN(RealMatrixes);
+
     Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d, 7d));
     assertEquals(1l, Math.round(prdictedValue));
     assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue);



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org