You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2016/03/03 12:00:33 UTC

svn commit: r1733443 - in /labs/yay/trunk/core/src: main/java/org/apache/yay/SkipGramNetwork.java test/java/org/apache/yay/SkipGramNetworkTest.java

Author: tommaso
Date: Thu Mar  3 11:00:32 2016
New Revision: 1733443

URL: http://svn.apache.org/viewvc?rev=1733443&view=rev
Log:
fixed momentum impl, more appropriate softmax usage

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1733443&r1=1733442&r2=1733443&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Thu Mar  3 11:00:32 2016
@@ -112,7 +112,9 @@ public class SkipGramNetwork {
     for (int d = 0; d < configuration.window - 1; d++) {
       int startColumn = d * len / (configuration.window - 1);
       RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn, startColumn + input.length);
-      probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix).getData(), 0, startColumn);
+      for (int sm = 0; sm < subMatrix.getRowDimension(); sm++) {
+        probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix.getRowMatrix(sm)).getData(), sm, startColumn);
+      }
     }
 
     RealVector d = probs.getRowVector(0);
@@ -240,7 +242,7 @@ public class SkipGramNetwork {
         }
       }
 
-      configuration.alpha = configuration.alpha * 0.5;
+//      configuration.alpha = configuration.alpha * 0.5;
 
       RealMatrix w0t = weights[0].transpose();
       final RealMatrix w1t = weights[1].transpose();
@@ -285,7 +287,9 @@ public class SkipGramNetwork {
       for (int d = 0; d < configuration.window - 1; d++) {
         int startColumn = d * len / (configuration.window - 1);
         RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn, startColumn + x.getColumnDimension());
-        probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix).getData(), 0, startColumn);
+        for (int sm = 0; sm < subMatrix.getRowDimension(); sm++) {
+          probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix.getRowMatrix(sm)).getData(), sm, startColumn);
+        }
       }
 
       RealMatrix correctLogProbs = MatrixUtils.createRealMatrix(x.getRowDimension(), 1);
@@ -510,7 +514,7 @@ public class SkipGramNetwork {
 
           @Override
           public double visit(int row, int column, double value) {
-            return configuration.mu * value - configuration.alpha + db.getEntry(row, column);
+            return configuration.mu * value - configuration.alpha * db.getEntry(row, column);
           }
 
           @Override
@@ -527,7 +531,7 @@ public class SkipGramNetwork {
 
           @Override
           public double visit(int row, int column, double value) {
-            return configuration.mu * value - configuration.alpha + db2.getEntry(row, column);
+            return configuration.mu * value - configuration.alpha * db2.getEntry(row, column);
           }
 
           @Override
@@ -545,7 +549,7 @@ public class SkipGramNetwork {
 
           @Override
           public double visit(int row, int column, double value) {
-            return configuration.mu * value - configuration.alpha + dWt.getEntry(row, column);
+            return configuration.mu * value - configuration.alpha * dWt.getEntry(row, column);
           }
 
           @Override
@@ -563,7 +567,7 @@ public class SkipGramNetwork {
 
           @Override
           public double visit(int row, int column, double value) {
-            return configuration.mu * value + configuration.alpha - dWt2.getEntry(row, column);
+            return configuration.mu * value - configuration.alpha * dWt2.getEntry(row, column);
           }
 
           @Override

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java?rev=1733443&r1=1733442&r2=1733443&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Thu Mar  3 11:00:32 2016
@@ -44,11 +44,13 @@ public class SkipGramNetworkTest {
   public void testWordVectorsLearningOnAbstracts() throws Exception {
     Path path = Paths.get(getClass().getResource("/word2vec/abstracts.txt").getFile());
     SkipGramNetwork network = SkipGramNetwork.newModel().
-            withWindow(3).
+            withWindow(4).
             fromTextAt(path).
             withDimension(10).
-            withAlpha(1).
-            withLambda(0.003).
+            withAlpha(0.09).
+            withLambda(0.03).
+            useMomentum(true).
+            withMu(0.9).
             withMaxIterations(500).
             build();
     RealMatrix wv = network.getWeights()[0];
@@ -62,11 +64,13 @@ public class SkipGramNetworkTest {
   public void testWordVectorsLearningOnSentences() throws Exception {
     Path path = Paths.get(getClass().getResource("/word2vec/sentences.txt").getFile());
     SkipGramNetwork network = SkipGramNetwork.newModel().
-            withWindow(3).
+            withWindow(4).
             fromTextAt(path).
             withDimension(10).
-            withAlpha(1).
-            withLambda(0.03).
+            withAlpha(0.001).
+            withLambda(0.003).
+            useMomentum(true).
+            withMu(0.9).
             withMaxIterations(500).
             build();
     RealMatrix wv = network.getWeights()[0];
@@ -83,10 +87,12 @@ public class SkipGramNetworkTest {
             withWindow(3).
             fromTextAt(path).
             withDimension(2).
-            withAlpha(1).
+            withAlpha(0.0008).
             withLambda(0.03).
-            withThreshold(0.000003).
-            withMaxIterations(1000).
+            useMomentum(true).
+            withMu(0.9).
+            withThreshold(0.00000000003).
+            withMaxIterations(10000).
             build();
     System.err.println("accuracy: " + SkipGramNetwork.evaluate(network));
     RealMatrix wv = network.getWeights()[0];



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org