You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2015/09/28 18:49:57 UTC
svn commit: r1705721 - in /labs/yay/trunk: api/src/main/java/org/apache/yay/
core/src/main/java/org/apache/yay/core/
core/src/main/java/org/apache/yay/core/utils/
core/src/test/java/org/apache/yay/core/
Author: tommaso
Date: Mon Sep 28 16:49:57 2015
New Revision: 1705721
URL: http://svn.apache.org/viewvc?rev=1705721&view=rev
Log:
switch from batch to stochastic GD in backprop, abstracted derivative update function
Added:
labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
Removed:
labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java
Modified:
labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
Added: labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java?rev=1705721&view=auto
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java (added)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java Mon Sep 28 16:49:57 2015
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.yay.TrainingSet;
+
+/**
+ * Derivatives update function
+ */
+public interface DerivativeUpdateFunction<F,O> {
+
+ RealMatrix[] updateParameters(RealMatrix[] weights, TrainingSet<F,O> trainingSet);
+}
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Mon Sep 28 16:49:57 2015
@@ -21,8 +21,19 @@ package org.apache.yay;
import org.apache.commons.math3.linear.RealMatrix;
/**
- * A neural network is a layered connected graph of elaboration units
+ * A Neural Network is a layered connected graph of elaboration units.
+ *
+ * It takes a double vector as input and produces a double vector as output.
*/
public interface NeuralNetwork extends Hypothesis<RealMatrix, Double, Double> {
+ /**
+ * Predict the output for a given input
+ *
+ * @param input the input to evaluate
+ * @return the predicted output
+ * @throws PredictionException if any error occurs during the prediction phase
+ */
+ Double[] getOutputVector(Input<Double> input) throws PredictionException;
+
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Mon Sep 28 16:49:57 2015
@@ -19,6 +19,7 @@
package org.apache.yay.core;
import java.util.Arrays;
+import java.util.DoubleSummaryStatistics;
import java.util.Iterator;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
@@ -26,6 +27,7 @@ import org.apache.commons.math3.linear.A
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.yay.CostFunction;
+import org.apache.yay.DerivativeUpdateFunction;
import org.apache.yay.LearningStrategy;
import org.apache.yay.NeuralNetwork;
import org.apache.yay.PredictionStrategy;
@@ -46,6 +48,7 @@ public class BackPropagationLearningStra
private final PredictionStrategy<Double, Double> predictionStrategy;
private final CostFunction<RealMatrix, Double, Double> costFunction;
+ private final DerivativeUpdateFunction<Double, Double> derivativeUpdateFunction;
private final double alpha;
private final double threshold;
private final int batch;
@@ -63,6 +66,7 @@ public class BackPropagationLearningStra
this.alpha = alpha;
this.threshold = threshold;
this.batch = batch;
+ this.derivativeUpdateFunction = new DefaultDerivativeUpdateFunction(predictionStrategy);
}
public BackPropagationLearningStrategy() {
@@ -72,6 +76,7 @@ public class BackPropagationLearningStra
this.alpha = DEFAULT_ALPHA;
this.threshold = DEFAULT_THRESHOLD;
this.batch = 1;
+ this.derivativeUpdateFunction = new DefaultDerivativeUpdateFunction(predictionStrategy);
}
@Override
@@ -114,7 +119,7 @@ public class BackPropagationLearningStra
cost = newCost;
// calculate the derivatives to update the parameters
- RealMatrix[] derivatives = calculateDerivatives(weightsMatrixSet, samples);
+ RealMatrix[] derivatives = derivativeUpdateFunction.updateParameters(weightsMatrixSet, samples);
// calculate the updated parameters
updatedWeights = updateWeights(weightsMatrixSet, derivatives, alpha);
@@ -131,48 +136,6 @@ public class BackPropagationLearningStra
return updatedWeights;
}
- private RealMatrix[] calculateDerivatives(RealMatrix[] weightsMatrixSet, TrainingSet<Double, Double> trainingExamples) throws WeightLearningException {
- // set up the accumulator matrix(es)
- RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
- RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length];
-
- int noOfMatrixes = weightsMatrixSet.length - 1;
- double count = 0;
- for (TrainingExample<Double, Double> trainingExample : trainingExamples) {
- try {
- // get activations from feed forward propagation
- RealVector[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet);
-
- // calculate output error (corresponding to the last delta^l)
- RealVector nextLayerDelta = calculateOutputError(trainingExample, activations);
-
- deltaVectors[noOfMatrixes] = nextLayerDelta;
-
- // back prop the error and update the deltas accordingly
- for (int l = noOfMatrixes; l > 0; l--) {
- RealVector currentActivationsVector = activations[l - 1];
- nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l], currentActivationsVector, nextLayerDelta);
-
- // collect delta vectors for this example
- deltaVectors[l - 1] = nextLayerDelta;
- }
-
- RealVector[] newActivations = new RealVector[activations.length];
- newActivations[0] = ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures()));
- System.arraycopy(activations, 0, newActivations, 1, activations.length - 1);
-
- // update triangle (big delta matrix)
- updateTriangle(triangle, newActivations, deltaVectors, weightsMatrixSet);
-
- } catch (Exception e) {
- throw new WeightLearningException("error during derivatives calculation", e);
- }
- count++;
- }
-
- return createDerivatives(triangle, count);
- }
-
private RealMatrix[] updateWeights(RealMatrix[] weightsMatrixSet, RealMatrix[] derivatives, double alpha) {
RealMatrix[] updatedParameters = new RealMatrix[weightsMatrixSet.length];
for (int l = 0; l < weightsMatrixSet.length; l++) {
@@ -187,48 +150,4 @@ public class BackPropagationLearningStra
return updatedParameters;
}
- private RealMatrix[] createDerivatives(RealMatrix[] triangle, double count) {
- RealMatrix[] derivatives = new RealMatrix[triangle.length];
- for (int i = 0; i < triangle.length; i++) {
- // TODO : introduce regularization diversification on bias term (currently not regularized)
- derivatives[i] = triangle[i].scalarMultiply(1d / count);
- }
- return derivatives;
- }
-
- private void updateTriangle(RealMatrix[] triangle, RealVector[] activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) {
- for (int l = weightsMatrixSet.length - 1; l >= 0; l--) {
- RealMatrix realMatrix = deltaVectors[l].outerProduct(activations[l]);
- if (triangle[l] == null) {
- triangle[l] = realMatrix;
- } else {
- triangle[l] = triangle[l].add(realMatrix);
- }
- }
- }
-
- private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector activationsVector, RealVector nextLayerDelta) {
- // TODO : remove the bias term from the error calculations
- ArrayRealVector identity = new ArrayRealVector(activationsVector.getDimension(), 1d);
- RealVector gz = activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l .* (1-a^l)
- return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz);
- }
-
- private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealVector[] activations) {
- RealVector output = activations[activations.length - 1];
-
- Double[] sampleOutput = new Double[output.getDimension()];
- int sampleOutputIntValue = trainingExample.getOutput().intValue();
- if (sampleOutputIntValue < sampleOutput.length) {
- sampleOutput[sampleOutputIntValue] = 1d;
- } else if (sampleOutput.length == 1) {
- sampleOutput[0] = trainingExample.getOutput();
- } else {
- throw new RuntimeException("problem with multiclass output mapping");
- }
- RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // turn example output to a vector
-
- // TODO : improve error calculation -> this could be er_a = out_a * (1 - out_a) * (tgt_a - out_a)
- return output.subtract(learnedOutputRealVector);
- }
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java Mon Sep 28 16:49:57 2015
@@ -94,4 +94,13 @@ public class BasicPerceptron implements
return perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
new Double[input.getFeatures().size()]));
}
+
+ @Override
+ public Double[] getOutputVector(Input<Double> input) throws PredictionException {
+ Double elaborate = perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
+ new Double[input.getFeatures().size()]));
+ Double[] ar = new Double[1];
+ ar[0] = elaborate;
+ return ar;
+ }
}
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1705721&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java Mon Sep 28 16:49:57 2015
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay.core;
+
+import org.apache.commons.math3.linear.ArrayRealVector;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
+import org.apache.yay.DerivativeUpdateFunction;
+import org.apache.yay.PredictionStrategy;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+import org.apache.yay.core.utils.ConversionUtils;
+
+/**
+ * Default derivatives update function
+ */
+public class DefaultDerivativeUpdateFunction implements DerivativeUpdateFunction<Double, Double> {
+
+ private final PredictionStrategy<Double, Double> predictionStrategy;
+
+ public DefaultDerivativeUpdateFunction(PredictionStrategy<Double, Double> predictionStrategy) {
+ this.predictionStrategy = predictionStrategy;
+ }
+
+ @Override
+ public RealMatrix[] updateParameters(RealMatrix[] weightsMatrixSet, TrainingSet<Double, Double> trainingExamples) {
+ // set up the accumulator matrix(es)
+ RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
+ RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length];
+
+ int noOfMatrixes = weightsMatrixSet.length - 1;
+ double count = 0;
+ for (TrainingExample<Double, Double> trainingExample : trainingExamples) {
+ try {
+ // get activations from feed forward propagation
+ RealVector[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet);
+
+ // calculate output error (corresponding to the last delta^l)
+ RealVector nextLayerDelta = calculateOutputError(trainingExample, activations);
+
+ deltaVectors[noOfMatrixes] = nextLayerDelta;
+
+ // back prop the error and update the deltas accordingly
+ for (int l = noOfMatrixes; l > 0; l--) {
+ RealVector currentActivationsVector = activations[l - 1];
+ nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l], currentActivationsVector, nextLayerDelta);
+
+ // collect delta vectors for this example
+ deltaVectors[l - 1] = nextLayerDelta;
+ }
+
+ RealVector[] newActivations = new RealVector[activations.length];
+ newActivations[0] = ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures()));
+ System.arraycopy(activations, 0, newActivations, 1, activations.length - 1);
+
+ // update triangle (big delta matrix)
+ updateTriangle(triangle, newActivations, deltaVectors, weightsMatrixSet);
+
+ } catch (Exception e) {
+ throw new RuntimeException("error during derivatives calculation", e);
+ }
+ count++;
+ }
+
+ return createDerivatives(triangle, count);
+ }
+
+ private RealMatrix[] createDerivatives(RealMatrix[] triangle, double count) {
+ RealMatrix[] derivatives = new RealMatrix[triangle.length];
+ for (int i = 0; i < triangle.length; i++) {
+ // TODO : introduce regularization diversification on bias term (currently not regularized)
+ derivatives[i] = triangle[i].scalarMultiply(1d / count);
+ }
+ return derivatives;
+ }
+
+ private void updateTriangle(RealMatrix[] triangle, RealVector[] activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) {
+ for (int l = weightsMatrixSet.length - 1; l >= 0; l--) {
+ RealMatrix realMatrix = deltaVectors[l].outerProduct(activations[l]);
+ if (triangle[l] == null) {
+ triangle[l] = realMatrix;
+ } else {
+ triangle[l] = triangle[l].add(realMatrix);
+ }
+ }
+ }
+
+ private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector activationsVector, RealVector nextLayerDelta) {
+ // TODO : remove the bias term from the error calculations
+ ArrayRealVector identity = new ArrayRealVector(activationsVector.getDimension(), 1d);
+ RealVector gz = activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l .* (1-a^l)
+ return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz);
+ }
+
+ private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealVector[] activations) {
+ RealVector output = activations[activations.length - 1];
+
+ Double[] sampleOutput = new Double[output.getDimension()];
+ int sampleOutputIntValue = trainingExample.getOutput().intValue();
+ if (sampleOutputIntValue < sampleOutput.length) {
+ sampleOutput[sampleOutputIntValue] = 1d;
+ } else if (sampleOutput.length == 1) {
+ sampleOutput[0] = trainingExample.getOutput();
+ } else {
+ throw new RuntimeException("problem with multiclass output mapping");
+ }
+ RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // turn example output to a vector
+
+ // TODO : improve error calculation -> this could be er_a = out_a * (1 - out_a) * (tgt_a - out_a)
+ return output.subtract(learnedOutputRealVector);
+ }
+}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Mon Sep 28 16:49:57 2015
@@ -16,7 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
-
package org.apache.yay.core;
import java.util.ArrayList;
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java Mon Sep 28 16:49:57 2015
@@ -18,9 +18,11 @@
*/
package org.apache.yay.core;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
import org.apache.yay.CreationException;
import org.apache.yay.Input;
import org.apache.yay.LearningException;
@@ -53,6 +55,12 @@ public class NeuralNetworkFactory {
final SelectionFunction<Collection<Double>, Double> selectionFunction) throws CreationException {
return new NeuralNetwork() {
+ @Override
+ public Double[] getOutputVector(Input<Double> input) throws PredictionException {
+ Collection<Double> inputVector = ConversionUtils.toValuesCollection(input.getFeatures());
+ return predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet);
+ }
+
private RealMatrix[] updatedRealMatrixSet = realMatrixSet;
@Override
@@ -77,8 +85,7 @@ public class NeuralNetworkFactory {
@Override
public Double predict(Input<Double> input) throws PredictionException {
try {
- Collection<Double> inputVector = ConversionUtils.toValuesCollection(input.getFeatures());
- Double[] doubles = predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet);
+ Double[] doubles = getOutputVector(input);
return selectionFunction.selectOutput(Arrays.asList(doubles));
} catch (Exception e) {
throw new PredictionException(e);
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java Mon Sep 28 16:49:57 2015
@@ -19,6 +19,8 @@
package org.apache.yay.core.utils;
import java.util.ArrayList;
+import java.util.Collection;
+
import org.apache.yay.Feature;
import org.apache.yay.Input;
import org.apache.yay.TrainingExample;
@@ -41,6 +43,21 @@ public class ExamplesFactory {
return output;
}
};
+ }
+
+ public static TrainingExample<Double, Collection<Double[]>> createSGMExample(final Collection<Double[]> output,
+ final Double... featuresValues) {
+ return new TrainingExample<Double, Collection<Double[]>>() {
+ @Override
+ public ArrayList<Feature<Double>> getFeatures() {
+ return doublesToFeatureVector(featuresValues);
+ }
+
+ @Override
+ public Collection<Double[]> getOutput() {
+ return output;
+ }
+ };
}
public static Input<Double> createDoubleInput(final Double... featuresValues) {
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java Mon Sep 28 16:49:57 2015
@@ -84,8 +84,11 @@ public class NeuralNetworkFactoryTest {
public void sampleCreationTest() throws Exception {
RealMatrix firstLayer = new Array2DRowRealMatrix(new double[][]{{1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}});
RealMatrix secondLayer = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 3d}});
+
RealMatrix[] RealMatrixes = new RealMatrix[]{firstLayer, secondLayer};
+
NeuralNetwork neuralNetwork = createFFNN(RealMatrixes);
+
Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d, 7d));
assertEquals(1l, Math.round(prdictedValue));
assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue);
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org