You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2013/04/17 16:48:26 UTC
svn commit: r1468944 - in /labs/yay/trunk/core/src:
main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
Author: tommaso
Date: Wed Apr 17 14:48:26 2013
New Revision: 1468944
URL: http://svn.apache.org/r1468944
Log:
started cleaning backpropagation phase 1
Added:
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java (with props)
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1468944&r1=1468943&r2=1468944&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Wed Apr 17 14:48:26 2013
@@ -53,36 +53,28 @@ public class BackPropagationLearningStra
for (TrainingExample<Double, Double> trainingExample : trainingExamples) {
try {
// contains activation errors for the current training example
- // TODO : check if this should be RealVector[] < probably yes
- RealMatrix[] activationErrors = new RealMatrix[weightsMatrixSet.length - 1];
+ int noOfMatrixes = weightsMatrixSet.length - 1;
// feed forward propagation
- RealMatrix[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet);
+ RealVector[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet);
// calculate output error
RealVector error = calculateOutputError(trainingExample, activations);
- activationErrors[activationErrors.length - 1] = new Array2DRowRealMatrix(error.toArray());
+ RealVector nextLayerDelta = error;
- RealVector nextLayerDelta = new ArrayRealVector(error);
+ triangle[noOfMatrixes] = new Array2DRowRealMatrix(weightsMatrixSet[noOfMatrixes].getColumnDimension(), weightsMatrixSet[noOfMatrixes].getRowDimension());
+ triangle[noOfMatrixes] = triangle[noOfMatrixes].add(activations[noOfMatrixes - 1].outerProduct(error));
- // back prop the error and update the activationErrors accordingly
+ // back prop the error and update the deltas accordingly
// TODO : eventually remove the bias term from the error calculations
for (int l = weightsMatrixSet.length - 2; l >= 0; l--) {
RealVector resultingDeltaVector = calculateDeltaVector(weightsMatrixSet[l], activations[l], nextLayerDelta);
- if (activationErrors[l] == null) {
- activationErrors[l] = new Array2DRowRealMatrix(new ArrayRealVector(resultingDeltaVector.getDimension(), 1d).toArray());
- }
- activationErrors[l] = new Array2DRowRealMatrix(resultingDeltaVector.toArray());
- nextLayerDelta = resultingDeltaVector;
- }
- // update the accumulator matrix
- for (int l = 0; l < triangle.length - 1; l++) {
if (triangle[l] == null) {
triangle[l] = new Array2DRowRealMatrix(weightsMatrixSet[l].getRowDimension(), weightsMatrixSet[l].getColumnDimension());
}
- triangle[l] = triangle[l].add(activationErrors[l + 1].getRowVector(0).outerProduct(activations[l].getRowVector(0)));
+ triangle[l] = triangle[l].add(resultingDeltaVector.outerProduct(activations[l-1]));
}
} catch (Exception e) {
@@ -95,27 +87,34 @@ public class BackPropagationLearningStra
triangle[i] = triangle[i].scalarMultiply(1 / count);
}
- // TODO : now apply gradient descent (or other optimization/minimization algorithms) with this derivative terms and the cost function
+ // TODO : now apply gradient descent (or other optimization/minimization algorithms) with 'triangle' derivative terms and the cost function
return null;
}
- private RealVector calculateDeltaVector(RealMatrix thetaL, RealMatrix activation, RealVector nextLayerDelta) {
- ArrayRealVector activationsVector = new ArrayRealVector(activation.getRowVector(0)); // get l-th nn layer activations
+ private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector activationsVector, RealVector nextLayerDelta) {
ArrayRealVector identity = new ArrayRealVector(activationsVector.getDimension(), 1d);
RealVector gz = activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l .* (1-a^l)
- return thetaL.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
+ return thetaL.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz); // TODO : fix this as it seems wrong
}
- private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealMatrix[] activations) {
- RealMatrix output = activations[activations.length - 1];
- RealVector predictedOutputVector = new ArrayRealVector(output.getColumn(output.getColumnDimension() - 1)); // turn output to vector
-
- Double[] learnedOutput = new Double[predictedOutputVector.getDimension()]; // training example output
- learnedOutput[trainingExample.getOutput().intValue()] = 1d;
- RealVector learnedOutputRealVector = new ArrayRealVector(learnedOutput); // turn example output to a vector
+ private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealVector[] activations) {
+ RealVector output = activations[activations.length - 1];
+
+ Double[] sampleOutput = new Double[output.getDimension()];
+ int sampleOutputIntValue = trainingExample.getOutput().intValue();
+ if (sampleOutputIntValue < sampleOutput.length) {
+ sampleOutput[sampleOutputIntValue] = 1d;
+ }
+ else if (sampleOutput.length == 1){
+ sampleOutput[0] = trainingExample.getOutput();
+ }
+ else {
+ throw new RuntimeException("problem with multiclass output mapping");
+ }
+ RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // turn example output to a vector
- // TODO : improve error calculation > this should be er_a = out_a * (1 - out_a) * (tgt_a - out_a)
- return predictedOutputVector.subtract(learnedOutputRealVector);
+ // TODO : improve error calculation > this could be er_a = out_a * (1 - out_a) * (tgt_a - out_a)
+ return output.subtract(learnedOutputRealVector);
}
}
Added: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1468944&view=auto
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java (added)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java Wed Apr 17 14:48:26 2013
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay.core;
+
+import java.util.ArrayList;
+import java.util.Collection;
+
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.yay.PredictionStrategy;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+import org.apache.yay.core.utils.ExamplesFactory;
+import org.junit.Test;
+
+import static junit.framework.Assert.assertNotNull;
+
+/**
+ * Testcase for {@link org.apache.yay.core.BackPropagationLearningStrategy}
+ */
+public class BackPropagationLearningStrategyTest {
+
+ @Test
+ public void testLearningWithRandomSamples() throws Exception {
+ PredictionStrategy<Double, Double> predictionStrategy = new FeedForwardStrategy(new SigmoidFunction());
+ BackPropagationLearningStrategy backPropagationLearningStrategy =
+ new BackPropagationLearningStrategy(predictionStrategy, new LogisticRegressionCostFunction(0.4d));
+
+ RealMatrix[] initialWeights = new RealMatrix[3];
+ initialWeights[0] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d}, {1d, 0.6d, 3d}, {1d, 2d, 2d}, {1d, 0.6d, 3d}});
+ initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 0d}, {0d, 0.5d, 1d, 0.5d}, {0d, 0.1d, 8d, 0.1d}, {0d, 0.1d, 8d, 0.2d}});
+ initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{0.7d, 2d, 0.3d, 0.5d}});
+
+ Collection<TrainingExample<Double, Double>> samples = createSamples(100, 2);
+ TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples);
+ RealMatrix[] weights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
+ assertNotNull(weights);
+ }
+
+ private Collection<TrainingExample<Double, Double>> createSamples(int size, int noOfFeatures) {
+ Collection<TrainingExample<Double, Double>> trainingExamples = new ArrayList<TrainingExample<Double, Double>>(size);
+ Double[] featureValues = new Double[noOfFeatures];
+ for (int i = 0; i < size; i++) {
+ for (int j = 0; j < noOfFeatures; j++) {
+ featureValues[j] = Math.random();
+ }
+ trainingExamples.add(ExamplesFactory.createDoubleTrainingExample(1d, featureValues));
+ }
+ return trainingExamples;
+ }
+}
Propchange: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
------------------------------------------------------------------------------
svn:eol-style = native
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org