You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2013/04/17 09:53:12 UTC

svn commit: r1468788 - in /labs/yay/trunk: api/src/main/java/org/apache/yay/PredictionStrategy.java core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java

Author: tommaso
Date: Wed Apr 17 07:53:12 2013
New Revision: 1468788

URL: http://svn.apache.org/r1468788
Log:
fixed prediction strategy API for debugging to return activation vectors instead of matrix

Modified:
    labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java?rev=1468788&r1=1468787&r2=1468788&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java Wed Apr 17 07:53:12 2013
@@ -21,6 +21,7 @@ package org.apache.yay;
 import java.util.Collection;
 
 import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
 
 /**
  * A {@link PredictionStrategy} defines an algorithm for the prediction of outputs
@@ -45,6 +46,6 @@ public interface PredictionStrategy<I, O
    * @param weightsMatrixSet the initial set of weights defined by an array of matrix
    * @return the perturbed neural network state via its activations values
    */
-  public RealMatrix[] debugOutput(Collection<I> inputs, RealMatrix[] weightsMatrixSet);
+  public RealVector[] debugOutput(Collection<I> inputs, RealMatrix[] weightsMatrixSet);
 
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1468788&r1=1468787&r2=1468788&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Wed Apr 17 07:53:12 2013
@@ -50,18 +50,17 @@ public class FeedForwardStrategy impleme
 
   @Override
   public Double[] predictOutput(Collection<Double> input, RealMatrix[] realMatrixSet) {
-    RealMatrix[] realMatrixes = applyFF(input, realMatrixSet);
-    RealMatrix x = realMatrixes[realMatrixes.length - 1];
-    double[] lastColumn = x.getColumn(x.getColumnDimension() - 1);
-    return ConversionUtils.toDoubleArray(lastColumn);
+    RealVector[] activations = applyFF(input, realMatrixSet);
+    RealVector x = activations[activations.length - 1];
+    return ConversionUtils.toDoubleArray(x.toArray());
   }
 
-  public RealMatrix[] debugOutput(Collection<Double> input, RealMatrix[] realMatrixSet) {
+  public RealVector[] debugOutput(Collection<Double> input, RealMatrix[] realMatrixSet) {
     return applyFF(input, realMatrixSet);
   }
 
-  private RealMatrix[] applyFF(Collection<Double> input, RealMatrix[] realMatrixSet) {
-    RealMatrix[] debugOutput = new RealMatrix[realMatrixSet.length];
+  private RealVector[] applyFF(Collection<Double> input, RealMatrix[] realMatrixSet) {
+    RealVector[] debugOutput = new ArrayRealVector[realMatrixSet.length];
 
     // TODO : fix this impl as it's very slow
     RealVector v = ConversionUtils.toRealVector(input);
@@ -85,7 +84,7 @@ public class FeedForwardStrategy impleme
         }
         x.setRow(i, finRow);
       }
-      debugOutput[w] = x;
+      debugOutput[w] = x.getRowVector(0);
     }
     return debugOutput;
   }

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java?rev=1468788&r1=1468787&r2=1468788&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java Wed Apr 17 07:53:12 2013
@@ -23,6 +23,7 @@ import java.util.LinkedList;
 
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
 import org.junit.Test;
 
 import static junit.framework.Assert.assertNotNull;
@@ -44,7 +45,7 @@ public class FeedForwardStrategyTest {
     inputs.add(2d);
     inputs.add(-5d);
     inputs.add(7d);
-    RealMatrix[] activations = feedForwardStrategy.debugOutput(inputs, weights);
+    RealVector[] activations = feedForwardStrategy.debugOutput(inputs, weights);
     assertNotNull(activations);
   }
 }



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org