You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2016/03/02 12:13:32 UTC

svn commit: r1733257 - in /labs/yay/trunk/core/src: main/java/org/apache/yay/HotEncodedSample.java main/java/org/apache/yay/SkipGramNetwork.java test/java/org/apache/yay/SkipGramNetworkTest.java

Author: tommaso
Date: Wed Mar  2 11:13:32 2016
New Revision: 1733257

URL: http://svn.apache.org/viewvc?rev=1733257&view=rev
Log:
biases initialized to 0.001, per output word softmax, cached expanded representation in hot encoded samples, 0.5 decay on learning rate

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java?rev=1733257&r1=1733256&r2=1733257&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java Wed Mar  2 11:13:32 2016
@@ -25,6 +25,9 @@ import java.util.Arrays;
  */
 public class HotEncodedSample extends Sample {
 
+  private double[] expandedInputs = null;
+  private double[] expandedOutputs = null;
+
   private final int vocabularySize;
 
   public HotEncodedSample(double[] inputs, double[] outputs, int vocabularySize) {
@@ -34,26 +37,32 @@ public class HotEncodedSample extends Sa
 
   @Override
   public double[] getInputs() {
-    double[] inputs = new double[this.inputs.length * vocabularySize];
-    int i = 0;
-    for (double d : this.inputs) {
-      double[] currentInput = hotEncode((int) d);
-      System.arraycopy(currentInput, 0, inputs, i, currentInput.length);
-      i += vocabularySize;
+    if (expandedInputs == null) {
+      double[] inputs = new double[this.inputs.length * vocabularySize];
+      int i = 0;
+      for (double d : this.inputs) {
+        double[] currentInput = hotEncode((int) d);
+        System.arraycopy(currentInput, 0, inputs, i, currentInput.length);
+        i += vocabularySize;
+      }
+      expandedInputs = inputs;
     }
-    return inputs;
+    return expandedInputs;
   }
 
   @Override
   public double[] getOutputs() {
-    double[] outputs = new double[this.outputs.length * vocabularySize];
-    int i = 0;
-    for (double d : this.outputs) {
-      double[] currentOutput = hotEncode((int) d);
-      System.arraycopy(currentOutput, 0, outputs, i, currentOutput.length);
-      i += vocabularySize;
+    if (expandedOutputs == null) {
+      double[] outputs = new double[this.outputs.length * vocabularySize];
+      int i = 0;
+      for (double d : this.outputs) {
+        double[] currentOutput = hotEncode((int) d);
+        System.arraycopy(currentOutput, 0, outputs, i, currentOutput.length);
+        i += vocabularySize;
+      }
+      expandedOutputs = outputs;
     }
-    return outputs;
+    return expandedOutputs;
   }
 
   private double[] hotEncode(int index) {

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1733257&r1=1733256&r2=1733257&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Wed Mar  2 11:13:32 2016
@@ -75,26 +75,18 @@ public class SkipGramNetwork {
 
   private SkipGramNetwork(Configuration configuration) {
     this.configuration = configuration;
-    this.weights = createRandomWeights();
-    this.biases = createRandomBiases();
+    this.weights = initWeights();
+    this.biases = initBiases();
   }
 
-  private RealMatrix[] createRandomBiases() {
+  private RealMatrix[] initBiases() {
 
     RealMatrix[] initialBiases = new RealMatrix[weights.length];
 
     for (int i = 0; i < initialBiases.length; i++) {
-      RealMatrix matrix = MatrixUtils.createRealMatrix(1, weights[i].getRowDimension());
-
-      UniformRealDistribution uniformRealDistribution = new UniformRealDistribution();
-      double[] vs = uniformRealDistribution.sample(matrix.getRowDimension() * matrix.getColumnDimension());
-      int r = 0;
-      int c = 0;
-      for (double v : vs) {
-        matrix.setEntry(r % matrix.getRowDimension(), c % matrix.getColumnDimension(), v);
-        r++;
-        c++;
-      }
+      double[] data = new double[weights[i].getRowDimension()];
+      Arrays.fill(data, 0.01d);
+      RealMatrix matrix = MatrixUtils.createRowRealMatrix(data);
 
       initialBiases[i] = matrix;
     }
@@ -113,13 +105,21 @@ public class SkipGramNetwork {
 
     RealMatrix hidden = rectifierFunction.applyMatrix(MatrixUtils.createRowRealMatrix(input).multiply(weights[0].transpose()).
             add(biases[0]));
-    RealMatrix pscores = softmaxActivationFunction.applyMatrix(hidden.multiply(weights[1].transpose()).add(biases[1]));
+    RealMatrix scores = hidden.multiply(weights[1].transpose()).add(biases[1]);
 
-    RealVector d = pscores.getRowVector(0);
+    RealMatrix probs = scores.copy();
+    int len = scores.getColumnDimension() - 1;
+    for (int d = 0; d < configuration.window - 1; d++) {
+      int startColumn = d * len / (configuration.window - 1);
+      RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn, startColumn + input.length);
+      probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix).getData(), 0, startColumn);
+    }
+
+    RealVector d = probs.getRowVector(0);
     return d.toArray();
   }
 
-  private RealMatrix[] createRandomWeights() {
+  private RealMatrix[] initWeights() {
     int[] conf = new int[]{configuration.inputs, configuration.vectorSize, configuration.outputs};
     int[] layers = new int[conf.length];
     System.arraycopy(conf, 0, layers, 0, layers.length);
@@ -146,9 +146,10 @@ public class SkipGramNetwork {
 
   }
 
-  static void evaluate(SkipGramNetwork network, int window) throws Exception {
+  static double evaluate(SkipGramNetwork network) throws Exception {
     double cc = 0;
     double wc = 0;
+    int window = network.configuration.window;
     for (Sample sample : network.samples) {
       Collection<Integer> exps = new ArrayList<>(window - 1);
       Collection<Integer> acts = new ArrayList<>(window - 1);
@@ -184,9 +185,7 @@ public class SkipGramNetwork {
       }
 
     }
-    if (cc > 0) {
-      System.out.println("accuracy: " + (cc / (wc + cc)));
-    }
+    return (cc / (wc + cc));
   }
 
   private static int getMaxIndex(double[] array, int start, int end) {
@@ -240,12 +239,8 @@ public class SkipGramNetwork {
           System.out.println("cost is " + cost + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)");
         }
       }
-      if (iterations % 1000 == 0) {
-        evaluate(this, this.configuration.window);
-        System.out.println("cost: " + cost);
-      }
 
-//      configuration.alpha = configuration.alpha * 0.999;
+      configuration.alpha = configuration.alpha * 0.5;
 
       RealMatrix w0t = weights[0].transpose();
       final RealMatrix w1t = weights[1].transpose();
@@ -285,7 +280,13 @@ public class SkipGramNetwork {
         }
       });
 
-      RealMatrix probs = softmaxActivationFunction.applyMatrix(scores);
+      RealMatrix probs = scores.copy();
+      int len = scores.getColumnDimension() - 1;
+      for (int d = 0; d < configuration.window - 1; d++) {
+        int startColumn = d * len / (configuration.window - 1);
+        RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn, startColumn + x.getColumnDimension());
+        probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix).getData(), 0, startColumn);
+      }
 
       RealMatrix correctLogProbs = MatrixUtils.createRealMatrix(x.getRowDimension(), 1);
       correctLogProbs.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
@@ -562,7 +563,7 @@ public class SkipGramNetwork {
 
           @Override
           public double visit(int row, int column, double value) {
-            return configuration.mu * value - configuration.alpha + dWt2.getEntry(row, column);
+            return configuration.mu * value + configuration.alpha - dWt2.getEntry(row, column);
           }
 
           @Override
@@ -813,6 +814,11 @@ public class SkipGramNetwork {
       return this;
     }
 
+    public Builder withMaxIterations(int iterations) {
+      this.configuration.maxIterations = iterations;
+      return this;
+    }
+
     public SkipGramNetwork build() throws Exception {
       System.out.println("reading fragments");
       Queue<List<byte[]>> fragments = getFragments(this.configuration.path, this.configuration.window);
@@ -825,7 +831,9 @@ public class SkipGramNetwork {
       System.out.println("creating training set");
       Collection<HotEncodedSample> trainingSet = createTrainingSet(vocabulary, fragments, this.configuration.window);
       fragments.clear();
-      this.configuration.maxIterations = trainingSet.size() * 100000;
+      if (this.configuration.maxIterations == 0) {
+        this.configuration.maxIterations = trainingSet.size() * 100000;
+      }
 
       HotEncodedSample next = trainingSet.iterator().next();
 

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java?rev=1733257&r1=1733256&r2=1733257&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Wed Mar  2 11:13:32 2016
@@ -43,47 +43,52 @@ public class SkipGramNetworkTest {
   @Test
   public void testWordVectorsLearningOnAbstracts() throws Exception {
     Path path = Paths.get(getClass().getResource("/word2vec/abstracts.txt").getFile());
-    int window = 3;
     SkipGramNetwork network = SkipGramNetwork.newModel().
-            withWindow(window).
+            withWindow(3).
             fromTextAt(path).
-            withDimension(2).
-            withAlpha(0.003).
-            withLambda(0.00003).
+            withDimension(10).
+            withAlpha(1).
+            withLambda(0.003).
+            withMaxIterations(500).
             build();
     RealMatrix wv = network.getWeights()[0];
     List<String> vocabulary = network.getVocabulary();
     serialize(vocabulary, wv);
-    SkipGramNetwork.evaluate(network, window);
+    System.err.println("accuracy: " + SkipGramNetwork.evaluate(network));
+    measure(vocabulary, wv);
   }
 
   @Test
   public void testWordVectorsLearningOnSentences() throws Exception {
     Path path = Paths.get(getClass().getResource("/word2vec/sentences.txt").getFile());
-    int window = 3;
     SkipGramNetwork network = SkipGramNetwork.newModel().
-            withWindow(window).
+            withWindow(3).
             fromTextAt(path).
-            withDimension(10).build();
+            withDimension(10).
+            withAlpha(1).
+            withLambda(0.03).
+            withMaxIterations(500).
+            build();
     RealMatrix wv = network.getWeights()[0];
     List<String> vocabulary = network.getVocabulary();
     serialize(vocabulary, wv);
-    SkipGramNetwork.evaluate(network, window);
+    System.err.println("accuracy: " + SkipGramNetwork.evaluate(network));
+    measure(vocabulary, wv);
   }
 
   @Test
   public void testWordVectorsLearningOnTestData() throws Exception {
     Path path = Paths.get(getClass().getResource("/word2vec/test.txt").getFile());
-    int window = 3;
     SkipGramNetwork network = SkipGramNetwork.newModel().
-            withWindow(window).
+            withWindow(3).
             fromTextAt(path).
             withDimension(2).
-            withAlpha(0.00002).
+            withAlpha(1).
             withLambda(0.03).
-            withThreshold(0.00000000003).
+            withThreshold(0.000003).
+            withMaxIterations(1000).
             build();
-    SkipGramNetwork.evaluate(network, window);
+    System.err.println("accuracy: " + SkipGramNetwork.evaluate(network));
     RealMatrix wv = network.getWeights()[0];
     List<String> vocabulary = network.getVocabulary();
     serialize(vocabulary, wv);



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org