You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2013/04/17 09:53:12 UTC
svn commit: r1468788 - in /labs/yay/trunk:
api/src/main/java/org/apache/yay/PredictionStrategy.java
core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java
Author: tommaso
Date: Wed Apr 17 07:53:12 2013
New Revision: 1468788
URL: http://svn.apache.org/r1468788
Log:
fixed prediction strategy API for debugging to return activation vectors instead of matrix
Modified:
labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java?rev=1468788&r1=1468787&r2=1468788&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java Wed Apr 17 07:53:12 2013
@@ -21,6 +21,7 @@ package org.apache.yay;
import java.util.Collection;
import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
/**
* A {@link PredictionStrategy} defines an algorithm for the prediction of outputs
@@ -45,6 +46,6 @@ public interface PredictionStrategy<I, O
* @param weightsMatrixSet the initial set of weights defined by an array of matrix
* @return the perturbed neural network state via its activations values
*/
- public RealMatrix[] debugOutput(Collection<I> inputs, RealMatrix[] weightsMatrixSet);
+ public RealVector[] debugOutput(Collection<I> inputs, RealMatrix[] weightsMatrixSet);
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1468788&r1=1468787&r2=1468788&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Wed Apr 17 07:53:12 2013
@@ -50,18 +50,17 @@ public class FeedForwardStrategy impleme
@Override
public Double[] predictOutput(Collection<Double> input, RealMatrix[] realMatrixSet) {
- RealMatrix[] realMatrixes = applyFF(input, realMatrixSet);
- RealMatrix x = realMatrixes[realMatrixes.length - 1];
- double[] lastColumn = x.getColumn(x.getColumnDimension() - 1);
- return ConversionUtils.toDoubleArray(lastColumn);
+ RealVector[] activations = applyFF(input, realMatrixSet);
+ RealVector x = activations[activations.length - 1];
+ return ConversionUtils.toDoubleArray(x.toArray());
}
- public RealMatrix[] debugOutput(Collection<Double> input, RealMatrix[] realMatrixSet) {
+ public RealVector[] debugOutput(Collection<Double> input, RealMatrix[] realMatrixSet) {
return applyFF(input, realMatrixSet);
}
- private RealMatrix[] applyFF(Collection<Double> input, RealMatrix[] realMatrixSet) {
- RealMatrix[] debugOutput = new RealMatrix[realMatrixSet.length];
+ private RealVector[] applyFF(Collection<Double> input, RealMatrix[] realMatrixSet) {
+ RealVector[] debugOutput = new ArrayRealVector[realMatrixSet.length];
// TODO : fix this impl as it's very slow
RealVector v = ConversionUtils.toRealVector(input);
@@ -85,7 +84,7 @@ public class FeedForwardStrategy impleme
}
x.setRow(i, finRow);
}
- debugOutput[w] = x;
+ debugOutput[w] = x.getRowVector(0);
}
return debugOutput;
}
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java?rev=1468788&r1=1468787&r2=1468788&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java Wed Apr 17 07:53:12 2013
@@ -23,6 +23,7 @@ import java.util.LinkedList;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
import org.junit.Test;
import static junit.framework.Assert.assertNotNull;
@@ -44,7 +45,7 @@ public class FeedForwardStrategyTest {
inputs.add(2d);
inputs.add(-5d);
inputs.add(7d);
- RealMatrix[] activations = feedForwardStrategy.debugOutput(inputs, weights);
+ RealVector[] activations = feedForwardStrategy.debugOutput(inputs, weights);
assertNotNull(activations);
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org