You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2013/04/17 16:48:26 UTC

svn commit: r1468944 - in /labs/yay/trunk/core/src: main/java/org/apache/yay/core/BackPropagationLearningStrategy.java test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java

Author: tommaso
Date: Wed Apr 17 14:48:26 2013
New Revision: 1468944

URL: http://svn.apache.org/r1468944
Log:
started cleaning backpropagation phase 1

Added:
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java   (with props)
Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1468944&r1=1468943&r2=1468944&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Wed Apr 17 14:48:26 2013
@@ -53,36 +53,28 @@ public class BackPropagationLearningStra
     for (TrainingExample<Double, Double> trainingExample : trainingExamples) {
       try {
         // contains activation errors for the current training example
-        // TODO : check if this should be RealVector[] < probably yes
-        RealMatrix[] activationErrors = new RealMatrix[weightsMatrixSet.length - 1];
+        int noOfMatrixes = weightsMatrixSet.length - 1;
 
         // feed forward propagation
-        RealMatrix[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet);
+        RealVector[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet);
 
         // calculate output error
         RealVector error = calculateOutputError(trainingExample, activations);
 
-        activationErrors[activationErrors.length - 1] = new Array2DRowRealMatrix(error.toArray());
+        RealVector nextLayerDelta = error;
 
-        RealVector nextLayerDelta = new ArrayRealVector(error);
+        triangle[noOfMatrixes] = new Array2DRowRealMatrix(weightsMatrixSet[noOfMatrixes].getColumnDimension(), weightsMatrixSet[noOfMatrixes].getRowDimension());
+        triangle[noOfMatrixes] = triangle[noOfMatrixes].add(activations[noOfMatrixes - 1].outerProduct(error));
 
-        // back prop the error and update the activationErrors accordingly
+        // back prop the error and update the deltas accordingly
         // TODO : eventually remove the bias term from the error calculations
         for (int l = weightsMatrixSet.length - 2; l >= 0; l--) {
           RealVector resultingDeltaVector = calculateDeltaVector(weightsMatrixSet[l], activations[l], nextLayerDelta);
-          if (activationErrors[l] == null) {
-            activationErrors[l] = new Array2DRowRealMatrix(new ArrayRealVector(resultingDeltaVector.getDimension(), 1d).toArray());
-          }
-          activationErrors[l] = new Array2DRowRealMatrix(resultingDeltaVector.toArray());
-          nextLayerDelta = resultingDeltaVector;
-        }
 
-        // update the accumulator matrix
-        for (int l = 0; l < triangle.length - 1; l++) {
           if (triangle[l] == null) {
             triangle[l] = new Array2DRowRealMatrix(weightsMatrixSet[l].getRowDimension(), weightsMatrixSet[l].getColumnDimension());
           }
-          triangle[l] = triangle[l].add(activationErrors[l + 1].getRowVector(0).outerProduct(activations[l].getRowVector(0)));
+          triangle[l] = triangle[l].add(resultingDeltaVector.outerProduct(activations[l-1]));
         }
 
       } catch (Exception e) {
@@ -95,27 +87,34 @@ public class BackPropagationLearningStra
       triangle[i] = triangle[i].scalarMultiply(1 / count);
     }
 
-    // TODO : now apply gradient descent (or other optimization/minimization algorithms) with this derivative terms and the cost function
+    // TODO : now apply gradient descent (or other optimization/minimization algorithms) with 'triangle' derivative terms and the cost function
 
     return null;
   }
 
-  private RealVector calculateDeltaVector(RealMatrix thetaL, RealMatrix activation, RealVector nextLayerDelta) {
-    ArrayRealVector activationsVector = new ArrayRealVector(activation.getRowVector(0)); // get l-th nn layer activations
+  private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector activationsVector, RealVector nextLayerDelta) {
     ArrayRealVector identity = new ArrayRealVector(activationsVector.getDimension(), 1d);
     RealVector gz = activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l .* (1-a^l)
-    return thetaL.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
+    return thetaL.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz); // TODO : fix this as it seems wrong
   }
 
