You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2015/10/08 17:03:49 UTC
svn commit: r1707564 - in /labs/yay/trunk/core/src:
main/java/org/apache/yay/core/ test/java/org/apache/yay/core/
Author: tommaso
Date: Thu Oct 8 15:03:49 2015
New Revision: 1707564
URL: http://svn.apache.org/viewvc?rev=1707564&view=rev
Log:
reduced boilerplate code in ff strategy for applying activation function, added layer specific AFs, improved error derivative calculation
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java
labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1707564&r1=1707563&r2=1707564&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java Thu Oct 8 15:03:49 2015
@@ -18,6 +18,8 @@
*/
package org.apache.yay.core;
+import java.util.Arrays;
+
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
@@ -107,19 +109,13 @@ class DefaultDerivativeUpdateFunction im
private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealVector[] activations) {
RealVector output = activations[activations.length - 1];
-// Double[] sampleOutput = new Double[output.getDimension()];
Double[] actualOutput = trainingExample.getOutput();
-// int sampleOutputIntValue = actualOutput.intValue();
-// if (sampleOutputIntValue < sampleOutput.length) {
-// sampleOutput[sampleOutputIntValue] = 1d;
-// } else if (sampleOutput.length == 1) {
-// sampleOutput[0] = actualOutput;
-// } else {
-// throw new RuntimeException("problem with multiclass output mapping");
-// }
RealVector learnedOutputRealVector = new ArrayRealVector(actualOutput); // turn example output to a vector
- // TODO : improve error calculation -> this could be er_a = out_a * (1 - out_a) * (tgt_a - out_a)
- return output.subtract(learnedOutputRealVector);
+ double[] ones = new double[output.getDimension()];
+ Arrays.fill(ones, 1d);
+
+ // error calculation -> er_a = out_a * (1 - out_a) * (tgt_a - out_a) (was: output.subtract(learnedOutputRealVector)
+ return output.ebeMultiply(new ArrayRealVector(ones).subtract(output)).ebeMultiply(output.subtract(learnedOutputRealVector));
}
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1707564&r1=1707563&r2=1707564&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Thu Oct 8 15:03:49 2015
@@ -20,10 +20,14 @@ package org.apache.yay.core;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.Transformer;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
import org.apache.commons.math3.linear.RealVector;
import org.apache.yay.ActivationFunction;
import org.apache.yay.PredictionStrategy;
@@ -40,10 +44,15 @@ import org.apache.yay.core.utils.Convers
*/
public class FeedForwardStrategy implements PredictionStrategy<Double, Double> {
- private final ActivationFunction<Double> activationFunction;
+ private final Map<Integer, ActivationFunction<Double>> activationFunctionMap;
public FeedForwardStrategy(ActivationFunction<Double> activationFunction) {
- this.activationFunction = activationFunction;
+ this.activationFunctionMap = new HashMap<Integer, ActivationFunction<Double>>();
+ this.activationFunctionMap.put(0, activationFunction);
+ }
+
+ public FeedForwardStrategy(Map<Integer, ActivationFunction<Double>> activationFunctionMap) {
+ this.activationFunctionMap = activationFunctionMap;
}
@Override
@@ -69,32 +78,27 @@ public class FeedForwardStrategy impleme
x = x.multiply(currentWeightsMatrix.transpose());
// apply the activation function to each element in the matrix
- for (int i = 0; i < x.getRowDimension(); i++) {
- double[] doubles = x.getRow(i);
- final ArrayList<Double> row = new ArrayList<Double>(doubles.length);
- for (int j = 0; j < doubles.length; j++) {
- row.add(j, doubles[j]);
+ int idx = activationFunctionMap.size() == realMatrixSet.length ? w : 0;
+ final ActivationFunction<Double> af = activationFunctionMap.get(idx);
+ x.walkInRowOrder(new RealMatrixChangingVisitor() {
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ return af.apply(value);
}
- // TODO : see if bias term is handled correctly here
- CollectionUtils.transform(row, new ActivationRowTransformer());
- double[] finRow = new double[row.size()];
- for (int h = 0; h < finRow.length; h++) {
- finRow[h] = row.get(h);
+
+ @Override
+ public double end() {
+ return 0;
}
- x.setRow(i, finRow);
- }
+ });
debugOutput[w] = x.getRowVector(0);
}
return debugOutput;
}
- private class ActivationRowTransformer implements Transformer {
- @Override
- public Object transform(Object input) {
- assert input instanceof Double;
- final Double d = (Double) input;
- return activationFunction.apply(d);
- }
- }
-
}
\ No newline at end of file
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java?rev=1707564&r1=1707563&r2=1707564&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java Thu Oct 8 15:03:49 2015
@@ -18,24 +18,21 @@
*/
package org.apache.yay.core;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Random;
-
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.yay.CreationException;
-import org.apache.yay.Feature;
-import org.apache.yay.Input;
-import org.apache.yay.LearningStrategy;
-import org.apache.yay.NeuralNetwork;
-import org.apache.yay.TrainingExample;
-import org.apache.yay.TrainingSet;
+import org.apache.commons.math3.ml.distance.CanberraDistance;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.yay.*;
import org.apache.yay.core.utils.ExamplesFactory;
import org.junit.Test;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Random;
+
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
/**
* Integration test for NN
@@ -134,15 +131,20 @@ public class NeuralNetworkIntegrationTes
int noOfFeatures = randomWeights[0].getColumnDimension() - 1;
Collection<TrainingExample<Double, Double>> samples = createSamples(1000000, noOfFeatures, noOfOutputs);
nn.learn(new TrainingSet<Double, Double>(samples));
+ DistanceMeasure distanceMeasure = new CanberraDistance();
for (TrainingExample<Double, Double> sample : samples) {
Double[] predictedOutput = nn.predict(sample);
+ double[] a1 = new double[predictedOutput.length];
+ for (int i = 0; i < a1.length; i++) {
+ a1[i] = predictedOutput[i];
+ }
Double[] expectedOutput = sample.getOutput();
- boolean equals = Arrays.equals(expectedOutput, predictedOutput);
-// if (!equals) {
-// System.err.println(Arrays.toString(expectedOutput) + " vs " + Arrays.toString(predictedOutput));
-// } else {
-// System.err.println("equals!");
-// }
+ double[] a2 = new double[expectedOutput.length];
+ for (int i = 0; i < a2.length; i++) {
+ a2[i] = expectedOutput[i];
+ }
+ double dist = distanceMeasure.compute(a1, a2);
+ assertTrue("expected and actual outputs are distant " + dist, dist < 10d);
}
}
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java?rev=1707564&r1=1707563&r2=1707564&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java Thu Oct 8 15:03:49 2015
@@ -31,21 +31,21 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
+import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
+import java.util.Map;
import java.util.Random;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.commons.math3.linear.SingularValueDecomposition;
+import org.apache.yay.ActivationFunction;
import org.apache.yay.Feature;
import org.apache.yay.Input;
import org.apache.yay.NeuralNetwork;
import org.apache.yay.TrainingExample;
import org.apache.yay.TrainingSet;
-import org.apache.yay.core.utils.ConversionUtils;
-import org.apache.yay.core.utils.ExamplesFactory;
import org.junit.Test;
import static org.junit.Assert.*;
@@ -70,13 +70,17 @@ public class Word2VecTest {
TrainingExample<Double, Double> next = trainingSet.iterator().next();
int inputSize = next.getFeatures().size() ;
int outputSize = next.getOutput().length;
- int n = new Random().nextInt(20);
+ int n = new Random().nextInt(20) + 5;
RealMatrix[] randomWeights = createRandomWeights(inputSize, n, outputSize);
- FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(new IdentityActivationFunction<Double>());
- BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.0005d, -1,
+ Map<Integer, ActivationFunction<Double>> activationFunctions = new HashMap<Integer, ActivationFunction<Double>>();
+ activationFunctions.put(0, new IdentityActivationFunction<Double>());
+ // TODO : place a softmax activation for the output layer
+ activationFunctions.put(0, new IdentityActivationFunction<Double>());
+ FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions);
+ BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.05d, 10,
BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LMSCostFunction(),
- 30);
+ 80);
NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy, predictionStrategy);
neuralNetwork.learn(trainingSet);
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org