You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2015/11/13 12:40:47 UTC
svn commit: r1714192 - in /labs/yay/trunk: api/src/main/java/org/apache/yay/
core/src/main/java/org/apache/yay/core/
core/src/main/java/org/apache/yay/core/neuron/
core/src/main/java/org/apache/yay/core/utils/
core/src/test/java/org/apache/yay/core/
Author: tommaso
Date: Fri Nov 13 11:40:47 2015
New Revision: 1714192
URL: http://svn.apache.org/viewvc?rev=1714192&view=rev
Log:
various performance improvements
Modified:
labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java
labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java Fri Nov 13 11:40:47 2015
@@ -38,4 +38,13 @@ public interface ActivationFunction<T> {
*/
T apply(RealMatrix weights, T signal);
+ /**
+ * Apply this <code>ActivationFunction</code> to the given matrix of signals, generating a new matrix of transformed
+ * signals.
+ *
+ * @param weights the matrix of weights the activation should be applied to
+ * @return the output signal generated
+ */
+ RealMatrix applyMatrix(RealMatrix weights);
+
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Fri Nov 13 11:40:47 2015
@@ -83,11 +83,12 @@ public class BackPropagationLearningStra
try {
int iterations = 0;
- NeuralNetwork hypothesis = NeuralNetworkFactory.create(weightsMatrixSet, new VoidLearningStrategy<Double, Double>(), predictionStrategy);
+ NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(weightsMatrixSet, new VoidLearningStrategy<>(), predictionStrategy);
Iterator<TrainingExample<Double, Double>> iterator = trainingExamples.iterator();
double cost = Double.MAX_VALUE;
while (true) {
+ System.err.println(iterations);
TrainingSet<Double, Double> samples;
if (batch == -1) {
samples = trainingExamples;
@@ -103,12 +104,12 @@ public class BackPropagationLearningStra
}
// calculate cost
- double newCost = costFunction.calculateAggregatedCost(samples, hypothesis);
+ double newCost = costFunction.calculateAggregatedCost(samples, neuralNetwork);
if (Double.POSITIVE_INFINITY == newCost || newCost > cost && batch == -1) {
throw new RuntimeException("failed to converge at iteration " + iterations + " with alpha " + alpha + " : cost going from " + cost + " to " + newCost);
} else if (iterations > 1 && (cost == newCost || newCost < threshold || iterations > maxIterations)) {
- System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + alpha + ",threshold:" + threshold + ") with cost " + newCost + " and parameters " + Arrays.toString(hypothesis.getParameters()));
+ System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + alpha + ",threshold:" + threshold + ") with cost " + newCost + " and parameters " + Arrays.toString(neuralNetwork.getParameters()));
break;
} else if (Double.isNaN(newCost)) {
throw new RuntimeException("failed to converge at iteration " + iterations + " with alpha " + alpha + " : cost calculation underflow");
@@ -124,7 +125,7 @@ public class BackPropagationLearningStra
updatedWeights = updateWeights(updatedWeights, derivatives, alpha);
// update parameters in the hypothesis
- hypothesis.setParameters(updatedWeights);
+ neuralNetwork.setParameters(updatedWeights);
iterations++;
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java Fri Nov 13 11:40:47 2015
@@ -18,8 +18,7 @@
*/
package org.apache.yay.core;
-import java.util.Collection;
-import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.yay.Input;
import org.apache.yay.LearningException;
@@ -30,6 +29,8 @@ import org.apache.yay.TrainingSet;
import org.apache.yay.core.neuron.BinaryThresholdNeuron;
import org.apache.yay.core.utils.ConversionUtils;
+import java.util.Collection;
+
/**
* A perceptron {@link org.apache.yay.NeuralNetwork} implementation based on
* {@link org.apache.yay.core.neuron.BinaryThresholdNeuron}s
@@ -56,7 +57,7 @@ public class BasicPerceptron implements
for (TrainingExample<Double, Double> example : trainingExamples) {
learn(example);
}
- return new RealMatrix[]{new Array2DRowRealMatrix(currentWeights)};
+ return new RealMatrix[]{MatrixUtils.createRowRealMatrix(currentWeights)};
}
public void learn(TrainingExample<Double, Double> example) {
@@ -87,7 +88,7 @@ public class BasicPerceptron implements
@Override
public RealMatrix[] getParameters() {
- return new RealMatrix[]{new Array2DRowRealMatrix(currentWeights)};
+ return new RealMatrix[]{MatrixUtils.createRowRealMatrix(currentWeights)};
}
@Override
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Fri Nov 13 11:40:47 2015
@@ -18,9 +18,8 @@
*/
package org.apache.yay.core;
-import org.apache.commons.math3.linear.ArrayRealVector;
+import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
import org.apache.commons.math3.linear.RealVector;
import org.apache.yay.ActivationFunction;
import org.apache.yay.PredictionStrategy;
@@ -29,6 +28,7 @@ import org.apache.yay.core.utils.Convers
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
+import java.util.stream.Stream;
/**
* Octave code for FF to be converted :
@@ -64,58 +64,23 @@ public class FeedForwardStrategy impleme
}
private RealVector[] applyFF(Collection<Double> input, RealMatrix[] realMatrixSet) {
- RealVector[] debugOutput = new ArrayRealVector[realMatrixSet.length];
+ RealVector[] debugOutput = new RealVector[realMatrixSet.length];
- // TODO : fix this impl as it's very slow
- RealVector v = ConversionUtils.toRealVector(input);
- RealMatrix x = v.outerProduct(new ArrayRealVector(new Double[]{1d})).transpose(); // a 1xN matrix
+ Double[] doubles = input.toArray(new Double[input.size()]);
+ RealMatrix x = MatrixUtils.createRowRealMatrix(Stream.of(doubles).mapToDouble(Double::doubleValue).toArray());
for (int w = 0; w < realMatrixSet.length; w++) {
final RealMatrix currentWeightsMatrix = realMatrixSet[w];
// compute matrix multiplication
x = x.multiply(currentWeightsMatrix.transpose());
- final RealMatrix cm = x.getRowMatrix(0);
-
// apply the activation function to each element in the matrix
+ final RealMatrix cm = x.getRowMatrix(0);
int idx = activationFunctionMap.size() == realMatrixSet.length ? w : 0;
final ActivationFunction<Double> af = activationFunctionMap.get(idx);
-
- if (af instanceof SoftmaxActivationFunction) {
- x = ((SoftmaxActivationFunction) af).applyMatrix(x);
- } else {
- x.walkInOptimizedOrder(new ActivationFunctionVisitor(af, cm));
- }
+ x = af.applyMatrix(cm);
debugOutput[w] = x.getRowVector(0);
}
return debugOutput;
}
- private static class ActivationFunctionVisitor implements RealMatrixChangingVisitor {
-
- private final ActivationFunction<Double> af;
- private final RealMatrix matrix;
-
- ActivationFunctionVisitor(ActivationFunction<Double> af, RealMatrix matrix) {
- this.af = af;
- this.matrix = matrix;
- }
-
- @Override
- public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
-
- }
-
- @Override
- public double visit(int row, int column, double value) {
- return af.apply(matrix, value);
- }
-
- @Override
- public double end() {
- return 0;
- }
-
-
- }
-
}
\ No newline at end of file
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java Fri Nov 13 11:40:47 2015
@@ -31,4 +31,9 @@ public class IdentityActivationFunction<
return signal;
}
+ @Override
+ public RealMatrix applyMatrix(RealMatrix weights) {
+ return weights;
+ }
+
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java Fri Nov 13 11:40:47 2015
@@ -19,6 +19,7 @@
package org.apache.yay.core;
import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
import org.apache.yay.ActivationFunction;
/**
@@ -27,7 +28,32 @@ import org.apache.yay.ActivationFunction
public class SigmoidFunction implements ActivationFunction<Double> {
public Double apply(RealMatrix matrix, final Double input) {
+ return sigmoid(input);
+ }
+
+ private double sigmoid(Double input) {
return 1d / (1d + Math.exp(-1d * input));
}
+ @Override
+ public RealMatrix applyMatrix(RealMatrix weights) {
+ RealMatrix matrix = weights.copy();
+ matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ return sigmoid(value);
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+ });
+ return matrix;
+ }
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java Fri Nov 13 11:40:47 2015
@@ -21,50 +21,23 @@ package org.apache.yay.core;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
import org.apache.commons.math3.linear.RealVector;
-import org.apache.commons.math3.stat.descriptive.rank.Max;
import org.apache.yay.ActivationFunction;
-import java.util.Map;
-import java.util.WeakHashMap;
-
/**
* Softmax activation function
*/
public class SoftmaxActivationFunction implements ActivationFunction<Double> {
- private static final Map<RealMatrix, Double> cache = new WeakHashMap<RealMatrix, Double>();
-
- private static final Max m = new Max();
-
- private static final RealMatrixChangingVisitor expVisitor = new RealMatrixChangingVisitor() {
- @Override
- public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
-
- }
-
- @Override
- public double visit(int row, int column, double value) {
- return Math.exp(value);
- }
-
- @Override
- public double end() {
- return 0;
- }
- };
-
@Override
public Double apply(RealMatrix weights, Double signal) {
double num = Math.exp(signal);
- double den = getDen(weights);
+ double den = expDen(weights);
return num / den;
}
public RealMatrix applyMatrix(RealMatrix weights) {
-
RealMatrix matrix = weights.copy();
- double d = expDen(matrix);
- final double finalD = d;
+ final double finalD = expDen(matrix);
matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
@Override
public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
@@ -96,14 +69,4 @@ public class SoftmaxActivationFunction i
return d;
}
- private double getDen(RealMatrix weights) {
- Double d = cache.get(weights);
- synchronized (cache) {
- if (d == null) {
- d = expDen(weights.copy());
- cache.put(weights, d);
- }
- }
- return d;
- }
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java Fri Nov 13 11:40:47 2015
@@ -19,6 +19,7 @@
package org.apache.yay.core;
import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
import org.apache.yay.ActivationFunction;
/**
@@ -34,7 +35,32 @@ public class StepActivationFunction impl
@Override
public Double apply(RealMatrix matrix, Double signal) {
+ return step(signal);
+ }
+
+ private double step(Double signal) {
return signal >= center ? 1d : 0d;
}
+ @Override
+ public RealMatrix applyMatrix(RealMatrix weights) {
+ RealMatrix matrix = weights.copy();
+ matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ return step(value);
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+ });
+ return matrix;
+ }
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java Fri Nov 13 11:40:47 2015
@@ -19,6 +19,7 @@
package org.apache.yay.core;
import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
import org.apache.yay.ActivationFunction;
/**
@@ -29,4 +30,26 @@ public class TanhFunction implements Act
public Double apply(RealMatrix matrix, Double signal) {
return Math.tanh(signal);
}
+
+ @Override
+ public RealMatrix applyMatrix(RealMatrix weights) {
+ RealMatrix matrix = weights.copy();
+ matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ return Math.tanh(value);
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+ });
+ return matrix;
+ }
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java Fri Nov 13 11:40:47 2015
@@ -19,6 +19,7 @@
package org.apache.yay.core.neuron;
import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
import org.apache.yay.ActivationFunction;
/**
@@ -32,8 +33,34 @@ class RectifiedLinearNeuron extends Line
this.activationFunction = new ActivationFunction<Double>() {
@Override
public Double apply(RealMatrix matrix, Double signal) {
- return signal > 0 ? signal : 0;
+ return rect(signal);
+ }
+
+ @Override
+ public RealMatrix applyMatrix(RealMatrix weights) {
+ RealMatrix matrix = weights.copy();
+ matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ return rect(value);
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+ });
+ return matrix;
}
};
}
+
+ private double rect(Double signal) {
+ return signal > 0 ? signal : 0;
+ }
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java Fri Nov 13 11:40:47 2015
@@ -18,7 +18,7 @@
*/
package org.apache.yay.core.utils;
-import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.OpenMapRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
@@ -56,7 +56,7 @@ public class ConversionUtils {
i++;
}
- return new Array2DRowRealMatrix(matrixData);
+ return MatrixUtils.createRealMatrix(matrixData);
}
/**
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Fri Nov 13 11:40:47 2015
@@ -19,7 +19,7 @@
package org.apache.yay.core;
import com.google.common.base.Splitter;
-import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.ml.distance.CanberraDistance;
import org.apache.commons.math3.ml.distance.ChebyshevDistance;
@@ -445,7 +445,7 @@ public class WordVectorsTest {
d[k][j] = val;
}
}
- initialWeights[i] = new Array2DRowRealMatrix(d);
+ initialWeights[i] = MatrixUtils.createRealMatrix(d);
}
return initialWeights;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org