You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2016/02/19 09:18:42 UTC

svn commit: r1731197 - /labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java

Author: tommaso
Date: Fri Feb 19 08:18:42 2016
New Revision: 1731197

URL: http://svn.apache.org/viewvc?rev=1731197&view=rev
Log:
avoiding some matrix transpose calls

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1731197&r1=1731196&r2=1731197&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Fri Feb 19 08:18:42 2016
@@ -158,8 +158,11 @@ public class SkipGramNetwork {
         i++;
       }
 
-      RealMatrix hidden = rectifierFunction.applyMatrix(x.multiply(weights[0].transpose()));
-      RealMatrix scores = hidden.multiply(weights[1].transpose());
+      RealMatrix w0t = weights[0].transpose();
+      final RealMatrix w1t = weights[1].transpose();
+
+      RealMatrix hidden = rectifierFunction.applyMatrix(x.multiply(w0t));
+      RealMatrix scores = hidden.multiply(w1t);
 
       RealMatrix probs = softmaxActivationFunction.applyMatrix(scores);
 
@@ -265,7 +268,7 @@ public class SkipGramNetwork {
         @Override
         public double visit(int row, int column, double value) {
           if (column != 0) {
-            return value + 0.03 * weights[1].transpose().getEntry(row, column);
+            return value + 0.3 * w1t.getEntry(row, column);
           } else {
             return value;
           }
@@ -289,7 +292,7 @@ public class SkipGramNetwork {
         @Override
         public double visit(int row, int column, double value) {
           if (column != 0) {
-            return value + 0.03 * weights[0].transpose().getEntry(row, column);
+            return value + 0.03 * w0t.getEntry(row, column);
           } else {
             return value;
           }
@@ -368,7 +371,7 @@ public class SkipGramNetwork {
     // user controlled parameters
     protected Path path;
     protected int maxIterations;
-    protected double alpha = 0.001d;
+    protected double alpha = 0.003d;
     protected double threshold = 0.004d;
     protected int vectorSize;
     protected int window;



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org