You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2013/02/25 08:48:17 UTC
svn commit: r1449612 -
/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
Author: tommaso
Date: Mon Feb 25 07:48:17 2013
New Revision: 1449612
URL: http://svn.apache.org/r1449612
Log:
started refactoring backprop
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1449612&r1=1449611&r2=1449612&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java Mon Feb 25 07:48:17 2013
@@ -32,70 +32,81 @@ import java.util.Collection;
*/
public class BackPropagationLearningStrategy implements LearningStrategy<Double, Double[]> {
- private final PredictionStrategy<Double, Double[]> predictionStrategy;
- private CostFunction<RealMatrix, Double> costFunction;
+ private final PredictionStrategy<Double, Double[]> predictionStrategy;
+ private CostFunction<RealMatrix, Double> costFunction;
- public BackPropagationLearningStrategy(PredictionStrategy<Double, Double[]> predictionStrategy, CostFunction<RealMatrix, Double> costFunction) {
- this.predictionStrategy = predictionStrategy;
- this.costFunction = costFunction;
- }
-
- @Override
- public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet, Collection<TrainingExample<Double, Double[]>> trainingExamples) throws WeightLearningException {
- // set up the accumulator matrix(es)
- RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
- for (TrainingExample<Double, Double[]> trainingExample : trainingExamples) {
- try {
- // contains activation errors for the current training example
- // TODO : check if this should be RealVector[] < probably yes
- RealMatrix[] activationErrors = new RealMatrix[weightsMatrixSet.length - 1];
-
- // feed forward propagation
- RealMatrix[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet);
- RealMatrix output = activations[activations.length - 1];
- Double[] learnedOutput = trainingExample.getOutput(); // training example output
- RealVector predictedOutputVector = new ArrayRealVector(output.getColumn(output.getColumnDimension() - 1)); // turn output to vector
- RealVector learnedOutputRealVector = new ArrayRealVector(learnedOutput); // turn example output to a vector
-
- RealVector error = predictedOutputVector.subtract(learnedOutputRealVector); // final layer error vector
- activationErrors[activationErrors.length - 1] = new Array2DRowRealMatrix(error.toArray());
-
- RealVector nextLayerDelta = new ArrayRealVector(error);
+ public BackPropagationLearningStrategy(PredictionStrategy<Double, Double[]> predictionStrategy, CostFunction<RealMatrix, Double> costFunction) {
+ this.predictionStrategy = predictionStrategy;
+ this.costFunction = costFunction;
+ }
- // back prop the error and update the activationErrors accordingly
- // TODO : remove the bias term from the error calculations
- for (int l = weightsMatrixSet.length - 2; l >= 0; l--) {
- RealMatrix thetaL = weightsMatrixSet[l];
- ArrayRealVector activationsVector = new ArrayRealVector(activations[l].getRowVector(0)); // get l-th nn layer activations
- ArrayRealVector identity = new ArrayRealVector(activationsVector.getDimension(), 1d);
- RealVector gz = activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l .* (1-a^l)
- RealVector resultingDeltaVector = thetaL.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
- if (activationErrors[l] == null) {
- activationErrors[l] = new Array2DRowRealMatrix(new ArrayRealVector(resultingDeltaVector.getDimension(), 1d).toArray());
- }
- activationErrors[l] = new Array2DRowRealMatrix(resultingDeltaVector.toArray());
- nextLayerDelta = resultingDeltaVector;
+ @Override
+ public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet, Collection<TrainingExample<Double, Double[]>> trainingExamples) throws WeightLearningException {
+ // set up the accumulator matrix(es)
+ RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
+ for (TrainingExample<Double, Double[]> trainingExample : trainingExamples) {
+ try {
+ // contains activation errors for the current training example
+ // TODO : check if this should be RealVector[] < probably yes
+ RealMatrix[] activationErrors = new RealMatrix[weightsMatrixSet.length - 1];
+
+ // feed forward propagation
+ RealMatrix[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet);
+
+ // calculate output error
+ RealVector error = calculateOutputError(trainingExample, activations);
+
+ activationErrors[activationErrors.length - 1] = new Array2DRowRealMatrix(error.toArray());
+
+ RealVector nextLayerDelta = new ArrayRealVector(error);
+
+ // back prop the error and update the activationErrors accordingly
+ // TODO : eventually remove the bias term from the error calculations
+ for (int l = weightsMatrixSet.length - 2; l >= 0; l--) {
+ RealVector resultingDeltaVector = calculateDeltaVector(weightsMatrixSet[l], activations[l], nextLayerDelta);
+ if (activationErrors[l] == null) {
+ activationErrors[l] = new Array2DRowRealMatrix(new ArrayRealVector(resultingDeltaVector.getDimension(), 1d).toArray());
+ }
+ activationErrors[l] = new Array2DRowRealMatrix(resultingDeltaVector.toArray());
+ nextLayerDelta = resultingDeltaVector;
+ }
+
+ // update the accumulator matrix
+ for (int l = 0; l < triangle.length - 1; l++) {
+ if (triangle[l] == null) {
+ triangle[l] = new Array2DRowRealMatrix(weightsMatrixSet[l].getRowDimension(), weightsMatrixSet[l].getColumnDimension());
+ }
+ triangle[l] = triangle[l].add(activationErrors[l + 1].getRowVector(0).outerProduct(activations[l].getRowVector(0)));
+ }
+
+ } catch (Exception e) {
+ throw new WeightLearningException("error during phase 1 of back-propagation algorithm", e);
+ }
}
-
- // update the accumulator matrix
- for (int l = 0; l < triangle.length - 1; l++) {
- if (triangle[l] == null) {
- triangle[l] = new Array2DRowRealMatrix(weightsMatrixSet[l].getRowDimension(), weightsMatrixSet[l].getColumnDimension());
- }
- triangle[l] = triangle[l].add(activationErrors[l+1].getRowVector(0).outerProduct(activations[l].getRowVector(0)));
+ for (int i = 0; i < triangle.length; i++) {
+ // TODO : introduce regularization diversification on bias term (currently not regularized)
+ triangle[i] = triangle[i].scalarMultiply(1 / trainingExamples.size());
}
- } catch (Exception e) {
- throw new WeightLearningException("error during phase 1 of back-propagation algorithm", e);
- }
+ // TODO : now apply gradient descent (or other optimization/minimization algorithms) with this derivative terms and the cost function
+
+ return null;
}
- for (int i = 0; i < triangle.length; i++) {
- // TODO : introduce regularization diversification on bias term (currently not regularized)
- triangle[i] = triangle[i].scalarMultiply(1 / trainingExamples.size());
+
+ private RealVector calculateDeltaVector(RealMatrix thetaL, RealMatrix activation, RealVector nextLayerDelta) {
+ ArrayRealVector activationsVector = new ArrayRealVector(activation.getRowVector(0)); // get l-th nn layer activations
+ ArrayRealVector identity = new ArrayRealVector(activationsVector.getDimension(), 1d);
+ RealVector gz = activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l .* (1-a^l)
+ return thetaL.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
}
- // TODO : now apply gradient descent (or other optimization/minimization algorithms) with this derivative terms and the cost function
+ private RealVector calculateOutputError(TrainingExample<Double, Double[]> trainingExample, RealMatrix[] activations) {
+ RealMatrix output = activations[activations.length - 1];
+ Double[] learnedOutput = trainingExample.getOutput(); // training example output
+ RealVector predictedOutputVector = new ArrayRealVector(output.getColumn(output.getColumnDimension() - 1)); // turn output to vector
+ RealVector learnedOutputRealVector = new ArrayRealVector(learnedOutput); // turn example output to a vector
- return null;
- }
+ // TODO : improve error calculation > this should be er_a = out_a * (1 - out_a) * (tgt_a - out_a)
+ return predictedOutputVector.subtract(learnedOutputRealVector);
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org