You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2015/10/06 14:01:51 UTC

svn commit: r1707019 - in /labs/yay/trunk/core/src: main/java/org/apache/yay/core/ test/java/org/apache/yay/core/ test/resources/word2vec/

Author: tommaso
Date: Tue Oct  6 12:01:51 2015
New Revision: 1707019

URL: http://svn.apache.org/viewvc?rev=1707019&view=rev
Log:
draft word2vec test for sgm network

Added:
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java
    labs/yay/trunk/core/src/test/resources/word2vec/
    labs/yay/trunk/core/src/test/resources/word2vec/sentences.txt
Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1707019&r1=1707018&r2=1707019&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Tue Oct  6 12:01:51 2015
@@ -39,7 +39,7 @@ import org.apache.yay.WeightLearningExce
 public class BackPropagationLearningStrategy implements LearningStrategy<Double, Double> {
 
   public static final double DEFAULT_THRESHOLD = 0.05;
-  private static final int MAX_ITERATIONS = 100000;
+  public static final int MAX_ITERATIONS = 100000;
   public static final double DEFAULT_ALPHA = 0.000003;
 
   private final PredictionStrategy<Double, Double> predictionStrategy;
@@ -48,31 +48,33 @@ public class BackPropagationLearningStra
   private final double alpha;
   private final double threshold;
   private final int batch;
-
+  private final int maxIterations;
 
   public BackPropagationLearningStrategy(double alpha, double threshold, PredictionStrategy<Double, Double> predictionStrategy,
                                          CostFunction<RealMatrix, Double, Double> costFunction) {
-    this(alpha, 1, threshold, predictionStrategy, costFunction);
+    this(alpha, 1, threshold, predictionStrategy, costFunction, MAX_ITERATIONS);
   }
 
   public BackPropagationLearningStrategy(double alpha, int batch, double threshold, PredictionStrategy<Double, Double> predictionStrategy,
-                                         CostFunction<RealMatrix, Double, Double> costFunction) {
+                                         CostFunction<RealMatrix, Double, Double> costFunction, int maxIterations) {
     this.predictionStrategy = predictionStrategy;
     this.costFunction = costFunction;
     this.alpha = alpha;
     this.threshold = threshold;
     this.batch = batch;
     this.derivativeUpdateFunction = new DefaultDerivativeUpdateFunction(predictionStrategy);
+    this.maxIterations = maxIterations;
   }
 
   public BackPropagationLearningStrategy() {
     // commonly used defaults
-    this.predictionStrategy = new FeedForwardStrategy(new TanhFunction());
+    this.predictionStrategy = new FeedForwardStrategy(new SigmoidFunction());
     this.costFunction = new LogisticRegressionCostFunction();
     this.alpha = DEFAULT_ALPHA;
     this.threshold = DEFAULT_THRESHOLD;
     this.batch = 1;
     this.derivativeUpdateFunction = new DefaultDerivativeUpdateFunction(predictionStrategy);
+    this.maxIterations = MAX_ITERATIONS;
   }
 
   @Override
@@ -106,7 +108,7 @@ public class BackPropagationLearningStra
 
         if (newCost > cost && batch == -1) {
           throw new RuntimeException("failed to converge at iteration " + iterations + " with alpha " + alpha + " : cost going from " + cost + " to " + newCost);
-        } else if (iterations > 1 && (cost == newCost || newCost < threshold || iterations > MAX_ITERATIONS)) {
+        } else if (iterations > 1 && (cost == newCost || newCost < threshold || iterations > maxIterations)) {
           System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + alpha + ",threshold:" + threshold + ") with cost " + newCost + " and parameters " + Arrays.toString(hypothesis.getParameters()));
           break;
         }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java?rev=1707019&r1=1707018&r2=1707019&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java Tue Oct  6 12:01:51 2015
