You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2016/03/02 12:13:32 UTC
svn commit: r1733257 - in /labs/yay/trunk/core/src:
main/java/org/apache/yay/HotEncodedSample.java
main/java/org/apache/yay/SkipGramNetwork.java
test/java/org/apache/yay/SkipGramNetworkTest.java
Author: tommaso
Date: Wed Mar 2 11:13:32 2016
New Revision: 1733257
URL: http://svn.apache.org/viewvc?rev=1733257&view=rev
Log:
biases initialized to 0.001, per output word softmax, cached expanded representation in hot encoded samples, 0.5 decay on learning rate
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java
labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java?rev=1733257&r1=1733256&r2=1733257&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java Wed Mar 2 11:13:32 2016
@@ -25,6 +25,9 @@ import java.util.Arrays;
*/
public class HotEncodedSample extends Sample {
+ private double[] expandedInputs = null;
+ private double[] expandedOutputs = null;
+
private final int vocabularySize;
public HotEncodedSample(double[] inputs, double[] outputs, int vocabularySize) {
@@ -34,26 +37,32 @@ public class HotEncodedSample extends Sa
@Override
public double[] getInputs() {
- double[] inputs = new double[this.inputs.length * vocabularySize];
- int i = 0;
- for (double d : this.inputs) {
- double[] currentInput = hotEncode((int) d);
- System.arraycopy(currentInput, 0, inputs, i, currentInput.length);
- i += vocabularySize;
+ if (expandedInputs == null) {
+ double[] inputs = new double[this.inputs.length * vocabularySize];
+ int i = 0;
+ for (double d : this.inputs) {
+ double[] currentInput = hotEncode((int) d);
+ System.arraycopy(currentInput, 0, inputs, i, currentInput.length);
+ i += vocabularySize;
+ }
+ expandedInputs = inputs;
}
- return inputs;
+ return expandedInputs;
}
@Override
public double[] getOutputs() {
- double[] outputs = new double[this.outputs.length * vocabularySize];
- int i = 0;
- for (double d : this.outputs) {
- double[] currentOutput = hotEncode((int) d);
- System.arraycopy(currentOutput, 0, outputs, i, currentOutput.length);
- i += vocabularySize;
+ if (expandedOutputs == null) {
+ double[] outputs = new double[this.outputs.length * vocabularySize];
+ int i = 0;
+ for (double d : this.outputs) {
+ double[] currentOutput = hotEncode((int) d);
+ System.arraycopy(currentOutput, 0, outputs, i, currentOutput.length);
+ i += vocabularySize;
+ }
+ expandedOutputs = outputs;
}
- return outputs;
+ return expandedOutputs;
}
private double[] hotEncode(int index) {
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1733257&r1=1733256&r2=1733257&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Wed Mar 2 11:13:32 2016
@@ -75,26 +75,18 @@ public class SkipGramNetwork {
private SkipGramNetwork(Configuration configuration) {
this.configuration = configuration;
- this.weights = createRandomWeights();
- this.biases = createRandomBiases();
+ this.weights = initWeights();
+ this.biases = initBiases();
}
- private RealMatrix[] createRandomBiases() {
+ private RealMatrix[] initBiases() {
RealMatrix[] initialBiases = new RealMatrix[weights.length];
for (int i = 0; i < initialBiases.length; i++) {
- RealMatrix matrix = MatrixUtils.createRealMatrix(1, weights[i].getRowDimension());
-
- UniformRealDistribution uniformRealDistribution = new UniformRealDistribution();
- double[] vs = uniformRealDistribution.sample(matrix.getRowDimension() * matrix.getColumnDimension());
- int r = 0;
- int c = 0;
- for (double v : vs) {
- matrix.setEntry(r % matrix.getRowDimension(), c % matrix.getColumnDimension(), v);
- r++;
- c++;
- }
+ double[] data = new double[weights[i].getRowDimension()];
+ Arrays.fill(data, 0.01d);
+ RealMatrix matrix = MatrixUtils.createRowRealMatrix(data);
initialBiases[i] = matrix;
}
@@ -113,13 +105,21 @@ public class SkipGramNetwork {
RealMatrix hidden = rectifierFunction.applyMatrix(MatrixUtils.createRowRealMatrix(input).multiply(weights[0].transpose()).
add(biases[0]));
- RealMatrix pscores = softmaxActivationFunction.applyMatrix(hidden.multiply(weights[1].transpose()).add(biases[1]));
+ RealMatrix scores = hidden.multiply(weights[1].transpose()).add(biases[1]);
- RealVector d = pscores.getRowVector(0);
+ RealMatrix probs = scores.copy();
+ int len = scores.getColumnDimension() - 1;
+ for (int d = 0; d < configuration.window - 1; d++) {
+ int startColumn = d * len / (configuration.window - 1);
+ RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn, startColumn + input.length);
+ probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix).getData(), 0, startColumn);
+ }
+
+ RealVector d = probs.getRowVector(0);
return d.toArray();
}
- private RealMatrix[] createRandomWeights() {
+ private RealMatrix[] initWeights() {
int[] conf = new int[]{configuration.inputs, configuration.vectorSize, configuration.outputs};
int[] layers = new int[conf.length];
System.arraycopy(conf, 0, layers, 0, layers.length);
@@ -146,9 +146,10 @@ public class SkipGramNetwork {
}
- static void evaluate(SkipGramNetwork network, int window) throws Exception {
+ static double evaluate(SkipGramNetwork network) throws Exception {
double cc = 0;
double wc = 0;
+ int window = network.configuration.window;
for (Sample sample : network.samples) {
Collection<Integer> exps = new ArrayList<>(window - 1);
Collection<Integer> acts = new ArrayList<>(window - 1);
@@ -184,9 +185,7 @@ public class SkipGramNetwork {
}
}
- if (cc > 0) {
- System.out.println("accuracy: " + (cc / (wc + cc)));
- }
+ return (cc / (wc + cc));
}
private static int getMaxIndex(double[] array, int start, int end) {
@@ -240,12 +239,8 @@ public class SkipGramNetwork {
System.out.println("cost is " + cost + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)");
}
}
- if (iterations % 1000 == 0) {
- evaluate(this, this.configuration.window);
- System.out.println("cost: " + cost);
- }
-// configuration.alpha = configuration.alpha * 0.999;
+ configuration.alpha = configuration.alpha * 0.5;
RealMatrix w0t = weights[0].transpose();
final RealMatrix w1t = weights[1].transpose();
@@ -285,7 +280,13 @@ public class SkipGramNetwork {
}
});
- RealMatrix probs = softmaxActivationFunction.applyMatrix(scores);
+ RealMatrix probs = scores.copy();
+ int len = scores.getColumnDimension() - 1;
+ for (int d = 0; d < configuration.window - 1; d++) {
+ int startColumn = d * len / (configuration.window - 1);
+ RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn, startColumn + x.getColumnDimension());
+ probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix).getData(), 0, startColumn);
+ }
RealMatrix correctLogProbs = MatrixUtils.createRealMatrix(x.getRowDimension(), 1);
correctLogProbs.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
@@ -562,7 +563,7 @@ public class SkipGramNetwork {
@Override
public double visit(int row, int column, double value) {
- return configuration.mu * value - configuration.alpha + dWt2.getEntry(row, column);
+ return configuration.mu * value + configuration.alpha - dWt2.getEntry(row, column);
}
@Override
@@ -813,6 +814,11 @@ public class SkipGramNetwork {
return this;
}
+ public Builder withMaxIterations(int iterations) {
+ this.configuration.maxIterations = iterations;
+ return this;
+ }
+
public SkipGramNetwork build() throws Exception {
System.out.println("reading fragments");
Queue<List<byte[]>> fragments = getFragments(this.configuration.path, this.configuration.window);
@@ -825,7 +831,9 @@ public class SkipGramNetwork {
System.out.println("creating training set");
Collection<HotEncodedSample> trainingSet = createTrainingSet(vocabulary, fragments, this.configuration.window);
fragments.clear();
- this.configuration.maxIterations = trainingSet.size() * 100000;
+ if (this.configuration.maxIterations == 0) {
+ this.configuration.maxIterations = trainingSet.size() * 100000;
+ }
HotEncodedSample next = trainingSet.iterator().next();
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java?rev=1733257&r1=1733256&r2=1733257&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Wed Mar 2 11:13:32 2016
@@ -43,47 +43,52 @@ public class SkipGramNetworkTest {
@Test
public void testWordVectorsLearningOnAbstracts() throws Exception {
Path path = Paths.get(getClass().getResource("/word2vec/abstracts.txt").getFile());
- int window = 3;
SkipGramNetwork network = SkipGramNetwork.newModel().
- withWindow(window).
+ withWindow(3).
fromTextAt(path).
- withDimension(2).
- withAlpha(0.003).
- withLambda(0.00003).
+ withDimension(10).
+ withAlpha(1).
+ withLambda(0.003).
+ withMaxIterations(500).
build();
RealMatrix wv = network.getWeights()[0];
List<String> vocabulary = network.getVocabulary();
serialize(vocabulary, wv);
- SkipGramNetwork.evaluate(network, window);
+ System.err.println("accuracy: " + SkipGramNetwork.evaluate(network));
+ measure(vocabulary, wv);
}
@Test
public void testWordVectorsLearningOnSentences() throws Exception {
Path path = Paths.get(getClass().getResource("/word2vec/sentences.txt").getFile());
- int window = 3;
SkipGramNetwork network = SkipGramNetwork.newModel().
- withWindow(window).
+ withWindow(3).
fromTextAt(path).
- withDimension(10).build();
+ withDimension(10).
+ withAlpha(1).
+ withLambda(0.03).
+ withMaxIterations(500).
+ build();
RealMatrix wv = network.getWeights()[0];
List<String> vocabulary = network.getVocabulary();
serialize(vocabulary, wv);
- SkipGramNetwork.evaluate(network, window);
+ System.err.println("accuracy: " + SkipGramNetwork.evaluate(network));
+ measure(vocabulary, wv);
}
@Test
public void testWordVectorsLearningOnTestData() throws Exception {
Path path = Paths.get(getClass().getResource("/word2vec/test.txt").getFile());
- int window = 3;
SkipGramNetwork network = SkipGramNetwork.newModel().
- withWindow(window).
+ withWindow(3).
fromTextAt(path).
withDimension(2).
- withAlpha(0.00002).
+ withAlpha(1).
withLambda(0.03).
- withThreshold(0.00000000003).
+ withThreshold(0.000003).
+ withMaxIterations(1000).
build();
- SkipGramNetwork.evaluate(network, window);
+ System.err.println("accuracy: " + SkipGramNetwork.evaluate(network));
RealMatrix wv = network.getWeights()[0];
List<String> vocabulary = network.getVocabulary();
serialize(vocabulary, wv);
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org