-  private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealMatrix[] activations) {
-    RealMatrix output = activations[activations.length - 1];
-    RealVector predictedOutputVector = new ArrayRealVector(output.getColumn(output.getColumnDimension() - 1)); // turn output to vector
-
-    Double[] learnedOutput = new Double[predictedOutputVector.getDimension()]; // training example output
-    learnedOutput[trainingExample.getOutput().intValue()] = 1d;
-    RealVector learnedOutputRealVector = new ArrayRealVector(learnedOutput); // turn example output to a vector
+  private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealVector[] activations) {
+    RealVector output = activations[activations.length - 1];
+
+    Double[] sampleOutput = new Double[output.getDimension()];
+    int sampleOutputIntValue = trainingExample.getOutput().intValue();
+    if (sampleOutputIntValue < sampleOutput.length) {
+      sampleOutput[sampleOutputIntValue] = 1d;
+    }
+    else if (sampleOutput.length == 1){
+      sampleOutput[0] = trainingExample.getOutput();
+    }
+    else {
+      throw new RuntimeException("problem with multiclass output mapping");
+    }
+    RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // turn example output to a vector
 
-    // TODO : improve error calculation > this should be er_a = out_a * (1 - out_a) * (tgt_a - out_a)
-    return predictedOutputVector.subtract(learnedOutputRealVector);
+    // TODO : improve error calculation > this could be er_a = out_a * (1 - out_a) * (tgt_a - out_a)
+    return output.subtract(learnedOutputRealVector);
   }
 }

Added: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1468944&view=auto
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java (added)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java Wed Apr 17 14:48:26 2013
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay.core;
+
+import java.util.ArrayList;
+import java.util.Collection;
+
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.yay.PredictionStrategy;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+import org.apache.yay.core.utils.ExamplesFactory;
+import org.junit.Test;
+
+import static junit.framework.Assert.assertNotNull;
+
+/**
+ * Testcase for {@link org.apache.yay.core.BackPropagationLearningStrategy}
+ */
+public class BackPropagationLearningStrategyTest {
+
+  @Test
+  public void testLearningWithRandomSamples() throws Exception {
+    PredictionStrategy<Double, Double> predictionStrategy = new FeedForwardStrategy(new SigmoidFunction());
+    BackPropagationLearningStrategy backPropagationLearningStrategy =
+            new BackPropagationLearningStrategy(predictionStrategy, new LogisticRegressionCostFunction(0.4d));
+
+    RealMatrix[] initialWeights = new RealMatrix[3];
+    initialWeights[0] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d}, {1d, 0.6d, 3d}, {1d, 2d, 2d}, {1d, 0.6d, 3d}});
+    initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 0d}, {0d, 0.5d, 1d, 0.5d}, {0d, 0.1d, 8d, 0.1d}, {0d, 0.1d, 8d, 0.2d}});
+    initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{0.7d, 2d, 0.3d, 0.5d}});
+
+    Collection<TrainingExample<Double, Double>> samples = createSamples(100, 2);
+    TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples);
+    RealMatrix[] weights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
+    assertNotNull(weights);
+  }
+
+  private Collection<TrainingExample<Double, Double>> createSamples(int size, int noOfFeatures) {
+    Collection<TrainingExample<Double, Double>> trainingExamples = new ArrayList<TrainingExample<Double, Double>>(size);
+    Double[] featureValues = new Double[noOfFeatures];
+    for (int i = 0; i < size; i++) {
+        for (int j = 0; j < noOfFeatures; j++) {
+            featureValues[j] = Math.random();
+        }
+        trainingExamples.add(ExamplesFactory.createDoubleTrainingExample(1d, featureValues));
+    }
+    return trainingExamples;
+  }
+}

Propchange: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
------------------------------------------------------------------------------
    svn:eol-style = native



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org