You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2012/07/31 00:02:09 UTC

svn commit: r1367334 - in /labs/yay/trunk/core/src/main/java/org/apache/yay: BackPropagationLearningStrategy.java PredictionStrategy.java

Author: tommaso
Date: Mon Jul 30 22:02:09 2012
New Revision: 1367334

URL: http://svn.apache.org/viewvc?rev=1367334&view=rev
Log:
added debug method to PredictionStrategy API

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1367334&r1=1367333&r2=1367334&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java Mon Jul 30 22:02:09 2012
@@ -20,34 +20,46 @@ package org.apache.yay;
 
 import java.util.Collection;
 
+import org.apache.commons.math.linear.ArrayRealVector;
+import org.apache.commons.math.linear.RealMatrix;
+import org.apache.commons.math.linear.RealVector;
+import org.apache.yay.utils.ConversionUtils;
+
 /**
- * Backpropagation learning algorithm for neural networks implementation (see
+ * Back propagation learning algorithm for neural networks implementation (see
  * <code>http://en.wikipedia.org/wiki/Backpropagation</code>).
  */
-public class BackPropagationLearningStrategy implements LearningStrategy<Double, Double> {
+public class BackPropagationLearningStrategy implements LearningStrategy<Double, double[]> {
 
-  private PredictionStrategy<Double,Double> predictionStrategy;
+  private PredictionStrategy<Double, double[]> predictionStrategy;
 
-  public BackPropagationLearningStrategy(PredictionStrategy<Double,Double> predictionStrategy) {
+  public BackPropagationLearningStrategy(PredictionStrategy<Double, double[]> predictionStrategy) {
     this.predictionStrategy = predictionStrategy;
   }
 
   @Override
-  public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<Double, Double>> trainingExamples) throws WeightLearningException {
-    for (TrainingExample<Double, Double> trainingExample : trainingExamples) {
-//      try {
-//        RealMatrix output = predictionStrategy.debugOutput(trainingExample, weightsMatrixSet);
-        Double learnedOutput = trainingExample.getOutput();
-//        Long error = learnedOutput - output;
+  public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, Collection<TrainingExample<Double, double[]>> trainingExamples) throws WeightLearningException {
+    for (TrainingExample<Double, double[]> trainingExample : trainingExamples) {
+      try {
+        RealMatrix output = predictionStrategy.debugOutput(ConversionUtils.toVector(trainingExample.getFeatureVector()), weightsMatrixSet);
+        double[] learnedOutput = trainingExample.getOutput();
+        RealVector predictedOutputVector = new ArrayRealVector(output.getColumn(output.getColumnDimension() - 1));
+        RealVector learnedOutputRealVector = new ArrayRealVector(learnedOutput);
+
+        RealVector error = predictedOutputVector.subtract(learnedOutputRealVector);
+
         // TODO : back prop the error and update the weights accordingly
-        for (int i = weightsMatrixSet.length; i > 0; i--) {
+        for (int i = weightsMatrixSet.length - 1; i > 0; i--) {
           WeightsMatrix currentMatrix = weightsMatrixSet[i];
-//          currentMatrix.transpose().multiply()
+          ArrayRealVector realVector = new ArrayRealVector(output.getColumn(i));
+          ArrayRealVector identity = new ArrayRealVector(realVector.getDimension(), 1d);
+          RealVector gz = realVector.ebeMultiply(identity.subtract(realVector)); // = a^i .* (1-a^i)
+          RealVector resultingLambdaVector = currentMatrix.transpose().preMultiply(error).ebeMultiply(gz);
         }
 
-//      } catch (PredictionException e) {
-//        throw new WeightLearningException("error during phase 1 of backpropagation algorithm", e);
-//      }
+      } catch (Exception e) {
+        throw new WeightLearningException("error during phase 1 of back-propagation algorithm", e);
+      }
     }
     return null;
   }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java?rev=1367334&r1=1367333&r2=1367334&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java Mon Jul 30 22:02:09 2012
@@ -20,6 +20,8 @@ package org.apache.yay;
 
 import java.util.Vector;
 
+import org.apache.commons.math.linear.RealMatrix;
+
 /**
  * A {@link PredictionStrategy} defines an algorithm for the prediction of an output of type O given an input of type I
  */
@@ -27,4 +29,6 @@ public interface PredictionStrategy<I, O
 
   public O predictOutput(Vector<I> input, WeightsMatrix[] weightsMatrixSet);
 
+  public RealMatrix debugOutput(Vector<Double> input, WeightsMatrix[] weightsMatrixSet);
+
 }



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org