You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2012/07/31 00:02:09 UTC
svn commit: r1367334 - in /labs/yay/trunk/core/src/main/java/org/apache/yay:
BackPropagationLearningStrategy.java PredictionStrategy.java
Author: tommaso
Date: Mon Jul 30 22:02:09 2012
New Revision: 1367334
URL: http://svn.apache.org/viewvc?rev=1367334&view=rev
Log:
added debug method to PredictionStrategy API
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1367334&r1=1367333&r2=1367334&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java Mon Jul 30 22:02:09 2012
@@ -20,34 +20,46 @@ package org.apache.yay;
import java.util.Collection;
+import org.apache.commons.math.linear.ArrayRealVector;
+import org.apache.commons.math.linear.RealMatrix;
+import org.apache.commons.math.linear.RealVector;
+import org.apache.yay.utils.ConversionUtils;
+
/**
- * Backpropagation learning algorithm for neural networks implementation (see
+ * Back propagation learning algorithm for neural networks implementation (see
* <code>http://en.wikipedia.org/wiki/Backpropagation</code>).
*/
-public class BackPropagationLearningStrategy implements LearningStrategy<Double, Double> {
+public class BackPropagationLearningStrategy implements LearningStrategy<Double, double[]> {
- private PredictionStrategy<Double,Double> predictionStrategy;
+ private PredictionStrategy<Double, double[]> predictionStrategy;
- public BackPropagationLearningStrategy(PredictionStrategy<Double,Double> predictionStrategy) {
+ public BackPropagationLearningStrategy(PredictionStrategy<Double, double[]> predictionStrategy) {
this.predictionStrategy = predictionStrategy;
}
@Override
- public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<Double, Double>> trainingExamples) throws WeightLearningException {
- for (TrainingExample<Double, Double> trainingExample : trainingExamples) {
-// try {
-// RealMatrix output = predictionStrategy.debugOutput(trainingExample, weightsMatrixSet);
- Double learnedOutput = trainingExample.getOutput();
-// Long error = learnedOutput - output;
+ public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<Double, double[]>> trainingExamples) throws WeightLearningException {
+ for (TrainingExample<Double, double[]> trainingExample : trainingExamples) {
+ try {
+ RealMatrix output = predictionStrategy.debugOutput(ConversionUtils.toVector(trainingExample.getFeatureVector()), weightsMatrixSet);
+ double[] learnedOutput = trainingExample.getOutput();
+ RealVector predictedOutputVector = new ArrayRealVector(output.getColumn(output.getColumnDimension() - 1));
+ RealVector learnedOutputRealVector = new ArrayRealVector(learnedOutput);
+
+ RealVector error = predictedOutputVector.subtract(learnedOutputRealVector);
+
// TODO : back prop the error and update the weights accordingly
- for (int i = weightsMatrixSet.length; i > 0; i--) {
+ for (int i = weightsMatrixSet.length - 1; i > 0; i--) {
WeightsMatrix currentMatrix = weightsMatrixSet[i];
-// currentMatrix.transpose().multiply()
+ ArrayRealVector realVector = new ArrayRealVector(output.getColumn(i));
+ ArrayRealVector identity = new ArrayRealVector(realVector.getDimension(), 1d);
+ RealVector gz = realVector.ebeMultiply(identity.subtract(realVector)); // = a^i .* (1-a^i)
+ RealVector resultingLambdaVector = currentMatrix.transpose().preMultiply(error).ebeMultiply(gz);
}
-// } catch (PredictionException e) {
-// throw new WeightLearningException("error during phase 1 of backpropagation algorithm", e);
-// }
+ } catch (Exception e) {
+ throw new WeightLearningException("error during phase 1 of back-propagation algorithm", e);
+ }
}
return null;
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java?rev=1367334&r1=1367333&r2=1367334&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java Mon Jul 30 22:02:09 2012
@@ -20,6 +20,8 @@ package org.apache.yay;
import java.util.Vector;
+import org.apache.commons.math.linear.RealMatrix;
+
/**
* A {@link PredictionStrategy} defines an algorithm for the prediction of an output of type O given an input of type I
*/
@@ -27,4 +29,6 @@ public interface PredictionStrategy<I, O
public O predictOutput(Vector<I> input, WeightsMatrix[] weightsMatrixSet);
+ public RealMatrix debugOutput(Vector<Double> input, WeightsMatrix[] weightsMatrixSet);
+
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org