@@ -77,8 +77,10 @@ public class LogisticRegressionCostFunct
       Double[] predictedOutput = hypothesis.predict(input);
       Double[] sampleOutput = input.getOutput();
       for (int i = 0; i < predictedOutput.length; i++) {
-        res += sampleOutput[i] * Math.log(predictedOutput[i]) + (1d - sampleOutput[i])
-                * Math.log(1d - predictedOutput[i]);
+        Double so = sampleOutput[i];
+        Double po = predictedOutput[i];
+        res += so * Math.log(po) + (1d - so)
+                * Math.log(1d - po);
       }
     }
     return (-1d / trainingExamples.length) * res;

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1707019&r1=1707018&r2=1707019&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java Tue Oct  6 12:01:51 2015
@@ -73,7 +73,8 @@ public class BackPropagationLearningStra
       assertFalse("weights have not been changed", learntWeights[i].equals(initialWeights[i]));
     }
 
-    backPropagationLearningStrategy = new BackPropagationLearningStrategy(alpha, -1, threshold, predictionStrategy, costFunction);
+    backPropagationLearningStrategy = new BackPropagationLearningStrategy(alpha, -1, threshold, predictionStrategy,
+            costFunction, BackPropagationLearningStrategy.MAX_ITERATIONS);
     learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
     assertNotNull(learntWeights);
 
@@ -147,7 +148,8 @@ public class BackPropagationLearningStra
     assertFalse(learntWeights[2].equals(initialWeights[2]));
 
     backPropagationLearningStrategy = new BackPropagationLearningStrategy(BackPropagationLearningStrategy.DEFAULT_ALPHA, -1,
-            BackPropagationLearningStrategy.DEFAULT_THRESHOLD, new FeedForwardStrategy(new SigmoidFunction()), new LogisticRegressionCostFunction(0.5d));
+            BackPropagationLearningStrategy.DEFAULT_THRESHOLD, new FeedForwardStrategy(new SigmoidFunction()),
+            new LogisticRegressionCostFunction(0.5d), BackPropagationLearningStrategy.MAX_ITERATIONS);
 
     learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
     assertNotNull(learntWeights);

