You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2012/08/03 08:34:45 UTC
svn commit: r1368803 -
/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
Author: tommaso
Date: Fri Aug 3 06:34:45 2012
New Revision: 1368803
URL: http://svn.apache.org/viewvc?rev=1368803&view=rev
Log:
improved first phase of backprop
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1368803&r1=1368802&r2=1368803&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java Fri Aug 3 06:34:45 2012
@@ -20,6 +20,7 @@ package org.apache.yay;
import java.util.Collection;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
@@ -29,42 +30,65 @@ import org.apache.yay.utils.ConversionUt
* Back propagation learning algorithm for neural networks implementation (see
* <code>http://en.wikipedia.org/wiki/Backpropagation</code>).
*/
-public class BackPropagationLearningStrategy implements LearningStrategy<Double, double[]> {
+public class BackPropagationLearningStrategy implements LearningStrategy<Double, Double[]> {
- private PredictionStrategy<Double, double[]> predictionStrategy;
+ private PredictionStrategy<Double, Double[]> predictionStrategy;
- public BackPropagationLearningStrategy(PredictionStrategy<Double, double[]> predictionStrategy) {
+ public BackPropagationLearningStrategy(PredictionStrategy<Double, Double[]> predictionStrategy) {
this.predictionStrategy = predictionStrategy;
}
@Override
- public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<Double, double[]>> trainingExamples) throws WeightLearningException {
- for (TrainingExample<Double, double[]> trainingExample : trainingExamples) {
+ public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<Double, Double[]>> trainingExamples) throws WeightLearningException {
+ // set up the accumulator matrix(es)
+ RealMatrix[] deltas = new RealMatrix[weightsMatrixSet.length];
+ for (TrainingExample<Double, Double[]> trainingExample : trainingExamples) {
try {
-
- RealMatrix[] deltas = new RealMatrix[weightsMatrixSet.length - 1];
+ // contains activation errors for the current training example
+ RealMatrix[] activationErrors = new RealMatrix[weightsMatrixSet.length - 1];
RealMatrix output = predictionStrategy.debugOutput(ConversionUtils.toVector(trainingExample.getFeatureVector()), weightsMatrixSet);
- double[] learnedOutput = trainingExample.getOutput();
+ Double[] learnedOutput = trainingExample.getOutput();
RealVector predictedOutputVector = new ArrayRealVector(output.getColumn(output.getColumnDimension() - 1));
RealVector learnedOutputRealVector = new ArrayRealVector(learnedOutput);
- RealVector error = predictedOutputVector.subtract(learnedOutputRealVector);
+ RealVector error = predictedOutputVector.subtract(learnedOutputRealVector); // final layer error
+ activationErrors[activationErrors.length - 1] = new Array2DRowRealMatrix(error.toArray());
+
+ RealVector nextLayerDelta = error;
- // TODO : back prop the error and update the weights accordingly
- for (int i = weightsMatrixSet.length - 1; i > 0; i--) {
- WeightsMatrix currentMatrix = weightsMatrixSet[i];
- ArrayRealVector realVector = new ArrayRealVector(output.getColumn(i));
+ // back prop the error and update the activationErrors accordingly
+ // TODO : remove the byas term from the error calculations
+ for (int l = weightsMatrixSet.length - 2; l > 0; l--) {
+ WeightsMatrix currentMatrix = weightsMatrixSet[l];
+ ArrayRealVector realVector = new ArrayRealVector(output.getColumn(l));
ArrayRealVector identity = new ArrayRealVector(realVector.getDimension(), 1d);
- RealVector gz = realVector.ebeMultiply(identity.subtract(realVector)); // = a^i .* (1-a^i)
- RealVector resultingDeltaVector = currentMatrix.transpose().preMultiply(error).ebeMultiply(gz);
-// deltas[i] = deltas[i].add()
+ RealVector gz = realVector.ebeMultiply(identity.subtract(realVector)); // = a^l .* (1-a^l)
+ RealVector resultingDeltaVector = currentMatrix.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
+ if (activationErrors[l] == null) {
+ activationErrors[l] = new Array2DRowRealMatrix(new ArrayRealVector(resultingDeltaVector.getDimension(), 1d).toArray());
+ }
+ activationErrors[l] = new Array2DRowRealMatrix(resultingDeltaVector.toArray());
+ }
+
+ // update the accumulator matrix
+ for (int l = 0; l < deltas.length - 1; l++) {
+ if (deltas[l] == null) {
+ deltas[l] = new Array2DRowRealMatrix(weightsMatrixSet[l].getColumnDimension(), weightsMatrixSet[l].getRowDimension());
+ }
+ deltas[l] = deltas[l].add(deltas[l + 1]).multiply(weightsMatrixSet[l].transpose());
}
} catch (Exception e) {
throw new WeightLearningException("error during phase 1 of back-propagation algorithm", e);
}
}
+ for (int i = 0; i < deltas.length; i++) {
+ deltas[i] = deltas[i].scalarMultiply(1 / trainingExamples.size());
+ }
+
+ // now apply gradient descent (or other optimization/minimization algorithms) with this derivative terms and the LRCF
+
return null;
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org