You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2012/08/03 08:34:45 UTC

svn commit: r1368803 - /labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java

Author: tommaso
Date: Fri Aug  3 06:34:45 2012
New Revision: 1368803

URL: http://svn.apache.org/viewvc?rev=1368803&view=rev
Log:
improved first phase of backprop

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1368803&r1=1368802&r2=1368803&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java Fri Aug  3 06:34:45 2012
@@ -20,6 +20,7 @@ package org.apache.yay;
 
 import java.util.Collection;
 
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.ArrayRealVector;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.RealVector;
@@ -29,42 +30,65 @@ import org.apache.yay.utils.ConversionUt
  * Back propagation learning algorithm for neural networks implementation (see
  * <code>http://en.wikipedia.org/wiki/Backpropagation</code>).
  */
-public class BackPropagationLearningStrategy implements LearningStrategy<Double, double[]> {
+public class BackPropagationLearningStrategy implements LearningStrategy<Double, Double[]> {
 
-  private PredictionStrategy<Double, double[]> predictionStrategy;
+  private PredictionStrategy<Double, Double[]> predictionStrategy;
 
-  public BackPropagationLearningStrategy(PredictionStrategy<Double, double[]> predictionStrategy) {
+  public BackPropagationLearningStrategy(PredictionStrategy<Double, Double[]> predictionStrategy) {
     this.predictionStrategy = predictionStrategy;
   }
 
   @Override
-  public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<Double, double[]>> trainingExamples) throws WeightLearningException {
-    for (TrainingExample<Double, double[]> trainingExample : trainingExamples) {
+  public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<Double, Double[]>> trainingExamples) throws WeightLearningException {
+    // set up the accumulator matrix(es)
+    RealMatrix[] deltas = new RealMatrix[weightsMatrixSet.length];
+    for (TrainingExample<Double, Double[]> trainingExample : trainingExamples) {
       try {
-        
-        RealMatrix[] deltas = new RealMatrix[weightsMatrixSet.length - 1];
+        // contains activation errors for the current training example
+        RealMatrix[] activationErrors = new RealMatrix[weightsMatrixSet.length - 1];
         
         RealMatrix output = predictionStrategy.debugOutput(ConversionUtils.toVector(trainingExample.getFeatureVector()), weightsMatrixSet);
-        double[] learnedOutput = trainingExample.getOutput();
+        Double[] learnedOutput = trainingExample.getOutput();
         RealVector predictedOutputVector = new ArrayRealVector(output.getColumn(output.getColumnDimension() - 1));
         RealVector learnedOutputRealVector = new ArrayRealVector(learnedOutput);
 
-        RealVector error = predictedOutputVector.subtract(learnedOutputRealVector);
+        RealVector error = predictedOutputVector.subtract(learnedOutputRealVector); // final layer error
+        activationErrors[activationErrors.length - 1] = new Array2DRowRealMatrix(error.toArray());
+
+        RealVector nextLayerDelta = error;
 
-        // TODO : back prop the error and update the weights accordingly
-        for (int i = weightsMatrixSet.length - 1; i > 0; i--) {
-          WeightsMatrix currentMatrix = weightsMatrixSet[i];
-          ArrayRealVector realVector = new ArrayRealVector(output.getColumn(i));
+        // back prop the error and update the activationErrors accordingly
+        // TODO : remove the byas term from the error calculations
+        for (int l = weightsMatrixSet.length - 2; l > 0; l--) {
+          WeightsMatrix currentMatrix = weightsMatrixSet[l];
+          ArrayRealVector realVector = new ArrayRealVector(output.getColumn(l));
           ArrayRealVector identity = new ArrayRealVector(realVector.getDimension(), 1d);
-          RealVector gz = realVector.ebeMultiply(identity.subtract(realVector)); // = a^i .* (1-a^i)
-          RealVector resultingDeltaVector = currentMatrix.transpose().preMultiply(error).ebeMultiply(gz);
-//          deltas[i] = deltas[i].add()
+          RealVector gz = realVector.ebeMultiply(identity.subtract(realVector)); // = a^l .* (1-a^l)
+          RealVector resultingDeltaVector = currentMatrix.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
+          if (activationErrors[l] == null) {
+            activationErrors[l] = new Array2DRowRealMatrix(new ArrayRealVector(resultingDeltaVector.getDimension(), 1d).toArray());
+          }
+          activationErrors[l] = new Array2DRowRealMatrix(resultingDeltaVector.toArray());
+        }
+
+        // update the accumulator matrix
+        for (int l = 0; l < deltas.length - 1; l++) {
+          if (deltas[l] == null) {
+            deltas[l] = new Array2DRowRealMatrix(weightsMatrixSet[l].getColumnDimension(), weightsMatrixSet[l].getRowDimension());
+          }
+          deltas[l] = deltas[l].add(deltas[l + 1]).multiply(weightsMatrixSet[l].transpose());
         }
 
       } catch (Exception e) {
         throw new WeightLearningException("error during phase 1 of back-propagation algorithm", e);
       }
     }
+    for (int i = 0; i < deltas.length; i++) {
+      deltas[i] = deltas[i].scalarMultiply(1 / trainingExamples.size());
+    }
+
+    // now apply gradient descent (or other optimization/minimization algorithms) with this derivative terms and the LRCF
+
     return null;
   }
 }



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org