Added: labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java?rev=1707019&view=auto
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java (added)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java Tue Oct  6 12:01:51 2015
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay.core;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.yay.Feature;
+import org.apache.yay.Input;
+import org.apache.yay.NeuralNetwork;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+import org.apache.yay.core.utils.ConversionUtils;
+import org.apache.yay.core.utils.ExamplesFactory;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+/**
+ * Integration test for using Yay to implement word vectors algorithms.
+ */
+public class Word2VecTest {
+
+  @Test
+  public void testSGM() throws Exception {
+    Collection<String> sentences = getSentences();
+    assertFalse(sentences.isEmpty());
+    List<String> vocabulary = getVocabulary(sentences);
+    assertFalse(vocabulary.isEmpty());
+    Collections.sort(vocabulary);
+    Collection<String> fragments = getFragments(sentences, 4);
+    assertFalse(fragments.isEmpty());
+
+    TrainingSet<Double, Double> trainingSet = createTrainingSet(vocabulary, fragments);
+
+//    int n = new Random().nextInt(20);
+
+    TrainingExample<Double, Double> next = trainingSet.iterator().next();
+    int inputSize = next.getFeatures().size();
+    int outputSize = next.getOutput().length;
+    RealMatrix[] randomWeights = createRandomWeights(inputSize, inputSize, outputSize);
+
+    FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(new IdentityActivationFunction<Double>());
+    BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(BackPropagationLearningStrategy.
+            DEFAULT_ALPHA, -1, BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LMSCostFunction(),
+            20);
+    NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy, predictionStrategy);
+
+    neuralNetwork.learn(trainingSet);
+
+    String word = "paper";
+//    final Double[] doubles = ConversionUtils.toValuesCollection(next.getFeatures()).toArray(new Double[next.getFeatures().size()]);
+    final Double[] doubles = hotEncode(word, vocabulary);
+//    String word = hotDecode(doubles, vocabulary);
+
+//    TrainingExample<Double, Double> input = ExamplesFactory.createDoubleArrayTrainingExample(new Double[outputSize], doubles);
+    Input<Double> input = new TrainingExample<Double, Double>() {
+      @Override
+      public ArrayList<Feature<Double>> getFeatures() {
+        ArrayList<Feature<Double>> features = new ArrayList<Feature<Double>>();
+        for (Double d : doubles) {
+          Feature<Double> f = new Feature<Double>();
+          f.setValue(d);
+          features.add(f);
+        }
+        return features;
+      }
+
+      @Override
+      public Double[] getOutput() {
+        return new Double[0];
+      }
+    };
+    Double[] predict = neuralNetwork.predict(input);
+    assertNotNull(predict);
+
+    System.out.println(Arrays.toString(predict));
+
+    Double[] wordVec1 = Arrays.copyOfRange(predict, 0, vocabulary.size());
+    assertNotNull(wordVec1);
+    Double[] wordVec2 = Arrays.copyOfRange(predict, vocabulary.size(), 2 * vocabulary.size());
+    assertNotNull(wordVec2);
+    Double[] wordVec3 = Arrays.copyOfRange(predict, 2 * vocabulary.size(), 3 * vocabulary.size());
+    assertNotNull(wordVec3);
+
+    String word1 = hotDecode(wordVec1, vocabulary);
+    assertNotNull(word1);
+    assertTrue(vocabulary.contains(word1));
+    String word2 = hotDecode(wordVec2, vocabulary);
+    assertNotNull(word2);
+    assertTrue(vocabulary.contains(word2));
+    String word3 = hotDecode(wordVec3, vocabulary);
+    assertNotNull(word3);
+    assertTrue(vocabulary.contains(word3));
+
+    System.out.println(word + " -> " + word1 + " " + word2 + " " + word3);
+  }
+
+  private String hotDecode(Double[] doubles, List<String> vocabulary) {
+    double max = -Double.MAX_VALUE;
+    int index = -1;
+    for (int i = 0; i < doubles.length; i++) {
+      Double aDouble = doubles[i];
+      if (aDouble > max) {
+        max = aDouble;
+        index = i;
+      }
+    }
+    return vocabulary.get(index);
+  }
+
+
+  private TrainingSet<Double, Double> createTrainingSet(List<String> vocabulary, Collection<String> fragments) {
+    Collection<TrainingExample<Double, Double>> samples = new LinkedList<TrainingExample<Double, Double>>();
+    for (String fragment : fragments) {
+      String[] tokens = fragment.split(" ");
+      String inputWord = null;
+      for (int i = 0; i < tokens.length; i++) {
+        List<String> outputWords = new LinkedList<String>();
+        for (int j = 0; j < tokens.length; j++) {
+          String token = tokens[i];
+          if (i == j) {
+            inputWord = token;
+          } else {
+            outputWords.add(token);
+          }
+        }
+
+        final Double[] input = hotEncode(inputWord, vocabulary);
+        final Double[] outputs = new Double[outputWords.size() * vocabulary.size()];
+        for (int k = 0; k < outputWords.size(); k++) {
+          Double[] doubles = hotEncode(outputWords.get(k), vocabulary);
+          for (int z = 0; z < doubles.length; z++) {
+            outputs[(k * doubles.length) + z] = doubles[z];
+          }
+        }
+        samples.add(new TrainingExample<Double, Double>() {
+          @Override
+          public Double[] getOutput() {
+            return outputs;
+          }
+
+          @Override
+          public ArrayList<Feature<Double>> getFeatures() {
+            ArrayList<Feature<Double>> features = new ArrayList<Feature<Double>>();
+            for (Double d : input) {
+              Feature<Double> e = new Feature<Double>();
+              e.setValue(d);
+              features.add(e);
+            }
+            return features;
+          }
+        });
+      }
+    }
+    return new TrainingSet<Double, Double>(samples);
+  }
+
+  private Double[] hotEncode(String word, List<String> vocabulary) {
+    Double[] vector = new Double[vocabulary.size()];
+    int index = Collections.binarySearch(vocabulary, word);
+    Arrays.fill(vector, 0d);
+    vector[index] = 1d;
+    return vector;
+  }
+
+  private List<String> getVocabulary(Collection<String> sentences) {
+    List<String> vocabulary = new LinkedList<String>();
+    for (String sentence : sentences) {
+      for (String token : sentence.split(" ")) {
+        if (!vocabulary.contains(token)) {
+          vocabulary.add(token);
+        }
+      }
+    }
+    return vocabulary;
+  }
+
+  private Collection<String> getFragments(Collection<String> vocabulary, int w) {
+    Collection<String> fragments = new LinkedList<String>();
+    for (String sentence : vocabulary) {
+      while (sentence.length() > 0) {
+        int idx = 0;
+        for (int i = 0; i < w; i++) {
+          idx = sentence.indexOf(' ', idx + 1);
+        }
+        if (idx > 0) {
+          String fragment = sentence.substring(0, idx);
+          if (fragment.split(" ").length == 4) {
+            fragments.add(fragment);
+            sentence = sentence.substring(sentence.indexOf(' ') + 1);
+          }
+        } else {
+          if (sentence.split(" ").length == 4) {
+            fragments.add(sentence);
+            sentence = "";
+          }
+        }
+      }
+    }
+    return fragments;
+  }
+
+  private Collection<String> getSentences() throws IOException {
+    InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt");
+    BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(resourceAsStream));
+    Collection<String> sentences = new LinkedList<String>();
+    String line;
+    while ((line = bufferedReader.readLine()) != null) {
+      sentences.add(line);
+    }
+    return sentences;
+  }
+
+  private RealMatrix[] createRandomWeights(int inputSize, int hiddenSize, int outputSize) {
+    Random r = new Random();
+    int weightsCount = 2;
+
+    RealMatrix[] initialWeights = new RealMatrix[weightsCount];
+    for (int i = 0; i < weightsCount; i++) {
+      int rows = inputSize;
+      int cols;
+      if (i == 0) {
+        cols = hiddenSize;
+      } else {
+        cols = initialWeights[i - 1].getRowDimension();
+        if (i == weightsCount - 1) {
+          rows = outputSize;
+        }
+      }
+      double[][] d = new double[rows][cols];
+      for (int c = 0; c < cols; c++) {
+        if (i == weightsCount - 1) {
+          if (c == 0) {
+            d[0][c] = 1d;
+          } else {
+            d[0][c] = r.nextInt(100) / 101d;
+          }
+        } else {
+          d[0][c] = 0;
+        }
+      }
+
+      for (int k = 1; k < rows; k++) {
+        for (int j = 0; j < cols; j++) {
+          double val;
+          if (j == 0) {
+            val = 1d;
+          } else {
+            val = r.nextInt(100) / 101d;
+          }
+          d[k][j] = val;
+        }
+      }
+      initialWeights[i] = new Array2DRowRealMatrix(d);
+    }
+    return initialWeights;
+  }
+}

