You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2016/03/03 12:00:33 UTC
svn commit: r1733443 - in /labs/yay/trunk/core/src:
main/java/org/apache/yay/SkipGramNetwork.java
test/java/org/apache/yay/SkipGramNetworkTest.java
Author: tommaso
Date: Thu Mar 3 11:00:32 2016
New Revision: 1733443
URL: http://svn.apache.org/viewvc?rev=1733443&view=rev
Log:
fixed momentum impl, more appropriate softmax usage
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1733443&r1=1733442&r2=1733443&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Thu Mar 3 11:00:32 2016
@@ -112,7 +112,9 @@ public class SkipGramNetwork {
for (int d = 0; d < configuration.window - 1; d++) {
int startColumn = d * len / (configuration.window - 1);
RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn, startColumn + input.length);
- probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix).getData(), 0, startColumn);
+ for (int sm = 0; sm < subMatrix.getRowDimension(); sm++) {
+ probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix.getRowMatrix(sm)).getData(), sm, startColumn);
+ }
}
RealVector d = probs.getRowVector(0);
@@ -240,7 +242,7 @@ public class SkipGramNetwork {
}
}
- configuration.alpha = configuration.alpha * 0.5;
+// configuration.alpha = configuration.alpha * 0.5;
RealMatrix w0t = weights[0].transpose();
final RealMatrix w1t = weights[1].transpose();
@@ -285,7 +287,9 @@ public class SkipGramNetwork {
for (int d = 0; d < configuration.window - 1; d++) {
int startColumn = d * len / (configuration.window - 1);
RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn, startColumn + x.getColumnDimension());
- probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix).getData(), 0, startColumn);
+ for (int sm = 0; sm < subMatrix.getRowDimension(); sm++) {
+ probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix.getRowMatrix(sm)).getData(), sm, startColumn);
+ }
}
RealMatrix correctLogProbs = MatrixUtils.createRealMatrix(x.getRowDimension(), 1);
@@ -510,7 +514,7 @@ public class SkipGramNetwork {
@Override
public double visit(int row, int column, double value) {
- return configuration.mu * value - configuration.alpha + db.getEntry(row, column);
+ return configuration.mu * value - configuration.alpha * db.getEntry(row, column);
}
@Override
@@ -527,7 +531,7 @@ public class SkipGramNetwork {
@Override
public double visit(int row, int column, double value) {
- return configuration.mu * value - configuration.alpha + db2.getEntry(row, column);
+ return configuration.mu * value - configuration.alpha * db2.getEntry(row, column);
}
@Override
@@ -545,7 +549,7 @@ public class SkipGramNetwork {
@Override
public double visit(int row, int column, double value) {
- return configuration.mu * value - configuration.alpha + dWt.getEntry(row, column);
+ return configuration.mu * value - configuration.alpha * dWt.getEntry(row, column);
}
@Override
@@ -563,7 +567,7 @@ public class SkipGramNetwork {
@Override
public double visit(int row, int column, double value) {
- return configuration.mu * value + configuration.alpha - dWt2.getEntry(row, column);
+ return configuration.mu * value - configuration.alpha * dWt2.getEntry(row, column);
}
@Override
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java?rev=1733443&r1=1733442&r2=1733443&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Thu Mar 3 11:00:32 2016
@@ -44,11 +44,13 @@ public class SkipGramNetworkTest {
public void testWordVectorsLearningOnAbstracts() throws Exception {
Path path = Paths.get(getClass().getResource("/word2vec/abstracts.txt").getFile());
SkipGramNetwork network = SkipGramNetwork.newModel().
- withWindow(3).
+ withWindow(4).
fromTextAt(path).
withDimension(10).
- withAlpha(1).
- withLambda(0.003).
+ withAlpha(0.09).
+ withLambda(0.03).
+ useMomentum(true).
+ withMu(0.9).
withMaxIterations(500).
build();
RealMatrix wv = network.getWeights()[0];
@@ -62,11 +64,13 @@ public class SkipGramNetworkTest {
public void testWordVectorsLearningOnSentences() throws Exception {
Path path = Paths.get(getClass().getResource("/word2vec/sentences.txt").getFile());
SkipGramNetwork network = SkipGramNetwork.newModel().
- withWindow(3).
+ withWindow(4).
fromTextAt(path).
withDimension(10).
- withAlpha(1).
- withLambda(0.03).
+ withAlpha(0.001).
+ withLambda(0.003).
+ useMomentum(true).
+ withMu(0.9).
withMaxIterations(500).
build();
RealMatrix wv = network.getWeights()[0];
@@ -83,10 +87,12 @@ public class SkipGramNetworkTest {
withWindow(3).
fromTextAt(path).
withDimension(2).
- withAlpha(1).
+ withAlpha(0.0008).
withLambda(0.03).
- withThreshold(0.000003).
- withMaxIterations(1000).
+ useMomentum(true).
+ withMu(0.9).
+ withThreshold(0.00000000003).
+ withMaxIterations(10000).
build();
System.err.println("accuracy: " + SkipGramNetwork.evaluate(network));
RealMatrix wv = network.getWeights()[0];
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org