You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2013/02/25 08:48:17 UTC

svn commit: r1449612 - /labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java

Author: tommaso
Date: Mon Feb 25 07:48:17 2013
New Revision: 1449612

URL: http://svn.apache.org/r1449612
Log:
started refactoring backprop

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1449612&r1=1449611&r2=1449612&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java Mon Feb 25 07:48:17 2013
@@ -32,70 +32,81 @@ import java.util.Collection;
  */
 public class BackPropagationLearningStrategy implements LearningStrategy<Double, Double[]> {
 
-  private final PredictionStrategy<Double, Double[]> predictionStrategy;
-  private CostFunction<RealMatrix, Double> costFunction;
+    private final PredictionStrategy<Double, Double[]> predictionStrategy;
+    private CostFunction<RealMatrix, Double> costFunction;
 
-  public BackPropagationLearningStrategy(PredictionStrategy<Double, Double[]> predictionStrategy, CostFunction<RealMatrix, Double> costFunction) {
-    this.predictionStrategy = predictionStrategy;
-    this.costFunction = costFunction;
-  }
-
-  @Override
-  public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet, Collection<TrainingExample<Double, Double[]>> trainingExamples) throws WeightLearningException {
-    // set up the accumulator matrix(es)
-    RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
-    for (TrainingExample<Double, Double[]> trainingExample : trainingExamples) {
-      try {
-        // contains activation errors for the current training example
-        // TODO : check if this should be RealVector[] < probably yes
-        RealMatrix[] activationErrors = new RealMatrix[weightsMatrixSet.length - 1];
-
-        // feed forward propagation
-        RealMatrix[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet);
-        RealMatrix output = activations[activations.length - 1];
-        Double[] learnedOutput = trainingExample.getOutput(); // training example output
-        RealVector predictedOutputVector = new ArrayRealVector(output.getColumn(output.getColumnDimension() - 1)); // turn output to vector
-        RealVector learnedOutputRealVector = new ArrayRealVector(learnedOutput); // turn example output to a vector
-
-        RealVector error = predictedOutputVector.subtract(learnedOutputRealVector); // final layer error vector
-        activationErrors[activationErrors.length - 1] = new Array2DRowRealMatrix(error.toArray());
-
-        RealVector nextLayerDelta = new ArrayRealVector(error);
+    public BackPropagationLearningStrategy(PredictionStrategy<Double, Double[]> predictionStrategy, CostFunction<RealMatrix, Double> costFunction) {
+        this.predictionStrategy = predictionStrategy;
+        this.costFunction = costFunction;
+    }
 
-        // back prop the error and update the activationErrors accordingly
-        // TODO : remove the bias term from the error calculations
-        for (int l = weightsMatrixSet.length - 2; l >= 0; l--) {
-          RealMatrix thetaL = weightsMatrixSet[l];
-          ArrayRealVector activationsVector = new ArrayRealVector(activations[l].getRowVector(0)); // get l-th nn layer activations
-          ArrayRealVector identity = new ArrayRealVector(activationsVector.getDimension(), 1d);
-          RealVector gz = activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l .* (1-a^l)
-          RealVector resultingDeltaVector = thetaL.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
-          if (activationErrors[l] == null) {
-            activationErrors[l] = new Array2DRowRealMatrix(new ArrayRealVector(resultingDeltaVector.getDimension(), 1d).toArray());
-          }
-          activationErrors[l] = new Array2DRowRealMatrix(resultingDeltaVector.toArray());
-          nextLayerDelta = resultingDeltaVector;
+    @Override
+    public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet, Collection<TrainingExample<Double, Double[]>> trainingExamples) throws WeightLearningException {
+        // set up the accumulator matrix(es)
+        RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
+        for (TrainingExample<Double, Double[]> trainingExample : trainingExamples) {
+            try {
+                // contains activation errors for the current training example
+                // TODO : check if this should be RealVector[] < probably yes
+                RealMatrix[] activationErrors = new RealMatrix[weightsMatrixSet.length - 1];
+
+                // feed forward propagation
+                RealMatrix[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet);
+
+                // calculate output error
+                RealVector error = calculateOutputError(trainingExample, activations);
+
+                activationErrors[activationErrors.length - 1] = new Array2DRowRealMatrix(error.toArray());
+
+                RealVector nextLayerDelta = new ArrayRealVector(error);
+
+                // back prop the error and update the activationErrors accordingly
+                // TODO : eventually remove the bias term from the error calculations
+                for (int l = weightsMatrixSet.length - 2; l >= 0; l--) {
+                    RealVector resultingDeltaVector = calculateDeltaVector(weightsMatrixSet[l], activations[l], nextLayerDelta);
+                    if (activationErrors[l] == null) {
+                        activationErrors[l] = new Array2DRowRealMatrix(new ArrayRealVector(resultingDeltaVector.getDimension(), 1d).toArray());
+                    }
+                    activationErrors[l] = new Array2DRowRealMatrix(resultingDeltaVector.toArray());
+                    nextLayerDelta = resultingDeltaVector;
+                }
+
+                // update the accumulator matrix
+                for (int l = 0; l < triangle.length - 1; l++) {
+                    if (triangle[l] == null) {
+                        triangle[l] = new Array2DRowRealMatrix(weightsMatrixSet[l].getRowDimension(), weightsMatrixSet[l].getColumnDimension());
+                    }
+                    triangle[l] = triangle[l].add(activationErrors[l + 1].getRowVector(0).outerProduct(activations[l].getRowVector(0)));
+                }
+
+            } catch (Exception e) {
+                throw new WeightLearningException("error during phase 1 of back-propagation algorithm", e);
+            }
         }
-
-        // update the accumulator matrix
-        for (int l = 0; l < triangle.length - 1; l++) {
-          if (triangle[l] == null) {
-            triangle[l] = new Array2DRowRealMatrix(weightsMatrixSet[l].getRowDimension(), weightsMatrixSet[l].getColumnDimension());
-          }
-          triangle[l] = triangle[l].add(activationErrors[l+1].getRowVector(0).outerProduct(activations[l].getRowVector(0)));
+        for (int i = 0; i < triangle.length; i++) {
+            // TODO : introduce regularization diversification on bias term (currently not regularized)
+            triangle[i] = triangle[i].scalarMultiply(1 / trainingExamples.size());
         }
 
-      } catch (Exception e) {
-        throw new WeightLearningException("error during phase 1 of back-propagation algorithm", e);
-      }
+        // TODO : now apply gradient descent (or other optimization/minimization algorithms) with this derivative terms and the cost function
+
+        return null;
     }
-    for (int i = 0; i < triangle.length; i++) {
-      // TODO : introduce regularization diversification on bias term (currently not regularized)
-      triangle[i] = triangle[i].scalarMultiply(1 / trainingExamples.size());
+
+    private RealVector calculateDeltaVector(RealMatrix thetaL, RealMatrix activation, RealVector nextLayerDelta) {
+        ArrayRealVector activationsVector = new ArrayRealVector(activation.getRowVector(0)); // get l-th nn layer activations
+        ArrayRealVector identity = new ArrayRealVector(activationsVector.getDimension(), 1d);
+        RealVector gz = activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l .* (1-a^l)
+        return thetaL.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
     }
 
-    // TODO : now apply gradient descent (or other optimization/minimization algorithms) with this derivative terms and the cost function
+    private RealVector calculateOutputError(TrainingExample<Double, Double[]> trainingExample, RealMatrix[] activations) {
+        RealMatrix output = activations[activations.length - 1];
+        Double[] learnedOutput = trainingExample.getOutput(); // training example output
+        RealVector predictedOutputVector = new ArrayRealVector(output.getColumn(output.getColumnDimension() - 1)); // turn output to vector
+        RealVector learnedOutputRealVector = new ArrayRealVector(learnedOutput); // turn example output to a vector
 
-    return null;
-  }
+        // TODO : improve error calculation > this should be er_a = out_a * (1 - out_a) * (tgt_a - out_a)
+        return predictedOutputVector.subtract(learnedOutputRealVector);
+    }
 }



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org