Added: labs/yay/trunk/core/src/test/resources/word2vec/sentences.txt
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/resources/word2vec/sentences.txt?rev=1707019&view=auto
==============================================================================
--- labs/yay/trunk/core/src/test/resources/word2vec/sentences.txt (added)
+++ labs/yay/trunk/core/src/test/resources/word2vec/sentences.txt Tue Oct  6 12:01:51 2015
@@ -0,0 +1,8 @@
+The word2vec software of Tomas Mikolov and colleagues1 has gained a lot of traction lately and provides state-of-the-art word embeddings
+The learning models behind the software are described in two research papers.
+We found the description of the models in these papers to be somewhat cryptic and hard to follow
+While the motivations and presentation may be obvious to the neural-networks language-modeling crowd we had to struggle quite a bit to figure out the rationale behind the equations
+This note is an attempt to explain the negative sampling equation in “Distributed Representations of Words and Phrases and their Compositionality” by Tomas Mikolov, Ilya Sutskever, Kai Chen, Greg Corrado and Jeffrey Dean
+The departure point of the paper is the skip-gram model
+In this model we are given a corpus of words w and their contexts c
+We consider the conditional probabilities p(c|w) and given a corpus Text, the goal is to set the parameters θ of p(c|w; θ) so as to maximize the corpus probability
\ No newline at end of file



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org